0%

ChatGPT使用的Transfomer模型

作为一个一直对AI技术很感兴趣的软件开发工程师,早在深度学习开始火起来的15、16年,我也开始了相关技术的学习。当时还组织了公司内部同样有兴趣的同学一起研究,最终的成果汇集成几次社区中的分享以及几篇学习文章(见这里)。

从去年OpenAI发布ChatGPT以来,AI的能力再次惊艳了世人。在这样的一个时间节点,重新去学习相关技术显得很有必要。

ChatGPT的内容很多,我计划采用一个系列,多篇文章来分享学习我自己学习过程中的一些理解。本系列文章,我将站在一个普通开发人员的角度展开,希望对想了解ChatGPT技术原理的普通开发者们有帮助。

ChatGPT本身就具备很丰富的知识,所以ChatGPT自身实际上就是一个很好的学习渠道,我也将借助ChatGPT来学习ChatGPT。

这是此系列的第三篇,ChatGPT使用的Transfomer模型。

上一篇文章我们聊到了ChatGPT使用的技术概览。了解了其最核心的模型结构是Transformer结构,本文来聊一聊Transformer模型。

介绍

Transformer的网络结构最早是Google在2017年的时候提出的,论文名称是《Attention Is All You Need》。从论文名称也能看出,Transformer结构强调了注意力机制在网络结构中的表示和应用。

当时这篇论文面世时,不少研究人员还认为标题有点夸大了注意力机制的作用。现在来看,似乎还真有注意力机制一统天下的势头。

下面我们将一起来揭开这个网络结构的面纱。

原始的Transfomer模型

原始的Transformer的整体结构比较复杂,以下是来自论文中的截图。

Model Architecture

可以看到,Transformer网络的主要由编码器和解码器组成。虽然看起来复杂,但实际上,编码器和解码器都是由多个相同的层堆叠而成,并且编码器和解码器结构也很相似。

编码器(Encoder)每一层内结构为:

  • 输入嵌入(Input Embedding):将输入序列中的每个单词或符号转换为连续的向量表示。
  • 位置编码(Positional Encoding):为输入序列中的每个位置添加一个表示位置信息的向量。
  • 多头自注意力(Multi-Head Self-Attention):通过对输入序列中的每个位置进行自注意力计算,从全局上理解输入序列间的关系和重要性。
  • 前馈神经网络(Feed-Forward Neural Network):在每个位置上应用一个全连接前馈神经网络,以对自注意力输出进行进一步的非线性变换。
  • 残差连接(Residual Connections)和层归一化(Layer Normalization):在每个子层之间应用残差连接和层归一化,以帮助梯度流动和减少训练中的梯度消失问题。

解码器(Decoder)每一层内结构为:

  • 编码器-解码器注意力(Encoder-Decoder Attention):除了自注意力,解码器还对编码器的输出进行注意力计算,以利用编码器对输入序列的理解。
  • 解码器自注意力(Decoder Self-Attention):类似于编码器的自注意力,但在解码器中应用于当前位置以前的输出。
  • 前馈神经网络:与编码器中的前馈神经网络相同。
  • 残差连接和层归一化:与编码器中的残差连接和层归一化相同。

通过堆叠多个编码器和解码器层,Transformer可以具备强大的能力。注意力机制还允许Transformer网络模型自动学习输入序列中的各个单词的依赖关系,并且可以通过并行计算来加速计算过程。

这里有一篇博客详细的介绍了每一个结构内部的实现机制。推荐大家阅读以了解细节。

如果希望阅读完整的代码,Transformer的完整代码在Google的TensorFlow框架和Meta的PyTorch框架中均有实现。TensorFlow的代码入库在这里,不过其代码风格偏函数式风格,并不是很容易理解。PyTorch中的代码相对更容易理解,有兴趣阅读代码的可以看这里,只需要阅读其forward函数即可了解到整个网络的结构。

ChatGPT中的Transfomer模型

