
什么是多任务学习


PALM多任务学习框架概览
组件层:PALM提供了6个 解耦的组件来实现NLP任务。每个组件包含丰富的预定义类和一个基类。预定义类是针对典型的NLP任务的,而基类则是帮助用户完成该组件的自定义。 训练器层:通过使用选定的构件建立计算图,用于进行训练和推理。该层描述了训练策略、模型保存和加载、评估和推理过程。一个训练器只能处理一个任务。 高级训练器层:用于复杂的学习和推理策略,如多任务学习。通过添加辅助任务来训练健壮的NLP模型(提高模型的测试集和领域外的性能),或者联合训练多个相关任务来获得每个任务的更高性能。



如何使用PALM?
1. 安装PALM
pip install paddlepalm #或 git clone https://github.com/PaddlePaddle/PALM.git
>>> from paddlepalm import downloader >>> downloader.ls('pretrain') Available pretrain items: => RoBERTa-zh-base => RoBERTa-zh-large => ERNIE-v2-en-base => ERNIE-v2-en-large => XLNet-cased-base => XLNet-cased-large => ERNIE-v1-zh-base => ERNIE-v1-zh-base-max-len-512 => BERT-en-uncased-large-whole-word-masking => BERT-en-cased-large-whole-word-masking => BERT-en-uncased-base => BERT-en-uncased-large => BERT-en-cased-base => BERT-en-cased-large => BERT-multilingual-uncased-base => BERT-multilingual-cased-base => BERT-zh-base >>> downloader.download('pretrain', 'BERT-en-uncased-base', './pretrain_models')
2. 参考如下例子编写代码
# 创建数据集的读取与预处理工具 seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, seed=random_seed) cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen, seed=random_seed) # 加载训练数据 seq_label_reader.load_data(train_slot, file_format='tsv', num_epochs=None, batch_size=batch_size) cls_reader.load_data(train_intent, batch_size=batch_size, num_epochs=None) # 创建骨干网络提取文本特征 ernie = palm.backbone.ERNIE.from_config(config) # 在ERNIE的骨干网络上注册数据集读取与预处理工具 seq_label_reader.register_with(ernie) cls_reader.register_with(ernie) # 创建任务的输出层 seq_label_head = palm.head.SequenceLabel(num_classes, input_dim, dropout_prob) cls_head = palm.head.Classify(num_classes_intent, input_dim, dropout_prob) # 创建任务训练单元和多任务训练模块 trainer_seq_label = palm.Trainer("slot", mix_ratio=1.0) trainer_cls = palm.Trainer("intent", mix_ratio=1.0) trainer = palm.MultiHeadTrainer([trainer_seq_label, trainer_cls]) # 构建包含主干网络和任务头的前向图 loss1 = trainer_cls.build_forward(ernie, cls_head) loss2 = trainer_seq_label.build_forward(ernie, seq_label_head) loss_var = trainer.build_forward() # 使能warmup策略以获取更好的微调效果 n_steps = seq_label_reader.num_examples * 1.5 * num_epochs warmup_steps = int(0.1 * n_steps) sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps) # 构建优化器 adam = palm.optimizer.Adam(loss_var, lr, sched) # 构建反向图 trainer.build_backward(optimizer=adam, weight_decay=weight_decay) #将准备好的reader和数据给到训练单元。 trainer.fit_readers_with_mixratio([seq_label_reader, cls_reader], "slot", num_epochs) # 加载预训练模型 trainer.load_pretrain('./pretrain/ERNIE-v2-en-base') # 设置训练期间保存模型 trainer.set_saver(save_path='./outputs/', save_steps=300) # 开始训练 trainer.train(print_steps=10)
global step: 5, slot: step 3/309 (epoch 0), loss: 68.965, speed: 0.58 steps/s global step: 10, intent: step 3/311 (epoch 0), loss: 3.407, speed: 8.76 steps/s global step: 15, slot: step 12/309 (epoch 0), loss: 54.611, speed: 1.21 steps/s global step: 20, intent: step 7/311 (epoch 0), loss: 3.487, speed: 10.28 steps/s
更多示例
除了上面的示例之外,飞桨PALM还可以用来帮助复现EMNLP2019 MRQA比赛中的夺冠方案D-Net。通过使用飞桨PALM,可以帮助机器阅读理解引入Mask Language Model和段落打分辅助任务的过程变得非常容易。