import os import torch import torch.distributed as dist from .utils import is_dist_avail_and_initialized, get_rank SEQ_PARALLEL_GROUP = None SEQ_PARALLEL_SIZE = None SEQ_PARALLEL_PROC_NUM = None # using how many process for sequence parallel SYNC_INPUT_GROUP = None SYNC_INPUT_SIZE = None def is_sequence_parallel_initialized(): if SEQ_PARALLEL_GROUP is None: return False else: return True def init_sequence_parallel_group(args): global SEQ_PARALLEL_GROUP global SEQ_PARALLEL_SIZE global SEQ_PARALLEL_PROC_NUM assert SEQ_PARALLEL_GROUP is None, "sequence parallel group is already initialized" assert is_dist_avail_and_initialized(), "The pytorch distributed should be initialized" SEQ_PARALLEL_SIZE = args.sp_group_size print(f"Setting the Sequence Parallel Size {SEQ_PARALLEL_SIZE}") rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() if args.sp_proc_num == -1: SEQ_PARALLEL_PROC_NUM = world_size else: SEQ_PARALLEL_PROC_NUM = args.sp_proc_num assert SEQ_PARALLEL_PROC_NUM % SEQ_PARALLEL_SIZE == 0, "The process needs to be evenly divided" for i in range(0, SEQ_PARALLEL_PROC_NUM, SEQ_PARALLEL_SIZE): ranks = list(range(i, i + SEQ_PARALLEL_SIZE)) group = torch.distributed.new_group(ranks) if rank in ranks: SEQ_PARALLEL_GROUP = group break def init_sync_input_group(args): global SYNC_INPUT_GROUP global SYNC_INPUT_SIZE assert SYNC_INPUT_GROUP is None, "parallel group is already initialized" assert is_dist_avail_and_initialized(), "The pytorch distributed should be initialized" SYNC_INPUT_SIZE = args.max_frames rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() for i in range(0, world_size, SYNC_INPUT_SIZE): ranks = list(range(i, i + SYNC_INPUT_SIZE)) group = torch.distributed.new_group(ranks) if rank in ranks: SYNC_INPUT_GROUP = group break def get_sequence_parallel_group(): assert SEQ_PARALLEL_GROUP is not None, "sequence parallel group is not initialized" return SEQ_PARALLEL_GROUP def get_sync_input_group(): return SYNC_INPUT_GROUP def get_sequence_parallel_world_size(): assert SEQ_PARALLEL_SIZE is not None, "sequence parallel size is not initialized" return SEQ_PARALLEL_SIZE def get_sequence_parallel_rank(): assert SEQ_PARALLEL_SIZE is not None, "sequence parallel size is not initialized" rank = get_rank() cp_rank = rank % SEQ_PARALLEL_SIZE return cp_rank def get_sequence_parallel_group_rank(): assert SEQ_PARALLEL_SIZE is not None, "sequence parallel size is not initialized" rank = get_rank() cp_group_rank = rank // SEQ_PARALLEL_SIZE return cp_group_rank def get_sequence_parallel_proc_num(): return SEQ_PARALLEL_PROC_NUM