Skip to content

Checkpoint 机制及使用介绍

在深度学习的实际场景中,模型训练一般都会持续一段时间,这对分布式训练任务的稳定性和效率提出了更高的要求。 而且,在实际训练的过程中,异常中断会导致训练过程中的模型状态丢失,需要重新开始训练, 这不仅浪费了时间和资源,这在 LLM 训练中尤为明显,而且也会影响模型的训练效果。

能够在训练过程中保存模型的状态,以便在训练过程中出现异常时能够恢复模型状态,变得至关重要。 Checkpoint 就是目前主流的解决方案,本文将介绍 Checkpoint 机制的基本概念和在 PyTorch 和 TensorFlow 中的使用方法。

什么是 Checkpoint?

Checkpoint 是在模型训练过程中保存模型状态的机制。通过定期保存 Checkpoint,可以在以下情况下恢复模型:

  • 训练过程中断(如系统崩溃或手动中断)
  • 需要在某个训练阶段进行评估
  • 希望在不同的实验中复用模型

PyTorch

在 PyTorch 中,torch.savetorch.load 是用于保存和加载模型的基本函数。

PyTorch 保存 Checkpoint

在 PyTorch 中,通常使用 state_dict 保存模型的参数。以下是一个简单的示例:

import torch
import torch.nn as nn

# 假设我们有一个简单的神经网络
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 2)

    def forward(self, x):
        return self.fc(x)

# 初始化模型和优化器
model = SimpleModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练模型...
# 保存 Checkpoint
checkpoint_path = 'model_checkpoint.pth'
torch.save({
    'epoch': 10,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': 0.02,
}, checkpoint_path)

PyTorch 恢复 Checkpoint

加载模型时,需要恢复模型参数和优化器状态,并继续训练或推理:

# 恢复 Checkpoint
checkpoint = torch.load('model_checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

# 继续训练或推理...
  • model_state_dict: 模型参数
  • optimizer_state_dict: 优化器状态
  • epoch: 当前训练轮数
  • loss: 损失值
  • learning_rate: 学习率
  • best_accuracy: 最佳准确率

TensorFlow

TensorFlow 提供了 tf.train.Checkpoint 类来管理模型和优化器的保存和恢复。

TensorFlow 保存 Checkpoint

以下是一个在 TensorFlow 中保存 Checkpoint 的示例:

import tensorflow as tf

# 假设我们有一个简单的模型
model = tf.keras.Sequential([
    tf.keras.layers.Dense(2, input_shape=(10,))
])
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

# 定义 Checkpoint
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
checkpoint_dir = './checkpoints'
checkpoint_prefix = f'{checkpoint_dir}/ckpt'

# 训练模型...
# 保存 Checkpoint
checkpoint.save(file_prefix=checkpoint_prefix)

Note

使用 AI Lab 的用户,可以直接将高性能存储挂载为 Checkpoint 目录,以提高 Checkpoint 保存和恢复的速度。

TensorFlow 恢复 Checkpoint

加载 Checkpoint 并恢复模型和优化器状态:

# 恢复 Checkpoint
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint.restore(latest_checkpoint)

# 继续训练或推理...

TensorFlow 在分布式训练的 Checkpoint 管理

TensorFlow 在分布式训练中管理 Checkpoint 的主要方法如下:

  • 使用 tf.train.Checkpointtf.train.CheckpointManager

    checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
    manager = tf.train.CheckpointManager(checkpoint, directory='/tmp/model', max_to_keep=3)
    
  • 在分布式策略中保存 Checkpoint

    strategy = tf.distribute.MirroredStrategy()
    with strategy.scope():
        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        manager = tf.train.CheckpointManager(checkpoint, directory='/tmp/model', max_to_keep=3)
    
  • 只在主节点 (chief worker) 保存 Checkpoint

    if strategy.cluster_resolver.task_type == 'chief':
        manager.save()
    
  • 使用 MultiWorkerMirroredStrategy 时的特殊处理

    strategy = tf.distribute.MultiWorkerMirroredStrategy()
    with strategy.scope():
        # 模型定义
        ...
        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        manager = tf.train.CheckpointManager(checkpoint, '/tmp/model', max_to_keep=3)
    
    def _chief_worker(task_type, task_id):
        return task_type is None or task_type == 'chief' or (task_type == 'worker' and task_id == 0)
    
    if _chief_worker(strategy.cluster_resolver.task_type, strategy.cluster_resolver.task_id):
        manager.save()
    
  • 使用分布式文件系统

    确保所有工作节点都能访问到同一个 Checkpoint 目录,通常使用分布式文件系统如 HDFS 或 GCS。

  • 异步保存

    使用 tf.keras.callbacks.ModelCheckpoint 并设置 save_freq 参数可以在训练过程中异步保存 Checkpoint。

  • Checkpoint 恢复

    status = checkpoint.restore(manager.latest_checkpoint)
    status.assert_consumed()  # (1)!
    
    1. 确保所有变量都被恢复
  • 性能优化

    • 使用 tf.train.experimental.enable_mixed_precision_graph_rewrite() 启用混合精度训练
    • 调整保存频率,避免过于频繁的 I/O 操作
    • 考虑使用 tf.saved_model.save() 保存整个模型,而不仅仅是权重

注意事项

  1. 定期保存:根据训练时间和资源消耗,决定合适的保存频率。如每个 epoch 或每隔一定的训练步数。

  2. 保存多个 Checkpoint:保留最新的几个 Checkpoint 以防止文件损坏或不适用的情况。

  3. 记录元数据:在 Checkpoint 中保存额外的信息,如 epoch 数、损失值等,以便更好地恢复训练状态。

  4. 使用版本控制:保存不同实验的 Checkpoint,便于对比和复用。

  5. 验证和测试:在训练的不同阶段使用 Checkpoint 进行验证和测试,确保模型性能和稳定性。

结论

Checkpoint 机制在深度学习训练中起到了关键作用。通过合理使用 PyTorch 和 TensorFlow 中的 Checkpoint 功能, 可以有效提高训练的可靠性和效率。希望本文所述的方法和最佳实践能帮助你更好地管理深度学习模型的训练过程。