deepspeed-model-memory-usage / src /parallelism_utils.py
andstor's picture
Update src/parallelism_utils.py
6ca6353 verified
raw
history blame contribute delete
No virus
6.53 kB
import torch
def get_precision_fac(precision: str):
if precision == "mixed":
return 2
elif precision == "single":
return 4
else:
raise ValueError("Precision must be either 'mixed' or 'single'")
def get_params_fac(model_dtype: str):
if model_dtype == "float16":
return 2
elif model_dtype == "float32":
return 4
else:
raise ValueError("Model dtype must be either torch.float16 or torch.float32")
####################### Zero Redundancy Optimizer (ZeRO) #######################
VARIANCE_FACTOR = 4
MOMENTUM_FACTOR = 4
OPTIMIZER_FACTOR = VARIANCE_FACTOR + MOMENTUM_FACTOR # Adam optimizer
FP32_GRADS_FACTOR = 4
FP32_PARAM_FACTOR = 4
MASTER_PARAMS_FACTOR = FP32_PARAM_FACTOR
def estimate_zero1_model_states_mem_needs(total_params,
num_gpus_per_node=1,
num_nodes=1,
cpu_offload=True,
additional_buffer_factor=1.5,
precision="mixed",
model_dtype = "float16",
):
total_gpus = num_nodes * num_gpus_per_node
precision_fac = get_precision_fac(precision)
params_fac = get_params_fac(model_dtype)
if cpu_offload:
gpu_mem = (precision_fac * total_params) # + (grads_fac * total_params)
cpu_mem = total_params * max(params_fac * total_gpus, (MASTER_PARAMS_FACTOR + OPTIMIZER_FACTOR + FP32_GRADS_FACTOR)) * additional_buffer_factor
else:
if precision == "mixed":
gpu_mem = (precision_fac * total_params) + (FP32_GRADS_FACTOR * total_params) + int((OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * total_params / total_gpus)
else:
gpu_mem = (precision_fac * total_params) + (FP32_GRADS_FACTOR * total_params) + int(OPTIMIZER_FACTOR * total_params / total_gpus)
cpu_mem = total_params * params_fac * num_gpus_per_node * additional_buffer_factor
return int(cpu_mem), int(gpu_mem)
def estimate_zero2_model_states_mem_needs(total_params,
num_gpus_per_node=1,
num_nodes=1,
cpu_offload=True,
additional_buffer_factor=1.5,
precision="mixed",
model_dtype = "float16",
):
total_gpus = num_nodes * num_gpus_per_node
precision_fac = get_precision_fac(precision)
params_fac = get_params_fac(model_dtype)
if cpu_offload:
gpu_mem = precision_fac * total_params # Negligible memory usage for partitioned gradients
cpu_mem = total_params * max(params_fac * total_gpus, (MASTER_PARAMS_FACTOR + OPTIMIZER_FACTOR + FP32_GRADS_FACTOR)) * additional_buffer_factor
else:
if precision == "mixed":
gpu_mem = precision_fac * total_params + int((FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * total_params / total_gpus)
else:
gpu_mem = precision_fac * total_params + int((FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * total_params / total_gpus)
cpu_mem = params_fac * total_params * num_gpus_per_node * additional_buffer_factor
return int(cpu_mem), int(gpu_mem)
def estimate_zero3_model_states_mem_needs(total_params,
largest_layer_params,
num_gpus_per_node=1,
num_nodes=1,
cpu_offload=True,
cpu_offload_params=True,
zero_init=True,
additional_buffer_factor=1.5,
precision="mixed",
model_dtype = "float16",
):
total_gpus = num_nodes * num_gpus_per_node
gpus_factor = 1 / num_nodes
precision_fac = get_precision_fac(precision)
params_fac = get_params_fac(model_dtype)
grads_fac = precision_fac
largest_layer_memory = (grads_fac + precision_fac) * largest_layer_params
if cpu_offload:
if cpu_offload_params:
gpu_mem = largest_layer_memory
if zero_init:
cpu_mem = total_params * (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + params_fac) * gpus_factor * additional_buffer_factor
else:
cpu_mem = total_params * max(params_fac * num_gpus_per_node, (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + params_fac) * gpus_factor) * additional_buffer_factor
else:
gpu_mem = max(
largest_layer_memory,
int((precision_fac) * total_params / total_gpus) # No need for gradients: ZeRO-Offload can transfer these gradients for each parameter individually or in small groups to the CPU memory immediately after they are computed
)
if zero_init:
cpu_mem = total_params * (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * gpus_factor * additional_buffer_factor
else:
cpu_mem = total_params * max(params_fac * num_gpus_per_node, (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * gpus_factor) * additional_buffer_factor
else:
if precision == "mixed":
gpu_mem = max(
int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * largest_layer_params),
int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * total_params / total_gpus)
)
else:
gpu_mem = max(
int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * largest_layer_params),
int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * total_params / total_gpus)
)
if zero_init:
cpu_mem = largest_layer_params * params_fac * num_gpus_per_node * additional_buffer_factor
else:
cpu_mem = total_params * params_fac * num_gpus_per_node * additional_buffer_factor
return int(cpu_mem), int(gpu_mem), largest_layer_memory