四十三、Scaling Laws for Neural Language Models[2020]

  1. language 为人工智能的研究提供了一个天然的领域,因为绝大多数的 reasoning task 都可以用语言有效地表达和评估,而且世界上的文本通过 generative modeling 为无监督学习提供了大量的数据。最近,深度学习在 language modeling 方面取得了快速进展, SOTA 模型在许多特定任务上的表现接近人类水平。

    人们可能预期语言建模的性能取决于模型结构、神经模型的规模、用于训练模型的算力、以及用于训练模型的数据。在论文 《Scaling Laws for Neural Language Models》 中,作者将实证研究 language modeling loss 对所有这些因素的依赖性,重点是 Transformer 架构。通过在语言任务上的性能的上限和下限,使作者能够研究 scale 中超过七个数量级的趋势。

    在整个过程中,作者将观察到性能的精确的 power-law scaling ,其中性能作为训练时间、上下文长度、数据集大小、模型大小、compute budget 的函数。

  2. 相关工作:

    • 幂律(power-law )可以从各种来源产生。在密度估计(density estimation ) 模型和随机森林模型中,随模型规模和数据集大小有关的 power-law scaling 可能与我们的结果有关。这些模型表明,幂律指数可能有一个非常粗糙的解释,即数据中相关特征数量的倒数。

    • 最近的工作 《Deep learning scaling is predictable,empirically》《Beyond human-level accuracy: Computational challenges in deep learning》 也研究了模型大小和数据大小之间的 scaling ;他们的工作可能是文献中最接近我们的工作。然而,请注意, 《Deep learning scaling is predictable,empirically》 发现数据集大小与模型大小的超线性 scaling ,而我们发现的是亚线性 scaling

      • 我们关于计算量的 optimal allocation 的发现和 《One epoch is all you need》 之间有一些相似之处,包括 power-law learning curve

      • EfficientNet 似乎也服从于准确率和模型大小之间的近似的幂律关系。

      • 最近的工作 《A constructive prediction of the generalization error across scales》研究了各种数据集的数据集大小和模型大小的 scaling ,并符合与我们类似的分析方法。

    • EfficientNet 主张以指数方式 scale 模型的深度和宽度,以获得图像模型的最佳性能,导致宽度作为深度的函数出现 power-law scaling 。我们发现,对于语言模型来说,这个幂律在 scaling up 时应该是大致为 1 (即, width/depth 的比例应该保持固定)。但更重要的是,我们发现,与语言模型的 overall scale 相比,精确的架构超参数并不重要。

      • 《Residual networks behave like ensembles of relatively shallow networks》 中,有人认为 deep 的模型可以作为较浅的模型的 ensemble ,这有可能解释这一发现。

      • 早期的工作 《Wide residual networks》 比较了宽度和深度,发现 wide ResNet 在图像分类上可以胜过 deep ResNet

      • 一些研究固定了每个数据样本的计算量(计算量往往与模型参数的数量成正比,固定这个值意味着固定模型大小),而我们研究的是与模型大小和 training computationscaling

    • 许多研究(《High-dimensional dynamics of generalization error inneural networks》《Reconciling modern machine learning and the bias-variance trade-off》对高度过参数化的模型(highly overparameterized model)的泛化进行了研究,发现当模型规模达到数据集规模时,会出现 "jamming transition"《Scaling description of generalization with number of parameters in deep learning》)(这可能需要超出典型实践许多数量级的 training ,尤其是不使用早停)。我们没有观察到这样的 transition ,并发现所需的训练数据在模型大小上呈亚线性 scaling 。模型规模的扩展,特别是在 large width 上的扩展,可能为思考我们的一些 scaling 关系提供了一个有用的框架。

      我们关于 optimization 的结果(如学习曲线的形状),很可能可以用一个 noisy 的二次模型来解释,它可以在现实环境中提供相当准确的预测(《Which algorithmic choices matter at which batch sizes? insights from a noisy quadratic model》)。定量评估这种联系将需要海森谱( Hessian spectrum) 的特性。

这篇论文直接看结论部分即可,剩余的大多数都是实验报告。

