跳到内容

使用 VLM 提取收据数据

设置

您需要安装依赖项

pip install outlines torch==2.4.0 transformers accelerate pillow rich

导入库

加载所有必需的库

# LLM stuff
import outlines
import torch
from transformers import AutoProcessor
from pydantic import BaseModel, Field
from typing import Literal, Optional, List

# Image stuff
from PIL import Image
import requests

# Rich for pretty printing
from rich import print

选择模型

此示例已使用 mistral-community/pixtral-12b (HF 链接) 和 Qwen/Qwen2-VL-7B-Instruct (HF 链接) 进行测试。

我们推荐使用 Qwen-2-VL,因为我们发现它比 Pixtral 更准确。

如果您想使用 Qwen-2-VL,可以按以下步骤操作

# To use Qwen-2-VL:
from transformers import Qwen2VLForConditionalGeneration
model_name = "Qwen/Qwen2-VL-7B-Instruct"
model_class = Qwen2VLForConditionalGeneration

如果您想使用 Pixtral,可以按以下步骤操作

# To use Pixtral:
from transformers import LlavaForConditionalGeneration
model_name="mistral-community/pixtral-12b"
model_class=LlavaForConditionalGeneration

加载模型

将模型加载到内存中

model = outlines.models.transformers_vision(
    model_name,
    model_class=model_class,
    model_kwargs={
        "device_map": "auto",
        "torch_dtype": torch.bfloat16,
    },
    processor_kwargs={
        "device": "cuda", # set to "cpu" if you don't have a GPU
    },
)

图像处理

图像可能很大。在 GPU 资源紧张的环境中,您可能需要将图像调整到较小尺寸。

这是一个帮助函数,可以做到这一点

def load_and_resize_image(image_path, max_size=1024):
    """
    Load and resize an image while maintaining aspect ratio

    Args:
        image_path: Path to the image file
        max_size: Maximum dimension (width or height) of the output image

    Returns:
        PIL Image: Resized image
    """
    image = Image.open(image_path)

    # Get current dimensions
    width, height = image.size

    # Calculate scaling factor
    scale = min(max_size / width, max_size / height)

    # Only resize if image is larger than max_size
    if scale < 1:
        new_width = int(width * scale)
        new_height = int(height * scale)
        image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)

    return image

您可以通过更改 max_size 参数来改变图像的分辨率。较小的 max_size 会使图像更模糊,但处理会更快并需要更少内存。

加载图像

加载图像并调整大小。我们提供了一张 Trader Joe's 收据的示例图片,但您可以使用任何您喜欢的图片。

图像看起来像这样

Trader Joe's receipt

# Path to the image
image_path = "https://raw.githubusercontent.com/dottxt-ai/outlines/refs/heads/main/docs/cookbook/images/trader-joes-receipt.jpg"

# Download the image
response = requests.get(image_path)
with open("receipt.png", "wb") as f:
    f.write(response.content)

# Load + resize the image
image = load_and_resize_image("receipt.png")

定义输出结构

我们将定义一个 Pydantic 模型来描述我们希望从图像中提取的数据。

在我们的案例中,我们希望提取以下信息

  • 商店名称
  • 商店地址
  • 商店电话
  • 商品列表,包括名称、数量、单价和总价
  • 税金
  • 总计
  • 日期
  • 付款方式

大多数字段是可选的,因为并非所有收据都包含所有信息。

class Item(BaseModel):
    name: str
    quantity: Optional[int]
    price_per_unit: Optional[float]
    total_price: Optional[float]

class ReceiptSummary(BaseModel):
    store_name: str
    store_address: str
    store_number: Optional[int]
    items: List[Item]
    tax: Optional[float]
    total: Optional[float]
    # Date is in the format YYYY-MM-DD. We can apply a regex pattern to ensure it's formatted correctly.
    date: Optional[str] = Field(pattern=r'\d{4}-\d{2}-\d{2}', description="Date in the format YYYY-MM-DD")
    payment_method: Literal["cash", "credit", "debit", "check", "other"]

准备提示

我们将使用 AutoProcessor 将图像和文本提示转换为模型可以理解的格式。实际上,这段代码会将用户、系统、助手和图像 token 添加到提示中。

# Set up the content you want to send to the model
messages = [
    {
        "role": "user",
        "content": [
            {
                # The image is provided as a PIL Image object
                "type": "image",
                "image": image,
            },
            {
                "type": "text",
                "text": f"""You are an expert at extracting information from receipts.
                Please extract the information from the receipt. Be as detailed as possible --
                missing or misreporting information is a crime.

                Return the information in the following JSON schema:
                {ReceiptSummary.model_json_schema()}
            """},
        ],
    }
]

# Convert the messages to the final prompt
processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)

如果您好奇,发送给模型的最终提示(大致)看起来像这样

<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
<|vision_start|><|image_pad|><|vision_end|>
You are an expert at extracting information from receipts.
Please extract the information from the receipt. Be as detailed as
possible -- missing or misreporting information is a crime.

Return the information in the following JSON schema:

<JSON SCHEMA OMITTED>
<|im_end|>
<|im_start|>assistant

运行模型

# Prepare a function to process receipts
receipt_summary_generator = outlines.generate.json(
    model,
    ReceiptSummary,

    # Greedy sampling is a good idea for numeric
    # data extraction -- no randomness.
    sampler=outlines.samplers.greedy()
)

# Generate the receipt summary
result = receipt_summary_generator(prompt, [image])
print(result)

输出

输出应如下所示

ReceiptSummary(
    store_name="Trader Joe's",
    store_address='401 Bay Street, San Francisco, CA 94133',
    store_number=0,
    items=[
        Item(name='BANANA EACH', quantity=7, price_per_unit=0.23, total_price=1.61),
        Item(name='BAREBELLS CHOCOLATE DOUG', quantity=1, price_per_unit=2.29, total_price=2.29),
        Item(name='BAREBELLS CREAMY CRISP', quantity=1, price_per_unit=2.29, total_price=2.29),
        Item(name='BAREBELLS CHOCOLATE DOUG', quantity=1, price_per_unit=2.29, total_price=2.29),
        Item(name='BAREBELLS CARAMEL CASHEW', quantity=2, price_per_unit=2.29, total_price=4.58),
        Item(name='BAREBELLS CREAMY CRISP', quantity=1, price_per_unit=2.29, total_price=2.29),
        Item(name='SPINDRIFT ORANGE MANGO 8', quantity=1, price_per_unit=7.49, total_price=7.49),
        Item(name='Bottle Deposit', quantity=8, price_per_unit=0.05, total_price=0.4),
        Item(name='MILK ORGANIC GALLON WHOL', quantity=1, price_per_unit=6.79, total_price=6.79),
        Item(name='CLASSIC GREEK SALAD', quantity=1, price_per_unit=3.49, total_price=3.49),
        Item(name='COBB SALAD', quantity=1, price_per_unit=5.99, total_price=5.99),
        Item(name='PEPPER BELL RED XL EACH', quantity=1, price_per_unit=1.29, total_price=1.29),
        Item(name='BAG FEE.', quantity=1, price_per_unit=0.25, total_price=0.25),
        Item(name='BAG FEE.', quantity=1, price_per_unit=0.25, total_price=0.25)
    ],
    tax=0.68,
    total=41.98,
    date='2023-11-04',
    payment_method='debit',

)

瞧!您已成功使用 LLM 从收据中提取信息。

附赠:吐槽用户的收据

您可以通过在 ReceiptSummary 模型末尾添加一个 roast 字段来吐槽用户的收据。

class ReceiptSummary(BaseModel):
    ...
    roast: str

这将为您提供一个如下所示的结果

ReceiptSummary(
    ...
    roast="You must be a fan of Trader Joe's because you bought enough
    items to fill a small grocery bag and still had to pay for a bag fee.
    Maybe you should start using reusable bags to save some money and the
    environment."
)

Qwen 不是特别幽默,但值得一试。