杜伟 陈萍作者

无性能损失,不用更改代码,Lightning 1.1版本发布,切分训练新功能节省50%以上内存

 1.0.0 版本推出不到两个月的时间,grid.ai CEO、纽约大学博士 William Falcon 创建的 PyTorch Lightning 于近日宣布推出 1.1 版本。新版本新增了 sharded training 功能,在多 GPU 上训练深度学习(DL)模型时可以节省 50% 以上的内存,并且没有性能损失,也不需要更改代码。

与 Facebook Research 的 FairScale 团队一道,PyTorch Lightning 团队在 1.1 版本中推出了 Sharded Training beta 版。在下面的博客文章中,PyTorch Lightning 团队的研究工程师 Sean Narenthiran 展示了只需要在 Lightning 训练器中增加单一的 flag,则在多 GPU 上训练 DL 模型时就会实现内存的大幅度降低。

此外,作者还介绍了如何使用 NeMo 来预训练 Transformer LM,并实现 55% 的内存提升,以及训练其他 PyTorch Lightning 赋能模型时进一步的内存降低。除了给出使用 NeMO Transformer LM 时 NLP 中的结果之外,作者还分别展示了语音识别中使用 DeepSpeech 2 以及计算机视觉中训练 SwAV ResNet 和 iGPT 的结果。

PyTorch Lightning 团队正努力增添新的模型并行化技术并保证鲁棒性,并且与 FairScale 团队展开合作提升所有 PyTorch Lightning 研究中的模型扩展性能。

更多使用技巧参考:https://pytorch-lightning.readthedocs.io/en/stable/multi_gpu.html#multi-gpu-training

更大的模型,更高的准确率

语言建模趋向于更大的预训练模型,这种模型在下游任务中表现得更好。OpenAI 的 GPT-3 就是一个很好的例子,该模型有 1750 亿个参数,在训练时需要大量的计算与优化技巧。

比较了语言模型参数随时间变化的曲线,GPT-3 继续在规模上超越。(图源:Microsoft)

训练大型模型时,内存很宝贵。当扩展模型大小时,GPU 的内存会被耗尽,而这又限制了训练模型的大小。这使得团队不断尝试更智能的内存管理技术。

Lightning 的 Sharded Training

传统分布式训练 VS Sharded Training。参数(P)在 GPU 之间拆分,以减少每个 GPU 的内存开销。Sharded Training 则拆分了优化器状态和梯度。

受微软 Zero Redundancy Optimizer (ZeRO) 的启发,Sharded Training 可以减少在多个 GPU 上训练大型模型所需的内存,训练过程中在 GPU 之间「切分」模型。Sharding 包括将参数分割到不同的设备上,减少每个设备所需的内存。特别地,优化器状态和梯度可以独立于模型进行切分,并且可以减少所有架构所需的内存。

Sharded Training 是在 FairScale 基础上构建的,与 PyTorch 兼容并得到优化。FairScale 是一个 PyTorch 扩展库,用于高性能以及大规模训练模型和数据并行。除了切分技术之外,它具有层间和层内并行性以及跨多个 GPU 和主机拆分模型。

通过在 GPU 上进行智能梯度和优化器状态 sharding,可以分别将内存成本(基于微软论文《ZeRO: Memory Optimizations Toward Training Trillion Parameter Models》的数据)降低大约 4 倍和 8 倍。这有利于所有模型,在所有模型架构以及训练过程中提供较低的内存使用率。需要注意的是,由于节点之间所需通信量的增加以及缺乏并行性,「naive implementations」导致运行速度急剧下降。

通过与 FairScale 的紧密合作,现在可以在所有 lightning 模块上实现 55% 以上的内存减少,只需通过一个单一的 flag,这意味着更大的机型可以适应内存有限的多张 GPU。

在不更改代码的情况下启用 Sharded Training

为了展示在 Lightning 中使用 Sharded Training 有多简单,使用 NVIDIA 的一个流行库 NeMo 来训练 Lightning 支持的对话 AI 模型。使用 NeMo 中提供的 vanilla Transformer LM 模型,有 12 亿个参数,该模型对训练内存要求很高。在训练大型语言模型时,内存是提高模型大小或提升 GPU 饱和度的宝贵资源。此外使用 WikiText 数据集训练模型。

首先下载数据集并使用 NVIDIA NeMo 提供的处理脚本进行提取,然后在 NeMo 中找到预配置文件定义模型配置,修改数据输入指向自定义数据集。为了进行基准测试,还构建了一个简单的基于单词的词汇表。

import osfrom omegaconf import OmegaConf# Build a simple word based vocabulary for benchmarking purposeswith open('wikitext-2/train.txt') as f:    vocab = set(f.read().split())with open('vocab.txt', 'w') as f:    f.write('\n'.join(vocab))# Define the model configuration using the preset configuration file found within NeMoconfig_path = "./examples/nlp/language_modeling/conf/transformer_lm_config.yaml"config = OmegaConf.load(config_path)config.model.language_model.vocab_file = 'vocab.txt'config.model.train_ds.file_name = os.path.join('wikitext-2/train.txt')config.model.validation_ds.file_name = os.path.join('wikitext-2/valid.txt')

在设置模型参数之后,用户只需要将 Sharded 插件 flag 传递给支持 Sharded Traing 的训练器就可以了。用户还可以通过增加 GPU 数量和启用本地混合精度(native mixed precision)来实现内存和速度的进一步提升。分区优化器和 GPU 之间的通信可以在后台自动处理。

import pytorch_lightning as plfrom nemo.collections import nlp as nemo_nlp# Set model parameters (roughly 1.2 billion parameters)config.model.train_ds.batch_size = 8  # Reduce batch size for training large modelconfig.model.language_model.hidden_size = 3072config.model.language_model.inner_size = 3072config.model.language_model.num_layers = 22# Use 8 GPUs, and enable Mixed Precision + Sharded Trainingtrainer = pl.Trainer(    gpus=8,    precision=16,    max_epochs=50,    accelerator='ddp',    plugins='ddp_sharded')model = nemo_nlp.models.TransformerLMModel(cfg=config.model, trainer=trainer)

下面介绍了使用 Lightning 内置 Sharding 与普通 GPU 扩展时每台设备的内存提升情况,每台设备的内存分配保持不变。不仅如此,Lightning 团队还给出了 SwAW、DeepSpeech 2 和 iGPT 等其他 PyTorch Lightning 支持模型的测试结果。

结果表明,每个 GPU 上最高节省内存 15GiB,从而可以增加模型能力。例如,在硬件一样的情况下,训练 Transformer LM 时模型参数量可以从 12 亿增至 20 亿。

使用 8 个 A100s 时训练 Transformer LM、SwAV Wide ResNet、DeepSpeech2 和 iGPT 时的平均峰值内存比较。

随着 GPU 之间通信的优化,与标准分布式加速器相比,节点内性能的扩展效果更好。请注意,随着向很多节点的扩展,内存降低的效果开始减弱,这是因为其他因素成为了瓶颈。但是,Sharded training 依然带来良好的 throughout 扩展。

在 8 个具有相同超参数和批大小的 A100s 上的平均 Epoch time 比较,越低越好。

博客地址:https://seannaren.medium.com/introducing-pytorch-lightning-sharded-train-sota-models-with-half-the-memory-7bcc8b4484f

Powered by Froala Editor

工程PyTorch LightningPyTorch
暂无评论
暂无评论~