Feature Extraction
Transformers
Safetensors
diva
custom_code
Will Held commited on
Commit
b7ca827
1 Parent(s): 259eb63

Add Stream

Browse files
Files changed (1) hide show
  1. modeling_diva.py +65 -0
modeling_diva.py CHANGED
@@ -243,3 +243,68 @@ class DiVAModel(PreTrainedModel):
243
  return self.tokenizer.decode(outs, skip_special_tokens=True).replace(
244
  "<|eot_id|>", ""
245
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  return self.tokenizer.decode(outs, skip_special_tokens=True).replace(
244
  "<|eot_id|>", ""
245
  )
246
+
247
+ def generate_stream(
248
+ self, audio, text_prompt, do_sample=False, logits_processor=None, max_new_tokens=128
249
+ ):
250
+ inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000)
251
+ input_features = inputs.input_features.to(self.speech_encoder_device)
252
+ hidden_states = self.whisper_encoder(input_features=input_features)[
253
+ "last_hidden_state"
254
+ ]
255
+ virt_tokens = self.connector(
256
+ hidden_states,
257
+ output_device=self.llama_decoder.model.embed_tokens.weight.device,
258
+ ).squeeze()
259
+
260
+ if text_prompt != None and text_prompt != "":
261
+ user_prompt_text = torch.tensor(
262
+ self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"],
263
+ device=self.pre_user_suffix.device,
264
+ )
265
+ prefix = torch.cat(
266
+ [self.pre_user_suffix, user_prompt_text, self.prefix], axis=0
267
+ )
268
+ else:
269
+ prefix = self.prefix
270
+ prefix_embed = self.llama_decoder.model.embed_tokens(prefix)
271
+ suffix = self.final_header
272
+ suffix_embed = self.llama_decoder.model.embed_tokens(suffix)
273
+ inputs_embeds = torch.cat(
274
+ [prefix_embed, virt_tokens, suffix_embed], axis=0
275
+ ).unsqueeze(0)
276
+ outs = []
277
+ outputs = None
278
+ greedy = 1
279
+ i = 0
280
+ while greedy != 128009 and len(outs) < max_new_tokens:
281
+ past_key_values = outputs.past_key_values if outputs else None
282
+ outputs = self.llama_decoder(
283
+ inputs_embeds=inputs_embeds.to(
284
+ self.llama_decoder.model.embed_tokens.weight.device
285
+ ).half(),
286
+ return_dict=True,
287
+ output_hidden_states=True,
288
+ past_key_values=past_key_values,
289
+ )
290
+ next_token_logits = outputs.logits[-1, -1, :]
291
+
292
+ if logits_processor:
293
+ local_outs = torch.tensor(outs) if outs != [] else suffix
294
+ local_outs = local_outs.reshape(1, -1)
295
+ next_token_logits = logits_processor(
296
+ local_outs,
297
+ next_token_logits.reshape(1, -1),
298
+ )
299
+ next_token_logits = next_token_logits.flatten()
300
+ if do_sample:
301
+ logits = next_token_logits / temperature
302
+ probs = F.softmax(logits, dim=-1)
303
+ greedy = torch.multinomial(probs, num_samples=1)[0]
304
+ else:
305
+ greedy = next_token_logits.argmax()
306
+ outs.append(greedy)
307
+ next_embed = self.llama_decoder.model.embed_tokens(greedy.reshape(1, 1))
308
+ inputs_embeds = next_embed
309
+ yield tokenizer.decode(outs).replace("<|eot_id|>", "")
310
+ return tokenizer.decode(outs).replace("<|eot_id|>", "")