43.1 背景和模型

  1. 符号:

    • L :交叉熵损失。通常情况下,它将是一个上下文中所有 token 的交叉熵损失的平均值,但在某些情况下,我们会报告上下文中特定位置处的 token 的损失。

    • N:模型参数的数量,不包括所有 vocabulary embeddingpositional embedding

    • C6NBS:是对 non-embeddingtraining 总计算量的估计,其中 Bbatch sizeStraining steps 数量(即 parameter updates )。我们引入 PF-days 为单位,其中一个 PF-days 表示 1015×24×3600=8.64×1019 次浮点运算。

    • D :数据集的大小,单位为 token

    • Bcritcritical batch size,在后续内容中定义和讨论。使用 critical batch size 的训练提供了训练时间和计算效率之间的大致上的最佳 trade-off

    • Cmin:对达到给定 loss 的最小 non-embedding compute 的估计。这是在模型以远小于 critical batch sizebatch size 进行训练时将使用的 training compute

      较小的 batch size 可以实现较小的计算量、但是需要较大的训练时间。

    • Smin:达到给定 loss 所需的最小 training steps 数的估计。这也是如果模型在batch size 远大于 critical batch size 的情况下所使用的 training steps 数量。

      较大的 batch size 可以实现较小的训练时间、但是需要较大的计算量。

    • αX:对 loss 缩放的幂率函数的指数,其中 L(X)(1/X)αXX 可以为 N,D,C,S,B,Cmin 中的任意一种。

    注,这里的 loss 函数是 test loss,而不是 train loss

  2. 我们在 WebText2 上训练语言模型,这是 WebText数据集的扩展版本,使用 byte-pair encoding: BPE 进行 tokenizevocabulary sizenvocab=50257 。我们优化 autoregressive log-likelihood (即交叉熵损失),在 1024-token 的上下文中取平均,这也是我们的主要性能指标。我们在 WebText2 测试集、以及其它一些测试集上记录 loss 。我们主要训练 decoder-only Transformer 模型,尽管我们也训练 LSTM 模型和 Universal Transformer 以进行比较。

    模型的性能评估指标是测试集上的交叉熵损失。

a. Transformer 的参数 scaling 和计算量 scaling
  1. 我们使用超参数 nlayer (层数)、dmodelresidual stream 的维度)、dff (中间 feed-forward layer 的维度)、dattnattention output 的维度)和 nheads (每层 attention head 的数量)对Transformer 架构进行参数化。我们在输入上下文中包括 nctxtoken ,除另有说明外否则默认采用 nctx=1024

    我们用 N 来表示模型大小,我们把它定义为 non-embedding parameters 的数量(推导过程参考Table 1 ):

    (1)N2dmodel×nlayer×(2dattn+dff)=12nlayerdmodel2 with the standard dattn=dff/4=dmodel

    其中,我们排除了 bias 项和其他次要的项。我们的模型在 embedding 矩阵中也有 nvocab×dmodel 个参数,并使用 nctx×dmodel 个参数进行 positional embedding ,但我们在讨论 "模型大小" N 时不包括这些参数。我们将看到,这产生了明显更干净的 scaling law

    评估 Transformer 的前向传播大致上包括 Cforward2N+2nlayer×nctx×dmodel 个乘加操作(add-multiply operation ),其中系数 2 来自于矩阵乘法中使用的 multiply-accumulate operation 。下表中包含了更详细的每个操作的参数数量和计算次数。

  2. 对于 dmodel>nctx/12 的上下文和模型,每个 tokencontext-dependent 的计算成本是总计算量的一个相对较小的部分。由于我们主要研究 dmodelnctx/12 的模型,我们的 training compute 的估计中不包括 context-dependent 项。考虑到反向传播(大约是前向传播计算量的两倍),我们为每个 training token 将估计的 non-embedding compute 定义 C6N 个浮点运算。

    对于 GPT-3 175Bdmodel=12288nlayer=96dattn=dmodeldff=4×dmodel 。因此只需要 nctx<147456 即可满足上述条件。

