McAwesomeville commited on
Commit
2c0a312
1 Parent(s): 8ed51d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -8
app.py CHANGED
@@ -5,13 +5,13 @@ import subprocess
5
  subprocess.run(["pip", "install", "-r", "requirements.txt"])
6
 
7
  import streamlit as st
8
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
 
10
  def main():
11
  st.title("Hugging Face SQL Generator")
12
 
13
  # Get user input
14
- prompt = st.text_area("Enter your prompt:")
15
 
16
  if st.button("Generate SQL"):
17
  # Call a function to generate SQL using the Hugging Face model
@@ -24,15 +24,13 @@ def main():
24
  def generate_sql(prompt):
25
  # Load the "NumbersStation/nsql-350M" model
26
  model_name = "NumbersStation/nsql-350M"
27
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
28
  tokenizer = AutoTokenizer.from_pretrained(model_name)
29
 
30
  # Tokenize and generate SQL
31
- inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
32
- outputs = model.generate(**inputs)
33
-
34
- # Decode the generated SQL
35
- sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
36
 
37
  return sql_query
38
 
 
5
  subprocess.run(["pip", "install", "-r", "requirements.txt"])
6
 
7
  import streamlit as st
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
 
10
  def main():
11
  st.title("Hugging Face SQL Generator")
12
 
13
  # Get user input
14
+ prompt = st.text_area("Enter your SQL prompt:")
15
 
16
  if st.button("Generate SQL"):
17
  # Call a function to generate SQL using the Hugging Face model
 
24
  def generate_sql(prompt):
25
  # Load the "NumbersStation/nsql-350M" model
26
  model_name = "NumbersStation/nsql-350M"
27
+ model = AutoModelForCausalLM.from_pretrained(model_name)
28
  tokenizer = AutoTokenizer.from_pretrained(model_name)
29
 
30
  # Tokenize and generate SQL
31
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
32
+ generated_ids = model.generate(input_ids, max_length=500)
33
+ sql_query = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
 
 
34
 
35
  return sql_query
36