ChatGPT中的Transformer模型与原始的Transformer模型有一些差异。主要区别是将Transformer中的Encoder-Decoder双模块设计简化为只有一个Decoder模块。其实也可以认为是只有一个Encoder模块,因为Encoder和Decoder模块本来就很相似。这里之所为大家认为是Decoder,是因为Transformer和ChatGPT的Decoder是自循环的,因为Decoder会根据前一部分的文本生成下一个单词。

在这个单模块中,Self-Attention被替换为了Masked Self-Attention。

Masked Self-Attention在计算时,会将当前输入文本中不存在的部分给遮蔽掉,只对已知的文本信息进行计算。遮蔽其实只是在训练阶段有效,因为训练阶段的输入文本是已知的所有文本。遮蔽掉当前单词的后续单词就可以让模型在无法获取后面单词的信息,使得这一场景与预测阶段的一致。

作为一个程序员,如果不能从代码的粒度去理解,始终会觉得理解不够透彻。下面我们结合代码来详细了解一下Transformer的计算过程。

在这里,我将用来做LLAMA模型的代码实现作为参考,与大家一起结合代码进行分析。LLAMA模型是Meta的研究团队开发的一个与ChatGPT类似的模型,其核心模型结构与ChatGPT的模型是一致的。

LLAMA的实现代码非常短,很适合拿来作为学习材料。完整的代码在这里。下面将结合代码与Transformer的原理进行分析。

文本生成逻辑

LLAMA生成文本的代码入口在这里,下面说明一下代码中关键的行为:

(为了说明代码主要的功能,以下代码仅截取了关键的代码行,并进行了注释,以便大家更容易阅读。)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# 整个程序的入口函数
def main(...):
# 调用下面的load函数,创建一个LLaMA对象,用于生成文本
generator = load(...)
# 调用LLaMA对象,根据传入的文本,以及最大生成长度、温度、单词概率选择
results = generator.generate(prompts, max_gen_len=256, temperature=temperature, top_p=top_p)
# 打印结果
for result in results:
print(result)
print("\n==================================\n")

def load(...) -> LLaMA:
# 加载保存的模型参数
checkpoint = torch.load(ckpt_path, map_location="cpu")
model_args: ModelArgs = ModelArgs(...)
# 初始化一个Tokenizer对象,此Tokenizer其实也是一个机器学习模型,用于将文本切分为单词,并将单词编码为整型数值
tokenizer = Tokenizer(model_path=tokenizer_path)
# 初始化核心的Transformer模型
model = Transformer(model_args)
# 构造LLaMA的文本生成器对象,并返回
generator = LLaMA(model, tokenizer)
return generator

