chenjoya commited on
Commit
7d1b5a5
1 Parent(s): 98f88b8

Upload 9 files

Browse files
models/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import HfArgumentParser
2
+
3
+ from .arguments_live import LiveTrainingArguments, get_args_class
4
+ from .live_llama import build_live_llama as build_model_and_tokenizer
5
+ from .modeling_live import fast_greedy_generate
6
+
7
+ def parse_args() -> LiveTrainingArguments:
8
+ args, = HfArgumentParser(LiveTrainingArguments).parse_args_into_dataclasses()
9
+ args, = HfArgumentParser(get_args_class(args.live_version)).parse_args_into_dataclasses()
10
+ return args
models/arguments_live.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from transformers import TrainingArguments
3
+
4
+ @dataclass
5
+ class LiveTrainingArguments(TrainingArguments):
6
+ live_version: str = 'live1+'
7
+ system_prompt: str = (
8
+ "A multimodal AI assistant is helping users with some activities."
9
+ " Below is their conversation, interleaved with the list of video frames received by the assistant."
10
+ )
11
+ train_datasets: list[str] = None
12
+ eval_datasets: list[str] = None
13
+ stream_loss_weight: float = 1.0
14
+ llm_pretrained: str = 'meta-llama/Meta-Llama-3-8B-Instruct'
15
+ vision_pretrained: str = 'google/siglip-large-patch16-384'
16
+ lora_modules: str = "model.*(q_proj|k_proj|v_proj|o_proj|gate_proj|up_proj|down_proj)|lm_head$"
17
+ lora_r: int = 128
18
+ lora_alpha: int = 256
19
+ finetune_modules: list[str] = field(default_factory=lambda: ['connector'])
20
+ frame_fps: int = 2 # for training. inference can be 10
21
+ frame_token_cls: bool = None
22
+ frame_token_pooled: list[int] = None
23
+ frame_resolution: int = 384
24
+ frame_token_interval: str = None
25
+ frame_token_interval_threshold: float = 0.0
26
+ augmentation: bool = False
27
+ attn_implementation: str = 'flash_attention_2'
28
+ output_dir: str = 'outputs/debug'
29
+
30
+ @dataclass
31
+ class LiveOneTrainingArguments(LiveTrainingArguments):
32
+ live_version: str = 'live1'
33
+ frame_token_cls: bool = True
34
+ frame_num_tokens: int = 1
35
+ frame_token_interval: str = ''
36
+ embed_mark: str = '2fps_384_1'
37
+ max_num_frames: int = 7200 # 1h, 2fps, 7200 frames
38
+
39
+ @dataclass
40
+ class LiveOnePlusTrainingArguments(LiveTrainingArguments):
41
+ live_version: str = 'live1+'
42
+ frame_token_cls: bool = True
43
+ frame_token_pooled: list[int] = field(default_factory=lambda: [3,3])
44
+ frame_num_tokens: int = 10 # 1+3x3
45
+ embed_mark: str = '2fps_384_1+3x3'
46
+ frame_token_interval: str = ','
47
+ max_num_frames: int = 1200 # 10min, 2fps, 1200 frames
48
+
49
+ def get_args_class(live_version: str):
50
+ if live_version == 'live1':
51
+ return LiveOneTrainingArguments
52
+ elif live_version == 'live1+':
53
+ return LiveOnePlusTrainingArguments
54
+ raise NotImplementedError
models/configuration_live.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import PretrainedConfig
3
+
4
+ class LiveConfigMixin(PretrainedConfig):
5
+ def __init__(self, *, vision_pretrained: str = None,
6
+ frame_resolution: int = None, frame_token_cls: bool = None, frame_token_pooled: list[int] = None, frame_num_tokens: int = None,
7
+ v_placeholder: str = '<v>', frame_token_interval: str = None, v_placeholder_id: int = None, frame_token_interval_id: int = None,
8
+ stream_loss_weight: float = 1.0, vision_hidden_size=1024, **kwargs
9
+ ):
10
+ super().__init__(**kwargs)
11
+ self.vision_pretrained = vision_pretrained
12
+ self.frame_resolution = frame_resolution
13
+ self.frame_token_cls = frame_token_cls
14
+ self.frame_token_pooled = frame_token_pooled
15
+ self.frame_num_tokens = frame_num_tokens
16
+ self.vision_hidden_size = vision_hidden_size
17
+ self.stream_loss_weight = stream_loss_weight
18
+ self.v_placeholder = v_placeholder
19
+ self.frame_token_interval = frame_token_interval
20
+ self.v_placeholder_id = v_placeholder_id
21
+ self.frame_token_interval_id = frame_token_interval_id
models/live_llama/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .configuration_live_llama import LiveLlamaConfig
2
+ from .modeling_live_llama import LiveLlamaForCausalLM, build_live_llama
models/live_llama/configuration_live_llama.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import LlamaConfig
3
+
4
+ from ..configuration_live import LiveConfigMixin
5
+
6
+ class LiveLlamaConfig(LlamaConfig, LiveConfigMixin):
7
+ pass
models/live_llama/modeling_live_llama.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import LlamaForCausalLM, Cache
4
+ from transformers.activations import GELUActivation
5
+ from transformers.utils import logging
6
+
7
+ from .configuration_live_llama import LiveLlamaConfig
8
+ from ..modeling_live import build_live, LiveMixin
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+ class LiveLlamaForCausalLM(LlamaForCausalLM, LiveMixin):
13
+ config_class = LiveLlamaConfig
14
+ _keys_to_ignore_on_load_missing = ['vision_encoder', 'connector']
15
+
16
+ def __init__(self, config: LiveLlamaConfig):
17
+ super().__init__(config)
18
+ self.connector = torch.nn.Sequential(
19
+ torch.nn.Linear(config.vision_hidden_size, config.hidden_size, bias=True),
20
+ GELUActivation(config.hidden_size),
21
+ torch.nn.Linear(config.hidden_size, config.hidden_size, bias=True),
22
+ )
23
+
24
+ def forward(
25
+ self,
26
+ input_ids: torch.LongTensor = None,
27
+ frames: torch.FloatTensor = None,
28
+ attention_mask: torch.Tensor = None,
29
+ position_ids: torch.LongTensor = None,
30
+ past_key_values: list[torch.FloatTensor] = None,
31
+ inputs_embeds: torch.FloatTensor = None,
32
+ labels: torch.LongTensor = None,
33
+ use_cache: bool = None,
34
+ output_attentions: bool = None,
35
+ output_hidden_states: bool = None,
36
+ return_dict: bool = None,
37
+ cache_position: torch.LongTensor = None,
38
+ **kwargs,
39
+ ):
40
+ if inputs_embeds is None:
41
+ inputs_embeds = self.joint_embed(input_ids, frames)
42
+ outputs = super().forward(
43
+ attention_mask = attention_mask,
44
+ position_ids = position_ids,
45
+ past_key_values = past_key_values,
46
+ inputs_embeds = inputs_embeds,
47
+ # labels
48
+ use_cache = use_cache,
49
+ output_attentions = output_attentions,
50
+ output_hidden_states = output_hidden_states,
51
+ return_dict = return_dict,
52
+ cache_position=cache_position,
53
+ )
54
+
55
+ loss = None
56
+ if labels is not None:
57
+ logits = outputs[0]
58
+ v_mask = input_ids.flatten(0, 1) == self.config.v_placeholder_id
59
+ weight = v_mask * self.config.stream_loss_weight + ~v_mask
60
+ loss = nn.functional.cross_entropy(logits.flatten(0, 1), labels.flatten(), reduction='none') * weight
61
+ loss = loss.sum() / (labels >= 0).sum()
62
+
63
+ if not return_dict:
64
+ return (loss,) + outputs[1:] if loss is not None else outputs
65
+
66
+ outputs.loss = loss
67
+ return outputs
68
+
69
+ def prepare_inputs_for_generation(
70
+ self,
71
+ input_ids,
72
+ past_key_values=None,
73
+ attention_mask=None,
74
+ inputs_embeds=None,
75
+ cache_position=None,
76
+ use_cache=True,
77
+ **kwargs,
78
+ ):
79
+ past_length = 0
80
+ if past_key_values is not None:
81
+ if isinstance(past_key_values, Cache):
82
+ past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
83
+ max_cache_length = (
84
+ torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
85
+ if past_key_values.get_max_length() is not None
86
+ else None
87
+ )
88
+ cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
89
+ # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
90
+ else:
91
+ cache_length = past_length = past_key_values[0][0].shape[2]
92
+ max_cache_length = None
93
+
94
+ # Keep only the unprocessed tokens:
95
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
96
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
97
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
98
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
99
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
100
+ # input_ids based on the past_length.
101
+ elif past_length < input_ids.shape[1]:
102
+ input_ids = input_ids[:, past_length:]
103
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
104
+
105
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
106
+ if (
107
+ max_cache_length is not None
108
+ and attention_mask is not None
109
+ and cache_length + input_ids.shape[1] > max_cache_length
110
+ ):
111
+ attention_mask = attention_mask[:, -max_cache_length:]
112
+
113
+ position_ids = kwargs.get("position_ids", None)
114
+ if attention_mask is not None and position_ids is None:
115
+ # create position_ids on the fly for batch generation
116
+ position_ids = attention_mask.long().cumsum(-1) - 1
117
+ position_ids.masked_fill_(attention_mask == 0, 1)
118
+ if past_key_values:
119
+ position_ids = position_ids[:, past_length :] # NOTE
120
+
121
+ # NOTE
122
+ if inputs_embeds is not None and past_length < inputs_embeds.size(1):
123
+ model_inputs = {"inputs_embeds": inputs_embeds[:, past_length:]}
124
+ else:
125
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
126
+ # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
127
+ # TODO: use `next_tokens` directly instead.
128
+ model_inputs = {"input_ids": input_ids.contiguous()}
129
+
130
+ input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
131
+ if cache_position is None:
132
+ cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
133
+ elif use_cache:
134
+ cache_position = cache_position[-input_length:]
135
+
136
+ model_inputs.update(
137
+ {
138
+ "position_ids": position_ids, # 长度为新的inputs,从past开始
139
+ "cache_position": cache_position, # 没有被cache的区域
140
+ "past_key_values": past_key_values,
141
+ "use_cache": use_cache,
142
+ "attention_mask": attention_mask, # cache + input的长度
143
+ }
144
+ )
145
+ return model_inputs
146
+
147
+ def build_live_llama(**kwargs):
148
+ return build_live(config_class=LiveLlamaConfig, model_class=LiveLlamaForCausalLM, **kwargs)
149
+
150
+ if __name__ == '__main__':
151
+ from ..arguments_live import LiveOnePlusTrainingArguments
152
+ print(LiveOnePlusTrainingArguments().to_dict())
153
+ model, tokenizer = build_live_llama(is_training=True, **LiveOnePlusTrainingArguments().to_dict())
154
+ print(model.config, tokenizer)
models/modeling_live.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os
2
+ from peft import LoraConfig, get_peft_model, PeftModel
3
+ from transformers import AutoModelForCausalLM, Cache
4
+ from transformers.utils import logging
5
+
6
+ from .tokenization_live import build_live_tokenizer_and_update_config
7
+ from .vision_live import build_live_vision
8
+
9
+ logger = logging.get_logger(__name__)
10
+
11
+ class LiveMixin(AutoModelForCausalLM):
12
+ def set_vision_inside(self):
13
+ logger.warning_once("!!! Set vision encoder in the model, only recommended for on in-the-wild inference. "
14
+ "Please dont call this for efficient training & evaluation. Instead, do visual feature pre-extraction.")
15
+ self.vision_encoder, self.vision_encode = build_live_vision(self.config)
16
+
17
+ def unset_vision_inside(self):
18
+ del self.vision_encoder
19
+ del self.vision_encode
20
+
21
+ def visual_embed(self, frames: torch.Tensor):
22
+ if hasattr(self, 'vision_encode'):
23
+ with torch.cuda.amp.autocast():
24
+ frames = self.vision_encode(self.vision_encoder, frames)
25
+ frames = frames.to(self.dtype)
26
+ frames = self.connector(frames)
27
+ return frames.view(-1, frames.shape[-1])
28
+
29
+ def joint_embed(
30
+ self,
31
+ input_ids: torch.Tensor = None,
32
+ frames: torch.Tensor = None,
33
+ ):
34
+ if frames is None:
35
+ return self.get_input_embeddings()(input_ids)
36
+ if input_ids is None:
37
+ return self.visual_embed(frames)
38
+ inputs_embeds = self.get_input_embeddings()(input_ids.clamp(max=self.vocab_size-1))
39
+ v_mask = input_ids == self.config.v_placeholder_id
40
+ if v_mask.any():
41
+ inputs_embeds[v_mask] = self.visual_embed(frames)
42
+ return inputs_embeds
43
+
44
+ @torch.no_grad()
45
+ def stream_evaluate(
46
+ self,
47
+ input_ids: torch.LongTensor,
48
+ labels: torch.LongTensor,
49
+ frames: torch.ByteTensor,
50
+ ignore_token_id: int = -100,
51
+ frame_token_interval_threshold: float = 0.0,
52
+ **kwargs
53
+ ):
54
+ # 0. evaluation only supports batch_size = 1
55
+ assert input_ids.size(0) == labels.size(0) == 1
56
+ input_id, label = input_ids[0], labels[0]
57
+ device = input_id.device
58
+ zero = torch.tensor(0, dtype=torch.int, device=device)
59
+ one = torch.tensor(1, dtype=torch.int, device=device)
60
+
61
+ # 1. prepare multi-turn start and stop
62
+ turn_stops = ((input_id == self.config.eos_token_id).nonzero() + 1)[:,0].tolist()
63
+ turn_starts = [0] + turn_stops[:-1]
64
+ num_turns = len(turn_starts)
65
+
66
+ # 2. forward the full input_ids and labels, get tokenwise logits and losses
67
+ outputs = self.forward(input_ids=input_ids, frames=frames, return_dict=True, use_cache=True)
68
+ logit, past_key_values = outputs.logits[0], outputs.past_key_values
69
+
70
+ # 3. compute metrics for each turn
71
+ v_placeholder_id = self.config.v_placeholder_id
72
+ use_interval = self.config.frame_token_interval_id is not None
73
+ frame_token_interval_id = self.config.frame_token_interval_id if use_interval else self.config.eos_token_id
74
+ frame_num_tokens = self.config.frame_token_cls
75
+ if self.config.frame_token_pooled:
76
+ frame_num_tokens += self.config.frame_token_pooled[0] * self.config.frame_token_pooled[1]
77
+ past_num_frames = 0
78
+ lm_ppls, frame_diffs, fluencies, lm_correctness = [], [], [], []
79
+ for r, (turn_start, turn_stop) in enumerate(zip(turn_starts, turn_stops)):
80
+ ## 3.1. we only have two losses: stream loss on frame tokens, and lm loss. prepare corresponding mask according two losses
81
+ turn_label = label[turn_start:turn_stop]
82
+ turn_learn_mask = turn_label != ignore_token_id
83
+ if not turn_learn_mask.any():
84
+ continue
85
+ turn_logit = logit[turn_start:turn_stop]
86
+ turn_input_id = input_id[turn_start:turn_stop]
87
+ turn_v_mask = turn_input_id == v_placeholder_id
88
+ turn_num_frames = turn_v_mask.sum() // frame_num_tokens
89
+ turn_stream_mask = turn_v_mask & turn_learn_mask
90
+ turn_lm_mask = turn_learn_mask & ~turn_stream_mask
91
+
92
+ ## 3.2 ppl, offline metric
93
+ if turn_lm_mask.any():
94
+ turn_lm_masked_logit, turn_lm_masked_label = turn_logit[turn_lm_mask], turn_label[turn_lm_mask]
95
+ lm_ppl = torch.nn.functional.cross_entropy(turn_lm_masked_logit, turn_lm_masked_label).exp()
96
+ lm_ppls.append(lm_ppl)
97
+ turn_lm_masked_wrong_mask = turn_lm_masked_logit.argmax(dim=-1) != turn_lm_masked_label
98
+ if turn_lm_masked_wrong_mask.any():
99
+ num_lm_correct_tokens = turn_lm_masked_wrong_mask.nonzero()[0,0]
100
+ else:
101
+ num_lm_correct_tokens = (~turn_lm_masked_wrong_mask).sum()
102
+ lm_correctness.append(num_lm_correct_tokens / turn_lm_masked_label.numel())
103
+
104
+ ## 3.3. frame_diff (will be casted to time_diff in compute_metrics)
105
+ if turn_stream_mask.any():
106
+ ## 3.3.1: reply before (at) turn_num_frames
107
+ turn_score = turn_logit.softmax(dim=-1)
108
+ turn_stream_masked_score = turn_score[turn_stream_mask]
109
+ if frame_token_interval_threshold > 0:
110
+ lower_threshold_mask = turn_stream_masked_score[:, frame_token_interval_id] < frame_token_interval_threshold
111
+ turn_stream_masked_score[lower_threshold_mask] = 0
112
+ turn_stream_masked_pred_mask = turn_stream_masked_score.argmax(dim=-1) != frame_token_interval_id
113
+ if turn_stream_masked_pred_mask.any():
114
+ frame_diff = turn_stream_mask.sum() - turn_stream_masked_pred_mask.nonzero()[0,0] - 1
115
+ else:
116
+ ## 3.3.2: the most complex part,reply after turn_num_frames. we assume the 'assistant: ...' not exists
117
+ turn_last_stream_idx = turn_stream_mask.nonzero()[-1,0]
118
+ past_key_values_before_assistant = self.trim_past_key_values(past_key_values, 0, turn_start + turn_last_stream_idx + 1)
119
+ if r == num_turns - 1: # no future frame. we assume the model should receive a signal when streaming ends (e.g. close button).
120
+ frame_diff = zero
121
+ else:
122
+ next_turn_num_frames = (input_id[turn_starts[r+1]:turn_stops[r+1]] == v_placeholder_id).sum() // frame_num_tokens
123
+ to_append_num_frames = min(next_turn_num_frames, turn_num_frames - 1) # avoid bias. current as center, two equal left/right side
124
+ if to_append_num_frames == 0:
125
+ frame_diff = zero
126
+ else:
127
+ to_append_frames = frames[past_num_frames+turn_num_frames:past_num_frames+turn_num_frames+to_append_num_frames]
128
+ frame_placeholder = [v_placeholder_id] * frame_num_tokens
129
+ if use_interval:
130
+ frame_placeholder = [frame_token_interval_id] + frame_placeholder
131
+ to_append_input_id = torch.tensor(frame_placeholder * to_append_num_frames, dtype=torch.long, device=device)
132
+ to_append_logit = self.forward(
133
+ input_ids=to_append_input_id[None],
134
+ past_key_values=past_key_values_before_assistant,
135
+ frames=to_append_frames,
136
+ return_dict=True, use_cache=True
137
+ ).logits[0]
138
+ # we only use the last idx of each frame
139
+ idxs = torch.arange(len(frame_placeholder)-1, len(to_append_input_id), len(frame_placeholder), device=device)
140
+ to_append_score = to_append_logit[idxs].softmax(dim=-1)
141
+ if frame_token_interval_threshold > 0:
142
+ lower_threshold_mask = to_append_score[:, frame_token_interval_id] < frame_token_interval_threshold
143
+ to_append_score[lower_threshold_mask] = 0
144
+ to_append_score_pred_mask = to_append_score.argmax(dim=-1) != frame_token_interval_id
145
+ if to_append_score_pred_mask.any():
146
+ frame_diff = -(to_append_score_pred_mask.nonzero()[0,0] + 1)
147
+ else:
148
+ frame_diff = -to_append_num_frames
149
+ frame_diffs.append(frame_diff.abs())
150
+
151
+ ## 2.6 fluency
152
+ if turn_lm_mask.any() and turn_stream_mask.any():
153
+ num_learn_v_tokens = turn_stream_mask.sum()
154
+ num_learn_valid_tokens = turn_lm_masked_label.numel() + num_learn_v_tokens
155
+ if frame_diff == 0:
156
+ fluency = (num_learn_v_tokens + num_lm_correct_tokens) / num_learn_valid_tokens
157
+ elif frame_diff > 0:
158
+ fluency = (num_learn_v_tokens - frame_diff) / num_learn_valid_tokens
159
+ else:
160
+ fluency = (num_learn_v_tokens - 1) / num_learn_valid_tokens
161
+ fluencies.append(fluency)
162
+ ## 2.7 next turn
163
+ past_num_frames += turn_num_frames
164
+ lm_ppl = torch.stack(lm_ppls).mean() if lm_ppls else one
165
+ frame_diff = torch.stack(frame_diffs).float().mean() if frame_diffs else zero
166
+ fluency = torch.stack(fluencies).float().mean() if fluencies else one
167
+ lm_correctness = torch.stack(lm_correctness).float().mean() if lm_correctness else one
168
+ return torch.stack([lm_ppl, frame_diff, fluency, lm_correctness])
169
+
170
+ def trim_past_key_values(self, past_key_values, start, stop):
171
+ return [[past_keys[:,:,start:stop], past_values[:,:,start:stop]] for past_keys, past_values in past_key_values]
172
+
173
+ def fast_greedy_generate(*, model: LiveMixin, inputs_embeds: torch.Tensor, past_key_values: Cache, eos_token_id: int, inplace_output_ids: torch.Tensor):
174
+ for i in range(inplace_output_ids.size(1)):
175
+ outputs = model(inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=True)
176
+ past_key_values = outputs.past_key_values
177
+ new_token_id = outputs.logits[:, -1:].argmax(dim=-1)
178
+ inplace_output_ids[:, i] = new_token_id
179
+ if new_token_id == eos_token_id:
180
+ break
181
+ inputs_embeds = model.get_input_embeddings()(new_token_id)
182
+ return inplace_output_ids[:, :i+1], past_key_values
183
+
184
+ def build_live(
185
+ *,
186
+ is_training: bool,
187
+ config_class: type,
188
+ model_class: type,
189
+ llm_pretrained: str = None,
190
+ finetune_modules: list[str] = None,
191
+ lora_modules: str = None,
192
+ lora_r: int = None,
193
+ lora_alpha: int = None,
194
+ set_vision_inside: bool = False,
195
+ resume_from_checkpoint: str = '',
196
+ attn_implementation: str = 'flash_attention_2',
197
+ torch_dtype: str | torch.dtype = 'auto',
198
+ **kwargs
199
+ ):
200
+ model = model_class.from_pretrained(llm_pretrained, config=config_class.from_pretrained(llm_pretrained, **kwargs), torch_dtype=torch_dtype, attn_implementation=attn_implementation)
201
+ tokenizer = build_live_tokenizer_and_update_config(llm_pretrained, model.config)
202
+ if is_training:
203
+ lora_config = LoraConfig(
204
+ r=lora_r,
205
+ lora_alpha=lora_alpha,
206
+ target_modules=lora_modules,
207
+ lora_dropout=0.05,
208
+ task_type="CAUSAL_LM",
209
+ modules_to_save=finetune_modules,
210
+ inference_mode=False,
211
+ )
212
+ model = get_peft_model(model, lora_config)
213
+ model.print_trainable_parameters()
214
+ else:
215
+ if resume_from_checkpoint and os.path.exists(resume_from_checkpoint):
216
+ model = PeftModel.from_pretrained(model, resume_from_checkpoint, is_trainable=False)
217
+ else:
218
+ logger.warning(f'!!! Fail to load checkpoint: {resume_from_checkpoint}. Return a new initialized model.')
219
+ if set_vision_inside:
220
+ model.set_vision_inside()
221
+ model.requires_grad_(False)
222
+ return model, tokenizer
models/tokenization_live.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer
3
+ from functools import partial
4
+
5
+ from .configuration_live import LiveConfigMixin
6
+
7
+ def get_stream_placeholder_len(num_frames: int, model_config: LiveConfigMixin) -> str:
8
+ return num_frames * model_config.frame_num_tokens * len(model_config.v_placeholder) + len(model_config.frame_token_interval) * (num_frames - 1)
9
+
10
+ def get_stream_placeholder_jinja2(model_config: LiveConfigMixin) -> str:
11
+ return f"'{model_config.frame_token_interval}'.join([{model_config.frame_num_tokens} * '{model_config.v_placeholder}'] * message['num_frames'])"
12
+
13
+ def get_stream_learn_ranges(num_frames: int, model_config: LiveConfigMixin) -> torch.Tensor:
14
+ len_frame_placeholder_with_interval = model_config.frame_num_tokens * len(model_config.v_placeholder) + len(model_config.frame_token_interval)
15
+ intermediate_interval_idxs = torch.arange(
16
+ len_frame_placeholder_with_interval,
17
+ len_frame_placeholder_with_interval * num_frames + 1,
18
+ len_frame_placeholder_with_interval
19
+ ) - len(model_config.frame_token_interval)
20
+ len_learn = len(model_config.frame_token_interval) if model_config.frame_token_interval else len(model_config.v_placeholder)
21
+ learn_ranges = torch.stack([
22
+ intermediate_interval_idxs,
23
+ intermediate_interval_idxs + len_learn
24
+ ], dim=1)
25
+ return learn_ranges
26
+
27
+ def chat_template(self, stream_placeholder_jinja2: str):
28
+ """
29
+ system prompt
30
+ [<v>,<v>,<v>]
31
+ User: ...
32
+ Assistant: ...</s>
33
+ [<v>,<v>]
34
+ Assistant: ...</s>
35
+ User: ...
36
+ Assistant: ...</s>
37
+ """
38
+ template = (
39
+ "{% if messages[0]['role'] == 'system' %}"
40
+ "{{ bos_token + messages[0]['content'] + '\n' }}" # system
41
+ "{% set messages = messages[1:] %}"
42
+ "{% endif %}"
43
+ "{% for message in messages %}"
44
+ "{% if message['role'] == 'user' %}"
45
+ "{% if add_stream_query_prompt %}"
46
+ "{{ ']\nUser: ' + message['content'] }}"
47
+ "{% else %}"
48
+ "{{ '\nUser: ' + message['content'] }}"
49
+ "{% endif %}"
50
+ "{% elif message['role'] == 'assistant' %}"
51
+ "{{ '\nAssistant: ' + message['content'] + eos_token }}"
52
+ "{% elif message['role'] == 'stream' and message['num_frames'] > 0: %}"
53
+ "{{ '\n[' + STREAM_PLACEHOLDER + ']' }}"
54
+ "{% endif %}"
55
+ "{% endfor %}"
56
+ "{% if add_generation_prompt %}"
57
+ "{{ '\nAssistant:' }}"
58
+ "{% elif add_stream_prompt %}"
59
+ "{{ '\n[' }}"
60
+ "{% elif add_stream_generation_prompt %}"
61
+ "{{ ']\nAssistant:' }}"
62
+ "{% endif %}"
63
+ )
64
+ template = template.replace('STREAM_PLACEHOLDER', stream_placeholder_jinja2)
65
+ return template
66
+
67
+ def chat_template_transition(tokenizer):
68
+ return {
69
+ (None, 'system'): tokenizer.bos_token,
70
+ ('system', 'user'): '\n\nUser: ',
71
+ ('system', 'stream'): '\n\n[',
72
+ ('user', 'assistant'): '\nAssistant: ',
73
+ ('user', 'stream'): '\n[',
74
+ ('user', 'user'): '\nUser: ',
75
+ ('assistant', 'user'): f'{tokenizer.eos_token}\nUser: ',
76
+ ('assistant', 'stream'): f'{tokenizer.eos_token}\n[',
77
+ ('stream', 'user'): ']\nUser: ',
78
+ ('stream', 'assistant'): ']\nAssistant: ',
79
+ 'assistant': 'Assistant: ',
80
+ 'eos_token': tokenizer.eos_token,
81
+ }
82
+
83
+ def chat_template_offsets(tokenizer):
84
+ return {k:len(v) for k, v in chat_template_transition(tokenizer).items()}
85
+
86
+ def get_learn_ranges(conversation: list[dict], *, chat_template_offsets: dict[tuple, int], model_config: LiveConfigMixin):
87
+ offset = 0
88
+ learn_ranges = []
89
+ last_role = None
90
+ for message in conversation:
91
+ role = message['role']
92
+ offset += chat_template_offsets[(last_role, role)]
93
+ last_role = role
94
+ if role == 'stream':
95
+ if message.get('learn', False):
96
+ ranges = get_stream_learn_ranges(message['num_frames'], model_config) + offset
97
+ # the last one has ]\n, should also consider \n
98
+ ranges[-1, 1] += 1
99
+ if not isinstance(message['learn'], bool):
100
+ ranges = ranges[:message['learn']]
101
+ learn_ranges.extend([range(r[0], r[1]) for r in ranges])
102
+ offset += get_stream_placeholder_len(message['num_frames'], model_config)
103
+ else:
104
+ if role == 'assistant':
105
+ if message.get('learn', False):
106
+ learn_ranges.append(range(offset - chat_template_offsets['assistant'], offset + len(message['content']) + chat_template_offsets['eos_token']))
107
+ offset += len(message['content'])
108
+ return learn_ranges
109
+
110
+ def build_live_tokenizer_and_update_config(llm_pretrained: str, model_config: LiveConfigMixin) -> AutoTokenizer:
111
+ tokenizer = AutoTokenizer.from_pretrained(llm_pretrained, use_fast=True, padding_side='left')
112
+ tokenizer.add_special_tokens({'additional_special_tokens': [model_config.v_placeholder]})
113
+ v_placeholder_id = len(tokenizer) - 1
114
+ if model_config.frame_token_interval:
115
+ frame_token_interval_id = tokenizer.convert_tokens_to_ids(model_config.frame_token_interval)
116
+ else:
117
+ frame_token_interval_id = None
118
+ tokenizer.pad_token = tokenizer.eos_token
119
+ model_config.update(dict(v_placeholder_id=v_placeholder_id, frame_token_interval_id=frame_token_interval_id, eos_token_id=tokenizer.eos_token_id))
120
+ tokenizer.chat_template = chat_template(tokenizer, get_stream_placeholder_jinja2(model_config))
121
+ tokenizer.get_learn_ranges = partial(get_learn_ranges, chat_template_offsets=chat_template_offsets(tokenizer), model_config=model_config)
122
+ return tokenizer
123
+
124
+ if __name__ == '__main__':
125
+ config = LiveConfigMixin(frame_token_interval=',', frame_token_cls=True, frame_token_pooled=[3,3], frame_num_tokens=10)
126
+ tokenizer = build_live_tokenizer_and_update_config('meta-llama/Meta-Llama-3-8B-Instruct', config)
127
+ chat = [
128
+ {'role': 'system', 'content': 'cool.'},
129
+ {'role': 'stream', 'num_frames': 2, 'learn': 1},
130
+ {'role': 'user', 'content': 'cool?'},
131
+ {'role': 'assistant', 'content': 'cool.', 'learn': True},
132
+ {'role': 'stream', 'num_frames': 3, 'learn': 3},
133
+ {'role': 'assistant', 'content': 'so cool.', 'learn': True},
134
+ ]
135
+ prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)
136
+ learn_ranges = tokenizer.get_learn_ranges(chat)
137
+ batch = tokenizer([prompt], return_offsets_mapping=True, add_special_tokens=False, return_tensors="pt", padding=True)
138
+ batch_labels = torch.full_like(batch.input_ids, -100, dtype=torch.long)
139
+ for text, labels, input_ids, offset_mapping, learn_range in zip(
140
+ [prompt], batch_labels, batch.input_ids, batch.offset_mapping, [learn_ranges]
141
+ ):
142
+ for learn_r in learn_range:
143
+ start = torch.nonzero(offset_mapping[:,0] == learn_r.start).item()
144
+ if offset_mapping[:,0][-1] >= learn_r.stop:
145
+ stop = torch.nonzero(offset_mapping[:,0] == learn_r.stop).item()
146
+ else: # the last eos token
147
+ stop = len(input_ids)
148
+ labels[start-1:stop-1] = input_ids[start:stop]
149
+ # NOTE: input_ids may out of boundary of len(tokenizer) - 1. (1 is the added vision placeholder)
150
+ # this is because some frames has v_placeholder_id target. so replace it with eos token.
151
+ labels[labels >= len(tokenizer) - 1] = tokenizer.eos_token_id
152
+ print(batch.input_ids)
153
+ print(batch_labels)
models/vision_live.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math, torch
2
+ from functools import partial
3
+ from torch import nn, Tensor
4
+ from torchvision.transforms.functional import normalize
5
+ from transformers import AutoModel
6
+ from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
7
+
8
+ from .configuration_live import LiveConfigMixin
9
+
10
+ def _siglip_vision_encode(vision_model: nn.Module, frames: Tensor, frame_token_cls: bool, frame_token_pooled: tuple,
11
+ mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], rescale_factor=0.00392156862745098, **kwargs):
12
+ frames = normalize(frames * rescale_factor, mean=mean, std=std)
13
+ with torch.cuda.amp.autocast():
14
+ vision_outputs = vision_model(frames)
15
+ last_hidden_state = vision_outputs.last_hidden_state
16
+ if frame_token_pooled:
17
+ s = int(math.sqrt(last_hidden_state.shape[1]))
18
+ spatial_tokens = torch.nn.functional.adaptive_avg_pool2d(
19
+ last_hidden_state.reshape(
20
+ last_hidden_state.shape[0], s, s, last_hidden_state.shape[-1]
21
+ ).permute(0, 3, 1, 2),
22
+ frame_token_pooled
23
+ ).flatten(2, 3).permute(0, 2, 1)
24
+ if not frame_token_cls:
25
+ return spatial_tokens
26
+ if frame_token_cls:
27
+ cls_token = vision_outputs.pooler_output[:, None]
28
+ if not frame_token_pooled:
29
+ return cls_token
30
+ return torch.cat([cls_token, spatial_tokens], dim=1)
31
+
32
+ def _clip_vision_encode(vision_model: nn.Module, frames: Tensor, frame_token_cls: bool, frame_token_pooled: tuple,
33
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, rescale_factor=0.00392156862745098, **kwargs):
34
+ frames = normalize(frames * rescale_factor, mean=mean, std=std)
35
+ with torch.cuda.amp.autocast():
36
+ vision_outputs = vision_model(frames)
37
+ last_hidden_state = vision_outputs.last_hidden_state
38
+ if frame_token_pooled:
39
+ s = int(math.sqrt(last_hidden_state.shape[1]))
40
+ spatial_tokens = torch.nn.functional.adaptive_avg_pool2d(
41
+ last_hidden_state[:,1:].reshape(
42
+ last_hidden_state.shape[0], s, s, last_hidden_state.shape[-1]
43
+ ).permute(0, 3, 1, 2),
44
+ frame_token_pooled
45
+ ).flatten(2, 3).permute(0, 2, 1)
46
+ if not frame_token_cls:
47
+ return spatial_tokens
48
+ if frame_token_cls:
49
+ cls_token = last_hidden_state[:,0]
50
+ if not frame_token_pooled:
51
+ return cls_token
52
+ return torch.cat([cls_token, spatial_tokens], dim=1)
53
+
54
+ def build_live_vision(config: LiveConfigMixin):
55
+ model = AutoModel.from_pretrained(config.vision_pretrained).vision_model
56
+ if 'google/siglip-large-patch16-384' == config.vision_pretrained:
57
+ return model, partial(_siglip_vision_encode, frame_token_cls=config.frame_token_cls, frame_token_pooled=config.frame_token_pooled)
58
+ elif 'laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90k' == config.vision_pretrained or 'openai/clip-vit-large-patch14-336' == config.vision_pretrained:
59
+ return model, partial(_clip_vision_encode, config)
60
+ else:
61
+ raise ValueError(f'Unverified vision_pretrained: {config.vision_pretrained}')