所有文章 > 学习各类API > Transformers Generate 功能介绍
Transformers Generate 功能介绍

Transformers Generate 功能介绍

介绍

最近的大语言模型,很多都是基于Huggingface Transformers开发库进行训练和公开的。在Transformers开发库中进行文本生成时,会使用到一个Generate的文本生成函数。这个Generate函数都具体做了什么? 这个函数与模型训练时的forward函数有什么区别,为什么文本的生成不能直接使用 forward 函数来实现呢?
本章内容将为你解释 Huggingface Transformers 中的 generate 文本生成功能。

本章将解答以下几个问题:

  • 为什么需要generate功能?
  • Generate的实现代码在什么位置?
  • Generate函数的输入的是什么?
  • 如何处理 Decoder-only(比如GPT系列)和 encoder-decoder (完整Transformer)网络的generate?
  • Auto-regressive 的迭代循环是如何处理的?
  • 函数返回的是什么样的数据? 如何修改返回数据?

通过对上面问题的理解和学习,你将会进一步了解Transformers的代码架构,从而可以尝试对核心代码进行个性化修改。

事前说明:
以下内容为个人学习的总结,所有内容仅供参考使用。
这是一个在高速发展和进化的领域,由于个人能力和见识所限,因此无法保证介绍的内容能在其他环境和配置下正确运行。

为什么需要 Generate?

一般情况下,AI模型构建的主要包含训练部分(Training)和推理部分(Inference,生成,解码,识别,判定)。
模型训练部分通过设计网络结构,训练方法,以及调整训练参数来得到高精度的模型参数。推理部分根据不同的任务,也常被称为预测部分,或者生成部分,其主要工作就是使用训练好的模型参数,对测试数据或者部署中的实际数据进行识别或者文本生成等处理。所以,这两部分都是模型的很重要的部分。

Transformer的模型训练时,有一个forward函数,是在模型训练时针对模型的输入来产生输出,从让来计算loss,来更新网络的参数。既然有这么一个生成的函数了,为什么transformers中还有专门的generate函数来负责生成呢?

这里面主要有两个原因:

  1. 模型训练与生成的差异: 一般情况下,在分类任务上,Forward函数与decoding或者prediction时的处理是一致的。但是在 LM 的训练过程中,训练时的forward函数中通常并不采用真正的递归的方式进行逐步处理。
    在理论上,LM模型通过输入的一组 token,来预测下一个token。 然后再将新的预测出的 token 与前面的 token 序列进行链接,用来继续预测下一个 token。 在实际训练中,上面的自回归的训练方式,训练的效率非常的低,无法更好的利用GPU的并行处理能力。 因此实际训练中,作为模型输入的 token 并不通过模型的预测来获取,而是直接使用训练数据中的文本。也可以理解为,前面的预测是完全正确的。
  2. LM Decoding方法的复杂性
    相比分类任务,LM这种自递归的生成任务的解码通常都比较复杂。比如在 LM 或者ASR任务中,解码方法有很多,比如 greedy_search(), contrastive_search(), sample(), beam_search(), beam_sample(), group_beam_search(), and constrained_beam_search() 等等。通过使用更加复杂的解码的方法,通常可以得到更好的效果。比如 beam_search的结果比greedy_search的结果通常会更好。但是在训练过程中,采用方法更加类似于greedy search的方法。因此,设计更好的解码算法也是一个很重要的研究方向。

以上两个原因导致了在 LLM 具体实现中,会经常可以见到非常复杂的 generate 功能用来生成文本。

Transformer 模型

除了LLM的文本生成,基于LM架构的语音识别模型,也采用类似处理方式。其中OpenAI的Whisper模型就是采用了transformer的架构,因此其本质上更像是一个语言模型。(注意:whisper与常见的CTC-based的语音识别模型是不同的)

比如,whisper 模型进行语音识别时,Encoder部分用于提取声学特征,Decoder部分用于文本的生成(语言学特征)。

处理过程类似于:

  • 声学特征: 声学特征的提取是输入一个30秒的语音片段,Encoder部分会将此输入编码为一个定长的向量。
  • 文本生成: 将声学特征和前面生成的文本输入到 Decoder 网络。 Decoder网络就和传统的 LM 一样,努力的预测下一个token的输出。
  • Loss计算: 通过计算预测的token与正确token之间的差异来计算参数更新所需要的loss。

