import os import random from dataset_viber import AnnotatorInterFace from datasets import load_dataset from huggingface_hub import InferenceClient import time # https://huggingface.co/models?inference=warm&pipeline_tag=text-generation&sort=trending MODEL_IDS = [ "microsoft/Phi-3-mini-4k-instruct" ] CLIENTS = [InferenceClient(model_id, token=os.environ["HF_TOKEN"]) for model_id in MODEL_IDS] dataset = load_dataset("argilla/distilabel-capybara-dpo-7k-binarized", split="train") def get_response(messages): max_retries = 3 retry_delay = 3 for attempt in range(max_retries): try: client = random.choice(CLIENTS) message = client.chat_completion( messages=messages, stream=False, max_tokens=2000 ) return message.choices[0].message.content except Exception as e: if attempt < max_retries - 1: print(f"An error occurred: {e}. Retrying in {retry_delay} seconds...") time.sleep(retry_delay) else: print(f"Max retries reached. Last error: {e}") raise return None # This line will only be reached if all retries fail def next_input(_prompt, _completion_a, _completion_b): new_dataset = dataset.shuffle() row = new_dataset[0] messages = row["chosen"][:-1] completions = [row["chosen"][-1]["content"]] completions.append(get_response(messages)) random.shuffle(completions) return messages, completions.pop(), completions.pop() if __name__ == "__main__": interface = AnnotatorInterFace.for_chat_generation_preference( fn_next_input=next_input, interactive=[False, True, True], dataset_name="dataset-viber-chat-generation-preference-inference-endpoints-battle", ) interface.launch()