ChessPT / app.py
philipp-zettl's picture
Update app.py
712ad44 verified
raw
history blame
628 Bytes
import spaces
import gradio as gr
from model import DecoderTransformer
from huggingface_hub import hf_hub_download
import torch
vocab_size=33
n_embed=384
context_size=256
n_layer=6
n_head=6
dropout=0.2
model_id = "philipp-zettl/chessPT"
model_path = hf_hub_download(repo_id=model_id, filename="chessPT.pkl")
model = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
@spaces.GPU
def greet(prompt):
return model.generate(prompt)
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
demo.launch()