juierror commited on
Commit
4124a85
1 Parent(s): fface4a

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +33 -0
README.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ datasets:
4
+ - wikisql
5
+ widget:
6
+ - text: 'question: get people name with age equal 25 table: id, name, age'
7
+ ---
8
+
9
+ # How to use
10
+ ```python
11
+ from typing import List
12
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
13
+
14
+ tokenizer = AutoTokenizer.from_pretrained("juierror/text-to-sql-with-table-schema")
15
+ model = AutoModelForSeq2SeqLM.from_pretrained("juierror/text-to-sql-with-table-schema")
16
+
17
+ def prepare_input(question: str, table: List[str]):
18
+ table_prefix = "table:"
19
+ question_prefix = "question:"
20
+ join_table = ",".join(table)
21
+ inputs = f"{question_prefix} {question} {table_prefix} {join_table}"
22
+ input_ids = tokenizer(inputs, max_length=512, return_tensors="pt").input_ids
23
+ return input_ids
24
+
25
+ def inference(question: str, table: List[str]) -> str:
26
+ input_data = prepare_input(question=question, table=table)
27
+ input_data = input_data.to(model.device)
28
+ outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=700)
29
+ result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True)
30
+ return result
31
+
32
+ print(inference(question="get people name with age equal 25", table=["id", "name", "age"]))
33
+ ```