mamiksik commited on
Commit
519c766
1 Parent(s): fe137a1

Add gradio interface

Browse files
Files changed (2) hide show
  1. app.py +72 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import gradio as gr
4
+ from transformers import RobertaTokenizer, pipeline, AutoModelForMaskedLM
5
+
6
+ tokenizer = RobertaTokenizer.from_pretrained("mamiksik/CommitPredictor")
7
+ model = AutoModelForMaskedLM.from_pretrained("mamiksik/CommitPredictor")
8
+ pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer)
9
+
10
+
11
+ def parse_files(patch) -> str:
12
+ accumulator = []
13
+ lines = patch.splitlines()
14
+
15
+ filename_before = None
16
+ for line in lines:
17
+ if line.startswith("index") or line.startswith("diff"):
18
+ continue
19
+ if line.startswith("---"):
20
+ filename_before = line.split(" ", 1)[1][1:]
21
+ continue
22
+
23
+ if line.startswith("+++"):
24
+ filename_after = line.split(" ", 1)[1][1:]
25
+
26
+ if filename_before == filename_after:
27
+ accumulator.append(f"<ide><path>{filename_before}")
28
+ else:
29
+ accumulator.append(f"<add><path>{filename_after}")
30
+ accumulator.append(f"<del><path>{filename_before}")
31
+ continue
32
+
33
+ line = re.sub("@@[^@@]*@@", "", line)
34
+ if len(line) == 0:
35
+ continue
36
+
37
+ if line[0] == "+":
38
+ line = line.replace("+", "<add>", 1)
39
+ elif line[0] == "-":
40
+ line = line.replace("-", "<del>", 1)
41
+ else:
42
+ line = f"<ide>{line}"
43
+
44
+ accumulator.append(line)
45
+
46
+ return '\n'.join(accumulator)
47
+
48
+
49
+ def predict(patch, commit_message):
50
+ input_text = parse_files(patch) + "\n<msg> " + commit_message
51
+ token_count = tokenizer(input_text, return_tensors="pt").input_ids.shape[1]
52
+ result = pipe.predict(input_text)
53
+
54
+ return token_count, input_text, {pred['token_str']: round(pred['score'], 3) for pred in result}
55
+
56
+
57
+ iface = gr.Interface(fn=predict, inputs=[
58
+ gr.Textbox(label="Patch (as generated by git diff)"),
59
+ gr.Textbox(label="Commit message (with one <mask> token)"),
60
+ ], outputs=[
61
+ gr.Textbox(label="Token count"),
62
+ gr.Textbox(label="Parsed patch"),
63
+ gr.Label(label="Predictions")
64
+ ], examples=[["""
65
+ def main():
66
+ - name = "John"
67
+ print("Hello World")
68
+ """, "Remove <mask> variable"]
69
+ ])
70
+
71
+ if __name__ == "__main__":
72
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio~=3.16.2
2
+ transformers~=4.25.1
3
+ torch~=1.13.1