rachith commited on
Commit
828b751
1 Parent(s): fd8df62
Files changed (1) hide show
  1. app.py +16 -11
app.py CHANGED
@@ -1,19 +1,24 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
 
4
- # def greet(name):
5
- # return "Hello " + name + "!!"
6
 
7
- # iface = gr.Interface(fn=greet, inputs="text", outputs="text")
8
- # iface.launch()
9
 
10
- model = pipeline("text-generation")
 
 
11
 
12
 
13
- def predict(prompt):
14
- completion = model(prompt)[0]["generated_text"]
15
- return completion
16
 
17
- predict("My favorite programming language is")
 
 
 
 
 
 
18
 
19
- gr.Interface(fn=predict, inputs="text", outputs="text").launch()
 
1
  import gradio as gr
2
+ from transformers import BartForSequenceClassification, BartTokenizer
3
 
 
 
4
 
5
+ # model = pipeline("text-generation")
 
6
 
7
+ # following https://joeddav.github.io/blog/2020/05/29/ZSL.html
8
+ tokenizer_bart = BartTokenizer.from_pretrained('facebook/bart-large-mnli')
9
+ model_bart_sq = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
10
 
11
 
12
+ # def predict(prompt):
13
+ # completion = model(prompt)[0]["generated_text"]
14
+ # return completion
15
 
16
+ def zs(premise,hypothesis):
17
+ input_ids = tokenizer_bart.encode(premise, hypothesis, return_tensors='pt')
18
+ logits = model_bart_sq(input_ids)[0]
19
+ entail_contradiction_logits = logits[:,[0,2]]
20
+ probs = entail_contradiction_logits.softmax(dim=1)
21
+ true_prob = probs[:,1].item() * 100
22
+ return true_prob
23
 
24
+ gr.Interface(fn=zs, inputs="text", outputs="text").launch()