NCCL集合通信算子DEMO及性能测试


以下代码用于测试NCCL算子的性能及正确性

一.复现代码

tee ccl_benchmark.py <<-'EOF'
import os
import torch
import argparse
import torch.distributed as dist
from torch.distributed import ReduceOp
from datetime import datetime
import time
import argparse
import numpy as np
dev_type="cuda"
 
class Timer:
    def __init__(self,duration):        
        self.duration=duration
  
    def __enter__(self):
        dist.barrier()
        self.beg= datetime.now().timestamp() * 1e6
 
    def __exit__(self, exc_type, exc_val, exc_tb):
        dist.barrier()
        self.end=datetime.now().timestamp() * 1e6
        self.duration.append(self.end-self.beg)
 
op_mapping={}
class ccl_benchmark:
    def __init__(self,func):
        global op_mapping  
        op_mapping[func.__name__]=func
        self.func=func
         
    def __call__(self,*args,**kwargs):
        return self.func(*args,**kwargs)
         
@ccl_benchmark
def all_gather(shape,device,rank,world_size,iters=5):
    '''
    将每个rank input_tensor的数据在dim 0维度拼接在一起
    '''
    duration=[]
    input_tensor=(torch.ones((shape[0]//world_size,shape[1]),dtype=torch.int64)*(100+rank)).to(device)
    gather_list=[torch.zeros((shape[0]//world_size,shape[1]),dtype=torch.int64).to(device) for _ in range(world_size)]
    for _ in range(iters):
        with Timer(duration):
            dist.all_gather(gather_list,input_tensor)   
    output=torch.cat(gather_list,dim=0)
    gt=[torch.ones((shape[0]//world_size,shape[1]),dtype=torch.int64)*(100+i) for i in range(world_size)]
    gt=torch.cat(gt,dim=0)
    return duration,(output.cpu()==gt).all()
    
@ccl_benchmark
def scatter(shape,device,rank,world_size,iters=5):
    '''
    将每个rank从scatter_list[rank]取数据到output_tensor
    '''
    duration=[]
    output_tensor=torch.zeros((shape[0]//world_size,shape[1]),dtype=torch.int64).to(device)
    scatter_list=[(torch.ones((shape[0]//world_size,shape[1]),dtype=torch.int64)*i).to(device) for i in range(world_size)]
    for _ in range(iters):
        with Timer(duration):
            if rank == 0:
                dist.scatter(output_tensor,scatter_list=scatter_list,src =0)
            else:
                dist.scatter(output_tensor,src  = 0)
    gt=torch.ones((shape[0]//world_size,shape[1]),dtype=torch.int64)*rank
    return duration,(output_tensor.cpu()==gt).all()
   
@ccl_benchmark
def gather(shape,device,rank,world_size,iters=5):
    '''
    将每个rank input_tensor的数据在dim 0维度拼接在一起 只在批定的rank做
    '''
    duration=[]
    input_tensor=(torch.ones((shape[0]//world_size,shape[1]),dtype=torch.int64)*(100+rank)).to(device)
    gather_list=[torch.zeros((shape[0]//world_size,shape[1]),dtype=torch.int64).to(device) for _ in range(world_size)]
    for _ in range(iters):
        with Timer(duration):
            if rank == 0:
                dist.gather(input_tensor,gather_list=gather_list,dst=0)
            else:
                dist.gather(input_tensor,dst=0)
    ret=True
    if rank==0:
        output=torch.cat(gather_list,dim=0)
        gt=[torch.ones((shape[0]//world_size,shape[1]),dtype=torch.int64)*(100+i) for i in range(world_size)]
        gt=torch.cat(gt,dim=0)
        ret=(output.cpu()==gt).all()
    return duration,ret
  
@ccl_benchmark
def reduce(shape,device,rank,world_size,iters=5):
    '''
    将每个rank input_tensor的数据在dim 0维度拼接在一起 只在批定的rank做
    '''
    duration=[]   
    for _ in range(iters):
        input_tensor=(torch.ones((shape[0],shape[1]),dtype=torch.int64)*(100+rank)).to(device)
        # input_tensor的内容会被修改,所以放在循环里
        with Timer(duration):
            dist.reduce(input_tensor,dst=0,op=dist.ReduceOp.SUM)
    ret=True
    if rank==0:
        gt=[torch.ones((shape[0],shape[1]),dtype=torch.int64)*(100+i) for i in range(world_size)]
        gt_=gt[0]       
        for i in range(1,world_size):
            gt_=gt_+gt[i]
        ret=(input_tensor.cpu()==gt_).all()
    return duration,ret
         
@ccl_benchmark
def broadcast(shape,device,rank,world_size,iters=5):
    '''
    将src的rank的数据广播到其它rank
    '''
    duration=[]   
    for _ in range(iters):
        input_tensor=(torch.ones((shape[0],shape[1]),dtype=torch.int64)*(100+rank)).to(device)
        with Timer(duration):
            dist.broadcast(input_tensor,src=0)
 
    gt=(torch.ones((shape[0],shape[1]),dtype=torch.int64)*(100+0)).to('cpu')
    ret=(input_tensor.cpu()==gt).all()
    return duration,ret
  
@ccl_benchmark
def p2p(shape,device,rank,world_size,iters=5):
    '''
    将src的rank的数据广播到其它rank
    '''
    duration=[]   
    for _ in range(iters):
        input_tensor=(torch.ones((shape[0],shape[1]),dtype=torch.int64)*(100+rank)).to(device)
        with Timer(duration):
            if rank!=0:
                dist.recv(input_tensor,rank-1)               
            if rank!=world_size-1:               
                dist.send(input_tensor,dst=rank+1)   
 
    gt=(torch.ones((shape[0],shape[1]),dtype=torch.int64)*(100+0)).to('cpu')
    ret=(input_tensor.cpu()==gt).all()
    return duration,ret
 
@ccl_benchmark
def all_reduce(shape,device,rank,world_size,iters=5):
    '''
    将每个rank input_tensor的数据在dim 0维度拼接在一起
    '''
    duration=[]   
    for _ in range(iters):
        input_tensor=(torch.ones((shape[0],shape[1]),dtype=torch.int64)*(100+rank)).to(device)
        # input_tensor的内容会被修改,所以放在循环里
        with Timer(duration):
            dist.all_reduce(input_tensor,op=dist.ReduceOp.SUM)
 
    gt=[torch.ones((shape[0],shape[1]),dtype=torch.int64)*(100+i) for i in range(world_size)]
    gt_=gt[0]       
    for i in range(1,world_size):
        gt_=gt_+gt[i]
    ret=(input_tensor.cpu()==gt_).all()
    return duration,ret
 
@ccl_benchmark
def reduce_scatter(shape,device,rank,world_size,iters=5):
    '''
    '''
    duration=[]
    output_tensor=torch.zeros((shape[0]//world_size,shape[1]),dtype=torch.int64).to(device)
    input_list=[(torch.ones((shape[0]//world_size,shape[1]),dtype=torch.int64)*(100*rank)+chunk_id).to(device) for chunk_id in range(world_size)]
    for _ in range(iters):
        with Timer(duration):
            dist.reduce_scatter(output_tensor,input_list=input_list,op=dist.ReduceOp.SUM)
     
    gt_list=[(torch.ones((shape[0]//world_size,shape[1]),dtype=torch.int64)*(100*rk)+rank).to('cpu') for rk in range(world_size)]
    gt_=gt_list[0]       
    for i in range(1,world_size):
        gt_=gt_+gt_list[i]    
    return duration,(output_tensor.cpu()==gt_).all()
     
def main():
    dist.init_process_group(backend='nccl')

    if not torch.distributed.is_initialized():
        return
         
    parser = argparse.ArgumentParser(description='test')
    parser.add_argument('--shape', type=str, default="(1024,8192)", help='Number of epochs to train.')
    parser.add_argument('--iters', type=int, default=5, help='Number of epochs to train.')
    parser.add_argument('--op', type=str, default="", help='Number of epochs to train.')
    args = parser.parse_args()
     
    global op_mapping
 
    if args.op in op_mapping:
        torch.manual_seed(1)
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()
        local_rank=int(os.environ['LOCAL_RANK'])
        torch.cuda.set_device(local_rank)
        device = torch.device(dev_type,local_rank)
        shape=eval(args.shape)
        duration,passed=op_mapping[args.op](shape,device,rank,world_size,args.iters)
        time.sleep(0.1*rank)
        print("rank:{} op:{} shape:{} iters:{} mean(us):{:.3f} passed:{}".format(rank,args.op,shape,args.iters,np.mean(duration[len(duration)//2:]),passed))
 
    dist.destroy_process_group()
         
if __name__=='__main__':
    main()
         
EOF
 
export NCCL_DEBUG=error
export NCCL_SOCKET_IFNAME=ens8
export NCCL_IB_DISABLE=1  
torchrun -m --nnodes=1 --nproc_per_node=4 ccl_benchmark --op=all_gather --shape="(1024,4096)" --iters=5
torchrun -m --nnodes=1 --nproc_per_node=4 ccl_benchmark --op=scatter --shape="(1024,4096)" --iters=5
torchrun -m --nnodes=1 --nproc_per_node=4 ccl_benchmark --op=gather --shape="(1024,4096)" --iters=5
torchrun -m --nnodes=1 --nproc_per_node=4 ccl_benchmark --op=reduce --shape="(1024,4096)" --iters=5
torchrun -m --nnodes=1 --nproc_per_node=4 ccl_benchmark --op=broadcast --shape="(1024,4096)" --iters=5
torchrun -m --nnodes=1 --nproc_per_node=4 ccl_benchmark --op=p2p --shape="(1024,4096)" --iters=5
torchrun -m --nnodes=1 --nproc_per_node=4 ccl_benchmark --op=all_reduce --shape="(1024,4096)" --iters=5
torchrun -m --nnodes=1 --nproc_per_node=4 ccl_benchmark --op=reduce_scatter --shape="(1024,4096)" --iters=5
04-13 06:21