import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig import gradio as gr # Constants for the Model MODEL_PATH = "SeaLLMs/SeaLLMs-v3-7B-Chat" MODEL_TITLE = "SeaLLMs Chat Model" MODEL_DESC = "A demo for the SeaLLMs-v3-7B-Chat language model." # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) # Load the model with efficient settings, using 8-bit quantization to reduce memory usage quantization_config = BitsAndBytesConfig(load_in_8bit=True) model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, device_map="auto", quantization_config=quantization_config, low_cpu_mem_usage=True ) # Enable gradient checkpointing for memory efficiency model.gradient_checkpointing_enable() def generate_response(prompt): inputs = tokenizer(prompt, return_tensors="pt") # Move inputs to the same device as the model inputs = {key: value.to(model.device) for key, value in inputs.items()} # Generate response try: outputs = model.generate( **inputs, max_length=256, # Reduced max_length to lower memory usage num_return_sequences=1, no_repeat_ngram_size=2, early_stopping=True, temperature=0.7 # Adding temperature scaling to control output ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) except RuntimeError as e: # Handle numerical instability gracefully response = "An error occurred during generation. Please try again with a different prompt." print(f"RuntimeError: {e}") return response # Create the Gradio interface iface = gr.Interface( fn=generate_response, inputs=gr.Textbox(lines=5, label="Enter your message:"), outputs=gr.Textbox(label="Model's response:"), title=MODEL_TITLE, description=MODEL_DESC, theme="default" # You can specify any custom theme or remove this line ) if __name__ == "__main__": iface.launch()