class LLaMA:
def generate(...) -> List[str]:
# 将输入的文本编码为数值
prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
# 用上面的文本数值编码创建一个适合模型输入的矩阵(不超过模型能支持的最大长度),长度太短的文本用pad_id填充
tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long()
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t).long()
# 根据配置的文本生成长度,迭代生成文本,一次生成一个单词
for cur_pos in range(start_pos, total_len):
# 调用模型进行计算
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
# 如果有传入温度参数,将输出结果根据温度放大,并选择累计概率达到top-p概率的那些结果
# 关于温度和top-p参数的详细解释见下文
probs = torch.softmax(logits / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
# 如果没有传入,则直接选择概率最大的那个单词
next_token = torch.argmax(logits, dim=-1)
# 将得到的单词加回原来的文本,继续生成下一个单词
tokens[:, cur_pos] = next_token
# 通过单词编码器将生成的数值型文本反编码为可读的文本,并返回
decoded = []
for i, t in enumerate(tokens.tolist()):
decoded.append(self.tokenizer.decode(t))
return decoded

这就是入口代码的主要逻辑。下面我们分析一下涉及到的几个核心子步骤。

词嵌入

词嵌入是将文本编码为数值的过程。LLaMA在进行词嵌入时,选择了sentencepiece库来实现。

SentencePiece 是一个开源的文本处理库,用于处理和生成分词模型。它的主要作用是将文本分割成子词(subwords)或标记(tokens),以便用于各种自然语言处理任务,例如机器翻译、文本分类、命名实体识别等。

SentencePiece 提供了基于不同分割算法的分词方法,包括未经训练的模型和基于训练数据的模型。它支持的分割算法包括 BPE(Byte-Pair Encoding)、Unigram 等。使用 SentencePiece,可以根据具体任务和需求创建自定义的分词模型。

通过使用 SentencePiece 库,可以实现以下功能:

  • 文本分词:将文本分割成子词或标记,提供更细粒度的语言处理单元。
  • 词汇表生成:根据训练数据生成词汇表,用于构建词汇表索引或编码器-解码器模型。
  • 子词编码:将文本转换为子词序列,以便在模型中进行处理和表示。
  • 子词解码:将子词序列转换回原始文本,用于生成文本或进行后处理。

SentencePiece 被广泛应用于各种自然语言处理任务和模型,特别是在跨语言和非常规语言处理方面具有很大的灵活性和适应性。它的灵活性使得可以根据不同语言、文本类型和任务的需求,定制化地构建分词模型,从而提高模型性能和效果。

下面是Tokenizer的核心代码分析:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Tokenizer:
def __init__(self, model_path: str):
# 根据模型文件创建SentencePieceProcessor对象
self.sp_model = SentencePieceProcessor(model_file=model_path)
# 保存词汇表大小及一些关键的ID,如开始、结束符、填充符等
self.n_words: int = self.sp_model.vocab_size()
self.bos_id: int = self.sp_model.bos_id()
self.eos_id: int = self.sp_model.eos_id()
self.pad_id: int = self.sp_model.pad_id()

def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
# 将文本编码为数值序列,并根据参数添加开始、结束符
t = self.sp_model.encode(s)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t

def decode(self, t: List[int]) -> str:
# 将文本数值序列反编码为可读文本
return self.sp_model.decode(t)

温度参数

生成文本时,有两个重要的参数:温度(temperature)和top-p。它们是如何产生作用的呢?

温度参数(temperature)在生成文本过程中起到控制多样性的作用。较高的温度值会增加生成文本时的随机性,使得模型更加倾向于选择概率较小的标记,从而产生更多样化的输出。

具体来说,温度参数会影响 softmax 操作中的指数运算。在 softmax 函数中,通过将 logits 值进行指数运算并归一化,将其转换为概率分布。温度参数的作用是调整指数运算的敏感度。较高的温度值会使指数运算的结果更加平滑,减小了各个标记之间的概率差异,降低了概率较大的标记相对于概率较小的标记的优势。这样,在生成过程中,模型更有可能选择概率较小的标记,从而产生更多样化的输出。

举个例子,假设有一个具有三个候选标记的生成任务,对应的 logits[1.0, 2.0, 3.0]。当温度参数为较低的值(例如1.0)时,通过 softmax 运算后,对应的概率分布为 [0.09, 0.24, 0.67]。可以看到,概率较大的标记 3 相对于其他标记有明显优势,模型更有可能选择标记 3。而当温度参数为较高的值(例如2.0)时,通过 softmax 运算后,对应的概率分布为 [0.19, 0.31, 0.51]。可以看到,概率差异缩小,标记 3 相对于其他标记的优势减小,模型更容易在标记之间进行随机选择。

因此,通过调整温度参数,可以在生成文本时控制多样性。较高的温度值可以增加生成文本的随机性,产生更多样化的输出;而较低的温度值可以增加生成文本的准确性,更倾向于选择概率较大的标记。根据具体的任务需求和应用场景,可以选择合适的温度值来平衡准确性和多样性之间的权衡。

top-p参数

top-p参数用于控制生成文本时的文本选择范围。

实现时,首先,计算 softmax 操作后的概率分布。然后,按照概率从高到低的顺序对概率进行排序。接下来,按照累积概率的方式逐个考虑排名靠前的标记,直到累积概率超过 top-p 的阈值。此时,只有排名靠前的文本才会被保留在选择范围内,其他排名较低的文本会被舍弃。

换句话说,top-p 参数通过动态地确定生成时所需的标记范围,使得生成的结果更加多样化且避免选择概率极低的标记。这种方式比传统的 top-k 采样更加灵活,因为 top-p 参数不依赖于固定的 k 值,而是根据概率分布动态地确定需要保留的标记数量。

下面函数是相关的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def sample_top_p(probs, p):
# 对概率进行排序并记录排序后的索引
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
# 计算累积概率
probs_sum = torch.cumsum(probs_sort, dim=-1)
# 创建一个布尔掩码,用于确定哪些概率需要保留
mask = probs_sum - probs_sort > p
# 将超过 top-p 阈值的概率置为 0
probs_sort[mask] = 0.0
# 将概率归一化,使其和为 1
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
# 从归一化的概率分布中进行多项式分布采样,得到下一个标记
next_token = torch.multinomial(probs_sort, num_samples=1)
# 使用排序后的索引获取对应的下一个标记
next_token = torch.gather(probs_idx, -1, next_token)
return next_token

函数接受两个参数:probs 是经过 softmax 操作得到的概率分布,ptop-p 参数,用于确定保留的概率范围。

根据top-p进行结果选择的逻辑如下:

  • 对概率 probs 进行排序,并记录排序后的索引,使得概率从高到低排列。
  • 计算概率的累积和。
  • 创建一个布尔掩码,用于确定哪些概率需要保留。如果累积概率超过了 top-p 阈值,则对应的概率置为 0。
  • 将概率归一化,使其和为 1,以便进行多项式分布采样。
  • 使用多项式采样方法从归一化的概率分布中选取一个下一个标记。
  • 使用排序后的索引 probs_idx 获取对应的下一个标记。
  • 返回选取的下一个标记 next_token
  • 这段代码实现了根据 top-p 参数选择结果的逻辑,确保生成的结果在给定的概率范围内,并增加生成文本的多样性。

模型结构

在生成文本的主流程中,构造了Transformer模型进行下一个单词的预测,下面分析一下Transformer模型的结构。

下面的代码需要有一些PyTorch构建神经网络模型的基础知识。对于不了解相关知识的同学,以下是一些要点:

  • PyTorch抽象了一个Module类用于构建基本的模型构造块
  • 在构建模型构造块时,需要继承Module类并实现其forward方法将输入变换为输出
  • 在构建模型构造块时,需要在类的初始化方法__init__中初始化用到的子构造块
  • 在构建模型构造块时,一般不需要关注参数更新的部分,PyTorch提供了自动计算梯度(参数的偏导数)的机制
  • PyTorch提供了很多内置的模块或函数,如full triu matmul silu等,帮助我们更快的复用标准构造块

Transformer相关的完整代码在这里,下面分析一下关键的实现。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
...
# 构造文本嵌入,用于将文本转化为向量表示
self.tok_embeddings = ParallelEmbedding(params.vocab_size, params.dim, ...)
# 根据模型参数指定的Transformer层数,创建核心网络结构
self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))
# RMSNorm归一化算子,见下文解释
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
# 对输出的文本进行最终的线性变换的算子
self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, ...)
# 预计算频率的复数表示,见下文旋转嵌入部分的分析
self.freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2)

