zhanghaoji
init
eb0678a
raw
history blame
No virus
14.1 kB
# This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors.
# Based on https://github.com/haotian-liu/LLaVA.
"""
This file demonstrates an implementation of a multiprocess Real-time Long Video Understanding System. With a multiprocess logging module.
main process: CLI server I/O, LLM inference
process-1: logger listener
process-2: frame generator,
process-3: frame memory manager
Author: Haoji Zhang, Haotian Liu
(This code is based on https://github.com/haotian-liu/LLaVA)
"""
import argparse
import requests
import logging
import torch
import numpy as np
import time
import os
from flash_vstream.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from flash_vstream.conversation import conv_templates, SeparatorStyle
from flash_vstream.model.builder import load_pretrained_model
from flash_vstream.utils import disable_torch_init
from flash_vstream.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from torch.multiprocessing import Process, Queue, Manager
from transformers import TextStreamer
from decord import VideoReader
from datetime import datetime
from PIL import Image
from io import BytesIO
class _Metric:
def __init__(self):
self._latest_value = None
self._sum = 0.0
self._max = 0.0
self._count = 0
@property
def val(self):
return self._latest_value
@property
def max(self):
return self._max
@property
def avg(self):
if self._count == 0:
return float('nan')
return self._sum / self._count
def add(self, value):
self._latest_value = value
self._sum += value
self._count += 1
if value > self._max:
self._max = value
def __str__(self):
latest_formatted = f"{self.val:.6f}" if self.val is not None else "None"
average_formatted = f"{self.avg:.6f}"
max_formatted = f"{self.max:.6f}"
return f"{latest_formatted} ({average_formatted}, {max_formatted})"
class MetricMeter:
def __init__(self):
self._metrics = {}
def add(self, key, value):
if key not in self._metrics:
self._metrics[key] = _Metric()
self._metrics[key].add(value)
def val(self, key):
metric = self._metrics.get(key)
if metric is None or metric.val is None:
raise ValueError(f"No values have been added for key '{key}'.")
return metric.val
def avg(self, key):
metric = self._metrics.get(key)
if metric is None:
raise ValueError(f"No values have been added for key '{key}'.")
return metric.avg
def max(self, key):
metric = self._metrics.get(key)
if metric is None:
raise ValueError(f"No values have been added for key '{key}'.")
return metric.max
def __getitem__(self, key):
metric = self._metrics.get(key)
if metric is None:
raise KeyError(f"The key '{key}' does not exist.")
return str(metric)
def load_image(image_file):
if image_file.startswith('http://') or image_file.startswith('https://'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
def listener(queue, filename):
############## Start sub process-1: Listener #############
import sys, traceback
root = logging.getLogger()
root.setLevel(logging.DEBUG)
# h = logging.StreamHandler(sys.stdout)
h = logging.FileHandler(filename)
f = logging.Formatter('%(asctime)s %(processName)-10s %(name)s %(levelname)-8s %(message)s')
h.setFormatter(f)
root.addHandler(h)
while True:
try:
record = queue.get()
if record is None: # None is a signal to finish
break
logger = logging.getLogger(record.name)
logger.handle(record) # No level or filter logic applied - just do it!
except Exception:
import sys, traceback
print('Whoops! Problem:', file=sys.stderr)
traceback.print_exc(file=sys.stderr)
def worker_configurer(queue):
h = logging.handlers.QueueHandler(queue) # Just the one handler needed
root = logging.getLogger()
root.addHandler(h)
root.setLevel(logging.DEBUG)
def video_stream_similator(video_file, frame_queue, log_queue, video_fps=1.0, play_speed=1.0):
############## Start sub process-2: Simulator #############
worker_configurer(log_queue)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
vr = VideoReader(video_file)
sample_fps = round(vr.get_avg_fps() / video_fps)
frame_idx = [i for i in range(0, len(vr), sample_fps)]
video = vr.get_batch(frame_idx).asnumpy()
video = np.repeat(video, 6, axis=0)
length = video.shape[0]
sleep_time = 1 / video_fps / play_speed
time_meter = MetricMeter()
logger.info(f'Simulator Process: start, length = {length}')
try:
for start in range(0, length):
start_time = time.perf_counter()
end = min(start + 1, length)
video_clip = video[start:end]
frame_queue.put(video_clip)
if start > 0:
time_meter.add('real_sleep', start_time - last_start)
logger.info(f'Simulator: write {end - start} frames,\t{start} to {end},\treal_sleep={time_meter["real_sleep"]}')
if end < length:
time.sleep(sleep_time)
last_start = start_time
frame_queue.put(None)
except Exception as e:
print(f'Simulator Exception: {e}')
time.sleep(0.1)
logger.info(f'Simulator Process: end')
def frame_memory_manager(model, image_processor, frame_queue, log_queue):
############## Start sub process-3: Memory Manager #############
worker_configurer(log_queue)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
time_meter = MetricMeter()
logger.info(f'MemManager Process: start')
frame_cnt = 0
while True:
try:
video_clip = frame_queue.get()
start_time = time.perf_counter()
if video_clip is None:
logger.info(f'MemManager: Ooops, get None')
break
logger.info(f'MemManager: get {video_clip.shape[0]} frames from queue')
image = image_processor.preprocess(video_clip, return_tensors='pt')['pixel_values']
image = image.unsqueeze(0)
image_tensor = image.to(model.device, dtype=torch.float16)
# time_2 = time.perf_counter()
logger.info(f'MemManager: Start embedding')
with torch.inference_mode():
model.embed_video_streaming(image_tensor)
logger.info(f'MemManager: End embedding')
end_time = time.perf_counter()
if frame_cnt > 0:
time_meter.add('memory_latency', end_time - start_time)
logger.info(f'MemManager: embedded {video_clip.shape[0]} frames,\tidx={frame_cnt},\tmemory_latency={time_meter["memory_latency"]}')
else:
logger.info(f'MemManager: embedded {video_clip.shape[0]} frames,\tidx={frame_cnt},\tmemory_latency={end_time - start_time:.6f}, not logged')
frame_cnt += video_clip.shape[0]
except Exception as e:
print(f'MemManager Exception: {e}')
time.sleep(0.1)
logger.info(f'MemManager Process: end')
def main(args):
# torch.multiprocessing.log_to_stderr(logging.DEBUG)
torch.multiprocessing.set_start_method('spawn', force=True)
disable_torch_init()
log_queue = Queue()
frame_queue = Queue(maxsize=10)
processes = []
############## Start listener process #############
p1 = Process(target=listener, args=(log_queue, args.log_file))
processes.append(p1)
p1.start()
############## Start main process #############
worker_configurer(log_queue)
logger = logging.getLogger(__name__)
model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
logger.info(f'Using conv_mode={args.conv_mode}')
conv = conv_templates[args.conv_mode].copy()
if "mpt" in model_name.lower():
roles = ('user', 'assistant')
else:
roles = conv.roles
with Manager() as manager:
image_tensor = None
model.use_video_streaming_mode = True
model.video_embedding_memory = manager.list()
if args.video_max_frames is not None:
model.config.video_max_frames = args.video_max_frames
logger.info(f'Important: set model.config.video_max_frames = {model.config.video_max_frames}')
logger.info(f'Important: set video_fps = {args.video_fps}')
logger.info(f'Important: set play_speed = {args.play_speed}')
############## Start simulator process #############
p2 = Process(target=video_stream_similator,
args=(args.video_file, frame_queue, log_queue, args.video_fps, args.play_speed))
processes.append(p2)
p2.start()
############## Start memory manager process #############
p3 = Process(target=frame_memory_manager,
args=(model, image_processor, frame_queue, log_queue))
processes.append(p3)
p3.start()
# start QA server
start_time = datetime.now()
time_meter = MetricMeter()
conv_cnt = 0
while True:
time.sleep(5)
try:
# inp = input(f"{roles[0]}: ")
inp = "what is in the video?"
except EOFError:
inp = ""
if not inp:
print("exit...")
break
# 获取当前时间
now = datetime.now()
conv_start_time = time.perf_counter()
# 将当前时间格式化为字符串
current_time = now.strftime("%H:%M:%S")
duration = now.timestamp() - start_time.timestamp()
# 打印当前时间
print("\nCurrent Time:", current_time, "Run for:", duration)
print(f"{roles[0]}: {inp}", end="\n")
print(f"{roles[1]}: ", end="")
# every conversation is a new conversation
conv = conv_templates[args.conv_mode].copy()
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
llm_start_time = time.perf_counter()
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
max_new_tokens=args.max_new_tokens,
streamer=streamer,
use_cache=True,
stopping_criteria=[stopping_criteria]
)
llm_end_time = time.perf_counter()
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
conv.messages[-1][-1] = outputs
conv_end_time = time.perf_counter()
if conv_cnt > 0:
time_meter.add('conv_latency', conv_end_time - conv_start_time)
time_meter.add('llm_latency', llm_end_time - llm_start_time)
time_meter.add('real_sleep', conv_start_time - last_conv_start_time)
logger.info(f'CliServer: idx={conv_cnt},\treal_sleep={time_meter["real_sleep"]},\tconv_latency={time_meter["conv_latency"]},\tllm_latency={time_meter["llm_latency"]}')
else:
logger.info(f'CliServer: idx={conv_cnt},\tconv_latency={conv_end_time - conv_start_time},\tllm_latency={llm_end_time - llm_start_time}')
conv_cnt += 1
last_conv_start_time = conv_start_time
for p in processes:
p.terminate()
print("All processes finished.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--image-file", type=str, default=None)
parser.add_argument("--video-file", type=str, default=None)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--conv-mode", type=str, default="vicuna_v1")
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
parser.add_argument("--debug", action="store_true")
parser.add_argument("--log-file", type=str, default="tmp_cli.log")
parser.add_argument("--use_1process", action="store_true")
parser.add_argument("--video_max_frames", type=int, default=None)
parser.add_argument("--video_fps", type=float, default=1.0)
parser.add_argument("--play_speed", type=float, default=1.0)
args = parser.parse_args()
main(args)