1. 首页 > 头条 > 行业热门

GQA LLM 一文详解MHA MQA原理

前言

本文回顾一下MHA、GQA、MQA,详细解读下MHA、GQA、MQA这三种常见注意力机制的原理。

图1 MHA、GQA、MQA一览

self-attention

self-attention

在自注意力机制中,输入通常是一个统一的输入矩阵,而这个矩阵后续会通过乘以不同的权重矩阵来转换成三个不同的向量集合:查询向量Q、键向量K和值向量V。这三组向量是通过线性变换方式生成:

1.查询向量 (Q): Q=XW

2.键向量 (K): K=XW

3.值向量 (V): V=XW

W,W和W是 可学习的权重矩阵 ,分别对应于查询、键和值。这些矩阵的维度取决于模型的设计,通常它们的输出维度(列数) 是预先定义的,以满足特定的模型架构要求。 在Transformer模型中,使用不同的权重矩阵W,W和W来分别生成查询向量Q、键向量K和值向量V的 目的是为了允许模型在不同的表示空间中学习和抽取特征 。这样做增加了模型的灵活性和表达能力,允许模型分别优化用于匹配(Q 和K)和用于输出信息合成(V)的表示。

在自注意力和多头注意力机制中,使用 作为缩放因子进行缩放操作是为了防止在计算点积时由于维度较高导致的数值稳定性问题。这里的d是键向量的维度。 如果不进行缩放,当d较大时,点积的结果可能会变得非常大,这会导致在应用softmax函数时产生的梯度非常小。 因为softmax函数是通过指数函数计算的,大的输入值会使得部分输出接近于1,而其他接近于0,从而导致梯度消失,这会在反向传播过程中造成梯度非常小,使得学习变得非常缓慢。

通过点积结果除以 ,可以调整这些值的范围,使得它们不会太大。这样,softmax的输入在一个合适的范围内, 有助于避免极端的指数运算结果,从而保持数值稳定性和更有效的梯度流 。这个操作确保了即使在d很大的情况下, 注意力机制也能稳定并有效地学习。

代码实现

import torchimport torch.nn as nnimport torch.nn.functional as Fclass SelfAttention(nn.Module):def __init__(self, seq_length):super(SelfAttention, self).__init__()self.input_size = seq_length# 定义三个权重矩阵:Wq、Wk、Wvself.Wq = nn.Linear(seq_length, seq_length)# 线性变换self.Wk = nn.Linear(seq_length, seq_length)self.Wv = nn.Linear(seq_length, seq_length)def forward(self, input):# 计算Q,K,V 三个矩阵q = self.Wq(input)k = self.Wk(input)v = self.Wv(input)# 计算QK^T,即向量之间的相关度attention_scores = torch.matmul(q, k.transpose(-1, -2)) / torch.sqrt(torch.tensor(float(self.input_size)))# 计算向量权重,softmax归一化attention_weight = F.softmax(attention_scores, dim=-1)# 计算输出output = torch.matmul(attention_weight, v)return outputx = torch.randn(2, 3, 4)Self_Attention = SelfAttention(4)# 传入输入向量的维度output = Self_Attention(x)print(output.shape)

MHA(多头注意力)

Transformer 编码器块内的缩放点积注意力机制和多头注意力机制

MHA计算过程

代码实现

import torchimport torch.nn as nnclass MultiHeadAttention(nn.Module):def __init__(self, embed_dim, num_heads):super(MultiHeadAttention, self).__init__()self.num_heads = num_headsself.head_dim = embed_dim // num_headsself.wq = nn.Linear(embed_dim, embed_dim)self.wk = nn.Linear(embed_dim, embed_dim)self.wv = nn.Linear(embed_dim, embed_dim)self.wo = nn.Linear(embed_dim, embed_dim)def mh_split(self, hidden):batch_size = hidden.shape[0]x = hidden.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)return xdef forward(self, hidden_states, mask=None):batch_size = hidden_states.size(0)# 线性变换q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)# 多头切分q, k, v = self.mh_split(q), self.mh_split(k), self.mh_split(v)# 注意力计算scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))if mask is not None:scores = scores.masked_fill(mask == 0, float('-inf'))attention = torch.softmax(scores, dim=-1)output = torch.matmul(attention, v)# 拼接多头output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)# 线性变换output = self.wo(output)return outputx = torch.rand(2, 3, 36)print(x)output = MultiHeadAttention(36, 6)y = output(x)print(y.shape)