def forward(self, tokens: torch.Tensor, start_pos: int):
# 对传入的文本取嵌入向量
h = self.tok_embeddings(tokens)
# 预计算旋转嵌入的旋转频率。可以减少在每个前向传播步骤中的重复计算,提高模型的运行效率。
# 这里用到了一些复数计算技巧,详见下文注意力机制部分分析
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
# 在传入的文本长度大于1时,构造一个上三角矩阵作为掩码,用于遮盖未生成的字词部分
mask = None
if seqlen > 1:
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
# 调用每一个Transfomer分层进行计算
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
# 归一化最后的结果
h = self.norm(h)
# 取计算出来的最后一个词,并进行最后的线性变换后作为输出返回
output = self.output(h[:, -1, :]) # only compute last logits
return output.float()

上述代码用到的核心结构TransformerBlock代码分析如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
# 构造注意力部分结构,见下文注意力机制部分
self.attention = Attention(args)
# 构造前馈神经网络部分结构,见下文前馈神经网络部分
self.feed_forward = FeedForward(dim=args.dim, hidden_dim=4 * args.dim, ...)
# 注意力部分用到的RMSNorm归一化算子,见下文解释
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
# 前馈神经网络部分用到的RMSNorm归一化算子,见下文解释
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
# 将输入归一化,并计算注意力,然后加上x以形成残差结构
h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask)
# 将上述结果进行归一化,并计算前馈神经网络部分,然后加上h以形成残差结构
out = h + self.feed_forward.forward(self.ffn_norm(h))
return out

