Adapters
Inference Endpoints
llm-tolkien / handler.py
jeremyarancio's picture
Remove return dict from model import
39ec5b7
raw
history blame
1.4 kB
from typing import Dict, List, Any
import logging
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftConfig, PeftModel
LOGGER = logging.getLogger(__name__)
class EndpointHandler():
def __init__(self, path=""):
config = PeftConfig.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, device_map='auto')
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
# Load the Lora model
self.model = PeftModel.from_pretrained(model, path)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Args:
data (Dict): The payload with the text prompt and generation parameters.
"""
LOGGER.info(f"Received data: {data}")
# Get inputs
prompt = data.pop("prompt", data)
parameters = data.pop("parameters", None)
# Preprocess
inputs = self.tokenizer(prompt, return_tensors="pt")
# Forward
if parameters is not None:
outputs = self.model.generate(**inputs, **parameters)
else:
outputs = self.model.generate(**inputs)
# Postprocess
prediction = self.tokenizer.decode(outputs[0])
LOGGER.info(f"Generated text: {prediction}")
return [{"generated_text": prediction}]