Pytorch多GPU并行
在PyTorch中,多GPU并行确实可以分为数据并行(Data Parallelism)和模型并行(Model Parallelism),这两种并行策略针对不同的并行化需求
数据并行(Data Parallelism): 数据并行是最常见的并行策略,它通过将数据集分割成多个小batch,然后在多个GPU上并行处理这些部分。每个GPU上运行相同的模型副本,独立地计算前向传播和反向传播,最后将结果汇总。在PyTorch中,这可以通过torch.nn.DataParallel
或torch.nn.parallel.DistributedDataParallel
(DDP)来实现
模型并行(Model Parallelism): 模型并行是将模型的不同部分分配到不同的GPU上。这种方法适用于模型太大,单个GPU无法容纳整个模型的情况。模型的不同层或模块被放置在不同的GPU上,每个GPU负责计算模型的一部分。这种方法需要更复杂的通信机制来协调不同GPU之间的计算。在PyTorch中,模型并行可以通过自定义的通信逻辑来实现,或者使用专门的库如fairseq
数据并行又主要分为三种方式:
DP(Data Parallel):在单个机器上的多个GPU上并行训练模型,这种方法适用于模型大小适中,单个GPU可以容纳整个模型的情况。
DDP(Distributed Data Parallel):用于在多台机器上进行数据并行训练的高级API。这种方法适用于需要大规模分布式训练的场景,如大型模型或大数据集
FSDP(Fully Sharded Data Parallel): Fully Sharded Data Parallel(FSDP)是Facebook AI Research (FAIR) 提出的一种新的并行策略,旨在解决DDP在大规模分布式训练中的通信瓶颈问题。FSDP通过将模型的权重分成多个部分(shards),并在不同的GPU上进行训练,从而减少了通信开销
平时用的对多的就是DP和DDP方式,同时,在Pytorch官方网站上已经推荐使用DDP去全面取代DP方式 ,即使在单个机器上,DDP运行效率仍然比DP要高。下面简单介绍一下DP与DDP的使用:
数据并行(DP):
DP最大的好处就是简单,平时在小数据集和小模型实验的时候方便,默认情况下,nn.DataParallel
会使用所有可用的GPU,用法如下:
1 torch.nn.DataParallel(module, device_ids=None , output_device=None , dim=0 )
参数解释:
device_ids
(可选):一个整数列表,指定了要使用的 GPU 设备的 ID。如果为 None
(默认值),则使用所有可用的 GPU
output_device
(可选):指定输出结果应该被发送到哪个设备。如果为 None
(默认值),则使用 device_ids
中的第一个设备
dim
(可选):指定在哪个维度上进行数据并行。默认为 0
,即在批次维度上进行并行
返回值:
返回一个 DataParallel
对象,它封装了传入的 module
。这个对象可以像普通模块一样使用,但它会在多个 GPU 上并行执行模型的前向和后向传播。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 import os os.environ["CUDA_VISIBLE_DEVICES" ] = "0,1,2,3" import torch model = CreateModel(*args)if torch.cuda.device_count() > 1 : print ("Let's us " , torch.cuda.device_count(), "GPUs" ) model = nn.DataParallel(model) torch.save(model.module)
注意:os.environ['CUDA_VISIBLE_DEVICES']
必须在import [torch]
之前,否则设置是无法生效的!
分布式数据并行(DDP):
DDP通信算法是一个名为环形通信算法,他不需要等待每个GPU全部都计算完成,他没有主卡的概念,这使得计算过程中没有"bubble",更专业的内容查看李沐的动手深度学习
先看几个概念:
world:代表全部计算设备,world_size就是全部GPU的数量
node:物理节点,就是一台机器(一个服务器),节点内部可以有多个GPU(一台机器有多卡)
rank & local_rank:用于表示进程的序号,用于进程间通信。每一个进程对应了一个rank,rank=0的进程就是master进程
例如有两台服务器,第一台服务器有4卡,第二台服务器有3卡,那么world_size=5 node=1,2 rank=[0,1,2,3,4,5,6] local_rank=[0,1,2,3],[0,1,2]
DDP的用法如下:
1.首先进行DDP初始化:
torch.distributed.init_process_group
是分布式训练中用于初始化进程组的函数。这个函数设置了分布式环境,允许不同的进程之间进行通信
1 torch.distributed.init_process_group(backend, init_method=None , timeout=datetime.timedelta(seconds=1800 ), world_size=None , rank=None )
参数解释:
backend
:指定分布式通信的后端,如 'nccl'
(用于多GPU环境)、'gloo'
(用于CPU或单GPU环境)或 'mpi'
init_method
(可选):指定初始化方法的URL或文件路径。默认为 None
,表示使用环境变量 MASTER_ADDR
和 MASTER_PORT
来初始化
timeout
(可选):设置初始化过程的超时时间,默认为1800秒(30分钟)
world_size
(可选):参与分布式训练的总进程数。默认为 -1
,表示从环境变量中自动获取
rank
(可选):当前进程的排名。默认为 -1
,表示从环境变量中自动获取
1 2 3 import torch.distributed as dist dist.init_process_group(backend='nccl' , init_method='env://' , world_size=num_processes, rank=process_rank)
2.准备数据dataloader和sampler,需要在DDP初始化之后进行:
DistributedSampler
是一个特殊的采样器(Sampler),它用于确保在多个进程(通常是多个 GPU 或多个节点)之间均匀且不重叠地分配数据集。这样,每个进程只处理数据集的一个子集,从而实现数据的并行处理,因此将原先的dataloader
换为DistributedSampler
即可,注意:batch_size指的是每个进程下的batch_size
1 train_sampler = torch.utils.data.distributed.DistributedSampler(my_trainset)
3.构造model模型:
1 2 3 4 5 6 local_rank = dist.get_rank() torch.cuda.set_device(local_rank) model = model.to(local_rank)
4.如果需要Load模型,则要在构造DDP模型之前,且只需要在master上加载就行了:
“master” 通常指的是负责初始化分布式环境和协调其他进程的进程。在分布式训练中,所有进程(或称为 “workers”)需要协同工作,而 “master” 进程则扮演着启动和配置这些进程的角色,注意是主进程不是主机,没有主机
1 2 if dist.get_rank() == 0 and ckpt_path is not None : model.load_state_dict(torch.load(ckpt_path))
5.构造DDP model 模型:
1 model = DDP(model, device_ids=[local_rank], output_device=local_rank)
6.要在构造DDP model之后,才能用model初始化optimizer:
1 optimizer = torch.optim.SGD(model.parameters(), lr=0.001 )
7.loss 函数也要转到指定设备:
1 loss_func = nn.CrossEntropyLoss().to(local_rank)
8.网络训练:
设置DDP sampler的epoch,DistributedSampler
需要这个来指定shuffle方式,通过维持各个进程之间的相同随机数种子使不同进程能获得同样的shuffle效果。
1 trainloader.sampler.set_epoch(epoch)
9.保存模型:
保存模型的时候,和DP模式一样,有一个需要注意的点:保存的是model.module
而不是model
。因为model其实是DDP model,参数是被model=DDP(model)
包起来的。并且只需要在进程0上保存一次就行了,避免多次保存重复的东西
1 2 if dist.get_rank() == 0 : torch.save(model.module.state_dict(), "%d.ckpt" % epoch)
DDP代码示例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 import argparsefrom tqdm import tqdmimport torchimport torchvisionimport torch.nn as nnimport torch.nn.functional as Fimport torch.distributed as distfrom torch.nn.parallel import DistributedDataParallel as DDPclass ToyModel (nn.Module): def __init__ (self ): super (ToyModel, self).__init__() self.conv = nn.Conv2d(3 , 6 , 5 ) def forward (self, x ): return self.conv(x)def get_dataset (): transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 )) ]) my_trainset = torchvision.datasets.CIFAR10(root='./data' , train=True , download=True , transform=transform) train_sampler = torch.utils.data.distributed.DistributedSampler(my_trainset) trainloader = torch.utils.data.DataLoader(my_trainset, batch_size=16 , num_workers=2 , sampler=train_sampler) return trainloader parser = argparse.ArgumentParser() parser.add_argument("--local_rank" , default=-1 , type =int ) FLAGS = parser.parse_args() local_rank = FLAGS.local_rank torch.cuda.set_device(local_rank) dist.init_process_group(backend='nccl' ) trainloader = get_dataset() model = ToyModel().to(local_rank) ckpt_path = None if dist.get_rank() == 0 and ckpt_path is not None : model.load_state_dict(torch.load(ckpt_path)) model = DDP(model, device_ids=[local_rank], output_device=local_rank) optimizer = torch.optim.SGD(model.parameters(), lr=0.001 ) loss_func = nn.CrossEntropyLoss().to(local_rank) model.train() iterator = tqdm(range (100 ))for epoch in iterator: trainloader.sampler.set_epoch(epoch) for data, label in trainloader: data, label = data.to(local_rank), label.to(local_rank) optimizer.zero_grad() prediction = model(data) loss = loss_func(prediction, label) loss.backward() iterator.desc = "loss = %0.3f" % loss optimizer.step() if dist.get_rank() == 0 : torch.save(model.module.state_dict(), "%d.ckpt" % epoch)