注意力机制

Transformer 中的注意力机制(Attention Mechanism)是核心组成部分之一,它在模型中用于捕捉输入序列中的相关信息,并为每个位置分配权重。

注意力的意思就是让模型关注在重要的地方,权重比较高的位置将得到更多的关注。如何实现?通过在每个位置上计算一个加权和就可以了!

Transformer 中使用的是自注意力机制(Self-Attention),即将输入序列中的每个位置视为查询(query)、键(key)和值(value)。通过计算查询与键的相似度得到权重分布,然后将权重与值进行加权求和得到每个位置的输出。

下面是 Transformer 中自注意力机制的主要步骤:

  • 对输入序列进行线性变换,分别得到查询(Q)、键(K)和值(V)。
  • 计算查询与键的相似度分数,通常使用点积或其他函数(如缩放点积)计算相似度。
  • 对相似度分数进行归一化处理,通过 softmax 函数将分数转换为注意力权重。
  • 将权重与值进行加权求和,得到加权和作为该位置的输出。
  • 将每个位置的输出进行线性变换,得到最终的自注意力输出。

自注意力机制的优势在于它能够捕捉输入序列中的长距离依赖关系,并且能够对不同位置之间的相关性进行灵活的建模。通过自注意力机制,Transformer 可以同时考虑输入序列中所有位置的信息,而无需像循环神经网络那样依次处理序列。

在 Transformer 中,注意力机制通常通过多头注意力(Multi-Head Attention)来进行扩展,即使用多组不同的查询、键和值进行注意力计算,并将它们的输出进行拼接和线性变换,以增加模型的表达能力和学习能力。

总结起来,注意力机制是 Transformer 模型中重要的组成部分,它通过计算查询与键的相似度来为每个位置分配权重,并将权重与值进行加权求和得到输出。它能够捕捉输入序列中的相关信息,提升模型的表达能力和学习能力。

TransformerBlock代码使用到的核心的Attention模块就是注意力机制的实现。这个模块的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
# 构造注意力查询(Q)、键(K)和值(V)所需要的线性变换算子
# 这里直接用一个变换算子支持了多头的场景,因为每个头实际上计算方式是完全一样的,只是参数不同
self.wq = ColumnParallelLinear(args.dim, args.n_heads * self.head_dim, ...)
self.wk = ColumnParallelLinear(args.dim, args.n_heads * self.head_dim, ...)
self.wv = ColumnParallelLinear(args.dim, args.n_heads * self.head_dim, ...)
# 构造对最终输出进行线性变换的算子
self.wo = RowParallelLinear(args.n_heads * self.head_dim, args.dim, ...)