b. 训练过程
  1. 除非另有说明,我们用 Adam 优化器训练模型,其中训练了固定的 2.5×105 步, batch size512 个序列,每个序列包含 1024 tokens

    • 由于内存限制,我们最大的模型(超过 1B 个参数)是用 Adafactor 训练的。

    • 我们试验了各种学习率和 schedules 。我们发现,收敛的结果在很大程度上与 learning rate schedule 无关。除非另有说明,我们的数据中包括的所有 training runs 都使用了同一个 learning rate schedule ,即:3000 步的线性预热,然后是余弦衰减学习率到零。

  2. 我们用各种学习率和 schedule 进行了实验。下图显示了一个小型语言模型的一系列 schedules 和测试结果。我们的结论是,只要 total summed learning rate 足够大,而且 schedule 包括一个预热期、以及最后衰减到接近零的学习率,那么 learning rate schedule 的选择基本上是不重要的。schedules 之间的方差似乎是统计学上的噪音,并为不同 training runs 之间的方差的 scale 提供一个粗略的衡量标准。在较大的模型上的实验表明,对于不同的模型大小,不同的随机种子之间的 final test loss 的方差的幅度是大致不变的。

    我们发现,较大的模型需要一个较小的学习率来防止发散,而较小的模型可以容忍较大的学习率。为了实现这一点,在大多数 runs 中使用了以下经验法则:

    (2)LR(N)0.003239±0.0001395×log(N)

    我们期望这个公式可以被改进。

    可能存在对 network width 的依赖性,可能是由 initialization scale 设定的。对于 N>1010 个参数,该公式也会被打破。尽管如此,我们发现它对我们所考虑的模型有足够的作用。

c. 数据集
  1. 我们在 GPT2中描述的 WebText 数据集的扩展版本上训练我们的模型。最初的 WebText 数据集是对 Reddit201712 月的外链(包含 Reddit 用户的至少三个点赞)的网络爬取。在第二个版本中(即,WebText2 ),我们增加了 20181 月至 10 月期间的 Reddit 外链,也是至少有 3 个点赞的。karma 阈值(即,点赞的数量阈值)作为一种启发式方法,用于判断人们是否认为该链接有趣或有用。新链接的文本是用 Newspaper3k python library 提取的。

    总的来说,该数据集由 20.3M 篇文档组成,包含 96GB 的文本和 1.62×1010 个单词。然后我们应用 GPT2 中描述的可逆的 tokenizer ,产生 2.29×1010token 。我们保留了其中的 6.6×108token 作为测试集,并且我们还在类似的 Books CorpusCommon CrawlEnglish Wikipedia 、以及公开的 Internet Books 集合中进行测试。

43.2 实验结果和 Basic Power Laws

  1. 为了刻画 language model scaling ,我们训练了各种各样的模型,其中改变了一些因子,包括:

    • 模型大小:从 7681.5Bnon-embedding 参数。

    • 数据集大小:从 22M230Btoken

    • 模型shape:包括 depthwidthattention head、以及 feed-forward dimension

    • 上下文长度:大多数 runs 设置为 1024,但我们也尝试使用较短的上下文。

    • batch size: 大多数 runs 设置为 219,但我们也改变 batch size 从而测量 critical batch size

43.2.1 Transformer Shape

  1. 当我们固定 non-embedding 参数总数 N 时, Transformer 的性能对 shape 参数 nlayer,nheads,dff 的依赖非常弱。

    为了确定这些结果,我们用固定的模型大小来训练模型,同时改变一个超参数。

    • 这对于 nheads 的情况是最简单的:改变 nheads 同时改变每个 head 的维度,使得 dmodel 保持不变。

      右图中,每条曲线对应于固定的 dmodel 同时改变 nheads

    • 当改变 nlayer 时,我们同时改变dmodel ,同时保持 N12nlayerdmodel2 固定。

      中间图中,每条曲线对应于固定的参数量同时改变 nlayerdmodel

    • 同样地,为了在固定的模型大小下改变 dff,我们也同时改变 dmodel 参数,这是由 Table 1 中的参数数量所要求的。

      左图中,每条曲线对应于固定的参数量同时改变 dffdmodel

    如果 deeper Transformer 有效地表现为 shallower modelensembles,那么就会导致 nlayer 的独立性,正如对 ResNet 的建议(《Residual networks behave like ensembles of relatively shallow networks》)。结果显示在下图中。

    下图中,不同的模型形状,得到的 test loss 差异很小,差异基本上都在 10% 以内。

43.2.2 non-embedding 参数数量 N

