Feature Extraction
Transformers
Safetensors
diva
custom_code
DiVA-llama-3-v0-8b / modeling_diva.py
Helw150
Add Batch Support
49f38f9
raw
history blame
13.2 kB
import copy
import json
import os
from typing import Optional, Union
import librosa
import numpy as np
import torch
import torch.nn.functional as F
from datasets import Audio
from safetensors.torch import load, load_model
from torch import nn
from .configuring_diva import DiVAConfig
from transformers import (
AutoProcessor,
AutoTokenizer,
LlamaForCausalLM,
PreTrainedModel,
WhisperForConditionalGeneration,
)
class WhisperConnector(nn.Module):
def __init__(
self,
):
super().__init__()
self.decoder = None
self.projection = nn.Linear(1280, 4096)
self.query_tokens = nn.Parameter(torch.randn(448, 1280))
def forward(self, x, output_device="cuda:1"):
bsz = x.shape[0]
query_tokens = self.query_tokens[None, :, :].expand(bsz, -1, -1)
virt_whisper_tokens = self.decoder(
inputs_embeds=query_tokens, encoder_hidden_states=x
)
if self.projection.weight.shape[-1] == 5120:
virtual_tokens = self.projection(virt_whisper_tokens[0].reshape(112, 5120))
else:
virtual_tokens = self.projection(virt_whisper_tokens[0])
return virtual_tokens.to(output_device)
class DiVAModel(PreTrainedModel):
config_class = DiVAConfig
def __init__(
self, via_path=None, config_dict={}, device_map=None, speech_encoder_device=None
):
super().__init__(DiVAConfig.from_dict(config_dict))
if speech_encoder_device is None:
speech_encoder_device = "cuda:0"
whisper = WhisperForConditionalGeneration.from_pretrained(
"openai/whisper-large-v3"
)
connector = WhisperConnector()
connector.decoder = copy.deepcopy(whisper.model.decoder)
if via_path is not None:
with open(via_path, "rb") as f:
sd = load(f.read())
with torch.no_grad():
connector.query_tokens = nn.Parameter(sd["query_tokens"])
connector.projection.weight = nn.Parameter(sd["projection.weight"].T)
connector.projection.bias = nn.Parameter(sd["projection.bias"])
wsd = {
key.replace("connector.", ""): sd[key]
for key in sd
if key.startswith("connector.")
}
connector.decoder.load_state_dict(wsd)
if device_map == None:
num_layers = 32
num_gpus = 2
device_map = dict(
**{"model.embed_tokens": 1, "model.norm": 1, "lm_head": 2},
**{
"model.layers." + str(i): 1 + (i // (num_layers // num_gpus))
for i in range(num_layers)
},
)
self.connector = connector.to(speech_encoder_device)
self.whisper_encoder = whisper.model.encoder.to(speech_encoder_device)
self.llama_decoder = LlamaForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B-Instruct",
device_map=device_map,
torch_dtype=torch.float16,
)
self.processor = AutoProcessor.from_pretrained("openai/whisper-large-v3")
self.tokenizer = AutoTokenizer.from_pretrained("WillHeld/via-llama")
self.prefix = torch.tensor([128000, 128006, 882, 128007, 271]).to(
self.llama_decoder.model.embed_tokens.weight.device
)
self.pre_user_suffix = torch.tensor(
self.tokenizer.encode(
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
)
).to(self.llama_decoder.model.embed_tokens.weight.device)
self.final_header = torch.tensor([128009, 128006, 78191, 128007, 271]).to(
self.llama_decoder.model.embed_tokens.weight.device
)
self.speech_encoder_device = speech_encoder_device
def can_generate(cls):
return False
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config=None,
cache_dir=None,
**kwargs,
):
if os.path.isdir(pretrained_model_name_or_path):
via_path = (
pretrained_model_name_or_path + "/model-00001-of-00004.safetensors"
)
config_path = pretrained_model_name_or_path + "/config.json"
else:
# Loading from huggingface repo
from huggingface_hub import hf_hub_download
hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="model-00001-of-00004.safetensors",
token=kwargs.get("token", None),
local_dir=os.path.dirname(__file__),
)
hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="config.json",
token=kwargs.get("token", None),
local_dir=os.path.dirname(__file__),
)
via_path = os.path.dirname(__file__) + "/model-00001-of-00004.safetensors"
config_path = os.path.dirname(__file__) + "/config.json"
with open(config_path, "r") as f:
config_dict = json.loads(f.read())
return cls(
via_path,
config_dict,
kwargs["device_map"] if "device_map" in kwargs else "auto",
(
kwargs["speech_encoder_device"]
if "speech_encoder_device" in kwargs
else None
),
)
def forward(self, audio, prefix_text_tokens, suffix_text_tokens):
inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000)
input_features = inputs.input_features.to(self.speech_encoder_device)
hidden_states = self.whisper_encoder(input_features=input_features)[
"last_hidden_state"
]
virt_tokens = self.connector(
hidden_states,
output_device=self.llama_decoder.model.embed_tokens.weight.device,
).squeeze()
prefix_embed = self.llama_decoder.model.embed_tokens(prefix_text_tokens)
suffix_embed = self.llama_decoder.model.embed_tokens(suffix_text_tokens)
inputs_embeds = torch.cat(
[prefix_embed, virt_tokens, suffix_embed], axis=0
).unsqueeze(0)
outputs = self.llama_decoder(
inputs_embeds=inputs_embeds.to(
self.llama_decoder.model.embed_tokens.weight.device
).half(),
return_dict=True,
output_hidden_states=True,
past_key_values=past_key_values,
)
return outputs
@torch.no_grad()
def generate(
self,
audio,
text_prompt=None,
do_sample=False,
logits_processor=None,
max_new_tokens=128,
):
inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000)
input_features = inputs.input_features.to(self.speech_encoder_device)
hidden_states = self.whisper_encoder(input_features=input_features)[
"last_hidden_state"
]
virt_tokens = self.connector(
hidden_states,
output_device=self.llama_decoder.model.embed_tokens.weight.device,
)
bsz = virt_tokens.shape[0]
if text_prompt != None and text_prompt != "":
user_prompt_text = torch.tensor(
self.tokenizer(
text_prompt,
add_special_tokens=False,
padding=True,
padding_side="right",
)["input_ids"],
device=self.pre_user_suffix.device,
)
prefix = torch.cat(
[
self.pre_user_suffix.expand(
bsz,
-1,
),
user_prompt_text,
self.prefix.expand(
bsz,
-1,
),
],
axis=1,
)
else:
prefix = self.prefix
prefix_embed = self.llama_decoder.model.embed_tokens(prefix).expand(bsz, -1, -1)
suffix = self.final_header
suffix_embed = self.llama_decoder.model.embed_tokens(suffix).expand(bsz, -1, -1)
inputs_embeds = torch.cat([prefix_embed, virt_tokens, suffix_embed], axis=1)
outs = [[] for i in range(bsz)]
complete = [False] * bsz
outputs = None
greedy = 1
i = 0
while not all(complete) and len(outs[0]) < max_new_tokens:
past_key_values = outputs.past_key_values if outputs else None
outputs = self.llama_decoder(
inputs_embeds=inputs_embeds.to(
self.llama_decoder.model.embed_tokens.weight.device
).half(),
return_dict=True,
output_hidden_states=True,
past_key_values=past_key_values,
)
next_token_logits = outputs.logits[:, -1, :]
if logits_processor:
local_outs = torch.tensor(outs) if outs != [] else suffix
local_outs = local_outs.reshape(1, -1)
next_token_logits = logits_processor(
local_outs,
next_token_logits.reshape(1, -1),
)
next_token_logits = next_token_logits.flatten()
if do_sample:
logits = next_token_logits / temperature
probs = F.softmax(logits, dim=-1)
greedy = torch.multinomial(probs, num_samples=1)[0]
else:
greedy = next_token_logits.argmax(dim=-1)
for token_index, out in enumerate(greedy.flatten().tolist()):
outs[token_index].append(out)
if out == 128009:
complete[token_index] = True
next_embed = self.llama_decoder.model.embed_tokens(greedy.reshape(-1, 1))
inputs_embeds = next_embed
return self.tokenizer.batch_decode(outs, skip_special_tokens=True)
def generate_stream(
self,
audio,
text_prompt,
do_sample=False,
logits_processor=None,
max_new_tokens=128,
):
inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000)
input_features = inputs.input_features.to(self.whisper_encoder.device)
hidden_states = self.whisper_encoder(input_features=input_features)[
"last_hidden_state"
]
virt_tokens = self.connector(
hidden_states,
output_device=self.llama_decoder.model.embed_tokens.weight.device,
).squeeze()
if text_prompt != None and text_prompt != "":
user_prompt_text = torch.tensor(
self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"],
device=self.pre_user_suffix.device,
)
prefix = torch.cat(
[self.pre_user_suffix, user_prompt_text, self.prefix], axis=0
)
else:
prefix = self.prefix
prefix_embed = self.llama_decoder.model.embed_tokens(prefix)
suffix = self.final_header
suffix_embed = self.llama_decoder.model.embed_tokens(suffix)
inputs_embeds = torch.cat(
[prefix_embed, virt_tokens, suffix_embed], axis=0
).unsqueeze(0)
outs = []
outputs = None
greedy = 1
i = 0
while greedy != 128009 and len(outs) < max_new_tokens:
past_key_values = outputs.past_key_values if outputs else None
outputs = self.llama_decoder(
inputs_embeds=inputs_embeds.to(
self.llama_decoder.model.embed_tokens.weight.device
).half(),
return_dict=True,
output_hidden_states=True,
past_key_values=past_key_values,
)
next_token_logits = outputs.logits[-1, -1, :]
if logits_processor:
local_outs = torch.tensor(outs) if outs != [] else suffix
local_outs = local_outs.reshape(1, -1)
next_token_logits = logits_processor(
local_outs,
next_token_logits.reshape(1, -1),
)
next_token_logits = next_token_logits.flatten()
if do_sample:
logits = next_token_logits / temperature
probs = F.softmax(logits, dim=-1)
greedy = torch.multinomial(probs, num_samples=1)[0]
else:
greedy = next_token_logits.argmax()
outs.append(greedy)
next_embed = self.llama_decoder.model.embed_tokens(greedy.reshape(1, 1))
inputs_embeds = next_embed
yield self.tokenizer.decode(outs, skip_special_tokens=True).replace(
"<|eot_id|>", ""
)
return self.tokenizer.decode(outs, skip_special_tokens=True).replace(
"<|eot_id|>", ""
)