def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
# 对输入序列进行线性变换,分别得到查询(Q)、键(K)和值(V)。
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
# 对查询和键应用旋转嵌入(Rotary Embedding)操作
# 旋转嵌入是一种在注意力机制中引入周期性信息的技术,有助于模型捕捉序列的顺序关系
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

# 更新缓存中的键(K)和值(V),将当前位置的键和值存储在缓存中以供后续的注意力计算使用。
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

# 从缓存中获取用于注意力计算的键(K)和值(V),包括当前位置之前的所有位置。
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]

# 对查询、键和值进行维度转置,以便进行矩阵乘法操作。
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
# 计算查询和键之间的相似度得分,通过矩阵乘法计算得到,同时除以头的维度的平方根来进行缩放,以控制相似度的范围。
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
# 如果存在掩码(mask),则将其加到相似度得分上,以屏蔽无效位置的影响。
scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen)

# 对相似度得分进行 softmax 操作,将其转换为注意力权重,使得权重在每个位置的分布总和为 1。
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
# 根据注意力权重对值进行加权求和,得到最终的注意力输出。
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
...
# 对注意力输出进行线性变换,得到最终的注意力机制的输出。
return self.wo(output)

旋转嵌入

旋转嵌入(Rotary Embedding)是一种在注意力机制中引入周期性信息的技术,用于增强模型对序列中的顺序关系的建模能力。它通过将输入的查询(Q)和键(K)进行旋转操作,以捕捉序列中位置之间的相对角度。

在注意力机制中,查询和键是通过点积运算来计算相似度得分的,而点积运算本质上是计算两个向量的内积。通过旋转嵌入,可以将原始的查询和键进行旋转操作,将它们的信息编码成一个复数的表示形式,从而引入角度信息。

旋转嵌入的具体操作如下:

  • 首先,将查询和键的维度分为实部和虚部两部分。
  • 然后,使用三角函数(sin 和 cos)计算旋转角度的正弦和余弦值。
  • 将原始的实部和虚部分别与正弦和余弦值相乘,得到旋转后的实部和虚部。
  • 最后,将旋转后的实部和虚部重新组合成查询和键的表示。

通过旋转嵌入,查询和键之间的点积运算相当于在复数域中进行了旋转操作,这样可以更好地处理序列中的相对位置关系。旋转嵌入的使用可以提升模型对序列中长距离依赖的建模能力,并有助于捕捉序列中的顺序信息。

需要注意的是,旋转嵌入只应用于查询和键,而值(V)保持不变。这是因为在注意力机制中,查询和键的作用是计算相似度得分,而值则用于根据得分对序列进行加权求和。旋转嵌入的引入主要是为了增强相似度计算的准确性,而对值的处理不需要引入旋转操作。

对应的代码为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
# 预计算旋转嵌入的旋转频率。可以减少在每个前向传播步骤中的重复计算,提高模型的运行效率。

# 计算旋转嵌入的频率 freqs。
# 1. 首先,生成一个从 0 到 dim 的整数序列,并取其中的偶数索引。
# 2. 然后,将这些索引转换为浮点数,并将其除以 dim 后取倒数。
# 这样可以生成一个频率递减的序列,用于控制旋转嵌入的旋转速度。
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
# 创建一个长度为 end 的序列 t,其中的值从 0 到 end-1。
t = torch.arange(end, device=freqs.device) # type: ignore
# 使用 torch.outer 函数计算旋转频率的复数形式。
# 将 t 与 freqs 进行外积,得到一个形状为 [end, dim // 2] 的张量,其中每个元素是一个复数,表示旋转频率。
freqs = torch.outer(t, freqs).float() # type: ignore
# 使用 torch.polar 函数将复数频率转换为极坐标形式,得到一个复数张量 freqs_cis。
# 该函数接受一个表示模长的张量(这里是全为1的张量)和一个表示相位的张量(这里是 freqs),并返回复数形式的张量。
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis

def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
# freqs_cis为旋转嵌入的旋转频率
# 将输入的查询张量和键张量进行形状变换:
# 1. 首先将其转换为浮点类型
# 2. 然后将最后两个维度重塑为两倍大小的维度,以便处理复数形式的旋转嵌入。
# 结果是一个形状为[batch_size, sequence_length, embedding_dim//2, 2]的复数张量。
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
# 将旋转频率进行形状调整,使其与查询张量的形状相匹配。
# 调整后的形状是根据查询张量形状的最后两个维度进行的,其他维度保持不变。
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
freqs_cis = freqs_cis.view(*shape)
# 将查询张量和键张量与旋转频率进行逐元素相乘。这相当于在复数域中将查询和键进行旋转操作。
# 并将旋转后的张量重新转换为实数形式,通过取实部得到最终的旋转嵌入结果。
# 将每个复数值展平为一个实数值。结果是形状为 [batch_size, sequence_length, embedding_dim] 的张量。
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)

旋转嵌入部分代码略显复杂,并且用到了一些数学计算技巧。如果大家在此理解有困难,也可以忽略它,只需要明白旋转嵌入是为了计算注意力中的查询和键的相似度即可。

如果我们阅读GPT2的代码,可以发现并没有使用旋转嵌入,只是简单的做了矩阵乘法。这是LLAMA引入的一个模型优化方式。

前馈神经网络

整个前馈神经网络的结构为:

  • 将输入进行线性变换并输入激活函数
  • 将输入进行另一个线性变换并与上述结果相乘
  • 将相乘后的结果再次经过线性变换得到最终的输出

对应的代码为TransformerBlock代码使用的FeedForward模块代码,如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int, multiple_of: int):
# 首先根据隐藏层维度的要求进行调整。将隐藏层维度的值设置为输入维度的 2/3,并将其转换为整数。
# 然后,使用 multiple_of 对隐藏层维度进行取整,确保隐藏层维度是 multiple_of 的倍数。
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

# 定义三个线性变换操作符 self.w1、self.w2 和 self.w3,分别用于前馈神经网络的第一层、第二层和第三层。
self.w1 = ColumnParallelLinear(dim, hidden_dim, ...)
self.w2 = RowParallelLinear(hidden_dim, dim, ...)
self.w3 = ColumnParallelLinear(dim, hidden_dim, ...)

def forward(self, x):
# 1. 将输入进行线性变换并输入激活函数
# 2. 将输入进行另一个线性变换并与上述结果相乘
# 3. 将相乘后的结果再次经过线性变换得到最终的输出
return self.w2(F.silu(self.w1(x)) * self.w3(x))

RMSNorm

上述代码中多次用到了RMSNorm归一化,这是什么技术呢?

其实,RMSNorm(Root Mean Square Normalization)是一种归一化技术,用于在神经网络中对输入进行标准化处理。它旨在增强网络的鲁棒性和稳定性,并有助于减轻输入数据中的噪声和变化对模型的影响。

RMSNorm 的核心思想是基于输入的均方根(RMS)进行标准化。它通过计算输入张量沿指定维度的均方根,并将每个元素除以该均方根值来进行归一化。这种归一化方法相比于传统的均值和方差归一化(例如 Batch Normalization)更加简单和直观。

其代码实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
# eps 参数是一个小的常数,用于避免分母为零的情况,确保数值稳定性。
self.eps = eps
# dim 参数表示输入张量的维度,即要在哪个维度上计算均方根并进行归一化。
# weight 是一个可学习的权重参数,用于缩放标准化后的输入。
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
# 计算输入张量的均方根,并将每个元素除以均方根值。
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
# 调用 _norm 方法对输入张量进行标准化处理,并将标准化后的结果与权重参数相乘,以进一步缩放和调整输出。
output = self._norm(x.float()).type_as(x)
return output * self.weight

掩码

掩码部分也有一些技巧,下面来看看它是如何实现的。

