# handler.py import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline # check for GPU device = 0 if torch.cuda.is_available() else -1 # multi-model list multi_model_list = [ {"model_id": "BAAI/bge-base-en-v1.5", "task":"feature-extraction"}, {"model_id": "BAAI/bge-reranker-base", "task":"text-classification"}, ] class EndpointHandler(): def __init__(self, path=""): self.multi_model={} # load all the models onto device for model in multi_model_list: self.multi_model[model["model_id"]] = pipeline(model["task"], model=model["model_id"], device=device) def __call__(self, data): # deserialize incomin request inputs = data.pop("inputs", data) parameters = data.pop("parameters", None) model_id = data.pop("model_id", None) # check if model_id is in the list of models if model_id is None or model_id not in self.multi_model: raise ValueError(f"model_id: {model_id} is not valid. Available models are: {list(self.multi_model.keys())}") # pass inputs with all kwargs in data if parameters is not None: prediction = self.multi_model[model_id](inputs, **parameters) else: prediction = self.multi_model[model_id](inputs) # postprocess the prediction return prediction