rachith commited on
Commit
b613751
1 Parent(s): 14b51c3

removed neutral

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -12,10 +12,11 @@ model_bart_sq = BartForSequenceClassification.from_pretrained('facebook/bart-lar
12
  def zs(premise,hypothesis):
13
  input_ids = tokenizer_bart.encode(premise, hypothesis, return_tensors='pt')
14
  logits = model_bart_sq(input_ids)[0]
15
- entail_contradiction_logits = logits[:,[0,1,2]]
 
16
  probs = entail_contradiction_logits.softmax(dim=1)
17
  contra_prob = round(probs[:,0].item() * 100,2)
18
- neut_prob = round(probs[:,1].item() * 100,2)
19
  entail_prob = round(probs[:,2].item() * 100,2)
20
  return contra_prob, neut_prob, entail_prob
21
 
 
12
  def zs(premise,hypothesis):
13
  input_ids = tokenizer_bart.encode(premise, hypothesis, return_tensors='pt')
14
  logits = model_bart_sq(input_ids)[0]
15
+ # entail_contradiction_logits = logits[:,[0,1,2]]
16
+ entail_contradiction_logits = logits[:,[0,2]]
17
  probs = entail_contradiction_logits.softmax(dim=1)
18
  contra_prob = round(probs[:,0].item() * 100,2)
19
+ # neut_prob = round(probs[:,1].item() * 100,2)
20
  entail_prob = round(probs[:,2].item() * 100,2)
21
  return contra_prob, neut_prob, entail_prob
22