Posted by Hao Liang's Blog on Monday, January 1, 0001

分布式训练模式

本章涵盖 区分传统模型训练和分布式训练 使用参数服务器(parameter servers)构建无法在单台机器上运行的模型 使用集合通信模式(collective communication pattern)提升分布式模型训练性能 分布式模型训练过程中的故障处理

上一章介绍了一些实用的模式,这些模式可以融入数据摄取过程,这通常在分布式机器学习系统的起始阶段,负责监控传入的数据并执行必要的预处理步骤,为模型训练做准备。

分布式训练是数据摄取过程之后的下一步,它是区分分布式机器学习系统和其他分布式系统的关键所在,也就是分布式机器学习系统中最关键的部分。

系统设计需要具有可扩展性和可靠性,以处理不同规模和不同复杂程度的数据集和模型。 一些大型且复杂的模型无法在单台机器上运行,而一些中等大小、足以在单台机器上运行的模型,其分布式训练的计算性能却很难提高。

当我们遇到性能瓶颈和意外故障时,知道该如何应对也很重要。 部分数据集可能已损坏或无法成功用于模型训练,或者分布式训练所依赖的分布式集群可能因天气条件、人为误操作等原因出现网络不稳定甚至断开的情况。

在本章中,我将探讨分布式训练过程中涉及的一些挑战,并介绍一些在行业中广泛采用的模式。 第 3.2 节讨论了训练大型机器学习模型的挑战,这些模型为新上传 YouTube 视频中的主题做样本标记,但模型无法在单台机器上运行;因此本节还展示了如何使用参数服务器模式解决这个问题。 第 3.3 节展示了如何使用集合通信模式来加速小型模型的分布式训练,避免了参数服务器和工作节点之间不必要的通信开销。 最后一节讨论了由于数据集损坏、网络不稳定和工作节点抢占等原因导致的分布式机器学习系统的一些不稳定性问题,以及解决这些问题的方法。

