Spaces:
Running
Running
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
import gradio as gr | |
# Constants for the Model | |
MODEL_PATH = "SeaLLMs/SeaLLMs-v3-7B-Chat" | |
MODEL_TITLE = "SeaLLMs Chat Model" | |
MODEL_DESC = "A demo for the SeaLLMs-v3-7B-Chat language model." | |
# Load the tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
# Load the model with efficient settings, using 8-bit quantization to reduce memory usage | |
quantization_config = BitsAndBytesConfig(load_in_8bit=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_PATH, | |
device_map="auto", | |
quantization_config=quantization_config, | |
low_cpu_mem_usage=True | |
) | |
# Enable gradient checkpointing for memory efficiency | |
model.gradient_checkpointing_enable() | |
def generate_response(prompt): | |
inputs = tokenizer(prompt, return_tensors="pt") | |
# Move inputs to the same device as the model | |
inputs = {key: value.to(model.device) for key, value in inputs.items()} | |
# Generate response | |
try: | |
outputs = model.generate( | |
**inputs, | |
max_length=256, # Reduced max_length to lower memory usage | |
num_return_sequences=1, | |
no_repeat_ngram_size=2, | |
early_stopping=True, | |
temperature=0.7 # Adding temperature scaling to control output | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
except RuntimeError as e: | |
# Handle numerical instability gracefully | |
response = "An error occurred during generation. Please try again with a different prompt." | |
print(f"RuntimeError: {e}") | |
return response | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=generate_response, | |
inputs=gr.Textbox(lines=5, label="Enter your message:"), | |
outputs=gr.Textbox(label="Model's response:"), | |
title=MODEL_TITLE, | |
description=MODEL_DESC, | |
theme="default" # You can specify any custom theme or remove this line | |
) | |
if __name__ == "__main__": | |
iface.launch() |