from videosys.utils.logging import logger PAB_MANAGER = None class PABConfig: def __init__( self, steps: int, cross_broadcast: bool = False, cross_threshold: list = None, cross_range: int = None, spatial_broadcast: bool = False, spatial_threshold: list = None, spatial_range: int = None, temporal_broadcast: bool = False, temporal_threshold: list = None, temporal_range: int = None, mlp_broadcast: bool = False, mlp_spatial_broadcast_config: dict = None, mlp_temporal_broadcast_config: dict = None, ): self.steps = steps self.cross_broadcast = cross_broadcast self.cross_threshold = cross_threshold self.cross_range = cross_range self.spatial_broadcast = spatial_broadcast self.spatial_threshold = spatial_threshold self.spatial_range = spatial_range self.temporal_broadcast = temporal_broadcast self.temporal_threshold = temporal_threshold self.temporal_range = temporal_range self.mlp_broadcast = mlp_broadcast self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config self.mlp_temporal_outputs = {} self.mlp_spatial_outputs = {} class PABManager: def __init__(self, config: PABConfig): self.config: PABConfig = config init_prompt = f"Init Pyramid Attention Broadcast. steps: {config.steps}." init_prompt += f" spatial broadcast: {config.spatial_broadcast}, spatial range: {config.spatial_range}, spatial threshold: {config.spatial_threshold}." init_prompt += f" temporal broadcast: {config.temporal_broadcast}, temporal range: {config.temporal_range}, temporal_threshold: {config.temporal_threshold}." init_prompt += f" cross broadcast: {config.cross_broadcast}, cross range: {config.cross_range}, cross threshold: {config.cross_threshold}." init_prompt += f" mlp broadcast: {config.mlp_broadcast}." logger.info(init_prompt) def if_broadcast_cross(self, timestep: int, count: int): if ( self.config.cross_broadcast and (timestep is not None) and (count % self.config.cross_range != 0) and (self.config.cross_threshold[0] < timestep < self.config.cross_threshold[1]) ): flag = True else: flag = False count = (count + 1) % self.config.steps return flag, count def if_broadcast_temporal(self, timestep: int, count: int): if ( self.config.temporal_broadcast and (timestep is not None) and (count % self.config.temporal_range != 0) and (self.config.temporal_threshold[0] < timestep < self.config.temporal_threshold[1]) ): flag = True else: flag = False count = (count + 1) % self.config.steps return flag, count def if_broadcast_spatial(self, timestep: int, count: int, block_idx: int): if ( self.config.spatial_broadcast and (timestep is not None) and (count % self.config.spatial_range != 0) and (self.config.spatial_threshold[0] < timestep < self.config.spatial_threshold[1]) ): flag = True else: flag = False count = (count + 1) % self.config.steps return flag, count @staticmethod def _is_t_in_skip_config(all_timesteps, timestep, config): is_t_in_skip_config = False skip_range = None for key in config: if key not in all_timesteps: continue index = all_timesteps.index(key) skip_range = all_timesteps[index : index + 1 + int(config[key]["skip_count"])] if timestep in skip_range: is_t_in_skip_config = True skip_range = [all_timesteps[index], all_timesteps[index + int(config[key]["skip_count"])]] break return is_t_in_skip_config, skip_range def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False): if not self.config.mlp_broadcast: return False, None, False, None if is_temporal: cur_config = self.config.mlp_temporal_broadcast_config else: cur_config = self.config.mlp_spatial_broadcast_config is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config) next_flag = False if ( self.config.mlp_broadcast and (timestep is not None) and (timestep in cur_config) and (block_idx in cur_config[timestep]["block"]) ): flag = False next_flag = True count = count + 1 elif ( self.config.mlp_broadcast and (timestep is not None) and (is_t_in_skip_config) and (block_idx in cur_config[skip_range[0]]["block"]) ): flag = True count = 0 else: flag = False return flag, count, next_flag, skip_range def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False): if is_temporal: self.config.mlp_temporal_outputs[(timestep, block_idx)] = ff_output else: self.config.mlp_spatial_outputs[(timestep, block_idx)] = ff_output def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False): skip_start_t = skip_range[0] if is_temporal: skip_output = ( self.config.mlp_temporal_outputs.get((skip_start_t, block_idx), None) if self.config.mlp_temporal_outputs is not None else None ) else: skip_output = ( self.config.mlp_spatial_outputs.get((skip_start_t, block_idx), None) if self.config.mlp_spatial_outputs is not None else None ) if skip_output is not None: if timestep == skip_range[-1]: # TODO: save memory if is_temporal: del self.config.mlp_temporal_outputs[(skip_start_t, block_idx)] else: del self.config.mlp_spatial_outputs[(skip_start_t, block_idx)] else: raise ValueError( f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}" ) return skip_output def get_spatial_mlp_outputs(self): return self.config.mlp_spatial_outputs def get_temporal_mlp_outputs(self): return self.config.mlp_temporal_outputs def set_pab_manager(config: PABConfig): global PAB_MANAGER PAB_MANAGER = PABManager(config) def enable_pab(): if PAB_MANAGER is None: return False return ( PAB_MANAGER.config.cross_broadcast or PAB_MANAGER.config.spatial_broadcast or PAB_MANAGER.config.temporal_broadcast ) def update_steps(steps: int): if PAB_MANAGER is not None: PAB_MANAGER.config.steps = steps def if_broadcast_cross(timestep: int, count: int): if not enable_pab(): return False, count return PAB_MANAGER.if_broadcast_cross(timestep, count) def if_broadcast_temporal(timestep: int, count: int): if not enable_pab(): return False, count return PAB_MANAGER.if_broadcast_temporal(timestep, count) def if_broadcast_spatial(timestep: int, count: int, block_idx: int): if not enable_pab(): return False, count return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx) def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False): if not enable_pab(): return False, count return PAB_MANAGER.if_skip_mlp(timestep, count, block_idx, all_timesteps, is_temporal) def save_mlp_output(timestep: int, block_idx: int, ff_output, is_temporal=False): return PAB_MANAGER.save_skip_output(timestep, block_idx, ff_output, is_temporal) def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False): return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal)