Feature Extraction
Transformers
Safetensors
diva
custom_code
DiVA-llama-3-v0-8b / modeling_diva.py
WillHeld's picture
Update modeling_diva.py
39645d5 verified
raw
history blame
No virus
9.41 kB
import copy
import json
import os
from typing import Optional, Union
import librosa
import numpy as np
import torch
import torch.nn.functional as F
from datasets import Audio
from safetensors.torch import load, load_model
from torch import nn
from .configuring_diva import DiVAConfig
from transformers import (
AutoProcessor,
AutoTokenizer,
LlamaForCausalLM,
PreTrainedModel,
WhisperForConditionalGeneration,
)
class WhisperConnector(nn.Module):
def __init__(
self,
):
super().__init__()
self.decoder = None
self.projection = nn.Linear(1280, 4096)
self.query_tokens = nn.Parameter(torch.randn(448, 1280))
def forward(self, x, output_device="cuda:1"):
bsz = x.shape[0]
query_tokens = self.query_tokens[None, :, :].expand(bsz, -1, -1)
virt_whisper_tokens = self.decoder(
inputs_embeds=query_tokens, encoder_hidden_states=x
)
if self.projection.weight.shape[-1] == 5120:
virtual_tokens = self.projection(virt_whisper_tokens[0].reshape(112, 5120))
else:
virtual_tokens = self.projection(virt_whisper_tokens[0])
return virtual_tokens.to(output_device)
class DiVAModel(PreTrainedModel):
config_class = DiVAConfig
def __init__(
self, via_path=None, config_dict={}, device_map=None, speech_encoder_device=None
):
super().__init__(DiVAConfig.from_dict(config_dict))
if speech_encoder_device is None:
speech_encoder_device = "cuda:0"
whisper = WhisperForConditionalGeneration.from_pretrained(
"openai/whisper-large-v3"
)
connector = WhisperConnector()
connector.decoder = copy.deepcopy(whisper.model.decoder)
if via_path is not None:
with open(via_path, "rb") as f:
sd = load(f.read())
with torch.no_grad():
connector.query_tokens = nn.Parameter(sd["query_tokens"])
connector.projection.weight = nn.Parameter(sd["projection.weight"].T)
connector.projection.bias = nn.Parameter(sd["projection.bias"])
wsd = {
key.replace("connector.", ""): sd[key]
for key in sd
if key.startswith("connector.")
}
connector.decoder.load_state_dict(wsd)
if device_map == None:
num_layers = 32
num_gpus = 2
device_map = dict(
**{"model.embed_tokens": 1, "model.norm": 1, "lm_head": 2},
**{
"model.layers." + str(i): 1 + (i // (num_layers // num_gpus))
for i in range(num_layers)
},
)
self.connector = connector.to(speech_encoder_device)
self.whisper_encoder = whisper.model.encoder.to(speech_encoder_device)
self.llama_decoder = LlamaForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B-Instruct",
device_map=device_map,
torch_dtype=torch.float16,
)
self.processor = AutoProcessor.from_pretrained("openai/whisper-large-v3")
self.tokenizer = AutoTokenizer.from_pretrained("WillHeld/via-llama")
self.prefix = torch.tensor([128000, 128006, 882, 128007, 271]).to(
self.llama_decoder.model.embed_tokens.weight.device
)
self.pre_user_suffix = torch.tensor(
self.tokenizer.encode(
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
)
).to(self.llama_decoder.model.embed_tokens.weight.device)
self.final_header = torch.tensor([128009, 128006, 78191, 128007, 271]).to(
self.llama_decoder.model.embed_tokens.weight.device
)
self.speech_encoder_device = speech_encoder_device
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config=None,
cache_dir=None,
**kwargs,
):
if os.path.isdir(pretrained_model_name_or_path):
via_path = (
pretrained_model_name_or_path + "/model-00001-of-00004.safetensors"
)
config_path = pretrained_model_name_or_path + "/config.json"
else:
# Loading from huggingface repo
from huggingface_hub import hf_hub_download
hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="model-00001-of-00004.safetensors",
token=kwargs.get("token", None),
local_dir=os.path.dirname(__file__),
)
hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="config.json",
token=kwargs.get("token", None),
local_dir=os.path.dirname(__file__),
)
via_path = os.path.dirname(__file__) + "/model-00001-of-00004.safetensors"
config_path = os.path.dirname(__file__) + "/config.json"
with open(config_path, "r") as f:
config_dict = json.loads(f.read())
return cls(
via_path,
config_dict,
kwargs["device_map"] if "device_map" in kwargs else "auto",
(
kwargs["speech_encoder_device"]
if "speech_encoder_device" in kwargs
else None
),
)
def forward(self, audio, prefix_text_tokens, suffix_text_tokens):
inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000)
input_features = inputs.input_features.to(self.speech_encoder_device)
hidden_states = self.whisper_encoder(input_features=input_features)[
"last_hidden_state"
]
virt_tokens = self.connector(
hidden_states,
output_device=self.llama_decoder.model.embed_tokens.weight.device,
).squeeze()
prefix_embed = self.llama_decoder.model.embed_tokens(prefix_text_tokens)
suffix_embed = self.llama_decoder.model.embed_tokens(suffix_text_tokens)
inputs_embeds = torch.cat(
[prefix_embed, virt_tokens, suffix_embed], axis=0
).unsqueeze(0)
outputs = self.llama_decoder(
inputs_embeds=inputs_embeds.to(
self.llama_decoder.model.embed_tokens.weight.device
).half(),
return_dict=True,
output_hidden_states=True,
past_key_values=past_key_values,
)
return outputs
def generate(
self, audio, prompt, do_sample=False, logits_processor=None, max_new_tokens=128
):
inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000)
input_features = inputs.input_features.to(self.speech_encoder_device)
hidden_states = self.whisper_encoder(input_features=input_features)[
"last_hidden_state"
]
virt_tokens = self.connector(
hidden_states,
output_device=self.llama_decoder.model.embed_tokens.weight.device,
).squeeze()
if prompt != None and prompt != "":
user_prompt_text = torch.tensor(
self.tokenizer(prompt, add_special_tokens=False)["input_ids"],
device=self.pre_user_suffix.device,
)
prefix = torch.cat(
[self.pre_user_suffix, user_prompt_text, self.prefix], axis=0
)
else:
prefix = self.prefix
prefix_embed = self.llama_decoder.model.embed_tokens(prefix)
suffix = self.final_header
suffix_embed = self.llama_decoder.model.embed_tokens(suffix)
inputs_embeds = torch.cat(
[prefix_embed, virt_tokens, suffix_embed], axis=0
).unsqueeze(0)
outs = []
outputs = None
greedy = 1
i = 0
while greedy != 128009 and len(outs) < max_new_tokens:
past_key_values = outputs.past_key_values if outputs else None
outputs = self.llama_decoder(
inputs_embeds=inputs_embeds.to(
self.llama_decoder.model.embed_tokens.weight.device
).half(),
return_dict=True,
output_hidden_states=True,
past_key_values=past_key_values,
)
next_token_logits = outputs.logits[-1, -1, :]
if logits_processor:
local_outs = torch.tensor(outs) if outs != [] else suffix
local_outs = local_outs.reshape(1, -1)
next_token_logits = logits_processor(
local_outs,
next_token_logits.reshape(1, -1),
)
next_token_logits = next_token_logits.flatten()
if do_sample:
logits = next_token_logits / temperature
probs = F.softmax(logits, dim=-1)
greedy = torch.multinomial(probs, num_samples=1)[0]
else:
greedy = next_token_logits.argmax()
outs.append(greedy)
next_embed = self.llama_decoder.model.embed_tokens(greedy.reshape(1, 1))
inputs_embeds = next_embed
return self.tokenizer.decode(outs, skip_special_tokens=True).replace(
"<|eot_id|>", ""
)