a. 实验结果
  1. 在下图中,我们展示了各种模型的性能,从 shape (nlayer,dmodel)=(2,128) 的小模型、到十亿参数的大模型,其中大模型的 shape(6,4288)(207,768) 不等。在这里,我们已经在完整的 WebText2 数据集上训练到接近收敛,并且观察到没有过拟合(除了非常大的模型,因为对于非常大的模型则可能过拟合)。

    Figure 1 右图所示,我们发现关于 non-embedding 参数数量 N 的一个稳定的趋势,即:

    (3)L(N)(NcN)αN

    要观察这些趋势,关键是要研究模型性能与 N 的关系。如果我们改用总的参数数量(包括 embedding 参数),趋势就有些模糊了(如下左图所示)。这表明,embedding matrix 可以做得更小而不影响性能,这在最近的工作中已经看到了(《Albert: A lite bert for self-supervised learning of language representations》)。

    ALBertembedding matrix 分解为两个低维矩阵的乘积,这等价于降低了 embedding matrix 的大小。

  2. 尽管这些模型是在 WebText2 数据集上训练的,但它们在其他各种数据集上的 test loss 也是 N 的幂律,指数几乎相同,如下左图所示。

    左图、右图的详细说明参考 “数据分布之间的泛化” 章节。

b. 与 LSTM 和 Universal Transformer的比较
  1. 在下图中,我们比较了 LSTMTransformer 的性能与 non-embedding 参数数量 N 的关系。LSTM 是用与 Transformer 相同的数据集和上下文长度来训练的。从右图中我们可以看出,LSTM 对于上下文位置中头部出现的 token 表现得和 Transformer 一样好,但是对于上下文位置中尾部出现的 token 则无法与 Transformer 的表现相媲美。

    红色曲线为 LSTM,蓝色曲线为 Transformer

  2. 下图中给出了模型性能与上下文位置之间的幂律关系,其中,除了第一个 token 之外(最上面的一条曲线),都显示了随着模型大小的增加而稳定地改善,表明快速识别模式的能力有所提高。

    我们还包括用很小的 nctx=8 上下文来训练的模型,以便与我们的更长的上下文模型进行比较(即,nctx=1024)。即使是在 nctx=8 的情况下训练的规模不大的模型也能在非常早期的 token 上超越我们最大的 nctx=1024 模型。这也表明,在大的上下文下训练的更大的模型应该可以有进一步的改进。

    因为对于 nctr=1024 的模型,它看到的上下文信息更多,理论上应该表现更好。

  3. 在固定模型大小的情况下, loss scale 似乎与上下文中的位置 T 成幂律关系,如下图所示。这可能是语言中潜在的 power-law correlation 的结果,或者是模型结构和 optimization 的一个更普遍的特征。它为在更大的上下文中训练的潜在好处(或潜在不足)提供了一些建议。不仅较大的模型在 T=1024 时收敛到较好的性能,而且在 early tokens 时也改善得更快,这表明较大的模型在检测 less contextual information 的模式时更有效率。

    在右边的图中,我们显示了对于一个固定的模型, per-token 性能是如何作为 training steps 的函数而变化的。early tokens 在训练过程中更快地被学好,而末尾的 tokens 在训练的后期才能训练好。

    左图:每条曲线对应一个模型(不同模型大小)的 per-token test loss ,在模型训练结束后。

    右图:每条曲线对应一个 token index 的学习曲线,在模型训练过程中,固定的模型。

  4. 我们还在下图中比较了标准 Transformerrecurrent Transformer《Universal transformers》)的性能。 recurrent Transformer 模型复用参数,因此表现得略好(右图),但代价是每个参数的额外计算量。

    左图把 reuse 的参数也认为是全新的,因此参数规模会变大。

c. 数据分布之间的泛化
  1. 我们还在一组额外的文本数据分布上测试了我们的模型。下图显示了这些数据集上的 test loss 与模型大小的关系。在所有情况下,模型都只在 WebText2 数据集上训练过。

    • 从左图中我们看到,在这些其他数据分布上的 loss 随着模型规模的增加而平滑地改善,直接地平行于 WebText2test loss 曲线。

    • 从右图中我们发现,泛化性几乎完全取决于 in-distribution validation loss ,而不取决于训练的持续时间、或接近于收敛的程度。

      虚线表示单个大型模型在它训练过程中,所得到的 test loss;圆点表示很多收敛的模型对应的 test loss

  2. 我们还观察到对模型深度没有依赖性(在固定模型大小的条件下)。

