Feature Extraction
Transformers
Safetensors
diva
custom_code
WillHeld commited on
Commit
259eb63
1 Parent(s): 39645d5

Update modeling_diva.py

Browse files
Files changed (1) hide show
  1. modeling_diva.py +3 -3
modeling_diva.py CHANGED
@@ -179,7 +179,7 @@ class DiVAModel(PreTrainedModel):
179
  return outputs
180
 
181
  def generate(
182
- self, audio, prompt, do_sample=False, logits_processor=None, max_new_tokens=128
183
  ):
184
  inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000)
185
  input_features = inputs.input_features.to(self.speech_encoder_device)
@@ -191,9 +191,9 @@ class DiVAModel(PreTrainedModel):
191
  output_device=self.llama_decoder.model.embed_tokens.weight.device,
192
  ).squeeze()
193
 
194
- if prompt != None and prompt != "":
195
  user_prompt_text = torch.tensor(
196
- self.tokenizer(prompt, add_special_tokens=False)["input_ids"],
197
  device=self.pre_user_suffix.device,
198
  )
199
  prefix = torch.cat(
 
179
  return outputs
180
 
181
  def generate(
182
+ self, audio, text_prompt, do_sample=False, logits_processor=None, max_new_tokens=128
183
  ):
184
  inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000)
185
  input_features = inputs.input_features.to(self.speech_encoder_device)
 
191
  output_device=self.llama_decoder.model.embed_tokens.weight.device,
192
  ).squeeze()
193
 
194
+ if text_prompt != None and text_prompt != "":
195
  user_prompt_text = torch.tensor(
196
+ self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"],
197
  device=self.pre_user_suffix.device,
198
  )
199
  prefix = torch.cat(