还记得那个看起来像 Keras 的轻量版 PyTorch 框架 Lightning 吗?它终于出了 1.0.0 版本,并增添了很多新功能,在度量、优化、日志记录、数据流、检查点等方面均进行了完善。
博客地址:https://medium.com/pytorch/pytorch-lightning-1-0-from-0-600k-80fc65e2fab0
GitHub 地址:https://github.com/PyTorchLightning/pytorch-lightning
class LitModel(pl.LightningModule):
def __init__(self):
...
self.train_acc = pl.metrics.Accuracy()
self.valid_acc = pl.metrics.Accuracy()
def training_step(self, batch, batch_idx):
logits = self(x)
...
self.train_acc(logits, y)
# log step metric
self.log('train_acc_step', self.train_acc)
def validation_step(self, batch, batch_idx):
logits = self(x)
...
self.valid_acc(logits, y)
# logs epoch metrics
self.log('valid_acc', self.valid_acc)
from pytorch_lightning.metrics import Metric
class MyAccuracy(Metric):
def __init__(self, dist_sync_on_step=False):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
preds, target = self._input_format(preds, target)
assert preds.shape == target.shape
self.correct += torch.sum(preds == target)
self.total += target.numel()
def compute(self):
return self.correct.float() / self.total
def training_step(self, batch, batch_idx):
loss = self.encoder(batch[0])
return loss
trainer *=* Trainer(automatic_optimization*=False*)
def training_step(self, batch, batch_idx, opt_idx):
(opt_a, opt_b, opt_c) = self.optimizers()
loss_a = self.generator(batch[0])
# use this instead of loss.backward so we can automate half
# precision, etc...
self.manual_backward(loss_a, opt_a, retain_graph=True)
self.manual_backward(loss_a, opt_a)
opt_a.step()
opt_a.zero_grad()
loss_b = self.discriminator(batch[0])
self.manual_backward(loss_b, opt_b)
...
def training_step(self, batch, batch_idx):
self.log('my_metric', x)
def training_step(self, batch, batch_idx):
self.log('my_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
x_step
x_step_end
x_epoch_end
outs *=* []
*for* batch *in* data:
out *=* training_step(batch)
outs*.*append(out)training_epoch_end(outs)
def training_step(self, batch, batch_idx):
prediction = …
return {'loss': loss, 'preds': prediction}
def training_epoch_end(self, training_step_outputs):
for out in training_step_outputs:
prediction = out['preds']
# do something with these
计算想要监控的任意度量或其他数量,如验证损失;
通过 log() 方法记录下数量以及 val_loss 等键(key);
初始化 ModelCheckpoint 回调函数,将 monitor 设置为数量的 key;
将回调函数 checkpoint_callback 返回训练器 flag。
from pytorch_lightning.callbacks import ModelCheckpoint
class LitAutoEncoder(pl.LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.backbone(x)
# 1. calculate loss
loss = F.cross_entropy(y_hat, y)
# 2. log `val_loss`
self.log('val_loss', loss)
# 3. Init ModelCheckpoint callback, monitoring 'val_loss'
checkpoint_callback = ModelCheckpoint(monitor='val_loss')
# 4. Pass your callback to checkpoint_callback trainer flag
trainer = Trainer(checkpoint_callback=checkpoint_callback)