关注

大模型入门-大模型优化方法1


5.1 深入浅出 SwiGLU:大模型时代的新一代激活函数

在探究 LLaMA 等前沿大语言模型时,我们经常会遇到一个核心概念:SwiGLU 。它究竟是什么?为什么大模型纷纷抛弃了传统的 ReLU,转而投入它的怀抱?

1. 传统 Transformer 的痛点

在最初的 Transformer 架构中,前馈神经网络(FFN)层通常包含两个线性变换,并在中间插入一个非线性激活函数 。传统模型最爱用的是 ReLU

FFN(x,W1,W2,b1,b2)=ReLU(xW1+b1)W2+b2FFN(x,W_{1},W_{2},b_{1},b_{2})=ReLU(xW_{1}+b_{1})W_{2}+b_{2}FFN(x,W1,W2,b1,b2)=ReLU(xW1+b1)W2+b2

为了追求极致的计算效率,后来的许多实验中甚至去掉了偏置项(bias),简化为:

FFN(x,W1,W2)=ReLU(xW1)W2FFN(x,W_{1},W_{2})=ReLU(xW_{1})W_{2}FFN(x,W1,W2)=ReLU(xW1)W2

ReLU 的局限性:
ReLU 非常“生硬”。当输入小于 0 时,它直接一刀切为 0。这种截断虽然计算极快,但在深层语言模型中,往往会丢失一些微妙的负值信息梯度,导致部分神经元永远无法被激活(即“神经元死亡”问题)。

2. 激活函数的进化之路:从 Swish 到 SwiGLU

为了让网络变得更加平滑且具表现力,学术界开启了对激活函数的持续改造。

第一阶段:Swish 函数的诞生

研究人员提出了 Swish 激活函数。它的巧妙之处在于引入了 Sigmoid 函数 σ(x)\sigma(x)σ(x) ,其定义如下:

Swish1=x⋅σ(x)Swish_{1}=x\cdot\sigma(x)Swish1=xσ(x)

将其套入 FFN 中,就变成了:

FFNSwish(x,W1,W2)=Swish1(xW1)W2FFN_{Swish}(x,W_{1},W_{2})=Swish_{1}(xW_{1})W_{2}FFNSwish(x,W1,W2)=Swish1(xW1)W2

Swish 的一大特色是它在 x<0x < 0x<0 附近有一个平滑的过渡,允许少量的负值通过。这种平滑的非线性特性让模型在处理微小负面信号时有了更细腻的感知能力。

为了让您更直观地理解这些数学公式底层的曲线差异,您可以操作下方的可视化面板。观察当 X 取负值时,Swish 曲线独有的“平滑过渡”特性:

第二阶段:终极形态 SwiGLU

在 Swish 的基础上,门控线性单元(GLU, Gated Linear Unit)的思想被引入。GLU 的核心是设计一个“双通道”:一路走激活函数,一路走线性变换,然后将它们相乘。

结合了 Swish 激活函数的 GLU,就是大名鼎鼎的 SwiGLU。相比早期的设计,它引入了更多的权重矩阵 ,其完整公式为:

SwiGLU(x,W,V,b,c)=Swish1(xW+b)⊗(xV+c)SwiGLU(x,W,V,b,c)=Swish_{1}(xW+b)\otimes(xV+c)SwiGLU(x,W,V,b,c)=Swish1(xW+b)(xV+c)

注:公式中的 ⊗\otimes 代表元素级别的乘法(Element-wise multiplication)。

通过这种精巧的门控乘法机制,SwiGLU 能够更加动态地控制信息的流通,这一优化由 LLaMA 提出并验证 ,显著提升了大模型在复杂推理任务上的表现。


3. 代码实战:用 NumPy 手写 SwiGLU

理论看完,我们来看看它在代码中是如何运转的。以下是剥离了复杂深度学习框架、纯用 NumPy 实现的整合版 SwiGLU 核心逻辑。这份代码极其适合入门者逐行理解其内部的计算步骤 :

import numpy as np

def swiglu(x, W, V, b=None, c=None):
    """
    [cite_start]整合版 SwiGLU 激活函数,包含所有内部计算步骤 [cite: 3957]

    参数:
    [cite_start]x -- 输入张量,形状为 (..., input_dim) [cite: 3962]
    [cite_start]W -- 第一个权重矩阵,形状为 (input_dim, hidden_dim) [cite: 3952]
    [cite_start]V -- 第二个权重矩阵,形状为 (input_dim, hidden_dim) [cite: 3953]
    [cite_start]b -- 第一个偏置项,形状为 (hidden_dim,),可选 [cite: 3954]
    [cite_start]c -- 第二个偏置项,形状为 (hidden_dim,),可选 [cite: 3987]
    """
    # 第一路计算:处理 xW + b
    [cite_start]xW = np.dot(x, W) [cite: 3991]
    if b is not None:
        [cite_start]xW += b [cite: 3993]

    # 第二路计算(门控信号):处理 xV + c
    [cite_start]xV = np.dot(x, V) [cite: 3995]
    if c is not None:
        [cite_start]xV += c [cite: 3997]

    # [cite_start]核心激活步骤:内部计算 Sigmoid 和 Swish1 [cite: 3998]
    [cite_start]sigmoid_xW = 1 / (1 + np.exp(-xW)) # Sigmoid 的数学计算 [cite: 3998]
    [cite_start]swish1_xW = xW * sigmoid_xW        # 组合生成 Swish1 的结果 [cite: 3999, 4002]

    # [cite_start]元素级 (element-wise) 乘法,将激活后的第一路与门控的第二路相乘 [cite: 4000]
    [cite_start]return swish1_xW * xV [cite: 4001]

