sagar007 commited on
Commit
76cf633
1 Parent(s): d98db84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -129,8 +129,6 @@ def load_model(model_path):
129
  model.to(device)
130
  return model
131
 
132
- # Don't load the model here
133
- # model = load_model('gpt_model.pth')
134
  enc = tiktoken.get_encoding('gpt2')
135
 
136
  # Update the generate_text function
@@ -166,6 +164,14 @@ async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
166
  if len(generated) == max_length:
167
  yield "... (output truncated due to length)"
168
 
 
 
 
 
 
 
 
 
169
 
170
  # # Your existing imports and model code here...
171
 
 
129
  model.to(device)
130
  return model
131
 
 
 
132
  enc = tiktoken.get_encoding('gpt2')
133
 
134
  # Update the generate_text function
 
164
  if len(generated) == max_length:
165
  yield "... (output truncated due to length)"
166
 
167
+ # Add the gradio_generate function
168
+ @spaces.GPU(duration=60)
169
+ async def gradio_generate(prompt, max_length, temperature, top_k):
170
+ output = ""
171
+ async for token in generate_text(prompt, max_length, temperature, top_k):
172
+ output += token
173
+ yield output
174
+
175
 
176
  # # Your existing imports and model code here...
177