与上面提到的一样,训练过程中,直接使用了文本的 label 数据来作为Decoder的输入,在forward函数中并不实际存在一个递归的预测过程。 但是在实际使用模型进行文本产生或者语音识别时,是必需要进行auto-regressive的操作来逐步产生文本。这个产生文本的过程,就被封装在generate函数中。

实战:Transformers 代码分析

下面,我们来通过 transformers 中的代码,来看看generate是如何实现的。
在Transformers中,具体模型的代码可以在 transformers/models 目录下找到,比如 OpenAI 的 whisper 模型的主要实现在 transformers/models/whisper/modeling_whisper.py文件中。

代码结构

Transformers的代码可以在GitHub上找到:https://github.com/huggingface/transformers
其代码的结构如下:

docs
scripts
utils
examples : 使用方法,参考案例
...
src/transformers (Transformer相关的代码)
data: 数据处理
models : 模型的实现代码,比如BERT, GPT,Whisper模型,都在此目录下实现
generation : 文本生成相关代码

...

从上面的代码中,examples中提供了模型的使用方法的参考例子。
我们的今天介绍的主要内容都在 src/transformers 目录下,其中 models 目录下,是基于transformer的各种模型的实现代码,Generation 包含通用的文本产生的实现代码。

模型 models/whisper

我们以Whisper 模型为例来详细介绍一下代码的结构和调用关系。下面我们以v4.29.1的版本为例进行介绍。
首先,whisper模型的代码位于:src/transformers/models/whisper 目录下。其主要功能都封装在 modeling_whisper.py 文件中。

调用入口:WhisperForConditionalGeneration类
此python文件中包含多个类,继承的关系比较复杂,它们之间的主要调用关系如下(以greedy search为例):

WhisperForConditionalGeneration (L1312) : 调用入口类
forward() (L1359) --> 细节在:WhisperModel,WhisperEncoder,WhisperDecoder 类中实现
generate() (L1455) --> 细节在: generation/utils.py#L1146 中实现
greedy_search(): L2164 --> 调用 search 函数来做实际的处理,比如自回归处理

Forward函数:
forward函数位于 class transformers.WhisperModel 类中,代码位置请参考:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L1215

def forward():
# Encoder将输入的语音信号,编码为声学信息,也就是 encoder_outputs
encoder_outputs = self.encoder(
input_features,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Decoder 的主要输入为 decoder_input_ids (对应文本) 和 encoder_outputs (对应声学信息,在翻译任务中,对应着源语言)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

通过上面,我们可以看到有 Encoder 部分和 Decoder 部分,分别对应声学特征的提取和文本产生部分。

在 WhisperForConditionalGeneration 类中,也有一个forward函数,是对上面forward函数的封装。
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L1359

其中self.encoder的实现代码位于 class WhisperEncoder(WhisperPreTrainedModel) 类中:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L735

其中self.decoder的实现代码位于 class WhisperDecoder(WhisperPreTrainedModel) 类中:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L881

Generate函数
Generate函数入口:位于 class WhisperForConditionalGeneration(WhisperPreTrainedModel) 类中:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L1455

此处只是调用的入口,具体的实现代码位于 class GenerationMixin 类中:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/generation/utils.py#L1146
def generate() L1146

其中generate函数使用的greedy_search的实现位于:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/generation/utils.py#L2164

Generate 代码分析

下面,我们来进一步了解 generate 的实现代码,来看看如何对此代码进行修改。

入口代码
Generate函数的入口位于: WhisperForConditionalGeneration类中的 def generate 函数
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L1455

代码的概要如下,从代码中可以看到,这个函数主要是进行了一些参数设置,具体的实现是调用了父类中的对应函数来执行的。

def generate()
# 参数设置部分

# 调用部分(此处调用了父类中的generate实现)
return super().generate(
inputs,
generation_config,
logits_processor,
stopping_criteria,
prefix_allowed_tokens_fn,
synced_gpus,
**kwargs,
)

然后,我们可以逐级向上搜索其父类,可以看到

到此为止,我们就可以看到,具体的实现都在 GenerationMixin 类中。

Generate函数实现细节

下面,我们来看一下 GenerationMixin类中的 generate 函数的实现细节。
代码位置:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/generation/utils.py#L1146

其代码概要如下:

def generate(): L1146
# 根据解码方式的不同,此函数中有最多14步的处理步骤,我们以greedy search为例
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
# 2. Set generation parameters if not already defined
# 3. Define model inputs
# 4. Define other model kwargs
# 5. Prepare `input_ids` which will be used for auto-regressive generation
# 6. Prepare `max_length` depending on other stopping criteria.
# 7. determine generation mode
# 8. prepare distribution pre_processing samplers
# 9. prepare stopping criteria
# 10. go into different generation modes
# 11. run greedy search (L1515)

def greedy_search(): L2164
# 初始化,设置

# 循环处理
while True: # L2317
# prepare model inputs (下面函数的具体实现位于: modeling_whisper.py#L1627)
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

# forward pass to get next token
# 这里是调用了 WhisperForConditionalGeneration 中的forward函数。这是因为 PyTorch 的 nn.Module 基类定义了一个 __call__ 方法,当你调用模型实例(即 self)时,它会自动调用这个 __call__ 方法,而这个 __call__ 方法又会调用 forward 方法。
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)

