Larisa Kolesnichenko commited on
Commit
c0de02a
1 Parent(s): e566d5e

Add the wrapper and app files

Browse files
Files changed (2) hide show
  1. app.py +33 -0
  2. model_wrapper.py +107 -0
app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import model_wrapper
3
+
4
+
5
+
6
+ model = model_wrapper.PredictionModel()
7
+
8
+ def pretty_print_opinion(opinion_dict):
9
+ res = []
10
+ maxlen = max([len(key) for key in opinion_dict.keys()]) + 2
11
+ maxlen = 0
12
+ for key, value in opinion_dict.items():
13
+ if key == 'Polarity':
14
+ res.append(f'{(key + ":").ljust(maxlen)} {value}')
15
+ else:
16
+ res.append(f'{(key + ":").ljust(maxlen)} \'{" ".join(value[0])}\'')
17
+ return '\n'.join(res) + '\n'
18
+
19
+
20
+ def predict(text):
21
+ predictions = model.predict([text])
22
+ prediction = predictions[0]
23
+ results = []
24
+ if not prediction['opinions']:
25
+ return 'No opinions detected'
26
+ for opinion in prediction['opinions']:
27
+ results.append(pretty_print_opinion(opinion))
28
+
29
+ return '\n'.join(results)
30
+
31
+
32
+ iface = gr.Interface(fn=predict, inputs="text", outputs="text")
33
+ iface.launch()
model_wrapper.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import tempfile
4
+ import sys
5
+ import datetime
6
+ import re
7
+ sys.path.append('mtool')
8
+
9
+ import torch
10
+
11
+ from model.model import Model
12
+ from data.dataset import Dataset
13
+ from config.params import Params
14
+ from utility.initialize import initialize
15
+ from data.batch import Batch
16
+ from mtool.main import main as mtool_main
17
+
18
+
19
+ from tqdm import tqdm
20
+
21
+ class PredictionModel:
22
+ def __init__(self, checkpoint_path=os.path.join('models', 'checkpoint.bin'), default_mrp_path=os.path.join('models', 'default.mrp'), verbose=False):
23
+ self.verbose = verbose
24
+ self.checkpoint = torch.load('./models/checkpoint.bin', map_location=torch.device('cpu'))
25
+ self.args = Params().load_state_dict(self.checkpoint['params'])
26
+ self.args.log_wandb = False
27
+ self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
28
+
29
+ self.args.training_data = default_mrp_path
30
+ self.args.validation_data = default_mrp_path
31
+ self.args.test_data = default_mrp_path
32
+ self.args.only_train = False
33
+ self.args.encoder = os.path.join('models', 'encoder')
34
+ initialize(self.args, init_wandb=False)
35
+ self.dataset = Dataset(self.args, verbose=False)
36
+ self.model = Model(self.dataset, self.args).to(self.device)
37
+ self.model.load_state_dict(self.checkpoint["model"])
38
+ self.model.eval()
39
+
40
+
41
+ def _mrp_to_text(self, mrp_list, graph_mode='labeled-edge'):
42
+ framework = 'norec'
43
+ with tempfile.NamedTemporaryFile(delete=False, mode='w') as output_text_file:
44
+ output_text_filename = output_text_file.name
45
+
46
+ with tempfile.NamedTemporaryFile(delete=False, mode='w') as mrp_file:
47
+ line = '\n'.join([json.dumps(entry) for entry in mrp_list])
48
+ mrp_file.write(line)
49
+ mrp_filename = mrp_file.name
50
+
51
+ if graph_mode == 'labeled-edge':
52
+ mtool_main([
53
+ '--strings',
54
+ '--ids',
55
+ '--read', 'mrp',
56
+ '--write', framework,
57
+ mrp_filename, output_text_filename
58
+ ])
59
+ elif graph_mode == 'node-centric':
60
+ mtool_main([
61
+ '--node_centric',
62
+ '--strings',
63
+ '--ids',
64
+ '--read', 'mrp',
65
+ '--write', framework,
66
+ mrp_filename, output_text_filename
67
+ ])
68
+ else:
69
+ raise Exception(f'Unknown graph mode: {graph_mode}')
70
+
71
+ with open(output_text_filename) as f:
72
+ texts = json.load(f)
73
+
74
+ os.unlink(output_text_filename)
75
+ os.unlink(mrp_filename)
76
+
77
+ return texts
78
+
79
+
80
+ def clean_texts(self, texts):
81
+ return [re.sub(r' +', ' ', t) for t in texts]
82
+
83
+
84
+ def _predict_to_mrp(self, texts, graph_mode='labeled-edge'):
85
+ texts = self.clean_texts(texts)
86
+ framework, language = self.args.framework, self.args.language
87
+ data = self.dataset.load_sentences(texts, self.args)
88
+ res_sentences = {f"{i}": {'input': sentence} for i, sentence in enumerate(texts)}
89
+ date_str = datetime.datetime.now().date().isoformat()
90
+ for key, value_dict in res_sentences.items():
91
+ value_dict['id'] = key
92
+ value_dict['time'] = date_str
93
+ value_dict['framework'], value_dict['language'] = framework, language
94
+ value_dict['nodes'], value_dict['edges'], value_dict['tops'] = [], [], []
95
+ for i, batch in enumerate(tqdm(data) if self.verbose else data):
96
+ with torch.no_grad():
97
+ predictions = self.model(Batch.to(batch, self.device), inference=True)
98
+ for prediction in predictions:
99
+ for key, value in prediction.items():
100
+ res_sentences[prediction['id']][key] = value
101
+ return res_sentences
102
+
103
+
104
+ def predict(self, text_list, graph_mode='labeled-edge', language='no'):
105
+ mrp_predictions = self._predict_to_mrp(text_list, graph_mode)
106
+ predictions = self._mrp_to_text(mrp_predictions.values(), graph_mode)
107
+ return predictions