Vira21 commited on
Commit
3c50dcb
1 Parent(s): bd031de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import gradio as gr
4
 
5
  # Constants for the Model
@@ -10,11 +10,12 @@ MODEL_DESC = "A demo for the SeaLLMs-v3-7B-Chat language model."
10
  # Load the tokenizer
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
12
 
13
- # Load the model with efficient settings
 
14
  model = AutoModelForCausalLM.from_pretrained(
15
  MODEL_PATH,
16
  device_map="auto",
17
- torch_dtype=torch.float32, # Switching to float32 to avoid numerical instability
18
  low_cpu_mem_usage=True
19
  )
20
 
@@ -31,7 +32,7 @@ def generate_response(prompt):
31
  try:
32
  outputs = model.generate(
33
  **inputs,
34
- max_length=512,
35
  num_return_sequences=1,
36
  no_repeat_ngram_size=2,
37
  early_stopping=True,
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
  import gradio as gr
4
 
5
  # Constants for the Model
 
10
  # Load the tokenizer
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
12
 
13
+ # Load the model with efficient settings, using 8-bit quantization to reduce memory usage
14
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
15
  model = AutoModelForCausalLM.from_pretrained(
16
  MODEL_PATH,
17
  device_map="auto",
18
+ quantization_config=quantization_config,
19
  low_cpu_mem_usage=True
20
  )
21
 
 
32
  try:
33
  outputs = model.generate(
34
  **inputs,
35
+ max_length=256, # Reduced max_length to lower memory usage
36
  num_return_sequences=1,
37
  no_repeat_ngram_size=2,
38
  early_stopping=True,