Khmer_SeaLLM / app.py
Vira21's picture
Update app.py
cfffff2 verified
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
# Constants for the Model
MODEL_PATH = "SeaLLMs/SeaLLMs-v3-7B-Chat"
MODEL_TITLE = "SeaLLMs Chat Model"
MODEL_DESC = "A demo for the SeaLLMs-v3-7B-Chat language model."
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
# Load the model with CPU offloading to reduce memory usage
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
device_map="auto",
torch_dtype=torch.float32, # Use float32 to avoid numerical issues
offload_folder="./offload", # Specify a folder for offloading to manage memory
low_cpu_mem_usage=True
)
# Enable gradient checkpointing for memory efficiency
model.gradient_checkpointing_enable()
def generate_response(prompt):
# Limit the input length to prevent excessive memory usage
prompt = prompt[:512]
inputs = tokenizer(prompt, return_tensors="pt")
# Move inputs to the same device as the model
inputs = {key: value.to(model.device) for key, value in inputs.items()}
# Generate response
try:
with torch.no_grad(): # Disable gradient calculation to save memory during inference
outputs = model.generate(
**inputs,
max_length=128, # Further reduced max_length to lower memory usage
num_return_sequences=1,
no_repeat_ngram_size=2,
early_stopping=True,
temperature=0.7 # Adding temperature scaling to control output
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
except RuntimeError as e:
# Handle numerical instability gracefully
response = "An error occurred during generation. Please try again with a different prompt."
print(f"RuntimeError: {e}")
return response
# Create the Gradio interface
iface = gr.Interface(
fn=generate_response,
inputs=gr.Textbox(lines=5, label="Enter your message:"),
outputs=gr.Textbox(label="Model's response:"),
title=MODEL_TITLE,
description=MODEL_DESC,
theme="default" # You can specify any custom theme or remove this line
)
if __name__ == "__main__":
iface.launch()