d. 性能与数据集大小和计算量的关系
  1. 我们在下图中显示了 test loss 作为数据集大小 D(以 token 数量为单位)和训练计算量 C 的函数的经验趋势。

    • 对于 D 上的趋势,我们在 WebText2 数据集的固定子集上训练了一个 (nlayer,nembd)=(36,1280) 的模型。一旦 test loss 不再减少,我们就停止训练。我们看到,所产生的 test loss 可以用数据集大小 D 的简单的幂律来拟合:

      (4)L(D)(DcD)αD

      Figure 1 中间的图所示。

    • 训练过程中使用的 non-embedding 计算总量可以估计为 C=6NBS ,其中 Bbatch sizeSparameter updates 的数量,系数 6 代表了同时包含前向传播和反向传播。因此,对于一个给定的 C 值,我们可以扫描所有具有不同 N 的模型,找到在 step 数量 S=C6BS 上具有最佳性能的模型。

      请注意,在这些结果中,batch size B 对所有模型都是固定的,这意味着这些经验结果不是真正的最优。我们将在后面的章节中使用调整后的 Cmin 来说明这一点,以产生更清晰的趋势。

      这个结果如Figure 1 左图的加粗黑线所示。它可以用如下的方程来拟合:

      (5)L(C)(CcC)αC

      此外,在左图中还包括每个模型的学习曲线,从而展示每个模型何时达到最优。

  2. 我们将在后面更仔细地研究计算量的最佳分配问题。数据强烈表明,sample efficiency 随着模型的大小增加而提高,如右图所示,对于给定的 test loss,模型越大所需要的样本数越少。

    • 左图:每条曲线表示在给定的 test loss (不同曲线采用不同的值)的条件下,最短的训练时间(由 minimum steps 衡量)和模型大小的关系。通常而言,模型参数越大,训练时间越短。

    • 右图:每条曲线表示在给定的 test loss (不同曲线采用不同的值)的条件下,最少的训练过程中看过的样本( E=B×S)和模型大小的关系。通常而言,模型参数越大,所需要的样本数越少。

43.3 Infinite Data Limit 和过拟合

  1. 在前面的内容中,我们发现了针对语言建模性能的一些 basic scaling laws 。这里,我们将研究在具有 Dtokens 的数据集上训练的大小为 N 的模型,并同时改变 ND 时模型的性能。我们将从经验上证明最佳的 test loss 符合公式 L(N,D)scaling law 。这为我们需要多少数据来训练规模不断扩大的模型、并同时控制过拟合提供了指导。

43.3.1 L(N, D) 公式

  1. 我们选择参数化 L(N,D) 为:

    (6)L(N,D)=[(NcN)αN/αD+DcD]αD

    其中,这个公式基于三个原则:

    • 词表大小或 tokenization 的变化有望通过一个整体因子来 rescale 损失函数。 L(N,D) 的参数化(以及所有 loss 的建模)自然必须考虑到这种 rescaling

    • 固定 D 并选择 N,那么 overall loss 应该接近 L(D) 。相反地,固定 N 并选择 D ,那么 overall loss 应该接近 L(N)

    • L(N,D) 应该在 D= 处是 analytic 的,因此它在 1/D 处具有整数幂次的级数展开(series expansion)。对这一原则的理论支持明显弱于前两者。

      即,对于无限的训练数据,模型应该能够收敛。

    我们选择的 L(N,D) 满足第一个要求,因为我们可以通过改变词表来 rescale Nc,Dc 。这也暗示了 Nc,Dc 的具体取值没有根本意义。

    因为当 test loss 停止改善时我们提前结束训练,并且我们以相同方式优化所有模型,所以我们期望较大的模型应该总是比较小的模型表现得更好。但是对于固定的、有限的 D ,我们也不期望任何模型能够接近最佳的 possible loss 。类似地,具有固定大小的模型将是容量有限的。这些考虑激发了我们的第二个原则。注意,关于在 D 无穷大时的 L(N) 、以及在 N 无穷大时的 L(D) 的知识,完全决定了 L(N,D) 中的参数。

    第三个原则更加是推测性的。一个简单而普遍的原因是,在非常大的 D 时,人们可能会过拟合到 scale 1/D。过拟合应该与数据集的方差或信噪比有关(《High-dimensional dynamics of generalization error in neural networks》),这与 1/D 成比例。这一预期应该适用于任何平滑的损失函数,因为我们希望能够将 loss 扩展到 D的极限。然而,这种观点假设 1/D correction 在其他方差来源中占主导地位,例如有限的 batch sizeefficacy of optimization 的其他限制。没有经验的证实,我们对它的适用性没有信心。

    我们的第三个原则解释了公式 L(N,D)ND 的作用之间的不对称。非常相似的对称表达式也是可能的,但是它们不会有 1/D 的整数幂次展开式,并且需要引入一个额外的参数。

    在任何情况下,我们都会看到,我们对于 L(N,D) 的公式很好地拟合了数据,这是我们的 L(N,D) 假设方程的最重要的理由。

