zhiweili commited on
Commit
ec4df84
1 Parent(s): fd86234

add missing file

Browse files
checkpoints/selfie_multiclass_256x256.tflite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6748b1253a99067ef71f7e26ca71096cd449baefa8f101900ea23016507e0e0
3
+ size 16371837
config.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ml_collections import config_dict
2
+ import yaml
3
+ from diffusers.schedulers import (
4
+ DDIMScheduler,
5
+ EulerAncestralDiscreteScheduler,
6
+ EulerDiscreteScheduler,
7
+ DDPMScheduler,
8
+ )
9
+ from inversion_utils import (
10
+ deterministic_ddim_step,
11
+ deterministic_ddpm_step,
12
+ deterministic_euler_step,
13
+ deterministic_non_ancestral_euler_step,
14
+ )
15
+
16
+ BREAKDOWNS = ["x_t_c_hat", "x_t_hat_c", "no_breakdown", "x_t_hat_c_with_zeros"]
17
+ SCHEDULERS = ["ddpm", "ddim", "euler", "euler_non_ancestral"]
18
+ MODELS = [
19
+ "stabilityai/sdxl-turbo",
20
+ "stabilityai/stable-diffusion-xl-base-1.0",
21
+ "CompVis/stable-diffusion-v1-4",
22
+ ]
23
+
24
+ def get_num_steps_actual(cfg):
25
+ return (
26
+ cfg.num_steps_inversion
27
+ - cfg.step_start
28
+ + (1 if cfg.clean_step_timestep > 0 else 0)
29
+ if cfg.timesteps is None
30
+ else len(cfg.timesteps) + (1 if cfg.clean_step_timestep > 0 else 0)
31
+ )
32
+
33
+
34
+ def get_config(args):
35
+ if args.config_from_file and args.config_from_file != "":
36
+ with open(args.config_from_file, "r") as f:
37
+ cfg = config_dict.ConfigDict(yaml.safe_load(f))
38
+
39
+ num_steps_actual = get_num_steps_actual(cfg)
40
+
41
+ else:
42
+ cfg = config_dict.ConfigDict()
43
+
44
+ cfg.seed = 2
45
+ cfg.self_r = 0.5
46
+ cfg.cross_r = 0.9
47
+ cfg.eta = 1
48
+ cfg.scheduler_type = SCHEDULERS[0]
49
+
50
+ cfg.num_steps_inversion = 50 # timesteps: 999, 799, 599, 399, 199
51
+ cfg.step_start = 20
52
+ cfg.timesteps = None
53
+ cfg.noise_timesteps = None
54
+ num_steps_actual = get_num_steps_actual(cfg)
55
+ cfg.ws1 = [2] * num_steps_actual
56
+ cfg.ws2 = [1] * num_steps_actual
57
+ cfg.real_cfg_scale = 0
58
+ cfg.real_cfg_scale_save = 0
59
+ cfg.breakdown = BREAKDOWNS[1]
60
+ cfg.noise_shift_delta = 1
61
+ cfg.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5]
62
+
63
+ cfg.clean_step_timestep = 0
64
+
65
+ cfg.model = MODELS[1]
66
+
67
+ if cfg.scheduler_type == "ddim":
68
+ cfg.scheduler_class = DDIMScheduler
69
+ cfg.step_function = deterministic_ddim_step
70
+ elif cfg.scheduler_type == "ddpm":
71
+ cfg.scheduler_class = DDPMScheduler
72
+ cfg.step_function = deterministic_ddpm_step
73
+ elif cfg.scheduler_type == "euler":
74
+ cfg.scheduler_class = EulerAncestralDiscreteScheduler
75
+ cfg.step_function = deterministic_euler_step
76
+ elif cfg.scheduler_type == "euler_non_ancestral":
77
+ cfg.scheduler_class = EulerDiscreteScheduler
78
+ cfg.step_function = deterministic_non_ancestral_euler_step
79
+ else:
80
+ raise ValueError(f"Unknown scheduler type: {cfg.scheduler_type}")
81
+
82
+ with cfg.ignore_type():
83
+ if isinstance(cfg.max_norm_zs, (int, float)):
84
+ cfg.max_norm_zs = [cfg.max_norm_zs] * num_steps_actual
85
+
86
+ if isinstance(cfg.ws1, (int, float)):
87
+ cfg.ws1 = [cfg.ws1] * num_steps_actual
88
+
89
+ if isinstance(cfg.ws2, (int, float)):
90
+ cfg.ws2 = [cfg.ws2] * num_steps_actual
91
+
92
+ if not hasattr(cfg, "update_eta"):
93
+ cfg.update_eta = False
94
+
95
+ if not hasattr(cfg, "save_timesteps"):
96
+ cfg.save_timesteps = None
97
+
98
+ if not hasattr(cfg, "scheduler_timesteps"):
99
+ cfg.scheduler_timesteps = None
100
+
101
+ assert (
102
+ cfg.scheduler_type == "ddpm" or cfg.timesteps is None
103
+ ), "timesteps must be None for ddim/euler"
104
+
105
+ cfg.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5]
106
+ assert (
107
+ len(cfg.max_norm_zs) == num_steps_actual
108
+ ), f"len(cfg.max_norm_zs) ({len(cfg.max_norm_zs)}) != num_steps_actual ({num_steps_actual})"
109
+
110
+ assert (
111
+ len(cfg.ws1) == num_steps_actual
112
+ ), f"len(cfg.ws1) ({len(cfg.ws1)}) != num_steps_actual ({num_steps_actual})"
113
+
114
+ assert (
115
+ len(cfg.ws2) == num_steps_actual
116
+ ), f"len(cfg.ws2) ({len(cfg.ws2)}) != num_steps_actual ({num_steps_actual})"
117
+
118
+ assert cfg.noise_timesteps is None or len(cfg.noise_timesteps) == (
119
+ num_steps_actual - (1 if cfg.clean_step_timestep > 0 else 0)
120
+ ), f"len(cfg.noise_timesteps) ({len(cfg.noise_timesteps)}) != num_steps_actual ({num_steps_actual})"
121
+
122
+ assert cfg.save_timesteps is None or len(cfg.save_timesteps) == (
123
+ num_steps_actual - (1 if cfg.clean_step_timestep > 0 else 0)
124
+ ), f"len(cfg.save_timesteps) ({len(cfg.save_timesteps)}) != num_steps_actual ({num_steps_actual})"
125
+
126
+ return cfg
127
+
128
+
129
+ def get_config_name(config, args):
130
+ if args.folder_name is not None and args.folder_name != "":
131
+ return args.folder_name
132
+ timesteps_str = (
133
+ f"step_start {config.step_start}"
134
+ if config.timesteps is None
135
+ else f"timesteps {config.timesteps}"
136
+ )
137
+ return f"""\
138
+ ws1 {config.ws1[0]} ws2 {config.ws2[0]} real_cfg_scale {config.real_cfg_scale} {timesteps_str} \
139
+ real_cfg_scale_save {config.real_cfg_scale_save} seed {config.seed} max_norm_zs {config.max_norm_zs[-1]} noise_shift_delta {config.noise_shift_delta} \
140
+ scheduler_type {config.scheduler_type} fp16 {args.fp16}\
141
+ """
inversion_utils.py ADDED
@@ -0,0 +1,794 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import PIL
4
+
5
+ from typing import List, Optional, Union
6
+ from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
7
+ from PIL import Image
8
+ from diffusers.utils import logging
9
+
10
+ VECTOR_DATA_FOLDER = "vector_data"
11
+ VECTOR_DATA_DICT = "vector_data"
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+ def get_ddpm_inversion_scheduler(
16
+ scheduler,
17
+ step_function,
18
+ config,
19
+ timesteps,
20
+ save_timesteps,
21
+ latents,
22
+ x_ts,
23
+ x_ts_c_hat,
24
+ save_intermediate_results,
25
+ pipe,
26
+ x_0,
27
+ v1s_images,
28
+ v2s_images,
29
+ deltas_images,
30
+ v1_x0s,
31
+ v2_x0s,
32
+ deltas_x0s,
33
+ folder_name,
34
+ image_name,
35
+ time_measure_n,
36
+ ):
37
+ def step(
38
+ model_output: torch.FloatTensor,
39
+ timestep: int,
40
+ sample: torch.FloatTensor,
41
+ eta: float = 0.0,
42
+ use_clipped_model_output: bool = False,
43
+ generator=None,
44
+ variance_noise: Optional[torch.FloatTensor] = None,
45
+ return_dict: bool = True,
46
+ ):
47
+ # if scheduler.is_save:
48
+ # start = timer()
49
+ res_inv = step_save_latents(
50
+ scheduler,
51
+ model_output[:1, :, :, :],
52
+ timestep,
53
+ sample[:1, :, :, :],
54
+ eta,
55
+ use_clipped_model_output,
56
+ generator,
57
+ variance_noise,
58
+ return_dict,
59
+ )
60
+ # end = timer()
61
+ # print(f"Run Time Inv: {end - start}")
62
+
63
+ res_inf = step_use_latents(
64
+ scheduler,
65
+ model_output[1:, :, :, :],
66
+ timestep,
67
+ sample[1:, :, :, :],
68
+ eta,
69
+ use_clipped_model_output,
70
+ generator,
71
+ variance_noise,
72
+ return_dict,
73
+ )
74
+ # res = res_inv
75
+ res = (torch.cat((res_inv[0], res_inf[0]), dim=0),)
76
+ return res
77
+ # return res
78
+
79
+ scheduler.step_function = step_function
80
+ scheduler.is_save = True
81
+ scheduler._timesteps = timesteps
82
+ scheduler._save_timesteps = save_timesteps if save_timesteps else timesteps
83
+ scheduler._config = config
84
+ scheduler.latents = latents
85
+ scheduler.x_ts = x_ts
86
+ scheduler.x_ts_c_hat = x_ts_c_hat
87
+ scheduler.step = step
88
+ scheduler.save_intermediate_results = save_intermediate_results
89
+ scheduler.pipe = pipe
90
+ scheduler.v1s_images = v1s_images
91
+ scheduler.v2s_images = v2s_images
92
+ scheduler.deltas_images = deltas_images
93
+ scheduler.v1_x0s = v1_x0s
94
+ scheduler.v2_x0s = v2_x0s
95
+ scheduler.deltas_x0s = deltas_x0s
96
+ scheduler.clean_step_run = False
97
+ scheduler.x_0s = create_xts(
98
+ config.noise_shift_delta,
99
+ config.noise_timesteps,
100
+ config.clean_step_timestep,
101
+ None,
102
+ pipe.scheduler,
103
+ timesteps,
104
+ x_0,
105
+ no_add_noise=True,
106
+ )
107
+ scheduler.folder_name = folder_name
108
+ scheduler.image_name = image_name
109
+ scheduler.p_to_p = False
110
+ scheduler.p_to_p_replace = False
111
+ scheduler.time_measure_n = time_measure_n
112
+ return scheduler
113
+
114
+ def step_save_latents(
115
+ self,
116
+ model_output: torch.FloatTensor,
117
+ timestep: int,
118
+ sample: torch.FloatTensor,
119
+ eta: float = 0.0,
120
+ use_clipped_model_output: bool = False,
121
+ generator=None,
122
+ variance_noise: Optional[torch.FloatTensor] = None,
123
+ return_dict: bool = True,
124
+ ):
125
+ # print(self._save_timesteps)
126
+ # timestep_index = map_timpstep_to_index[timestep]
127
+ # timestep_index = ((self._save_timesteps == timestep).nonzero(as_tuple=True)[0]).item()
128
+ timestep_index = self._save_timesteps.index(timestep) if not self.clean_step_run else -1
129
+ next_timestep_index = timestep_index + 1 if not self.clean_step_run else -1
130
+ u_hat_t = self.step_function(
131
+ model_output=model_output,
132
+ timestep=timestep,
133
+ sample=sample,
134
+ eta=eta,
135
+ use_clipped_model_output=use_clipped_model_output,
136
+ generator=generator,
137
+ variance_noise=variance_noise,
138
+ return_dict=False,
139
+ scheduler=self,
140
+ )
141
+
142
+ x_t_minus_1 = self.x_ts[next_timestep_index]
143
+ self.x_ts_c_hat.append(u_hat_t)
144
+
145
+ z_t = x_t_minus_1 - u_hat_t
146
+ self.latents.append(z_t)
147
+ z_t, _ = normalize(z_t, timestep_index, self._config.max_norm_zs)
148
+
149
+ x_t_minus_1_predicted = u_hat_t + z_t
150
+
151
+ if not return_dict:
152
+ return (x_t_minus_1_predicted,)
153
+
154
+ return DDIMSchedulerOutput(prev_sample=x_t_minus_1, pred_original_sample=None)
155
+
156
+ def step_use_latents(
157
+ self,
158
+ model_output: torch.FloatTensor,
159
+ timestep: int,
160
+ sample: torch.FloatTensor,
161
+ eta: float = 0.0,
162
+ use_clipped_model_output: bool = False,
163
+ generator=None,
164
+ variance_noise: Optional[torch.FloatTensor] = None,
165
+ return_dict: bool = True,
166
+ ):
167
+ # timestep_index = ((self._save_timesteps == timestep).nonzero(as_tuple=True)[0]).item()
168
+ timestep_index = self._timesteps.index(timestep) if not self.clean_step_run else -1
169
+ next_timestep_index = (
170
+ timestep_index + 1 if not self.clean_step_run else -1
171
+ )
172
+ z_t = self.latents[next_timestep_index] # + 1 because latents[0] is X_T
173
+
174
+ _, normalize_coefficient = normalize(
175
+ z_t[0] if self._config.breakdown == "x_t_hat_c_with_zeros" else z_t,
176
+ timestep_index,
177
+ self._config.max_norm_zs,
178
+ )
179
+
180
+ if normalize_coefficient == 0:
181
+ eta = 0
182
+
183
+ # eta = normalize_coefficient
184
+
185
+ x_t_hat_c_hat = self.step_function(
186
+ model_output=model_output,
187
+ timestep=timestep,
188
+ sample=sample,
189
+ eta=eta,
190
+ use_clipped_model_output=use_clipped_model_output,
191
+ generator=generator,
192
+ variance_noise=variance_noise,
193
+ return_dict=False,
194
+ scheduler=self,
195
+ )
196
+
197
+ w1 = self._config.ws1[timestep_index]
198
+ w2 = self._config.ws2[timestep_index]
199
+
200
+ x_t_minus_1_exact = self.x_ts[next_timestep_index]
201
+ x_t_minus_1_exact = x_t_minus_1_exact.expand_as(x_t_hat_c_hat)
202
+
203
+ x_t_c_hat: torch.Tensor = self.x_ts_c_hat[next_timestep_index]
204
+ if self._config.breakdown == "x_t_c_hat":
205
+ raise NotImplementedError("breakdown x_t_c_hat not implemented yet")
206
+
207
+ # x_t_c_hat = x_t_c_hat.expand_as(x_t_hat_c_hat)
208
+ x_t_c = x_t_c_hat[0].expand_as(x_t_hat_c_hat)
209
+
210
+ # if self._config.breakdown == "x_t_c_hat":
211
+ # v1 = x_t_hat_c_hat - x_t_c_hat
212
+ # v2 = x_t_c_hat - x_t_c
213
+ if (
214
+ self._config.breakdown == "x_t_hat_c"
215
+ or self._config.breakdown == "x_t_hat_c_with_zeros"
216
+ ):
217
+ zero_index_reconstruction = 1 if not self.time_measure_n else 0
218
+ edit_prompts_num = (
219
+ (model_output.size(0) - zero_index_reconstruction) // 3
220
+ if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p
221
+ else (model_output.size(0) - zero_index_reconstruction) // 2
222
+ )
223
+ x_t_hat_c_indices = (zero_index_reconstruction, edit_prompts_num + zero_index_reconstruction)
224
+ edit_images_indices = (
225
+ edit_prompts_num + zero_index_reconstruction,
226
+ (
227
+ model_output.size(0)
228
+ if self._config.breakdown == "x_t_hat_c"
229
+ else zero_index_reconstruction + 2 * edit_prompts_num
230
+ ),
231
+ )
232
+ x_t_hat_c = torch.zeros_like(x_t_hat_c_hat)
233
+ x_t_hat_c[edit_images_indices[0] : edit_images_indices[1]] = x_t_hat_c_hat[
234
+ x_t_hat_c_indices[0] : x_t_hat_c_indices[1]
235
+ ]
236
+ v1 = x_t_hat_c_hat - x_t_hat_c
237
+ v2 = x_t_hat_c - normalize_coefficient * x_t_c
238
+ if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p:
239
+ path = os.path.join(
240
+ self.folder_name,
241
+ VECTOR_DATA_FOLDER,
242
+ self.image_name,
243
+ )
244
+ if not hasattr(self, VECTOR_DATA_DICT):
245
+ os.makedirs(path, exist_ok=True)
246
+ self.vector_data = dict()
247
+
248
+ x_t_0 = x_t_c_hat[1]
249
+ empty_prompt_indices = (1 + 2 * edit_prompts_num, 1 + 3 * edit_prompts_num)
250
+ x_t_hat_0 = x_t_hat_c_hat[empty_prompt_indices[0] : empty_prompt_indices[1]]
251
+
252
+ self.vector_data[timestep.item()] = dict()
253
+ self.vector_data[timestep.item()]["x_t_hat_c"] = x_t_hat_c[
254
+ edit_images_indices[0] : edit_images_indices[1]
255
+ ]
256
+ self.vector_data[timestep.item()]["x_t_hat_0"] = x_t_hat_0
257
+ self.vector_data[timestep.item()]["x_t_c"] = x_t_c[0].expand_as(x_t_hat_0)
258
+ self.vector_data[timestep.item()]["x_t_0"] = x_t_0.expand_as(x_t_hat_0)
259
+ self.vector_data[timestep.item()]["x_t_hat_c_hat"] = x_t_hat_c_hat[
260
+ edit_images_indices[0] : edit_images_indices[1]
261
+ ]
262
+ self.vector_data[timestep.item()]["x_t_minus_1_noisy"] = x_t_minus_1_exact[
263
+ 0
264
+ ].expand_as(x_t_hat_0)
265
+ self.vector_data[timestep.item()]["x_t_minus_1_clean"] = self.x_0s[
266
+ next_timestep_index
267
+ ].expand_as(x_t_hat_0)
268
+
269
+ else: # no breakdown
270
+ v1 = x_t_hat_c_hat - normalize_coefficient * x_t_c
271
+ v2 = 0
272
+
273
+ if self.save_intermediate_results and not self.p_to_p:
274
+ delta = v1 + v2
275
+ v1_plus_x0 = self.x_0s[next_timestep_index] + v1
276
+ v2_plus_x0 = self.x_0s[next_timestep_index] + v2
277
+ delta_plus_x0 = self.x_0s[next_timestep_index] + delta
278
+
279
+ v1_images = decode_latents(v1, self.pipe)
280
+ self.v1s_images.append(v1_images)
281
+ v2_images = (
282
+ decode_latents(v2, self.pipe)
283
+ if self._config.breakdown != "no_breakdown"
284
+ else [PIL.Image.new("RGB", (1, 1))]
285
+ )
286
+ self.v2s_images.append(v2_images)
287
+ delta_images = decode_latents(delta, self.pipe)
288
+ self.deltas_images.append(delta_images)
289
+ v1_plus_x0_images = decode_latents(v1_plus_x0, self.pipe)
290
+ self.v1_x0s.append(v1_plus_x0_images)
291
+ v2_plus_x0_images = (
292
+ decode_latents(v2_plus_x0, self.pipe)
293
+ if self._config.breakdown != "no_breakdown"
294
+ else [PIL.Image.new("RGB", (1, 1))]
295
+ )
296
+ self.v2_x0s.append(v2_plus_x0_images)
297
+ delta_plus_x0_images = decode_latents(delta_plus_x0, self.pipe)
298
+ self.deltas_x0s.append(delta_plus_x0_images)
299
+
300
+ # print(f"v1 norm: {torch.norm(v1, dim=0).mean()}")
301
+ # if self._config.breakdown != "no_breakdown":
302
+ # print(f"v2 norm: {torch.norm(v2, dim=0).mean()}")
303
+ # print(f"v sum norm: {torch.norm(v1 + v2, dim=0).mean()}")
304
+
305
+ x_t_minus_1 = normalize_coefficient * x_t_minus_1_exact + w1 * v1 + w2 * v2
306
+
307
+ if (
308
+ self._config.breakdown == "x_t_hat_c"
309
+ or self._config.breakdown == "x_t_hat_c_with_zeros"
310
+ ):
311
+ x_t_minus_1[x_t_hat_c_indices[0] : x_t_hat_c_indices[1]] = x_t_minus_1[
312
+ edit_images_indices[0] : edit_images_indices[1]
313
+ ] # update x_t_hat_c to be x_t_hat_c_hat
314
+ if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p:
315
+ x_t_minus_1[empty_prompt_indices[0] : empty_prompt_indices[1]] = (
316
+ x_t_minus_1[edit_images_indices[0] : edit_images_indices[1]]
317
+ )
318
+ self.vector_data[timestep.item()]["x_t_minus_1_edited"] = x_t_minus_1[
319
+ edit_images_indices[0] : edit_images_indices[1]
320
+ ]
321
+ if timestep == self._timesteps[-1]:
322
+ torch.save(
323
+ self.vector_data,
324
+ os.path.join(
325
+ path,
326
+ f"{VECTOR_DATA_DICT}.pt",
327
+ ),
328
+ )
329
+ # p_to_p_force_perfect_reconstruction
330
+ if not self.time_measure_n:
331
+ x_t_minus_1[0] = x_t_minus_1_exact[0]
332
+
333
+ if not return_dict:
334
+ return (x_t_minus_1,)
335
+
336
+ return DDIMSchedulerOutput(
337
+ prev_sample=x_t_minus_1,
338
+ pred_original_sample=None,
339
+ )
340
+
341
+ def create_xts(
342
+ noise_shift_delta,
343
+ noise_timesteps,
344
+ clean_step_timestep,
345
+ generator,
346
+ scheduler,
347
+ timesteps,
348
+ x_0,
349
+ no_add_noise=False,
350
+ ):
351
+ if noise_timesteps is None:
352
+ noising_delta = noise_shift_delta * (timesteps[0] - timesteps[1])
353
+ noise_timesteps = [timestep - int(noising_delta) for timestep in timesteps]
354
+
355
+ first_x_0_idx = len(noise_timesteps)
356
+ for i in range(len(noise_timesteps)):
357
+ if noise_timesteps[i] <= 0:
358
+ first_x_0_idx = i
359
+ break
360
+
361
+ noise_timesteps = noise_timesteps[:first_x_0_idx]
362
+
363
+ x_0_expanded = x_0.expand(len(noise_timesteps), -1, -1, -1)
364
+ noise = (
365
+ torch.randn(x_0_expanded.size(), generator=generator, device="cpu").to(
366
+ x_0.device
367
+ )
368
+ if not no_add_noise
369
+ else torch.zeros_like(x_0_expanded)
370
+ )
371
+ x_ts = scheduler.add_noise(
372
+ x_0_expanded,
373
+ noise,
374
+ torch.IntTensor(noise_timesteps),
375
+ )
376
+ x_ts = [t.unsqueeze(dim=0) for t in list(x_ts)]
377
+ x_ts += [x_0] * (len(timesteps) - first_x_0_idx)
378
+ x_ts += [x_0]
379
+ if clean_step_timestep > 0:
380
+ x_ts += [x_0]
381
+ return x_ts
382
+
383
+ def normalize(
384
+ z_t,
385
+ i,
386
+ max_norm_zs,
387
+ ):
388
+ max_norm = max_norm_zs[i]
389
+ if max_norm < 0:
390
+ return z_t, 1
391
+
392
+ norm = torch.norm(z_t)
393
+ if norm < max_norm:
394
+ return z_t, 1
395
+
396
+ coeff = max_norm / norm
397
+ z_t = z_t * coeff
398
+ return z_t, coeff
399
+
400
+ def decode_latents(latent, pipe):
401
+ latent_img = pipe.vae.decode(
402
+ latent / pipe.vae.config.scaling_factor, return_dict=False
403
+ )[0]
404
+ return pipe.image_processor.postprocess(latent_img, output_type="pil")
405
+
406
+ def deterministic_ddim_step(
407
+ model_output: torch.FloatTensor,
408
+ timestep: int,
409
+ sample: torch.FloatTensor,
410
+ eta: float = 0.0,
411
+ use_clipped_model_output: bool = False,
412
+ generator=None,
413
+ variance_noise: Optional[torch.FloatTensor] = None,
414
+ return_dict: bool = True,
415
+ scheduler=None,
416
+ ):
417
+
418
+ if scheduler.num_inference_steps is None:
419
+ raise ValueError(
420
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
421
+ )
422
+
423
+ prev_timestep = (
424
+ timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
425
+ )
426
+
427
+ # 2. compute alphas, betas
428
+ alpha_prod_t = scheduler.alphas_cumprod[timestep]
429
+ alpha_prod_t_prev = (
430
+ scheduler.alphas_cumprod[prev_timestep]
431
+ if prev_timestep >= 0
432
+ else scheduler.final_alpha_cumprod
433
+ )
434
+
435
+ beta_prod_t = 1 - alpha_prod_t
436
+
437
+ if scheduler.config.prediction_type == "epsilon":
438
+ pred_original_sample = (
439
+ sample - beta_prod_t ** (0.5) * model_output
440
+ ) / alpha_prod_t ** (0.5)
441
+ pred_epsilon = model_output
442
+ elif scheduler.config.prediction_type == "sample":
443
+ pred_original_sample = model_output
444
+ pred_epsilon = (
445
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
446
+ ) / beta_prod_t ** (0.5)
447
+ elif scheduler.config.prediction_type == "v_prediction":
448
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (
449
+ beta_prod_t**0.5
450
+ ) * model_output
451
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
452
+ else:
453
+ raise ValueError(
454
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample`, or"
455
+ " `v_prediction`"
456
+ )
457
+
458
+ # 4. Clip or threshold "predicted x_0"
459
+ if scheduler.config.thresholding:
460
+ pred_original_sample = scheduler._threshold_sample(pred_original_sample)
461
+ elif scheduler.config.clip_sample:
462
+ pred_original_sample = pred_original_sample.clamp(
463
+ -scheduler.config.clip_sample_range,
464
+ scheduler.config.clip_sample_range,
465
+ )
466
+
467
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
468
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
469
+ variance = scheduler._get_variance(timestep, prev_timestep)
470
+ std_dev_t = eta * variance ** (0.5)
471
+
472
+ if use_clipped_model_output:
473
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
474
+ pred_epsilon = (
475
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
476
+ ) / beta_prod_t ** (0.5)
477
+
478
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
479
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (
480
+ 0.5
481
+ ) * pred_epsilon
482
+
483
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
484
+ prev_sample = (
485
+ alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
486
+ )
487
+ return prev_sample
488
+
489
+
490
+ def deterministic_euler_step(
491
+ model_output: torch.FloatTensor,
492
+ timestep: Union[float, torch.FloatTensor],
493
+ sample: torch.FloatTensor,
494
+ eta,
495
+ use_clipped_model_output,
496
+ generator,
497
+ variance_noise,
498
+ return_dict,
499
+ scheduler,
500
+ ):
501
+ """
502
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
503
+ process from the learned model outputs (most often the predicted noise).
504
+
505
+ Args:
506
+ model_output (`torch.FloatTensor`):
507
+ The direct output from learned diffusion model.
508
+ timestep (`float`):
509
+ The current discrete timestep in the diffusion chain.
510
+ sample (`torch.FloatTensor`):
511
+ A current instance of a sample created by the diffusion process.
512
+ generator (`torch.Generator`, *optional*):
513
+ A random number generator.
514
+ return_dict (`bool`):
515
+ Whether or not to return a
516
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
517
+
518
+ Returns:
519
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
520
+ If return_dict is `True`,
521
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
522
+ otherwise a tuple is returned where the first element is the sample tensor.
523
+
524
+ """
525
+
526
+ if (
527
+ isinstance(timestep, int)
528
+ or isinstance(timestep, torch.IntTensor)
529
+ or isinstance(timestep, torch.LongTensor)
530
+ ):
531
+ raise ValueError(
532
+ (
533
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
534
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
535
+ " one of the `scheduler.timesteps` as a timestep."
536
+ ),
537
+ )
538
+
539
+ if scheduler.step_index is None:
540
+ scheduler._init_step_index(timestep)
541
+
542
+ sigma = scheduler.sigmas[scheduler.step_index]
543
+
544
+ # Upcast to avoid precision issues when computing prev_sample
545
+ sample = sample.to(torch.float32)
546
+
547
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
548
+ if scheduler.config.prediction_type == "epsilon":
549
+ pred_original_sample = sample - sigma * model_output
550
+ elif scheduler.config.prediction_type == "v_prediction":
551
+ # * c_out + input * c_skip
552
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
553
+ sample / (sigma**2 + 1)
554
+ )
555
+ elif scheduler.config.prediction_type == "sample":
556
+ raise NotImplementedError("prediction_type not implemented yet: sample")
557
+ else:
558
+ raise ValueError(
559
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
560
+ )
561
+
562
+ sigma_from = scheduler.sigmas[scheduler.step_index]
563
+ sigma_to = scheduler.sigmas[scheduler.step_index + 1]
564
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
565
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
566
+
567
+ # 2. Convert to an ODE derivative
568
+ derivative = (sample - pred_original_sample) / sigma
569
+
570
+ dt = sigma_down - sigma
571
+
572
+ prev_sample = sample + derivative * dt
573
+
574
+ # Cast sample back to model compatible dtype
575
+ prev_sample = prev_sample.to(model_output.dtype)
576
+
577
+ # upon completion increase step index by one
578
+ scheduler._step_index += 1
579
+
580
+ return prev_sample
581
+
582
+
583
+ def deterministic_non_ancestral_euler_step(
584
+ model_output: torch.FloatTensor,
585
+ timestep: Union[float, torch.FloatTensor],
586
+ sample: torch.FloatTensor,
587
+ eta: float = 0.0,
588
+ use_clipped_model_output: bool = False,
589
+ s_churn: float = 0.0,
590
+ s_tmin: float = 0.0,
591
+ s_tmax: float = float("inf"),
592
+ s_noise: float = 1.0,
593
+ generator: Optional[torch.Generator] = None,
594
+ variance_noise: Optional[torch.FloatTensor] = None,
595
+ return_dict: bool = True,
596
+ scheduler=None,
597
+ ):
598
+ """
599
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
600
+ process from the learned model outputs (most often the predicted noise).
601
+
602
+ Args:
603
+ model_output (`torch.FloatTensor`):
604
+ The direct output from learned diffusion model.
605
+ timestep (`float`):
606
+ The current discrete timestep in the diffusion chain.
607
+ sample (`torch.FloatTensor`):
608
+ A current instance of a sample created by the diffusion process.
609
+ s_churn (`float`):
610
+ s_tmin (`float`):
611
+ s_tmax (`float`):
612
+ s_noise (`float`, defaults to 1.0):
613
+ Scaling factor for noise added to the sample.
614
+ generator (`torch.Generator`, *optional*):
615
+ A random number generator.
616
+ return_dict (`bool`):
617
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
618
+ tuple.
619
+
620
+ Returns:
621
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
622
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
623
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
624
+ """
625
+
626
+ if (
627
+ isinstance(timestep, int)
628
+ or isinstance(timestep, torch.IntTensor)
629
+ or isinstance(timestep, torch.LongTensor)
630
+ ):
631
+ raise ValueError(
632
+ (
633
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
634
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
635
+ " one of the `scheduler.timesteps` as a timestep."
636
+ ),
637
+ )
638
+
639
+ if not scheduler.is_scale_input_called:
640
+ logger.warning(
641
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
642
+ "See `StableDiffusionPipeline` for a usage example."
643
+ )
644
+
645
+ if scheduler.step_index is None:
646
+ scheduler._init_step_index(timestep)
647
+
648
+ # Upcast to avoid precision issues when computing prev_sample
649
+ sample = sample.to(torch.float32)
650
+
651
+ sigma = scheduler.sigmas[scheduler.step_index]
652
+
653
+ gamma = (
654
+ min(s_churn / (len(scheduler.sigmas) - 1), 2**0.5 - 1)
655
+ if s_tmin <= sigma <= s_tmax
656
+ else 0.0
657
+ )
658
+
659
+ sigma_hat = sigma * (gamma + 1)
660
+
661
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
662
+ # NOTE: "original_sample" should not be an expected prediction_type but is left in for
663
+ # backwards compatibility
664
+ if (
665
+ scheduler.config.prediction_type == "original_sample"
666
+ or scheduler.config.prediction_type == "sample"
667
+ ):
668
+ pred_original_sample = model_output
669
+ elif scheduler.config.prediction_type == "epsilon":
670
+ pred_original_sample = sample - sigma_hat * model_output
671
+ elif scheduler.config.prediction_type == "v_prediction":
672
+ # denoised = model_output * c_out + input * c_skip
673
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
674
+ sample / (sigma**2 + 1)
675
+ )
676
+ else:
677
+ raise ValueError(
678
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
679
+ )
680
+
681
+ # 2. Convert to an ODE derivative
682
+ derivative = (sample - pred_original_sample) / sigma_hat
683
+
684
+ dt = scheduler.sigmas[scheduler.step_index + 1] - sigma_hat
685
+
686
+ prev_sample = sample + derivative * dt
687
+
688
+ # Cast sample back to model compatible dtype
689
+ prev_sample = prev_sample.to(model_output.dtype)
690
+
691
+ # upon completion increase step index by one
692
+ scheduler._step_index += 1
693
+
694
+ return prev_sample
695
+
696
+
697
+ def deterministic_ddpm_step(
698
+ model_output: torch.FloatTensor,
699
+ timestep: Union[float, torch.FloatTensor],
700
+ sample: torch.FloatTensor,
701
+ eta,
702
+ use_clipped_model_output,
703
+ generator,
704
+ variance_noise,
705
+ return_dict,
706
+ scheduler,
707
+ ):
708
+ """
709
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
710
+ process from the learned model outputs (most often the predicted noise).
711
+
712
+ Args:
713
+ model_output (`torch.FloatTensor`):
714
+ The direct output from learned diffusion model.
715
+ timestep (`float`):
716
+ The current discrete timestep in the diffusion chain.
717
+ sample (`torch.FloatTensor`):
718
+ A current instance of a sample created by the diffusion process.
719
+ generator (`torch.Generator`, *optional*):
720
+ A random number generator.
721
+ return_dict (`bool`, *optional*, defaults to `True`):
722
+ Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
723
+
724
+ Returns:
725
+ [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`:
726
+ If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a
727
+ tuple is returned where the first element is the sample tensor.
728
+
729
+ """
730
+ t = timestep
731
+
732
+ prev_t = scheduler.previous_timestep(t)
733
+
734
+ if model_output.shape[1] == sample.shape[1] * 2 and scheduler.variance_type in [
735
+ "learned",
736
+ "learned_range",
737
+ ]:
738
+ model_output, predicted_variance = torch.split(
739
+ model_output, sample.shape[1], dim=1
740
+ )
741
+ else:
742
+ predicted_variance = None
743
+
744
+ # 1. compute alphas, betas
745
+ alpha_prod_t = scheduler.alphas_cumprod[t]
746
+ alpha_prod_t_prev = (
747
+ scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else scheduler.one
748
+ )
749
+ beta_prod_t = 1 - alpha_prod_t
750
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
751
+ current_alpha_t = alpha_prod_t / alpha_prod_t_prev
752
+ current_beta_t = 1 - current_alpha_t
753
+
754
+ # 2. compute predicted original sample from predicted noise also called
755
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
756
+ if scheduler.config.prediction_type == "epsilon":
757
+ pred_original_sample = (
758
+ sample - beta_prod_t ** (0.5) * model_output
759
+ ) / alpha_prod_t ** (0.5)
760
+ elif scheduler.config.prediction_type == "sample":
761
+ pred_original_sample = model_output
762
+ elif scheduler.config.prediction_type == "v_prediction":
763
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (
764
+ beta_prod_t**0.5
765
+ ) * model_output
766
+ else:
767
+ raise ValueError(
768
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample` or"
769
+ " `v_prediction` for the DDPMScheduler."
770
+ )
771
+
772
+ # 3. Clip or threshold "predicted x_0"
773
+ if scheduler.config.thresholding:
774
+ pred_original_sample = scheduler._threshold_sample(pred_original_sample)
775
+ elif scheduler.config.clip_sample:
776
+ pred_original_sample = pred_original_sample.clamp(
777
+ -scheduler.config.clip_sample_range, scheduler.config.clip_sample_range
778
+ )
779
+
780
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
781
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
782
+ pred_original_sample_coeff = (
783
+ alpha_prod_t_prev ** (0.5) * current_beta_t
784
+ ) / beta_prod_t
785
+ current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
786
+
787
+ # 5. Compute predicted previous sample µ_t
788
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
789
+ pred_prev_sample = (
790
+ pred_original_sample_coeff * pred_original_sample
791
+ + current_sample_coeff * sample
792
+ )
793
+
794
+ return pred_prev_sample
run_configs/noise_shift_3_steps.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ breakdown: "x_t_hat_c"
2
+ cross_r: 0.9
3
+ eta_reconstruct: 1
4
+ eta_retrieve: 1
5
+ max_norm_zs: [-1, -1, 15.5]
6
+ model: "stabilityai/sdxl-turbo"
7
+ noise_shift_delta: 1
8
+ noise_timesteps: [599, 299, 0]
9
+ timesteps: [799, 499, 199]
10
+ num_steps_inversion: 5
11
+ step_start: 1
12
+ real_cfg_scale: 0
13
+ real_cfg_scale_save: 0
14
+ scheduler_type: "ddpm"
15
+ seed: 2
16
+ self_r: 0.5
17
+ ws1: 1.5
18
+ ws2: 1
19
+ clean_step_timestep: 0
run_configs/noise_shift_guidance_1_5.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ breakdown: "x_t_hat_c"
2
+ cross_r: 0.9
3
+ eta: 1
4
+ max_norm_zs: [-1, -1, -1, 15.5]
5
+ model: ""
6
+ noise_shift_delta: 1
7
+ noise_timesteps: null
8
+ num_steps_inversion: 20
9
+ step_start: 5
10
+ real_cfg_scale: 0
11
+ real_cfg_scale_save: 0
12
+ scheduler_type: "ddpm"
13
+ seed: 2
14
+ self_r: 0.5
15
+ timesteps: null
16
+ ws1: 1.5
17
+ ws2: 1
18
+ clean_step_timestep: 0