kirankunapuli commited on
Commit
e0f9553
1 Parent(s): 7cb8138

Update app.py with Gemma Hinglish Inference

Browse files
Files changed (1) hide show
  1. app.py +56 -1
app.py CHANGED
@@ -1,3 +1,58 @@
1
  import gradio as gr
2
 
3
- gr.load("models/kirankunapuli/Gemma-2B-Hinglish-LORA-v1.0").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
+ # gr.load("models/kirankunapuli/Gemma-2B-Hinglish-LORA-v1.0").launch()
4
+
5
+ import re
6
+ import torch
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+
9
+ tokenizer = AutoTokenizer.from_pretrained("kirankunapuli/Gemma-2B-Hinglish-LORA-v1.0")
10
+ model = AutoModelForCausalLM.from_pretrained("kirankunapuli/Gemma-2B-Hinglish-LORA-v1.0")
11
+
12
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
+ model = model.to(device)
14
+
15
+ alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
16
+
17
+ ### Instruction:
18
+ {}
19
+
20
+ ### Input:
21
+ {}
22
+
23
+ ### Response:
24
+ {}"""
25
+
26
+
27
+ def get_response(input_text: str) -> str:
28
+ inputs = tokenizer(
29
+ [
30
+ alpaca_prompt.format(
31
+ "Please answer the following sentence as requested", # instruction
32
+ input_text, # input
33
+ "", # output - leave this blank for generation!
34
+ )
35
+ ],
36
+ return_tensors="pt",
37
+ ).to(device)
38
+
39
+ outputs = model.generate(**inputs, max_new_tokens=256)
40
+ output = tokenizer.batch_decode(outputs)[0]
41
+ response_pattern = re.compile(r"### Response:\n(.*?)<eos>", re.DOTALL)
42
+ response_match = response_pattern.search(output)
43
+
44
+ if response_match:
45
+ response = response_match.group(1).strip()
46
+ return response
47
+ else:
48
+ return "Response not found"
49
+
50
+
51
+ interface = gr.Interface(
52
+ fn=get_response,
53
+ inputs="text",
54
+ outputs="text",
55
+ title="Gemma Hinglish Model Inference",
56
+ )
57
+
58
+ interface.launch()