43.3.2 拟合结果

  1. 我们以 10%dropout rate 来正则化我们的所有模型,然后跟踪 test loss ,一旦 test loss 不再下降时我们就停止训练。实验结果如 Figure 9 左图所示,它包含了针对公式 L(N,D) 中的四个参数 αN,αD,Nc,Dd 的拟合。拟合的参数如下表所示:

    Figure 9 左图:每条曲线代表一个 data size,给出了在该 data size 的条件下,test lossN 的关系。

    Figure 9 右图:L(N,D)L(N,D=)1NαNαD/D 之间的拟合曲线。

    我们获得了极好的拟合,除了如下的 runs:数据集已经减少了 1024 倍到大约 2×107 tokens 。对于如此小的数据集,一个 epoch 仅包含 40 次参数更新。也许这样一个微小的数据集代表了语言建模的一个不同的区域,因为过拟合在训练的早期就发生了 (Figure 16 右图的第一条 Test Loss 曲线)。还要注意,这些参数与 basic power laws 中获得的参数略有不同,因为这里我们拟合的是完整的 L(N,D) 而不仅仅是 L(N,)L(,D)

43.3.3 D 和 N 之间的亚线性关系

  1. 为了绘制 infinite data limit 的边界,我们可以直接研究过拟合的程度。对于除了最大的模型之外的所有模型,当用完整的 22B tokenWebText2 数据集训练时,我们没有看到过拟合的迹象,因此我们可以将其作为 D= 的代表。因此,我们可以通过如下定义来比较有限的 Dinfinite data limit

    (7)δL(N,D)=L(N,D)L(N,)1

    并把它作为 N,D 的函数来研究。事实上,我们从经验上看到, δL 仅依赖于 ND 的特定组合,如 Figure 9 右图所示。这是根据等式 L(N,D)scaling law 得出的,即:

    (8)δL(1+(NNc)αNαD×DcD)αD1

    根据 L(N)(NcN)αN ,我们可以得到:

    (9)L(N,D)L(N,)[(NcN)αN/αD+DcD]αD((NcN)αN/αD)αD=(1+(NNc)αNαD×DcD)αD

    注意,在大的 D 时,该公式也有 1/D 的幂级数展开。

    L(N,) 表示数据集无限大时的 test loss,它是 L(N,D) 的下限,代表没有过拟合时的 test loss 。因此:

    • δL 一定大于等于 0 。如 Figure 9 右图所示。

    • δL 代表了由于数据量不足导致的过拟合。

    对于不同的随机数种子,我们估计 loss 的方差大约为 0.02 ,这意味着,为了避免过拟合,我们需要:

    (10)D(5×103)×N0.74

    推导过程:δL0.02 ,即可得到该方程。

    利用这种关系,小于 109 个参数的模型可以在 22B tokenWebText2 数据集上被训练从而具有最小的过拟合,但是我们最大的模型将会遇到一些轻微的过拟合。更一般地,这种关系表明数据集大小可以关于模型大小成亚线性增长,同时避免过拟合。然而,请注意,这通常并不代表 compute-efficient training

    我们还应该强调,在改变数据集和模型大小时,我们没有 optimized regularization (例如,dropout rate )。

43.4 关于模型大小和训练时间的 scaling laws

  1. 在本节中,我们将证明一个简单的 scaling law 可以很好地描述 loss 作为模型大小 N 和训练时间的函数。

    • 首先,我们将解释如何使用 《An empirical model of large-batch training》 的结果来定义 universal training step Smin ,这说明了我们的大多数模型尚未在最佳 batch size 下训练的事实。

    • 然后,我们将证明我们可以使用公式 L(N,S) 拟合 loss 与模型大小和训练时间。

    • 稍后,我们将使用这些结果来预测模型大小和训练时间之间关于训练计算量的最佳分配,然后确认该预测。

