import string from transformers import ( AutoModelForSeq2SeqLM, AutoTokenizer, Text2TextGenerationPipeline, ) class KeyphraseGenerationPipeline(Text2TextGenerationPipeline): def __init__(self, model, keyphrase_sep_token=";", *args, **kwargs): super().__init__( model=AutoModelForSeq2SeqLM.from_pretrained(model), tokenizer=AutoTokenizer.from_pretrained(model, truncation=True), *args, **kwargs ) self.keyphrase_sep_token = keyphrase_sep_token def postprocess(self, model_outputs): results = super().postprocess(model_outputs=model_outputs) return [ [ keyphrase.strip().translate(str.maketrans("", "", string.punctuation)) for keyphrase in result.get("generated_text").split( self.keyphrase_sep_token ) if keyphrase.translate(str.maketrans("", "", string.punctuation)) != "" ] for result in results ][0]