John6666 commited on
Commit
3a6bc2d
1 Parent(s): 1d22096

Upload 2 files

Browse files
Files changed (1) hide show
  1. mod.py +5 -2
mod.py CHANGED
@@ -71,15 +71,17 @@ def get_repo_safetensors(repo_id: str):
71
  # Initialize the base model
72
  base_model = models[0]
73
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
74
-
75
 
76
  def change_base_model(repo_id: str, progress=gr.Progress(track_tqdm=True)):
77
  global pipe
 
78
  try:
79
- if not is_repo_name(repo_id) or not is_repo_exists(repo_id): return
80
  progress(0, f"Loading model: {repo_id}")
81
  clear_cache()
82
  pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
 
83
  progress(1, f"Model loaded: {repo_id}")
84
  except Exception as e:
85
  print(e)
@@ -135,6 +137,7 @@ def fuse_loras(pipe, lorajson: list[dict]):
135
  #pipe.unload_lora_weights()
136
 
137
 
 
138
  fuse_loras.zerogpu = True
139
 
140
 
 
71
  # Initialize the base model
72
  base_model = models[0]
73
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
74
+ last_model = models[0]
75
 
76
  def change_base_model(repo_id: str, progress=gr.Progress(track_tqdm=True)):
77
  global pipe
78
+ global last_model
79
  try:
80
+ if repo_id == last_model or not is_repo_name(repo_id) or not is_repo_exists(repo_id): return
81
  progress(0, f"Loading model: {repo_id}")
82
  clear_cache()
83
  pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
84
+ last_model = repo_id
85
  progress(1, f"Model loaded: {repo_id}")
86
  except Exception as e:
87
  print(e)
 
137
  #pipe.unload_lora_weights()
138
 
139
 
140
+ change_base_model.zerogpu = True
141
  fuse_loras.zerogpu = True
142
 
143