akhaliq HF staff commited on
Commit
8a407e4
1 Parent(s): a23d1b7
Files changed (1) hide show
  1. app.py +11 -1
app.py CHANGED
@@ -12,10 +12,20 @@ from huggingface_hub import snapshot_download
12
  import os
13
 
14
  # Download and load the model
15
- model_path = 'pyramid_flow_model'
16
  if not os.path.exists(model_path):
17
  snapshot_download("rain1011/pyramid-flow-sd3", local_dir=model_path, local_dir_use_symlinks=False, repo_type='model')
18
 
 
 
 
 
 
 
 
 
 
 
19
  torch.cuda.set_device(0)
20
  model_dtype, torch_dtype = 'bf16', torch.bfloat16
21
 
 
12
  import os
13
 
14
  # Download and load the model
15
+ model_path = os.path.join(os.getcwd(), 'pyramid_flow_model')
16
  if not os.path.exists(model_path):
17
  snapshot_download("rain1011/pyramid-flow-sd3", local_dir=model_path, local_dir_use_symlinks=False, repo_type='model')
18
 
19
+ # Verify model files exist
20
+ dit_path = os.path.join(model_path, 'diffusion_transformer_768p')
21
+ vae_path = os.path.join(model_path, 'causal_video_vae')
22
+
23
+ if not os.path.exists(os.path.join(dit_path, 'diffusion_pytorch_model.safetensors')):
24
+ raise FileNotFoundError(f"DiT model file not found in {dit_path}")
25
+
26
+ if not os.path.exists(os.path.join(vae_path, 'diffusion_pytorch_model.safetensors')):
27
+ raise FileNotFoundError(f"VAE model file not found in {vae_path}")
28
+
29
  torch.cuda.set_device(0)
30
  model_dtype, torch_dtype = 'bf16', torch.bfloat16
31