p208p2002 commited on
Commit
f8bb73e
1 Parent(s): d37721a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -5
app.py CHANGED
@@ -1,25 +1,79 @@
1
  import gradio as gr
2
- from transformers import BertTokenizerFast, BertForSequenceClassification
3
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  model = BertForSequenceClassification.from_pretrained('./ch-sent-check-model')
6
  tokenizer = BertTokenizerFast.from_pretrained('./ch-sent-check-model')
 
7
 
8
  def judge(sentence):
9
  input_ids = tokenizer(sentence,return_tensors='pt')['input_ids']
10
  out = model(input_ids)
11
  logits = out.logits
12
- pred = torch.argmax(logits,dim=-1).item()
 
13
  pred_text = 'Incorrect' if pred == 0 else 'Correct'
14
- return pred_text
 
 
 
 
15
 
16
  iface = gr.Interface(
17
  fn=judge,
18
  inputs=gr.Textbox(
19
  label="請輸入一段中文句子來檢測正確性",
20
  lines=1,
21
- value="請注意用字的鄭確性",
22
  ),
23
- outputs="text"
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  )
25
  iface.launch()
 
1
  import gradio as gr
2
+ from transformers import BertTokenizerFast, BertForSequenceClassification,GPT2LMHeadModel,BartForConditionalGeneration
3
  import torch
4
+ import math
5
+
6
+ class CHSentenceSmoothScorer():
7
+ def __init__(self) -> None:
8
+ super().__init__()
9
+
10
+ self.tokenizer = BertTokenizerFast.from_pretrained(
11
+ "fnlp/bart-base-chinese")
12
+ self.model = BartForConditionalGeneration.from_pretrained(
13
+ "fnlp/bart-base-chinese")
14
+
15
+ def __call__(self, sentences):
16
+ input_ids = self.tokenizer.batch_encode_plus(
17
+ sentences, return_tensors='pt',
18
+ padding=True,
19
+ max_length=50,
20
+ truncation='longest_first'
21
+ )['input_ids']
22
+ logits = self.model(input_ids).logits
23
+
24
+ softmax = torch.softmax(logits, dim=-1)
25
+
26
+ out = []
27
+ for i, sentence in enumerate(sentences):
28
+ sent_token_ids = input_ids[i].tolist()
29
+ sent_token_ids = list(
30
+ filter(lambda x: x not in [self.tokenizer.pad_token_id], sent_token_ids))
31
+ ppl = 0.0
32
+ for j, token_id in enumerate(sent_token_ids):
33
+ ppl += math.log(softmax[i][j][token_id].item())
34
+ ppl = -1*(ppl/len(sent_token_ids))
35
+ prob_socre = math.exp(ppl*-1)
36
+ out.append(prob_socre)
37
+
38
+ return out
39
+
40
 
41
  model = BertForSequenceClassification.from_pretrained('./ch-sent-check-model')
42
  tokenizer = BertTokenizerFast.from_pretrained('./ch-sent-check-model')
43
+ smooth_scorer = CHSentenceSmoothScorer()
44
 
45
  def judge(sentence):
46
  input_ids = tokenizer(sentence,return_tensors='pt')['input_ids']
47
  out = model(input_ids)
48
  logits = out.logits
49
+ prob = torch.softmax(logits,dim=-1)
50
+ pred = torch.argmax(prob,dim=-1).item()
51
  pred_text = 'Incorrect' if pred == 0 else 'Correct'
52
+
53
+ correct_prob = prob[0][1].item()
54
+ pred_text = pred_text + f", score: {round(correct_prob*100,2)}"
55
+ smooth_score = round(smooth_scorer([sentence])[0]*100,2)
56
+ return pred_text,smooth_score
57
 
58
  iface = gr.Interface(
59
  fn=judge,
60
  inputs=gr.Textbox(
61
  label="請輸入一段中文句子來檢測正確性",
62
  lines=1,
 
63
  ),
64
+ outputs=[
65
+ gr.Textbox(
66
+ label="正確性檢查",
67
+ lines=1
68
+ ),
69
+ gr.Textbox(
70
+ label="流暢性檢查",
71
+ lines=1
72
+ )
73
+ ],
74
+ examples = [
75
+ '請注意用字的鄭確性',
76
+ '請注意用字的正確性'
77
+ ]
78
  )
79
  iface.launch()