MHA 能够理解输入不同部分之间的关系。然而,这种复杂性是有代价的——对内存带宽的需求很大,尤其是在解码器推理期间。主要问题的关键在于内存开销。 在自回归模型中,每个解码步骤都需要加载解码器权重以及所有注意键和值。这个过程不仅计算量大,而且内存带宽也大。随着模型规模的扩大,这种开销也会增加,使得扩展变得越来越艰巨。

因此,多查询注意 (MQA) 应运而生,成为缓解这一瓶颈的解决方案。其理念简单而有效: 使用多个查询头,但只使用一个键和值头。这种方法显著减少了内存负载,提高了推理速度。

MQA(多查询注意力)

图2 MHA和MQA的差别

MQA是MHA的一种变体,也是用于自回归解码的一种注意力机制。,图1、图2很形象的描绘了MHA和MQA的对比,与MHA 不同的是, MQA 让所有的Head之间共享同样的一份 K 和 V 矩阵(意味K和V的计算唯一),只让 Q 保留了原始多头的性质 (每个Head存在不同的转换),从而大大减少 K 和 V 矩阵的参数量以及KV Cache的显存占用,以此来达到提升推理速度,但是会带来精度上的损失。MQA被大量应用于LLM中,如ChatGLM2。

左 - 多头注意力,中 - 多查询注意力,右 - 将现有的 MHA 检查点转换为 MQA

如何将现有的预训练多头注意力模型转换为多查询注意力模型 (MQA)? 从现有的多头模型创建多查询注意力模型涉及两个步骤:模型结构的转换和随后的预训练。

代码实现

import torchimport torch.nn as nnclass MultiQuerySelfAttention(nn.Module):def __init__(self, embed_dim, num_heads):super(MultiQuerySelfAttention, self).__init__()self.num_heads = num_headsself.head_dim = embed_dim // num_headsself.wq = nn.Linear(embed_dim, embed_dim)# MHA# self.wk = nn.Linear(embed_dim, embed_dim)# self.wv = nn.Linear(embed_dim, embed_dim)# MQAself.wk = nn.Linear(embed_dim, self.head_dim)self.wv = nn.Linear(embed_dim, self.head_dim)self.wo = nn.Linear(embed_dim, embed_dim)def q_h_split(self, hidden, head_num=None):batch_size, seq_len = hidden.size()[:2]# q拆分多头if head_num == None:x = hidden.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)return xelse:# 这是MQA: 需要拆分k和v,这里面的head_num =1 的# 最终返回维度(batch_size, 1, seq_len, head_dim)return hidden.view(batch_size, seq_len, head_num, self.head_dim).transpose(1, 2)def forward(self, hidden_states, mask=None):batch_size = hidden_states.size(0)# 线性变换q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)# 多头切分# 这是MHA的# q, k ,v= self.split(q), self.split(k), self.split(v)# 这是MQA的q, k, v = self.q_h_split(q), self.q_h_split(k, 1), self.q_h_split(v, 1)# 注意力计算scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))print("scores:", scores.shape)if mask is not None:scores = scores.masked_fill(mask == 0, float('-inf'))attention = torch.softmax(scores, dim=-1)output = torch.matmul(attention, v)# 多头合并output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)# 线性变换output = self.wo(output)return outputx = torch.rand(3, 12, 512)atten = MultiQuerySelfAttention(512, 8)y = atten(x)print(y.shape)

GQA(分组查询注意力)