通过以上代码,我们可以清晰地看到,SwiGLU 实际上是把输入分成了两条支路分别进行线性映射,其中一条经过平滑的 Swish 激活后,再与另一条执行乘法操作。这种复杂的非线性交互过滤,正是它赋予前沿大模型强大生命力的奥秘所在。


5.2 RMSNorm:大模型时代的归一化新贵

在深入研究大模型(如 LLaMA)的架构时,细心的读者会发现:Transformer 结构图中的一些关键组件发生了变化。
其中一个极其显著的变化就是:归一化层不仅位置从后(Post-LN)移到了前(Pre-LN),而且传统的 Layer Normalization (LayerNorm) 被替换为了 RMSNorm (Root Mean Square Normalization)。

为什么要做这种替换?RMSNorm 到底有什么魔力?

1. 核心公式对比:LayerNorm vs. RMSNorm

要想搞懂 RMSNorm,我们得先看看老前辈 LayerNorm 是怎么工作的。

传统 LayerNorm

LayerNorm 的核心是去中心化(减去均值)加上缩放(除以方差)

  1. 计算标准差: σ=1n∑i=1n(xi−μ)2\sigma = \sqrt{\frac{1}{n}\sum_{i=1}^{n}(x_{i}-\mu)^{2}}σ=n1i=1n(xiμ)2 (注:μ\muμ 为样本均值)
  2. 应用归一化: yi=xi−E(x)σ∗γy_{i} = \frac{x_{i}-E(x)}{\sigma} * \gammayi=σxiE(x)γ

新晋 RMSNorm

RMSNorm 提出了一个非常大胆的简化:直接砍掉减去均值(去中心化)的操作,只保留缩放操作!

  1. 计算均方根 (RMS): RMS=1n∑i=1nxi2RMS = \sqrt{\frac{1}{n}\sum_{i=1}^{n}x_{i}^{2}}RMS=n1i=1nxi2
  2. 应用归一化: yi=xiRMS∗γy_{i} = \frac{x_{i}}{RMS} * \gammayi=RMSxiγ

一句话总结:RMSNorm 就是强行把均值假设为 0 的 LayerNorm。


2. 为什么要砍掉均值?(核心优势)

大家可能会疑惑,砍掉均值计算不会导致模型性能下降吗?

论文和大量实验给出了明确的答案:在绝大多数自然语言处理任务中,去中心化(减去均值)对最终性能的贡献微乎其微。

如上图所示的验证集 BLEU 分数对比:

  • 效果基本持平: LayerNorm 的得分是 22.6,而 RMSNorm 的得分是 22.4。两者在测试集上的表现并没有明显的差异。
  • 速度大幅提升: 这是 RMSNorm 最大的卖点。因为移除了计算均值和减去均值的步骤,RMSNorm 的计算效率大幅提升。实验数据显示,RMS-Norm 的计算效率提高了 32%(训练耗时从 LayerNorm 的 665s 缩短到了 RMSNorm 的 501s)。

在大模型动辄成百上千层的极深网络中,每一层节约一点点计算量,叠加起来就是极其可观的算力和时间节省。

为了让你直观感受这种计算上的简化,你可以通过下方的交互工具,输入一组数据,看看这两种归一化方法的计算过程和耗时差异:


3. Pytorch 代码实战

RMSNorm 的逻辑非常简洁,我们用十几行 Pytorch 代码就能手写一个完美的 RMSNorm 模块。这对于理解其底层原理极有帮助:

import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        # gamma (γ) 缩放参数,初始化为全是 1 的张量,模型在训练中会学习它
        self.weight = nn.Parameter(torch.ones(hidden_size)) 
        # 防止除以 0 的极小值
        self.eps = eps

    def forward(self, x):
        # 1. 计算平方的均值 (RMS 核心步骤)
        # x.pow(2) 对每个元素求平方,.mean(-1) 求最后一个维度的平均值
        mean_square = x.pow(2).mean(-1, keepdim=True)
        
        # 2. 归一化并应用缩放
        # x 除以均方根,然后再乘以学习到的权重参数
        return self.weight * x / torch.sqrt(mean_square + self.eps)

在这里插入图片描述

转载自CSDN-专业IT技术社区

原文链接:https://blog.csdn.net/qq_32146369/article/details/161335557

评论

赞0

评论列表

微信小程序
QQ小程序

关于作者

点赞数:0
关注数:0
粉丝:0
文章:0
关注标签:0
加入于:--