在Transformer的前向计算时,会计算一个掩码矩阵。然后,在计算注意力时,使用此掩码来遮蔽掉无效位置。对应的代码片段如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class Transformer(nn.Module):
def forward(self, tokens: torch.Tensor, start_pos: int):
...
# 在传入的文本长度大于1时,构造一个上三角矩阵作为掩码,用于遮盖未生成的字词部分
mask = None
if seqlen > 1:
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)

class Attention(nn.Module):
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
...
if mask is not None:
# 如果存在掩码(mask),则将其加到相似度得分上,以屏蔽无效位置的影响。
scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen)

在生成掩码时,上述代码生成了一个上三角掩码,以屏蔽未来位置的注意力。

在计算注意力分数时,通过将未来位置的分数设置为负无穷,可以使模型在自回归任务中只依赖于当前及之前的信息。这样可以确保模型在生成序列时不会看到未来位置的信息,保持了模型的自回归性质。

生成掩码的方式如下:

  • 首先,创建一个名为 mask 的变量,并将其初始化为 None。这意味着在开始时没有生成掩码。
  • 如果 seqlen 大于 1,表示当前处理的序列长度大于 1,存在需要屏蔽的位置。
  • 创建一个形状为 (1, 1, seqlen, seqlen) 的张量 mask,并将所有元素的值设为负无穷(float("-inf"))。这里使用 float("-inf") 是为了在计算注意力分数时将被掩盖的位置的注意力分数设为负无穷大,从而在softmax操作后将其值近似为0。
  • 使用 torch.triu() 函数将 mask 张量的下三角部分(包括对角线)设为负无穷。这是通过设置 diagonal 参数为 start_pos + 1 来实现的,表示从对角线位置 start_pos + 1 开始屏蔽。这样,注意力机制在计算时将只关注当前位置及之前的位置,而忽略之后的位置。
  • 最后,将 mask 张量的数据类型转换为输入张量 h 的数据类型,并将其赋值给 mask 变量。

在代码中,scores 与 mask 相加,实际上是将 mask 中的非负数值添加到 scores 对应位置的元素上。通过这样的操作,可以将特定位置的注意力分数调整为一个较小的值,从而有效地屏蔽或降低模型对该位置的关注度。

总结

到这里,我们就分析完了整个LLAMA模型的代码。需要注意的是,这里的代码只是LLAMA模型在生成文本时(即预测时)要执行的代码。LLAMA在训练阶段会有更多的技巧,也会涉及更多的代码。可惜Meta并没有公布相关的训练代码。

在分析代码时,我们有意忽略了模型并行处理的部分代码,这些是一些并行优化的机制,对于我们理解模型帮助不大。但如果我们希望将这个模型创建为一个服务,从而为大规模的用户服务时,并行处理部分就比较关键了。

在代码分析过程中,我借助了ChatGPT辅助进行理解,并引用了部分ChatGPT生成的内容,当然,也修正了ChatGPT回复中的一些明显的错误。在这个过程中,ChatGPT可以帮助提供足够多的详细的信息,我也深刻的体会到ChatGPT对于代码和我提出的问题的准确理解。可以说,ChatGPT很大程度上帮助我提升了代码分析的效率和学习的效率。

自ChatGPT发布以来,很多人认为这是一个人类走向通用人工智能的突破,也有一些人认为它其实没什么本质的改进。有很多人对自己的职业发展产生了很深的焦虑感,也有很多人感觉触碰到了科幻世界中的未来,还有很多人觉得又是一个可以好好捞一把的机会。

也许每个人都有必要去了解一下机器学习技术的原理,这样才能形成对它的理性的认知。

ChatGPT的内容很多,我计划采用一个系列,多篇文章来分享学习我自己学习过程中的一些理解。本系列文章,我将站在一个普通开发人员的角度展开,希望对想了解ChatGPT技术原理的普通开发者们有帮助。

这是此系列的第三篇,ChatGPT使用的Transfomer模型。

参考

欢迎关注我的其它发布渠道