# 得到下一个token的logits
next_token_logits = outputs.logits[:, -1, :]

# pre-process distribution 得到其score
next_tokens_scores = logits_processor(input_ids, next_token_logits)

# argmax :使用argmax 获取对应的 tokens
next_tokens = torch.argmax(next_tokens_scores, dim=-1)

# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)

# 判断是否结束search:
# if eos_token was found in one sentence, set sentence to finished
# stop if we exceed the maximum length

通过上面的代码概要,我们就可以知道generate函数进行了很多的设置以后,会调用 greedy_search() 函数来进行文本产生的实际处理。
到此为止,我们就已经对整个的代码结构了解了。

下面我们通过几个问题,来回顾一下对代码的理解。

  • 问题1: Generate函数的入口在什么位置?
  • 问题2: Generate函数的具体实现在什么位置?
  • 问题3: 在处理 Decoder-only 和 encoder-decoder 网络时有什么差异 ?
    Generate同时可以支持Deocder-only和encoder-decoder(完整的transformer)。
    Transformers可以理解为一个LM,Encoder编码后的向量只是Decoder的一个输入。
  • 问题4:Generate 函数的输入的是什么?
    最主要的输入为 inputs,在作为LM时,可以为 prompt tokens,若为 None,则初始化为 bos_token_id。
    如果是 decoder-only,则input为input_ids的格式, 如果为encoder-decoder,则支持input_ids和feature等多种格式。
  • 问题5:Auto-regressive 的迭代循环是如何处理的? 在什么位置?
    迭代循环可以在 greedy_search 函数中找到。
  • 问题6:函数返回的数据格式是什么? 如何添加自己额外的数据?
    返回的格式,根据网络结构不同也有不同的格式,可以支持 GreedySearchEncoderDecoderOutput 和 GreedySearchDecoderOnlyOutput。
    也可以输出任意的数据,比如 L2421 只返回了 input_ids。
    如果要添加额外的输出,可以修改上面的类 GreedySearchEncoderDecoderOutput 和 GreedySearchDecoderOnlyOutput,也可以在 L2421 添加输出的内容。
  • 问题7: 如何对Generate进行修改?在那个脚本中修改?
    在src/transformers/generation/utils.py中修改,还是在src/transformers/models/whisper/modeling_whisper.py中修改呢?

代码修改建议

针对上面的问题7,如果要对generate或者其他部分进行修改,建议在 models/whisper的目录下对父类函数进行重构。
比如,如果要对greedy_search功能进行调整来实现一些独特的功能时,可以在modeling_whisper.py中重构 greedy_search(),具体做法可以是:

  1. 将 utils.py 中的 greedy_search 函数拷贝到 modeling_whisper.py 文件中。
  2. 需要import 一些必要的库文件。(具体的库,可以根据运行时的错误提示确定)
  3. 在greedy_search函数中进行修改,来实现想要的功能。

函数在子类中被重新实现之后,调用时,将会优先调用新重构的函数。这样既实现了自己独特的功能,还不影响其他的模型的运行。

参考文献

  1. 【基本概念】https://huggingface.co/blog/how-to-generate
  2. https://huggingface.co/docs/transformers/main_classes/text_generation
  3. https://huggingface.co/docs/transformers/internal/generation_utils

文章转载自: Transformers Generate 功能介绍

#你可能也喜欢这些API文章!