
跟大牛学LLM训练和使用技巧
最近的大语言模型,很多都是基于Huggingface Transformers开发库进行训练和公开的。在Transformers开发库中进行文本生成时,会使用到一个Generate的文本生成函数。这个Generate函数都具体做了什么? 这个函数与模型训练时的forward函数有什么区别,为什么文本的生成不能直接使用 forward 函数来实现呢?
本章内容将为你解释 Huggingface Transformers 中的 generate 文本生成功能。
本章将解答以下几个问题:
通过对上面问题的理解和学习,你将会进一步了解Transformers的代码架构,从而可以尝试对核心代码进行个性化修改。
事前说明:
以下内容为个人学习的总结,所有内容仅供参考使用。
这是一个在高速发展和进化的领域,由于个人能力和见识所限,因此无法保证介绍的内容能在其他环境和配置下正确运行。
一般情况下,AI模型构建的主要包含训练部分(Training)和推理部分(Inference,生成,解码,识别,判定)。
模型训练部分通过设计网络结构,训练方法,以及调整训练参数来得到高精度的模型参数。推理部分根据不同的任务,也常被称为预测部分,或者生成部分,其主要工作就是使用训练好的模型参数,对测试数据或者部署中的实际数据进行识别或者文本生成等处理。所以,这两部分都是模型的很重要的部分。
Transformer的模型训练时,有一个forward函数,是在模型训练时针对模型的输入来产生输出,从让来计算loss,来更新网络的参数。既然有这么一个生成的函数了,为什么transformers中还有专门的generate函数来负责生成呢?
这里面主要有两个原因:
以上两个原因导致了在 LLM 具体实现中,会经常可以见到非常复杂的 generate 功能用来生成文本。
除了LLM的文本生成,基于LM架构的语音识别模型,也采用类似处理方式。其中OpenAI的Whisper模型就是采用了transformer的架构,因此其本质上更像是一个语言模型。(注意:whisper与常见的CTC-based的语音识别模型是不同的)
比如,whisper 模型进行语音识别时,Encoder部分用于提取声学特征,Decoder部分用于文本的生成(语言学特征)。
处理过程类似于:
与上面提到的一样,训练过程中,直接使用了文本的 label 数据来作为Decoder的输入,在forward函数中并不实际存在一个递归的预测过程。 但是在实际使用模型进行文本产生或者语音识别时,是必需要进行auto-regressive的操作来逐步产生文本。这个产生文本的过程,就被封装在generate函数中。
下面,我们来通过 transformers 中的代码,来看看generate是如何实现的。
在Transformers中,具体模型的代码可以在 transformers/models 目录下找到,比如 OpenAI 的 whisper 模型的主要实现在 transformers/models/whisper/modeling_whisper.py文件中。
Transformers的代码可以在GitHub上找到:https://github.com/huggingface/transformers
其代码的结构如下:
docs
scripts
utils
examples : 使用方法,参考案例
...
src/transformers (Transformer相关的代码)
data: 数据处理
models : 模型的实现代码,比如BERT, GPT,Whisper模型,都在此目录下实现
generation : 文本生成相关代码
...
从上面的代码中,examples中提供了模型的使用方法的参考例子。
我们的今天介绍的主要内容都在 src/transformers 目录下,其中 models 目录下,是基于transformer的各种模型的实现代码,Generation 包含通用的文本产生的实现代码。
我们以Whisper 模型为例来详细介绍一下代码的结构和调用关系。下面我们以v4.29.1的版本为例进行介绍。
首先,whisper模型的代码位于:src/transformers/models/whisper 目录下。其主要功能都封装在 modeling_whisper.py 文件中。
调用入口:WhisperForConditionalGeneration类
此python文件中包含多个类,继承的关系比较复杂,它们之间的主要调用关系如下(以greedy search为例):
WhisperForConditionalGeneration (L1312) : 调用入口类
forward() (L1359) --> 细节在:WhisperModel,WhisperEncoder,WhisperDecoder 类中实现
generate() (L1455) --> 细节在: generation/utils.py#L1146 中实现
greedy_search(): L2164 --> 调用 search 函数来做实际的处理,比如自回归处理
Forward函数:
forward函数位于 class transformers.WhisperModel 类中,代码位置请参考:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L1215
def forward():
# Encoder将输入的语音信号,编码为声学信息,也就是 encoder_outputs
encoder_outputs = self.encoder(
input_features,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Decoder 的主要输入为 decoder_input_ids (对应文本) 和 encoder_outputs (对应声学信息,在翻译任务中,对应着源语言)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
通过上面,我们可以看到有 Encoder 部分和 Decoder 部分,分别对应声学特征的提取和文本产生部分。
在 WhisperForConditionalGeneration 类中,也有一个forward函数,是对上面forward函数的封装。
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L1359
其中self.encoder的实现代码位于 class WhisperEncoder(WhisperPreTrainedModel) 类中:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L735
其中self.decoder的实现代码位于 class WhisperDecoder(WhisperPreTrainedModel) 类中:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L881
Generate函数
Generate函数入口:位于 class WhisperForConditionalGeneration(WhisperPreTrainedModel) 类中:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L1455
此处只是调用的入口,具体的实现代码位于 class GenerationMixin 类中:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/generation/utils.py#L1146
def generate() L1146
其中generate函数使用的greedy_search的实现位于:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/generation/utils.py#L2164
下面,我们来进一步了解 generate 的实现代码,来看看如何对此代码进行修改。
入口代码
Generate函数的入口位于: WhisperForConditionalGeneration类中的 def generate 函数
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L1455
代码的概要如下,从代码中可以看到,这个函数主要是进行了一些参数设置,具体的实现是调用了父类中的对应函数来执行的。
def generate()
# 参数设置部分
# 调用部分(此处调用了父类中的generate实现)
return super().generate(
inputs,
generation_config,
logits_processor,
stopping_criteria,
prefix_allowed_tokens_fn,
synced_gpus,
**kwargs,
)
然后,我们可以逐级向上搜索其父类,可以看到
到此为止,我们就可以看到,具体的实现都在 GenerationMixin 类中。
Generate函数实现细节
下面,我们来看一下 GenerationMixin类中的 generate 函数的实现细节。
代码位置:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/generation/utils.py#L1146
其代码概要如下:
def generate(): L1146
# 根据解码方式的不同,此函数中有最多14步的处理步骤,我们以greedy search为例
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
# 2. Set generation parameters if not already defined
# 3. Define model inputs
# 4. Define other model kwargs
# 5. Prepare `input_ids` which will be used for auto-regressive generation
# 6. Prepare `max_length` depending on other stopping criteria.
# 7. determine generation mode
# 8. prepare distribution pre_processing samplers
# 9. prepare stopping criteria
# 10. go into different generation modes
# 11. run greedy search (L1515)
def greedy_search(): L2164
# 初始化,设置
# 循环处理
while True: # L2317
# prepare model inputs (下面函数的具体实现位于: modeling_whisper.py#L1627)
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
# 这里是调用了 WhisperForConditionalGeneration 中的forward函数。这是因为 PyTorch 的 nn.Module 基类定义了一个 __call__ 方法,当你调用模型实例(即 self)时,它会自动调用这个 __call__ 方法,而这个 __call__ 方法又会调用 forward 方法。
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# 得到下一个token的logits
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution 得到其score
next_tokens_scores = logits_processor(input_ids, next_token_logits)
# argmax :使用argmax 获取对应的 tokens
next_tokens = torch.argmax(next_tokens_scores, dim=-1)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
# 判断是否结束search:
# if eos_token was found in one sentence, set sentence to finished
# stop if we exceed the maximum length
通过上面的代码概要,我们就可以知道generate函数进行了很多的设置以后,会调用 greedy_search() 函数来进行文本产生的实际处理。
到此为止,我们就已经对整个的代码结构了解了。
下面我们通过几个问题,来回顾一下对代码的理解。
针对上面的问题7,如果要对generate或者其他部分进行修改,建议在 models/whisper的目录下对父类函数进行重构。
比如,如果要对greedy_search功能进行调整来实现一些独特的功能时,可以在modeling_whisper.py中重构 greedy_search(),具体做法可以是:
函数在子类中被重新实现之后,调用时,将会优先调用新重构的函数。这样既实现了自己独特的功能,还不影响其他的模型的运行。
文章转载自: Transformers Generate 功能介绍