43.4.1 临界 Batch size

  1. 《An empirical model of large-batch training》 中提出了一个关于训练的 batch size 依赖性的简单的经验理论。该论文认为,对于 training,有一个临界 batch size Bcrit

    • 对于 B<Bcritbatch size B 可以增加,而 compute-efficiency 的下降非常小。

    • 而对于 B>BcritB 的增加导致收益递减。

    该论文还认为, gradient noise scaleBcrit 提供了一个简单的预测,并且 gradient noise scale 也不直接依赖于模型大小,而是依赖于已经获得的 loss value 。这些结果可用于预测训练时间和计算量将如何随 batch size 而变化。为了尽可能有效地平衡训练时间和计算量,最好使用 batch sizeBBcrit 进行训练。在 BBcrit 的训练最小化了 training steps 数量,而在 BBcrit 的训练最小化了计算量的使用。

    即:大的 batch size 可以节省训练时间、但是浪费了计算量;小的 batch size 可以节省计算量、但是浪费了计算时间。

    更具体地,可以证明对于各种各样的神经网络任务, 当训练到 loss L 的任何固定值时,training steps 数量 Sprocessed 数据样本数量 E=BS 满足简单的关系:

    (11)(SSmin1)×(EEmin1)=1

    其中:Smin 为达到 L 所需的最小 training steps 数量,而 Emin 是必须处理的数据样本的最小数量。

    注意:Smin 并不是在 Emin 处取得的,即:EminBSmin ,除了在 B=Bcrit

    我们在下图中展示了针对 Transformer 的这种关系(即,(SSmin1)×(EEmin1)=1 )。左图、右图分别给出了两种不同参数规模下的曲线,横轴为 S 、纵轴为 E 。每条曲线代表一个固定的 test loss

    这个关系定义了临界的 batch size

    (12)Bcrit(L)=EminSmin

    它是目标损失值的函数。在临界 batch size 的训练可以实现大致最优的 time/compute tradeoff ,需要 S=2Smintraining steps 、以及处理 E=2Emin 的数据样本。

  2. 在下图中,我们绘制了两个不同模型的临界 batch sizegradient noise scale ,作为 training loss 的函数。我们看到,Bcrit(L) 与模型大小无关,仅取决于 loss L 。因此,《An empirical model of large-batch training》prediction 继续适用于 Transformer 语言模型。临界 batch sizeloss 上符合幂律:

    (13)Bcrit(L)BL1/αB

    其中:B2×108αB0.21

    为什么选择这样的拟合公式?因为随着 loss 接近其最小值 Lmingradient noise scale 预计会发散,并且我们期望Bcrit 跟踪该 noise scale 。因此随着 L 越来越小(接近 Lmin ),我们希望 Bcrit 越来越大。即, L0Bcrit 趋向于正无穷。

    注意:下图中,横轴的右侧为零。

  3. 如果我们在 BBcrit(L) 处训练,则大小为 N 的模型训练到 loss L 所需的 training steps 数量的临界值定义为:

    (14)Smin=S1+Bcrit(L)/B,min steps at BBcrit

    推导过程:根据定义 Bcrit(L)=EminSminE=BS ,以及拟合公式 (SSmin1)×(EEmin1)=1 ,有:

    (15)1=(SSmin1)×(EEmin1)=(SSmin1)×(BSBcritSmin1)=BBcrit(SSmin)2(1+BBcrit)SSmin+1Smin=S1+Bcrit(L)/B

    推导过程中并未要求 BBcrit(L) ,这个约束条件从何而来?论文并未说明。读者猜测可能是由于 (SSmin1)×(EEmin1)=1 需要满足的条件而引起的。

    该式子的物理意义:

    • 根据 S=Smin×(1+Bcrit/B),则给定 batch size BBcrit ,我们可以预测需要多少个 training step S 模型能够收敛。

    • 另一方面,任意选择一个 batch size BBcrit ,根据它的收敛时的 training step S ,我们能够统计得到 Smin

    如果我们在 BBcrit(L) 处训练,这也定义了用大小为 N 的模型训练到 loss L 所需的计算量的临界值: