MeYourHint commited on
Commit
2012098
1 Parent(s): 9629d26
models/mask_transformer/transformer.py CHANGED
@@ -179,9 +179,10 @@ class MaskTransformer(nn.Module):
179
  clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
180
  jit=False) # Must set jit=False for training
181
  # Cannot run on cpu
182
- clip.model.convert_weights(
183
- clip_model) # Actually this line is unnecessary since clip by default already on float16
184
- # Date 0707: It's necessary, only unecessary when load directly to gpu. Disable if need to run on cpu
 
185
 
186
  # Freeze CLIP weights
187
  clip_model.eval()
@@ -731,9 +732,10 @@ class ResidualTransformer(nn.Module):
731
  clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
732
  jit=False) # Must set jit=False for training
733
  # Cannot run on cpu
734
- clip.model.convert_weights(
735
- clip_model) # Actually this line is unnecessary since clip by default already on float16
736
- # Date 0707: It's necessary, only unecessary when load directly to gpu. Disable if need to run on cpu
 
737
 
738
  # Freeze CLIP weights
739
  clip_model.eval()
 
179
  clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
180
  jit=False) # Must set jit=False for training
181
  # Cannot run on cpu
182
+ if str(self.evice) != "cpu":
183
+ clip.model.convert_weights(
184
+ clip_model) # Actually this line is unnecessary since clip by default already on float16
185
+ # Date 0707: It's necessary, only unecessary when load directly to gpu. Disable if need to run on cpu
186
 
187
  # Freeze CLIP weights
188
  clip_model.eval()
 
732
  clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
733
  jit=False) # Must set jit=False for training
734
  # Cannot run on cpu
735
+ if str(self.evice) != "cpu":
736
+ clip.model.convert_weights(
737
+ clip_model) # Actually this line is unnecessary since clip by default already on float16
738
+ # Date 0707: It's necessary, only unecessary when load directly to gpu. Disable if need to run on cpu
739
 
740
  # Freeze CLIP weights
741
  clip_model.eval()
options/base_option.py CHANGED
@@ -12,7 +12,7 @@ class BaseOptions():
12
 
13
  self.parser.add_argument('--vq_name', type=str, default="rvq_nq1_dc512_nc512", help='Name of the rvq model.')
14
 
15
- self.parser.add_argument("--gpu_id", type=int, default=0, help='GPU id')
16
  self.parser.add_argument('--dataset_name', type=str, default='t2m', help='Dataset Name, {t2m} for humanml3d, {kit} for kit-ml')
17
  self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here.')
18
 
 
12
 
13
  self.parser.add_argument('--vq_name', type=str, default="rvq_nq1_dc512_nc512", help='Name of the rvq model.')
14
 
15
+ self.parser.add_argument("--gpu_id", type=int, default=-1, help='GPU id')
16
  self.parser.add_argument('--dataset_name', type=str, default='t2m', help='Dataset Name, {t2m} for humanml3d, {kit} for kit-ml')
17
  self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here.')
18