GPU并行训练

为什么要使用多GPU并行训练

简单来说,有两种原因:第一种是模型在一块GPU上放不下,两块或多块GPU上就能运行完整的模型(如早期的AlexNet)。第二种是多块GPU并行计算可以达到加速训练的效果。想要成为“炼丹大师“,多GPU并行训练是不可或缺的技能。

常见的多GPU训练方法:

1.模型并行方式:如果模型特别大,GPU显存不够,无法将一个显存放在GPU上,需要把网络的不同模块放在不同GPU上,这样可以训练比较大的网络。(下图左半部分)

2.数据并行方式:将整个模型放在一块GPU里,再复制到每一块GPU上,同时进行正向传播和反向误差传播。相当于加大了batch_size。(下图右半部分)

GPU 设置

nn.DataParalleltorch.nn.parallel.DistributedDataParallel`都是 PyTorch 用来在多个 GPU 上并行训练模型的工具, 但它们在实现上有一些重要的区别。

  1. nn.DataParallel
  • 单机多卡并行nn.DataParallel适用于在单台机器的多个 GPU 上进行模型的并行训练。
  • 模型复制: 它会在多个 GPU 上复制同一个模型, 每个模型处理部分数据(mini-batch) ,然后将梯度汇总更新模型的参数
  • 简单易用: 使用简单, 只需在**模型外层套用nn.DataParallel(model)**即可。 PyTorch 会自动处理数据切分和梯度汇总的过程。

建议使用 DistributedDataParallel, 而不是这个类, 进行多 GPU 训练, 即使只有一个 节点。 请参阅: 使用 nn.parallel.DistributedDataParallel 而不是多处理或 nn.DataParallel 和 分布式数据并行 .

如果要增加 GPU 占用率, 提高 batch_size

  1. torch.nn.parallel.DistributedDataParallel
    PyTorch 的分布式训练主要通过torch.distributed包来实现, 该包提供了多进程多GPU 训练模型的能力。 以下是分布式训练的关键概念和组件:

  2. 进程组(Process Group) :
    进程组是一组可以相互通信的进程。 在分布式训练中, 每个进程都是一个独立的 Python进程, 通常运行在不同的机器或机器的不同 GPU 上。

  3. 初始化进程组:
    在 开 始 分 布 式 训 练 之 前 , 需 要 初 始 化 进 程 组 。 这 通 常 通 过 调 用torch.distributed.init_process_group()函数完成, 该函数需要指定后端(如”nccl”、 “gloo”)和初始化方法。

  4. 分布式数据并行(Distributed Data Parallel, DDP) :
    DDP 是一种在多个进程和多个 GPU 上并行训练模型的方法。 每个进程拥有模型的一个副本, 并且只处理数据的一个子集。 通过减少每个进程所需的数据量, DDP 可以有效地扩展到大量的 GPU。

  5. DistributedDataParallel类:

    使用torch.nn.parallel.DistributedDataParallel类可以很容易地将现有的nn.Module包装起来, 使其能够在 DDP 中使用。 该类负责在不同的进程间同步模型的梯度。

  6. 环境变量:
    使用环境变量如MASTER_ADDRMASTER_PORT可以指定用于初始化进程组的主机
    地址和端口。

  7. 分布式采样器:
    为了确保每个进程只处理数据的一个独立子集, 通常需要使用分布式采样器, 如torch.utils.data.distributed.DistributedSampler

  8. 梯度累积:
    在资源有限的情况下, 可以通过梯度累积技术来模拟更大批量的训练, 即使不能在单个迭代中放入足够多的样本。

  9. 通信后端:
    PyTorch 支持多种通信后端, 如”nccl”、 “gloo”和”mpi”, 它们在不同的硬件和网络配置下提供了不同的性能和特性。

  10. 调试和优化:
    分布式训练可能会引入额外的复杂性和潜在的错误源。 PyTorch 提供了工具和最佳实践来帮助调试和优化分布式训练的性能。

    https://pytorch.org/docs/stable/distributed.html#

针对分布式训练
字节跳动(抖音)
https://github.com/bytedance/byteps?tab=readme-ov-file
阿里巴巴
https://github.com/alibaba/EasyParallelLibrary?tab=readme-ov-file

Llama- Factory 多机多卡并行训练

https://llamafactory.readthedocs.io/zh-cn/latest/advanced/distributed.html