虽然MQA方式大幅减小了参数数量,但是,带来推理加速的同时会造成模型性能损失,且在训练过程使得模型变得不稳定( 复杂度的降低可能会导致质量下降和训练不稳定 ),因此在此基础上提出了GQA,它将Query进行分组,每个组内共享一组Key、Value。(GQA在LLaMA-2 和 Mistral7B得到应用)

GQA 的数学原理

分组:在 GQA 中,传统多头模型中的查询头 (Q) 被分成 G 组。每组分配一个键 (K) 和值 (V) 头。此配置表示为 GQA-G,其中 G 表示组数。

GQA 的特殊情况

对每个组中原始头部的键和值投影矩阵进行均值池化,以将MHA模型转换为 GQA 模型。此技术对组中每个头部的投影矩阵进行平均,从而为该组生成单个键和值投影。

通过 利用 GQA,该模型在 MHA 质量和 MQA 速度之间保持平衡 。由于键值对较少,内存带宽和数据加载需求被最小化。G 的选择代表了一种权衡:更多的组(更接近 MHA)可带来更高的质量但性能较慢,而更少的组(接近 MQA)可提高速度但有牺牲质量的风险。此外,随着模型规模的扩大,GQA 允许内存带宽和模型容量按比例减少,与模型规模相对应。相比之下,对于更大的模型,在 MQA 中减少到单个键和值头可能会过于严重。

代码实现

import torchimport torch.nn as nnclass GroupedQueryAttention(nn.Module):def __init__(self, embed_dim, num_heads):super(GroupedQueryAttention, self).__init__()self.num_heads = num_headsself.head_dim = embed_dim // num_headsself.wq = nn.Linear(embed_dim, embed_dim)# 这是MHA的# self.wk = nn.Linear(embed_dim, embed_dim)# self.wv = nn.Linear(embed_dim, embed_dim)# 这是MQA的# self.wk = nn.Linear(embed_dim, self.head_dim)# self.wv = nn.Linear(embed_dim, self.head_dim)# 这是GQA的self.group_num = 4# 这是4个组self.wk = nn.Linear(embed_dim, self.group_num * self.head_dim)self.wv = nn.Linear(embed_dim, self.group_num * self.head_dim)self.wo = nn.Linear(embed_dim, embed_dim)def split(self, hidden, group_num=None):batch_size, seq_len = hidden.size()[:2]# q需要拆分多头if group_num == None:x = hidden.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)return xelse:# 这是kv需要拆分的多头x = hidden.view(batch_size, seq_len, group_num, self.head_dim).transpose(1, 2)x = x[:, :, None, :, :].expand(batch_size, group_num, self.num_heads // group_num, seq_len,self.head_dim).reshape(batch_size, self.num_heads, seq_len, self.head_dim)return xdef forward(self, hidden_states, mask=None):batch_size = hidden_states.size(0)# 线性变换q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)# 多头切分# 这是MHA的# q, k ,v= self.split(q), self.split(k), self.split(v)# 这是MQA的# q, k ,v= self.split(q), self.split(k, 1), self.split(v, 1)# 这是GQA的q, k, v = self.split(q), self.split(k, self.group_num), self.split(v, self.group_num)# 注意力计算scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))print("scores:", scores.shape)if mask is not None:scores = scores.masked_fill(mask == 0, float('-inf'))attention = torch.softmax(scores, dim=-1)output = torch.matmul(attention, v)# 合并多头output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)# 线性变换output = self.wo(output)return outputx = torch.ones(3, 12, 512)atten = GroupedQueryAttention(512, 8)y = atten(x)print(y.shape)

参考文献

原文链接:​ ​​ ​

本网站的文章部分内容可能来源于网络和网友发布,仅供大家学习与参考,如有侵权,请联系站长进行删除处理,不代表本网站立场,转载者并注明出处:https://jmbhsh.com/xingyeremen/32649.html

联系我们

QQ号:***

微信号:***

工作日:9:30-18:30,节假日休息