McAwesomeville commited on
Commit
332bc83
1 Parent(s): 1b7f8f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -1,9 +1,11 @@
1
  # app.py
2
  import subprocess
 
 
3
  subprocess.run(["pip", "install", "-r", "requirements.txt"])
4
- import streamlit as st
5
- from transformers import AutoModelForSeq2SeqLM
6
 
 
 
7
 
8
  def main():
9
  st.title("Hugging Face SQL Generator")
@@ -20,17 +22,17 @@ def main():
20
  st.code(sql_result, language="sql")
21
 
22
  def generate_sql(prompt):
23
- # Load the "NumbersStation/nsql-350M" model
24
- model_name = "NumbersStation/nsql-350M"
25
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
26
  tokenizer = AutoTokenizer.from_pretrained(model_name)
27
 
28
  # Tokenize and generate SQL
29
- inputs = tokenizer(prompt, return_tensors="pt")
30
- outputs = model(**inputs)
31
 
32
  # Decode the generated SQL
33
- sql_query = tokenizer.batch_decode(outputs["output_ids"], skip_special_tokens=True)[0]
34
 
35
  return sql_query
36
 
 
1
  # app.py
2
  import subprocess
3
+
4
+ # Install dependencies from requirements.txt
5
  subprocess.run(["pip", "install", "-r", "requirements.txt"])
 
 
6
 
7
+ import streamlit as st
8
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
 
10
  def main():
11
  st.title("Hugging Face SQL Generator")
 
22
  st.code(sql_result, language="sql")
23
 
24
  def generate_sql(prompt):
25
+ # Load the "facebook/bart-large-cnn" model
26
+ model_name = "facebook/bart-large-cnn"
27
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
28
  tokenizer = AutoTokenizer.from_pretrained(model_name)
29
 
30
  # Tokenize and generate SQL
31
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
32
+ outputs = model.generate(**inputs)
33
 
34
  # Decode the generated SQL
35
+ sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
36
 
37
  return sql_query
38