build_model = None ZeroRedundancyOptimizer = None GradScaler = None laion_loader = None pile_loader = None autocast = None zero_embedding_gradient = None torch = None lr_scheduler = None get_cosine_schedule_with_warmup = None ddp_model = build_model(...) optimizer = ZeroRedundancyOptimizer(...) lr_scheduler = get_cosine_schedule_with_warmup(...) scaler = GradScaler() for batch_laion, batch_pile in zip(laion_loader, pile_loader): with autocast(): loss_laion = ddp_model(batch_laion) scaler.scale(loss_laion).backward() with autocast(): loss_pile = ddp_model(batch_pile) scaler.scale(loss_pile).backward() zero_embedding_gradient() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(ddp_model.parameters(), 1.0) scaler.step(optimizer) scaler.update() lr_scheduler.step() optimizer.zero_grad()