跳到内容

transformers

安装

您需要安装 transformerdatasetstorch 库才能在 Outlines 中使用这些模型,或者您也可以

pip install "outlines[transformers]"

Outlines 提供了与 transformers 库中因果模型的 torch 实现集成。您可以通过传递模型的名称来初始化模型

from outlines import models

model = models.transformers("microsoft/Phi-3-mini-4k-instruct", device="cuda")

如果您需要更精细的控制,您也可以单独初始化模型和分词器

from transformers import AutoModelForCausalLM, AutoTokenizer
from outlines import models

llm = AutoModelForCausalLM.from_pretrained("gpt2", output_attentions=True)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = models.Transformers(llm, tokenizer)

使用 Logits 处理器

使用 HuggingFace Transformers 和 Outlines 结构化生成有两种方式

  1. 使用 Outlines 生成包装器 outlines.models.transformers
  2. OutlinesLogitsProcessortransformers.AutoModelForCausalLM 一起使用

Outlines 支持多种 logits 处理器用于结构化生成。在此示例中,我们将使用 RegexLogitsProcessor,它保证生成的文本匹配指定的模式。

使用 outlines.models.transformers

import outlines

time_regex_pattern = r"(0?[1-9]|1[0-2]):[0-5]\d\s?(am|pm)?"

model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct", device="cuda")
generator = outlines.generate.regex(model, time_regex_pattern)

output = generator("The the best time to visit a dentist is at ")
print(output)
# 2:30 pm

使用通过 transformers 库初始化的模型

import outlines
import transformers


model_uri = "microsoft/Phi-3-mini-4k-instruct"

outlines_tokenizer = outlines.models.TransformerTokenizer(
    transformers.AutoTokenizer.from_pretrained(model_uri)
)
phone_number_logits_processor = outlines.processors.RegexLogitsProcessor(
    "\\+?[1-9][0-9]{7,14}",  # phone number pattern
    outlines_tokenizer,
)

generator = transformers.pipeline('text-generation', model=model_uri)

output = generator(
    "Jenny gave me her number it's ",
    logits_processor=transformers.LogitsProcessorList([phone_number_logits_processor])
)
print(output)
# [{'generated_text': "Jenny gave me her number it's 2125550182"}]
# not quite 8675309 what we expected, but it is a valid phone number

替代模型类

outlines.models.transformers 默认使用 transformers.AutoModelForCausalLM,这是大多数标准大型语言模型(包括 Llama 3、Mistral、Phi-3 等)的合适类。

但是,通过传递相应的类,也可以使用具有独特行为的其他变体。

Mamba

Mamba 是一种 Transformers 的替代方案,它采用内存高效、线性时间解码。

要将 Mamba 与 outlines 一起使用,您必须首先安装必要的依赖项

pip install causal-conv1d>=1.2.0 mamba-ssm torch transformers

然后您可以通过以下方式创建一个 Mamba-2 Outlines 模型

import outlines

model = outlines.models.mamba("state-spaces/mamba-2.8b-hf")

或明确使用

import outlines
from transformers import MambaForCausalLM

model = outlines.models.transformers(
    "state-spaces/mamba-2.8b-hf",
    model_class=MambaForCausalLM
)

阅读 transformers 的文档以获取更多信息。

编码器-解码器模型

您可以将像 T5 和 BART 这样的编码器-解码器 (seq2seq) 模型与 Outlines 一起使用。

但在模型选择时请谨慎,某些模型(如 t5-base)不包含某些字符({),尝试进行结构化生成时可能会出错。

T5 示例

import outlines
from transformers import AutoModelForSeq2SeqLM

model_pile_t5 = outlines.models.transformers(
    model_name="EleutherAI/pile-t5-large",
    model_class=AutoModelForSeq2SeqLM,
)

BART 示例

model_bart = outlines.models.transformers(
    model_name="facebook/bart-large",
    model_class=AutoModelForSeq2SeqLM,
)