TokenTrails / handler.py
Kowsher's picture
Update handler.py
aad3543 verified
raw
history blame contribute delete
No virus
1.81 kB
import torch
from typing import Any, Dict
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from transformers.models.auto import modeling_auto
class EndpointHandler:
def __init__(self, path=""):
print('starting machine')
config = AutoConfig.from_pretrained("Kowsher/Egol_model", trust_remote_code=True)
# load model and tokenizer from path
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(
path, device_map="auto", torch_dtype=torch.float16, config = config, trust_remote_code=True
)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
# process input
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
# preprocess
print(print("inputs......", inputs))
inputs = self.tokenizer(inputs, return_tensors="pt").to(self.device)
t=0
for j in range(len(inputs['token_type_ids'][0])):
if inputs['input_ids'][0][j]==39 and inputs['input_ids'][0][j+1]== 5584:
t=0
if inputs['input_ids'][0][j]==39 and inputs['input_ids'][0][j+1]== 13359:
t=1
inputs['token_type_ids'][0][j]=t
# pass inputs with all kwargs in data
print("inputs......", inputs)
if parameters is not None:
outputs = self.model.generate(**inputs, **parameters)
else:
outputs = self.model.generate(**inputs)
# postprocess the prediction
prediction = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return [{"generated_text": prediction}]