3.1 分布式训练的基本概念 分布式训练是采用已经通过数据摄取处理的数据(在第 2 章中讨论)来初始化机器学习模型,然后在分布式环境(例如多个节点)中使用处理后的数据训练模型的过程。 这个过程很容易与传统机器学习模型训练过程相混淆,传统的模型训练过程发生在单节点环境中,其中数据集和机器学习模型对象位于同一台机器(例如笔记本电脑)上。 相比之下,分布式模型训练通常发生在一组可以同时工作的机器中,以大大加快训练过程。 另外,在传统模型训练中,数据集通常位于单台机器的本地磁盘上,而在分布式模型训练中,则使用远程分布式数据库来存储数据集,或者将数据集分区存储在多台机器的磁盘上。 如果模型不够小,无法运行在单台机器上,就不能用传统的单机方式进行训练。 从网络基础设施的角度来看,分布式训练通常更倾向于使用 InfiniBand (https://wiki.archlinux.org/title/InfiniBand) 或远程直接内存访问 (RDMA;https://www.geeksforgeeks.org/remote-direct-memory-access-rdma/) 网络。表 3-1 提供了这些训练方法之间的对比。

表 3-1 传统(非分布式)机器学习模型训练与分布式训练的比较 Traditional model training Distributed model training Computational resources Laptop or single remote server Cluster of machines Dataset location Local disk on a single laptop or machine Remote distributed database or partitions on disks of multiple machines Network infrastructure Local hosts InfiniBand or RDMA Model size Small enough to fit on a single machine Medium to large 传统模型训练 | 分布式模型训练 计算资源 | 笔记本电脑或单个远程服务器 | 机器集群 数据集位置 | 单台笔记本电脑或机器上的本地磁盘 | 远程分布式数据库或多台机器磁盘上的分区 网络基础设施 | 本地主机 | InfiniBand 或 RDMA 模型大小 | 足够小,能够单台机器运行 | 中等到大型

传统模型训练 分布式模型训练 计算资源 笔记本电脑或单个远程服务器 机器集群 数据集位置 单个笔记本电脑或机器上的本地磁盘 远程分布式数据库或多台机器磁盘上的分区

InfiniBand 和 RDMA InfiniBand 是一种用于高性能计算的计算机网络通信标准。 它具有高吞吐量和低延迟的特点,适用于计算机或存储系统之间和内部的数据互连,这通常是分布式训练所必需的。 RDMA(Remote Direct Memory Access,远程直接内存访问)提供从多台机器之间内存的直接访问,而无需任何一方操作系统的介入。 该通信标准旨在支持高吞吐量、低延迟的网络通信,这有助于多机之间通信频繁的分布式训练。

3.2 参数服务器模式:800 万样本的实体标记 YouTube 视频 假设我们有一个名为 YouTube-8M(http://research.google.com/youtube8m;图 3-1)的数据集,它包含数百万个 YouTube 视频 ID,并且带有来自超过 3,800 个视觉实体类别(例如:食物、汽车和音乐)的机器生成的高质量标注。 我们希望训练一个机器学习模型来标记它未见过的 YouTube 视频的主题。

这个数据集包含了粗粒度(coarse)和细粒度(fine-grained)的实体类别。 粗粒度实体类别是指非专业领域的人员在研究现有样本后可以识别的实体,而细粒度实体类别可以被专业领域的人员所识别,他们有区分极其相似实体能力的。 这些实体类别已经通过 3 位评估人员根据评估指南做出判断和手动验证,以确保它们在视觉上可被识别。 每个实体类别最少有 200 个相应的视频样本,平均下来有 3552 个样本。 在评估人员验证评估视频实体类别时,他们根据每个实体类别的视觉可识别性,使用 1 到 5 的离散等级来量化评估,其中 1 级代表一个普通人可以轻易识别的实体类别(图3.2)。

在 YouTube-8M 提供的在线数据集浏览页面中 (http://research.google.com/youtube8m/explore.html ),实体类别列表显示在左侧,并且每个实体类别下的视频样本数量显示在实体类别名称旁边(图 3.3)。

在数据集浏览页面中,实体分类按它所包含的视频数量从多到少排序。 在图 3-3 中,最受欢迎的 3 个实体分类分别是游戏(Game)、电子游戏(Video game)和交通工具(Vehicle),训练样本数量从 415,890 到 788,288 个不等。 最不受欢迎的实体(图中未显示)是圆柱体(Cylinder)和砂浆(Mortar),分别有 123 个和 127 个视频样本。

图 3-1 托管 YouTube-8M 数据集的网站,其中包含来自 3,800 多个视觉实体类别的数百万个 YouTube 视频(来源:Sudheendra Vijayanarasimhan 等人,遵循 Nonexclusive License 1.0 许可协议)

图 3-2 评估人员需要关注的问题和评估指南,以便评估人员验证 YouTube 视频中的实体类别,并评估每个实体类别的视觉可识别性(来源:Sudheendra Vijayanarasimhan 等人,遵循 Nonexclusive License 1.0 许可协议)

The description of an entity that would give the reader a sense of how the entity looks like. 对实体的描述,让读者了解该实体的模样。 Raters are given the question to assess how specific and visually recognizable each entity is. 评估者面临的问题是评估每个实体的具体性和视觉可识别性。 How difficult is it to identify this entity in images of videos (without audio, titles, comments, etc)? 在视频图像(没有音频、标题、评论等)中识别这个实体有多困难?

图 3-3 YouTube-8M 网站提供的在线数据集浏览页面,按视频数量对实体类别进行排序(来源:Sudheendra Vijayanarasimhan 等人,遵循 Nonexclusive License 1.0 许可协议)

Different entities are listed here and they are ordered by the number of videos in each entity. For example, the entity Games is the most popular entity in this dataset. 此处列出了不同的实体,并按每个实体中的视频数量排序。例如,Games 是该数据集中最热门的实体。

3.2.1 问题 基于这个数据集,我们希望训练一个机器学习模型来标记新的 YouTube 视频的主题。 对于一个简单的数据集和机器学习模型来说,这项任务可能非常简单,但对于 You-Tube-8M 数据集来说,情况并非如此。 该数据集附带了从数十亿视频帧和音频片段中预计算处理的视听特征,因此我们不必自己计算和获取这些特征——这通常需要很长时间并且需要大量的计算资源。

尽管可以在单个 GPU 上用不到一天的时间就可以在这个数据集上训练出一个强大的基线模型(baseline model),但数据集的规模和多样性需要我们对视听模型做进一步深入探索,因此这可能需要花费数周的训练时间。 那么是否有办法更高效地训练这个潜在的大型模型呢?

3.2.2 解决方案 首先,让我们看一下 YouTube-8M 网站上使用数据浏览器的一些实体,看看实体之间是否存在任何关系。 例如,这些实体是否无关,或者它们在内容上是否有一定程度的重叠? 经过一番探索后,我们将对模型进行必要的调整以考虑这些关系。 图 3-4 显示了属于 Pet 实体的 YouTube 视频列表。 第一排第三个视频中,一个孩子正在和一只狗玩耍。

首先,让我们使用 YouTube-8M 网站上的数据浏览页面查看各个实体类别,并探索这些实体类别之间是否存在某种关系。 例如,这些实体类别之间是否不相关?还是在内容上有一定程度的相关性? 经过一番探索后,我们对模型进行必要的调整,以考虑这些相关性带来的影响。 图 3-4 显示了属于宠物(Pet)实体类别的 YouTube 视频列表。 在第一排的第三个视频中,一个孩子正在和一只狗玩耍。

The entity Pet has been selected here. 此处已选择 Pet 实体类别

图 3-4 属于宠物实体类别的样本视频(来源:Sudheendra Vijayanarasimhan 等人,遵循 Nonexclusive License 1.0 许可协议)

让我们看一下类似的实体类别。 图 3-5 显示了属于动物(Animal)实体类别的 YouTube 视频列表,视频中我们可以看到鱼、马和熊猫等动物。 有趣的是,在第五行的第三个视频中,一只猫正在用吸尘器做清洁。 人们可能会猜测该视频也属于宠物实体类别,因为如果猫被人类收养,它就成为了宠物。

图 3-5 属于动物实体类别的样本视频(来源:Sudheendra Vijayanarasimhan 等人。遵循 Nonexclusive License 1.0 许可协议)

Here we’ve selected the Animal entity to see a list of videos with animals. 在这里,我们选择了 Animal 实体来查看动物视频列表。

如果我们想基于这个数据集构建机器学习模型,在直接将模型拟合数据集之前,我们可能需要执行一些额外的特征工程步骤。 我们可以将这两个实体类别(动物和宠物)的视听特征结合成一个派生特征(因为它们提供了相似的信息并且有相关性重叠),以此来提高模型的性能。 如果我们继续探索实体中现有视听特征的组合,或执行大量特征工程步骤,我们可能就无法在单个 GPU 上在不到一天的时间内在该数据集上训练机器学习模型了。

如果我们使用的是深度学习模型,而不是需要大量特征工程和数据集探索的传统机器学习模型,那么模型本身就会学习特征之间的潜在关系,例如:相似实体类别之间的视听特征。 模型架构中的每层神经网络都由权重(weights)和偏差(biases)向量组成,它们共同代表了一个经过训练的神经网络层,随着模型从数据集中收集到更多信息,该神经网络层会在训练迭代中进行更新。

如果我们仅使用 3,862 个实体类别中的 10 个,我们可以构建一个 LeNet 模型(图 3-6),将新的 YouTube 视频分类到 10 个选定实体分类中的 1 个。 总言之,LeNet 模型由一个卷积编码器组成,包含 2 个卷积层(convolutional layers),以及一个由 3 个全连接层(fully connected layers)组成的密集块(dense block)。 为简单起见,我们假设视频中的每一帧都是 28 × 28 像素的图像,它将被各种卷积层和池化层处理,这些层负责学习视听特征和实体类别之间的特征映射关系。

LeNet 的历史 LeNet (https://en.wikipedia.org/wiki/LeNet) 是最早发布的卷积神经网络 (CNN;https://en.wikipedia.org/wiki/Convolutional_neural_network) 之一,因其在计算机视觉任务上的卓越性能表现而引起广泛关注。 它是由 AT&T 贝尔实验室研究员 Yann LeCun 提出的,用于识别图像中的手写数字。 经过十年的研发,LeCun 于 1989 年发表了第一项通过反向传播成功训练 CNN 的研究成果。 当时,LeNet 取得了与支持向量机(SVM,support vector machines,一种有监督机器学习算法中的主导方法)性能相匹配的出色结果。

事实上,那些学习到的特征映射包含了与模型相关的参数。 这些参数是用作该模型层表示权重和偏差的数值向量。 对于每次训练迭代,模型将 YouTube 视频中的每一帧作为特征,计算损失函数,然后更新这些模型参数进一步优化模型,使得特征与实体类别之间的关系可以被更紧密地建模。

图 3-6 LeNet 模型架构,该模型可用于将新的YouTube视频分类到10个选定的实体类别中的一个。 (来源:Aston Zhang 等人。遵循 Creative Commons Attribution-ShareAlike 4.0 International Public License 许可协议)

The original image is processed by various convolution and pooling layers that learns the underlying feature mapping. 原始图像经过各种卷积和池化层处理,学习底层特征映射。 The original image that represents a single frame of the YouTube video. 代表 YouTube 视频中每一帧的原始图像。 28x28 image 28x28 图像 6@28x28 C1 feature map 6@28x28 C1 特征映射 6@14x14 S2 feature map 6@14x14 S2 特征映射 16@10x10 C3 feature map 16@10x10 C3 特征映射 16@5x5 S4 feature map 16@5x5 S4 特征映射 Convolution 卷积层 Pooling 池化层 Dense 全连接层

不幸的是,这个训练过程相对缓慢,因为它涉及到更新模型不同层中的所有参数。目前有两种潜在的解决方案来加快训练过程。

我们来看看第一种方法。 先做一个假设,当后面讨论到更好的方法时我们会取消它。 假设模型不是特别大,我们可以使用现有资源完整地容纳整个模型,而不会遇到内存溢出或磁盘错误的问题。

在这种情况下,我们可以使用一台专用的服务器来存储所有 LeNet 模型参数,并使用多台机器来分担计算工作。图 3.7 显示了对应的架构图。

每个工作节点处理数据集的一部分来计算梯度,然后将结果发送到这台专用的服务器上以更新 LeNet 模型参数。 由于工作节点使用独立的计算资源,因此它们可以异步执行繁重的计算任务,而无需相互通信。 因此,如果忽略节点之间的消息传递的开销,我们仅通过引入额外的工作节点,就可以实现 3 倍的加速。

这种负责存储和更新模型参数的专用单一服务器称为参数服务器(parameter server)。 通过引入参数服务器模式,我们设计了一个更高效的分布式机器学习训练系统。

接下来是现实的挑战。 深度学习模型通常会变得越来越复杂,因为我们可以在基线模型之上添加额外的层和自定义结构。 由于这些附加层中存在大量模型参数,这通常会占用大量的磁盘空间。 并且训练需要大量计算资源来满足内存占用的要求。如果模型很大,以至于我们无法将其所有参数放在单个参数服务器上该怎么办呢?

图 3-7 使用单台参数服务器的机器学习训练组件

Single server 单台服务器 LeNet Model LeNet 模型 Worker node 工作节点

第二种方法可以解决这个问题。 我们可以引入多个参数服务器,然后将模型划分为多个分区,每个参数服务器负责存储和更新模型的一部分分区。 每个工作节点负责处理一部分的数据集,然后更新对应模型分区的参数。

图 3-8 展示了一个使用多个参数服务器的架构图。 该图与图 3-7 的不同在于,图 3-7 中单个服务器存储了所有 LeNet 模型参数,并使用多个工作节点分担计算工作。 而在多参数服务器的架构中,每个工作节点获取数据集中的部分子集数据,执行每个神经网络层所需的计算,然后将计算出的梯度发送到其中一个参数服务器中,更新它所存储的模型分区。 由于所有工作节点都以异步方式进行计算,因此每个工作节点用于计算梯度的模型分区可能不是最新的。 为了保证每个工作节点正在使用的模型分区或每个参数服务器存储的模型分区都是最新的,我们必须不断地在工作节点之间拉取和推送数据来更新模型。

图 3-8 使用多个参数服务器的机器学习训练组件

Parameter server 参数服务器 Push updates 推送更新 Pull updates 拉取更新 Worker node 工作节点

借助参数服务器,我们可以有效地解决构建机器学习模型面临的挑战,从而让模型能够标记新的 YouTube 视频主题。 图 3-9 展示了未用于模型训练的新 YouTube 视频列表,它们被经过训练的机器学习模型标记为飞机(Aircraft)主题。 即使因为模型太大而无法容纳在一台机器上,我们也可以进行模型训练。 值得注意的是,虽然参数服务器模式对这种情况很帮助,但它是专门为训练参数众多的模型而设计的。

图 3.9 未用于模型训练的新 YouTube 视频列表,标有飞机主题(来源:Sudheendra Vijayanarasimhan 等人,遵循 Nonexclusive License 1.0 许可协议)

3.2.3 讨论 上一节介绍了参数服务器模式,并展示了如何使用它来解决 YouTube-8M 视频识别应用中可能遇到的挑战。 尽管当模型太大而无法容纳在一台机器上时,参数服务器模式非常有帮助,且它看起来是解决问题的最直接的方法,但在实际应用中,我们仍然需要考虑一些其它因素使分布式训练系统更高效地运行。

机器学习研究人员和 DevOps 工程师需要努力找到参数服务器和工作节点数量之间的最佳比例,以适应不同的机器学习应用场景。 从工作节点向参数服务器发送计算后的梯度数据的通信成本、拉取和推送更新最新模型分区的成本都非常高。 如果我们发现模型变得越来越大、向系统中添加了过多的参数服务器,系统最终将花费大量时间在节点间通信上,而在神经网络层间进行计算所花费的时间很少。

3.3 节将更详细地讨论这些问题。 本节介绍了一种解决这些问题的模式,这样工程师就不再需要花时间为不同类型的模型调整工作节点和参数服务器的性能。

3.2.4 练习 1 如果我们想在一台笔记本电脑上使用多个 CPU 或 GPU 进行模型训练,这个过程可以被认为是分布式训练吗? 2 增加工作节点或参数服务器的数量会产生什么结果? 3 我们应该为参数服务器分配哪些类型的计算资源(例如:CPU、GPU、内存或磁盘)?我们应该分配多少这些类型的资源?

3.3 集合通信模式 第 3.2.2 节介绍了参数服务器模式,当模型太大而无法容纳在一台机器中时,该模式会派上用场,例如我们需要构建一个模型来标记 800 万个 YouTube 视频的实体类别。 尽管我们可以使用参数服务器来处理具有大量参数的复杂模型,但将这种模式纳入高效分布式训练系统的设计中并不容易。

第 3.2.3 节指出,机器学习研究人员和 DevOps 工程师常常难以确定参数服务器与工作节点数量之间的最佳比例。 假设我们的机器学习系统中有 3 个参数服务器和 3 个工作节点,如图 3-10 所示。 这 3 个工作节点都异步地执行密集计算,然后将计算好的梯度发送到参数服务器以更新不同的模型分区。

图 3-10 由 3 个参数服务器和 3 个工作节点组成的分布式模型训练组件

Push updates 推送更新 Pull updates 拉取更新 Worker node 工作节点

实际上,工作节点和参数服务器并不是一一对应的,特别是当工作节点的数量与参数服务器的数量不同时。 换句话说,多个工作节点可以向一个参数服务器发送更新。 现在假设两个工作节点同时完成了梯度计算,并且他们都想要更新存储在同一参数服务器上的模型参数(如图 3-11 所示)。

图 3-11 两个工作节点完成了梯度计算,并希望同时向第一个参数服务器推送更新

Push updates at time t0 t0 时刻推送更新 Push updates 推送更新 Which one of them can be accepted? 其中哪一个可以被接收? Worker node 工作节点

结果,这两个工作节点相互阻塞,都无法将梯度发送到参数服务器。 也就是说,同一个参数服务器无法同时接受来自两个工作节点的梯度。

3.3.1 问题:当参数服务器成为瓶颈时提高性能 在这种情况下,只有两个工作节点在向同一个参数服务器发送梯度时相互阻塞,这让梯度数据的及时更新变得困难,我们需要一种策略来解决这种阻塞问题。 现实情况下,在集成了参数服务器的分布式训练系统中,不可避免会有多个工作节点同时发送梯度,这引起的通信阻塞问题需要得到解决。

当工作节点与参数服务器的数量比例不理想时,例如,大量工作节点同时向同一个参数服务器发送梯度时,问题变得更加严重。 最终,不同工作节点或参数服务器之间的通信阻塞成为了瓶颈,我们有办法能避免这个问题吗?

3.3.2 解决方案 在这种情况下,两个工作节点需要相互协商谁先执行下一个步骤,然后轮流将计算好的梯度发送到特定的参数服务器。 此外,当一个工作节点完成向参数服务器发送梯度并更新模型参数后,参数服务器开始将更新后的模型分区发送回该工作节点。 因此,工作节点拥有了最新的模型,可以在接收到新数据时进行微调。 如果与此同时,另一个工作节点也向该参数服务器发送计算好的梯度,如图 3-12 所示,则又会发生新的通信阻塞,工作节点之间需要再次进行协商。

但这想要完成这一次协商并不简单,因为尝试发送梯度的工作节点在计算梯度时可能没有使用最新的模型进行计算。 当模型之间的差异很小时,这种情况还勉强能接受,但这最终可能会导致模型的统计性能出现巨大差异。

如果每个参数服务器存储的模型分区不均匀,例如:第一个参数服务器存储了 2/3 的模型参数,如图 3-13 所示,使用这种旧模型分区计算出的梯度将对最终训练的模型产生巨大影响。 在这种情况下,我们希望让这个工作节点丢弃掉计算好的梯度,并让其他工作节点将新的梯度发送到参数服务器中。

现在又出现了另一个挑战。 如果旧模型分区计算出的梯度是根据整个训练数据中的较大部分计算出来的,并且可能需要很长时间才能使用最新的模型分区重新计算它们(如图 3-14 所示)时,该怎么办呢?在这种情况下,我们可能希望保留这些梯度,以免浪费太多时间重新计算。

图 3-12 一个工作节点正在拉取更新的同时,另一个工作节点正在向同一个参数服务器推送更新

Push updates at time t0 t0 时刻推送更新 Pull updates at time t1 t1 时刻拉取更新 Push updates at time t1 t1 时刻推送更新 Push updates 推送更新 Model is from time t0 (outdated). 模型来自 t0 时刻(已过时)。 Worker node 工作节点

图 3-13 一个不平衡的模型分区示例,其中第一个参数服务器包含整个模型参数集的 2/3

PS 1 with ⅔ of model parameters 包含 ⅔ 模型参数的 PS 1 Push updates at time t0 t0 时刻推送更新 Pull updates at time t1 t1 时刻拉取更新 Push updates at time t1 t1 时刻推送更新 Push updates 推送更新 Model is from time t0 (outdated). 模型来自 t0 时刻(已过时)。 Worker node 工作节点

现实情况是,在带有参数服务器的分布式机器学习系统中,我们可能会遇到许多无法完全解决的挑战和问题。 当遇到这些问题时,我们必须考虑协调和权衡的方法。 随着工作节点和参数服务器数量的增加,在工作节点和参数服务器之间拉取和推送模型参数所需的协调和通信成本变得非常重要。 系统最终将花费大量时间在节点之间进行通信,而在神经网络层之间进行计算的时间却很少。

图 3-14 第二个工作节点试图更新从一半训练数据中计算出的梯度

Push updates at time t0 t0 时刻推送更新 Pull updates at time t1 t1 时刻拉取更新 Push updates at time t1 t1 时刻推送更新 Push updates 推送更新 Model is from time t0 (outdated). 模型来自 t0 时刻(已过时)。 Worker node 工作节点 ¼ training data 1/4 训练数据

尽管我们可能在将不同的参数服务器与工作节点的比例和计算资源应用到我们的系统中时,有许多丰富的经验,但将系统调整到一个完美的状态会非常耗费时间。 在某些情况下,某些工作节点或参数在训练期间发生故障,又或者网络不稳定时,在节点之间用推送和拉取更新进行通信时会出现问题。 也就是说,由于我们缺乏专业知识或时间来处理底层分布式基础设施,参数服务器模式在某些场景下可能并不适用。

那么有没有解决这个问题的替代方案呢? 参数服务器模式可能是大型模型为数不多的选择之一,但为了简化且方便演示,我们假设模型大小是固定的。 并且整个模型足够小,可以容纳在一台机器上。也就意味着每台机器都有足够的磁盘空间来存储整个模型。

考虑到这一假设,如果我们想提高分布式训练的性能,那么参数服务器的替代方案是什么呢?如果没有参数服务器,只有工作节点,每个节点都存储了整套模型参数的副本,如图 3-15 所示。

图 3-15 只包含工作节点的分布式模型训练组件,每个工作节点都存储整套模型参数的副本,并使用数据分区来计算梯度。

Each of these workers contains a copy of the entire set of model parameters and consumes partitions of data to calculate the gradients. 每个工作节点都包含完整的模型参数副本,并使用数据分区来计算梯度。 Data partitions 数据分区 Consumes data partition 消费数据分区 Worker 1 工作节点 1

这种情况下我们应该如何进行模型训练呢?回想一下,每个工作节点都会消费一部分数据并计算梯度,这些梯度用于更新存储在该节点上的模型参数。 当所有节点都完成梯度计算后,我们需要聚合所有梯度,并确保每个节点的整套模型参数都根据聚合后的梯度进行更新。 这样每个节点都存储了一份相同的、更新后的模型副本。那么我们怎样聚合所有梯度呢?

我们已经熟悉将梯度从一个节点发送到另一个节点的过程,一般来说,该过程没有其他节点的参与,称为点对点通信(Point-to-Point Communication)(图 3-16)。

图 3-16 两个节点之间点对点通信传输数据的示例,这其中没有其他节点参与。

Process 1 进程 1 Process 2 进程 2 Data transfers between the two processes 2 个进程之间的数据传输

在这种情况下,点对点通信的效率有些低下。 只有工作节点参与,我们需要对所有工作节点的结果进行某种形式的聚合。 幸运的是,我们可以使用另一种通信方式——集合通信(Collective Communication)。 集合通信允许在一组进程间相互通信,该组由所有进程的子集组成。图 3-17 展示了一个进程与由其他 3 个进程组成的一组进程之间的集合通信。 在这种情况下,每个工作节点都计算好梯度,并希望将它们发送到其余的工作节点中,使得所有工作节点都能获得其余工作节点的计算结果。

图 3-17 1 个进程与其他 3 个进程所组成的组之间的集合通信示例

Process 1 进程 1 Process 2 进程 2 Process 3 进程 3 Process 4 进程 4 Data transfer 数据传输 Group 进程组

我们通常需要对工作节点所接收到的梯度执行某种聚合操作,然后将聚合结果发送给其他所有工作节点。 这种聚合操作称为 reduce,它涉及将大量数字聚合成少量数字。 其作用包括了求出一组数字的总和、最大值、最小值或平均值。在我们的例子中,它的作用是从所有工作节点接收梯度。

图 3-18 展示了一个 reduce 操作,进程组中的每个进程中的向量 v0、v1 和 v2 通过 reduce 操作与第一个进程合并

当使用分布式方式降低梯度时,我们将下降后梯度发送给所有工作节点,以便它们能够同步并以相同的方式更新模型参数,确保他们拥有完全相同的模型。 这种广播操作称为 broadcast,常用于集合通信。 图 3-19 展示了向进程组中的每个进程发送数据的 broadcast 操作

这里我们将 reduce 和 broadcast 操作的组合称为 allreduce,它根据指定的 reduce 函数对结果进行约简,然后将约简后的结果分发到所有进程中。在我们的例子中,结果被分发给所有工作节点,这样每个工作节点上存储的模型完全相同且是最新的(如图 3-20 所示)。 当我们完成一轮 allreduce 操作后,我们继续下一轮操作:将新数据提供给更新后的模型,计算梯度,然后再次执行 allreduce 操作,收集来自工作节点的所有梯度来更新模型。

图 3-18 以求和作为 reduce 函数的 reduce 操作示例

Process 1 进程 1 Process 2 进程 2 Process 3 进程 3 Process 4 进程 4 Data transfer 数据传输 Group 进程组

图 3-19 向进程组中的每个进程发送数据的 broadcast 操作示例

Process 1 进程 1 Process 2 进程 2 Process 3 进程 3 Process 4 进程 4 Data transfer 数据传输 Group 进程组

现在我们成功地使用了集合通信模式,该模式利用底层网络基础设施来执行 allreduce 操作,用于在多个工作节点之间传递梯度,使我们能够以分布式方式训练中等规模的机器学习模型。 这样我们就不需要使用到参数服务器了,也就没有了参数服务器和工作节点之间的通信开销。 集合通信模式在机器学习系统和分布式并行计算系统中非常有用,它所具有的并发特性应用于并行计算,而 broadcast 和 reduce 等通信原语对于节点间的通信至关重要。我们将在第 9.2.2 节中应用此模式。

图 3-20 allreduce 操作的示例,该操作约简了组内每个进程产生的结果,然后将结果发送到组内的每个进程中

Process 2 进程 2 Process 3 进程 3 Process 4 进程 4 Data transfer 数据传输 Group 进程组

3.3.3 讨论 当我们构建的模型规模不大时,集合通信模式是参数服务器的一个很好的替代方案。 这样一来,参数服务器和工作节点之间就没有了通信开销,也就不再需要花费大量的精力来调整工作节点和参数服务器之间的比例。换句话说,我们可以轻松地通过增加工作节点的数量来加快模型训练过程,而不必担心性能衰减。

还有一个潜在的问题值得一提。 在我们通过 allreduce 操作引入集合通信模式后,每个工作节点将需要与其他所有工作节点进行通信,如果工作节点数量变大,可能会拖慢整个训练过程。 事实上,集合通信依赖于网络基础设施上的通信,而我们在 allreduce 操作中还没有充分利用到这方面的优势。

好消息是我们可以使用更好的集合通信算法来更高效地更新模型。 例如:使用 ring-allreduce 算法。这个过程与 allreduce 操作类似,但数据是环形传输的,没有 reduce 操作。 集群中有 N 个节点,其中每个工作节点仅需要与它相邻的两个节点通信 2 * (N – 1) 次,就能完成所有模型参数的更新。 换句话说,该算法是带宽最优的;如果聚合的梯度足够大,它将能充分地利用底层网络基础设施的优势。 参数服务器模式和集合通信模式都能够使分布式训练变得可扩展和高效。 然而,在实践中,任何工作节点或参数服务器都可能因为资源不足而无法启动,或者在分布式训练过程中出现故障。 第 3.4 节将介绍一些能够应对这些异常情况的模式,使得整个分布式训练过程更加可靠。

3.3.4 练习 1 通信阻塞是否只发生在工作节点之间? 2 工作节点更新各自模型参数的过程是异步的还是同步的? 3 使用哪些集合通信操作的组合能够表示一个 allreduce 操作?

3.4 弹性与容错模式 参数服务器和集合通信模式使我们能够扩展分布式模型训练过程。参数服务器对于处理无法容纳在单台机器上的大型模型非常有用;大型模型可以分区并存储在多个参数服务器上,而各个工作节点可以执行繁重的计算并异步更新各个模型分区的参数。 而当我们在使用参数服务器时发现通信开销过大时,可以使用集合通信模式来加快中等规模模型的训练过程。

假设我们的分布式训练组件设计合理,能够高效地训练机器学习模型;并且能够使用参数服务器和集合通信等模式来满足不同类型模型的需求。 值得一提的是,分布式模型训练是一项需要长期运行的任务,通常会持续数小时、数天甚至数周。 与所有其他类型的软件和系统一样,由于模型训练是一个长期运行的过程,随时可能被内部或外部的干预所影响。 以下是一些分布式模型训练系统中经常出现的被干预的示例:  部分数据集已损坏,无法用于正常的模型训练。  分布式训练模型所依赖的集群可能因为天气状况或人为错误而出现网络不稳定或断线的情况。  部分参数服务器或工作节点被抢占,它们所依赖的计算资源被重新分配给了具有更高优先级的任务和节点。

3.4.1 问题:使用有限的计算资源进行训练时的故障处理 当系统出现不符合预期的异常时,如果不采取措施加以解决,问题就会开始累积。 在上一节的第一个示例中,所有工作节点都使用相同的逻辑来消费数据,当它们的训练代码无法处理的损坏数据时,任务失败。 在第二个示例中,当网络变得不稳定时,参数服务器和工作节点之间的通信将被挂起,直到网络恢复。 在第三个示例中,当参数服务器或工作节点被抢占时,整个训练过程被迫中断,导致不可恢复的故障。 在这些情况下我们应该如何让分布式训练系统恢复呢?我们有办法预防这些意外故障吗?

3.4.2 解决方案 先看第一种情况。假设训练过程遇到了一批损坏的数据。 在图 3-21 中,YouTube-8M 数据集中的一些视频在从原始数据源下载后被第三方视频编辑软件意外修改。 第一个工作节点尝试读取这些数据来提供给模型进行训练。此时发现,之前初始化的机器学习模型无法处理被编辑过的、不兼容的视频数据。

图 3-21 工作节点无法成功使用被编辑过的一批新训练数据

Push updates 推送更新 Pull updates 拉取更新 Model is from time t0 (outdated). 模型来自 t0 时刻(已过时)。 Worker node 工作节点 Exception: Unable to read the data 异常:无法读取数据 Video data being edited 被修改过的视频数据 ⅓ training data(corrupted) 1/3 训练数据(损坏) ⅓ training data 1/3 训练数据

当这种情况发生时,训练过程会意外失败,因为现有代码不包含处理编辑过或损坏的数据集的逻辑。 因此我们需要修改分布式模型训练逻辑来处理这种情况,然后从头开始重新训练模型。

现在让我们重新开始分布式训练过程,看看一切是否正常。 我们可以跳过损坏的数据批次,继续使用后续批次的剩余数据来训练机器学习模型。

不巧,在使用一半的数据对模型训练了数小时后,我们发现新批次数据的消耗速度比以前慢了很多。 经过一番排查并与 DevOps 团队沟通后,我们发现由于我们的一个数据中心迎来了暴风雨,网络变得极不稳定(前面提到的第二种情况)。 如果我们的数据集保存在远程服务器上,而没有下载到本地,如图 3-22 所示,训练过程将阻塞在等待与远程数据库成功连接的状态。 在等待期间,我们应该对当前训练的模型参数进行检查点存档(checkpoint)并暂停训练。 等到网络稳定后,就可以轻松地恢复训练。

图 3-22 工作节点在从远程数据库获取数据时遇到网络不稳定的情况

Push updates 推送更新 Pull updates 拉取更新 Model is from time t0 (outdated). 模型来自 t0 时刻(已过时)。 Worker node 工作节点 Exception: Failed to connect to the database 异常:连接数据库失败 Unstable network 网络不稳定 ⅓ training data(in remote database) 1/3 训练数据(远程数据库中)

网络不稳定是否还有其他影响呢?我们还忽略了一个事实:我们还依赖网络在工作节点和参数服务器节点之间进行通信,以发送计算出的梯度并更新模型参数。 但如果采用集合通信模式,整个训练的过程是同步的。 也就是说,一个工作节点会阻塞集群中其他节点的通信。我们需要从所有节点中获取到梯度,才能够将结果聚合来更新模型参数。 如果有一个工作节点通信变慢,经过连锁反应最终会阻塞整个任务的训练。

在图 3-23 中,同一进程组中的 3 个工作进程正在执行 allreduce 操作。 由于集群的网络不稳定,其中两个进程的通信变慢。 结果,两个通信慢的进程没有及时接收到数据(用问号表示),整个 allreduce 操作被阻塞,直到它们成功接收到了所有数据。

图 3-23 由于网络不稳定阻塞了整个 allreduce 训练过程

Process 2 进程 2 Process 3 进程 3 Process 4 进程 4 Data transfer 数据传输 Group 进程组 Slow communication 通信缓慢

我们是否可以采取一些措施,使整个训练过程不会受到个别节点网络性能下降的影响? 在这种情况下,首先我们可以考虑忽略这两个网络连接不稳定的工作进程,然后跳过这一次 allreduce 操作。 考虑到集合通信模式的特点,剩余的工作节点仍然拥有完全相同的模型副本,因此我们可以通过重建一个由剩余工作节点组成的新工作节点组,再次执行 allreduce 操作来继续训练。

这种方法还可以处理某些工作节点被抢占的情况(例如:它们的计算资源被重新分配给更高优先级的任务和节点)。 当这些工作节点被抢占时,我们重新构建工作节点组,然后执行 allreduce 操作。 这种方法能够避免在发生意外故障时从头开始训练模型,浪费了大量资源。 相反,我们可以从之前因故障发生暂停的地方开始,使用已经分配了计算资源的工作节点继续训练。 如果有额外的资源,我们可以添加工作节点,然后重建工作节点组以更高效地进行训练。 现在,我们可以轻松地扩缩分布式训练系统的规模,使整个系统在资源上具有弹性。 许多其他分布式系统也采用了相同的理念,来确保所系统的可靠性和可扩展性。

3.4.3 讨论 我们成功地实现了在集合通信模式下分布式训练的故障恢复,避免了工作节点计算资源的浪费。 如果我们的分布式训练使用的是参数服务器,而不是仅有工作节点的集合通信模式会怎么样呢?

回想一下,当使用参数服务器时,每个参数服务器存储了包含一部分模型参数的模型分区。 如果我们需要移除一些工作节点或参数服务器,例如当某个参数服务器上的网络不稳定导致某些节点通信失败、阻塞时,又或者工作节点的计算资源被抢占时,我们需要对失败节点中的模型分区进行检查点存档,然后将模型分区重新分配给剩余健康的参数服务器。

实际上,这里还存在许多挑战。 例如:我们如何对模型分区进行检查点存档,以及应该将它们保存在哪里?我们应该多久进行一次检查点存档,以确保它们的数据是最新的?

3.4.4 练习 1 为防止发生故障,在检查点中最需要保存的东西是什么? 2 当我们移除了那些因阻塞、无法恢复而没有来得及对模型做检查点存档的工作节点后,假设我们使用的是集合通信模式,我们应该从哪里获取最新的模型?

3.5 习题答案 第3.2.4节 1 不可以,因为训练发生在单台笔记本电脑上。 2 系统最终会花费大量时间在节点间通信上,而在神经网络层之间的计算上花费的时间很少。 3 我们需要更多的磁盘空间供参数服务器存储大型模型分区,但不需要过多的 CPU/GPU/内存资源,因为参数服务器不涉及大量运算。

第3.3.4节 1 不是,它们也出现在工作节点和参数服务器之间。 2 异步地。 3 reduce 和 broadcast 操作。

第3.4.4节 1 保存最新的模型参数 2 在集合通信模式下,剩余的工作节点仍然拥有相同的模型副本,我们可以用它来继续训练。

总结  考虑到数据集的大小和位置、模型的大小、计算资源和底层网络基础设施等因素,分布式模型训练不同于传统的模型训练。  我们可以使用参数服务器来构建大型且复杂的模型,在每个服务器上存储模型参数的分区。  如果工作节点和参数服务器之间的通信出现瓶颈,我们可以切换到集合通信模式来提高中小型模型的分布式模型训练性能。  分布式模型训练过程中如果发生意外故障,我们可以采取多种方法来避免计算资源的浪费。