跳过内容

分类

本教程展示了如何使用 Outlines 实现多标签分类。我们将使用该库的两项功能:generate.choicegenerate.json

像往常一样,我们从初始化模型开始。由于我们的 GPU 资源有限,我们将使用 Mistal-7B-v0.1 的量化版本。

我们将使用以下提示模板

import outlines

model = outlines.models.transformers("TheBloke/Mistral-7B-OpenOrca-AWQ", device="cuda")

Outlines 提供了一个快捷方式来实现多标签分类,即使用 outlines.generate.choice 函数初始化生成器。Outlines 默认使用多项式采样,这里我们将使用贪婪采样器来获取概率最高的标签。

from outlines import Template


customer_support = Template.from_string(
    """You are an experienced customer success manager.

    Given a request from a client, you need to determine when the
    request is urgent using the label "URGENT" or when it can wait
    a little with the label "STANDARD".

    # Examples

    Request: "How are you?"
    Label: STANDARD

    Request: "I need this fixed immediately!"
    Label: URGENT

    # TASK

    Request: {{ request }}
    Label: """
)

使用 JSON 结构化生成

Outlines 支持批量请求,因此我们将向模型传递两个请求。

from outlines.samplers import greedy

generator = outlines.generate.choice(model, ["URGENT", "STANDARD"], sampler=greedy())
现在我们可以让模型对请求进行分类了。

requests = [
    "My hair is one fire! Please help me!!!",
    "Just wanted to say hi"
]

prompts = [customer_support(request) for request in requests]

现在,您可能很着急,不想等到模型完成全部生成。毕竟,您只需要看到响应的第一个字母就知道请求是紧急还是标准。您可以改为流式传输响应。

labels = generator(prompts)
print(labels)
# ['URGENT', 'STANDARD']

另一种(稍显复杂)进行多标签分类的方法是使用 Outlines 的 JSON 结构化生成功能。我们首先需要定义包含标签的 Pydantic schema。

tokens = generator.stream(prompts)
labels = ["URGENT" if "U" in token else "STANDARD" for token in next(tokens)]
print(labels)
# ['URGENT', 'STANDARD']

命名实体提取

然后我们可以通过传递刚刚定义的 Pydantic 模型来使用 generate.json,并调用生成器。

from enum import Enum
from pydantic import BaseModel


class Label(str, Enum):
    urgent = "URGENT"
    standard = "STANDARD"


class Classification(BaseModel):
    label: Label

2025-02-25 2023-12-22 GitHub

generator = outlines.generate.json(model, Classification, sampler=greedy())
labels = generator(prompts)
print(labels)
# [Classification(label=<Label.urgent: 'URGENT'>), Classification(label=<Label.standard: 'STANDARD'>)]