Khmer_SeaLLM / app.py
Vira21's picture
Update app.py
3c50dcb verified
raw
history blame
2 kB
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import gradio as gr
# Constants for the Model
MODEL_PATH = "SeaLLMs/SeaLLMs-v3-7B-Chat"
MODEL_TITLE = "SeaLLMs Chat Model"
MODEL_DESC = "A demo for the SeaLLMs-v3-7B-Chat language model."
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
# Load the model with efficient settings, using 8-bit quantization to reduce memory usage
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
device_map="auto",
quantization_config=quantization_config,
low_cpu_mem_usage=True
)
# Enable gradient checkpointing for memory efficiency
model.gradient_checkpointing_enable()
def generate_response(prompt):
inputs = tokenizer(prompt, return_tensors="pt")
# Move inputs to the same device as the model
inputs = {key: value.to(model.device) for key, value in inputs.items()}
# Generate response
try:
outputs = model.generate(
**inputs,
max_length=256, # Reduced max_length to lower memory usage
num_return_sequences=1,
no_repeat_ngram_size=2,
early_stopping=True,
temperature=0.7 # Adding temperature scaling to control output
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
except RuntimeError as e:
# Handle numerical instability gracefully
response = "An error occurred during generation. Please try again with a different prompt."
print(f"RuntimeError: {e}")
return response
# Create the Gradio interface
iface = gr.Interface(
fn=generate_response,
inputs=gr.Textbox(lines=5, label="Enter your message:"),
outputs=gr.Textbox(label="Model's response:"),
title=MODEL_TITLE,
description=MODEL_DESC,
theme="default" # You can specify any custom theme or remove this line
)
if __name__ == "__main__":
iface.launch()