File size: 2,426 Bytes
9dcb348
ee2b517
3479f48
3a76146
f8bdf54
4249eba
 
9ede36f
d41eec1
 
 
49b8bf0
3a76146
 
 
 
 
 
 
 
 
c51abf0
3479f48
3a76146
 
 
3479f48
3a76146
 
aa43f32
3479f48
 
3a76146
49b8bf0
 
796a2f3
3479f48
4249eba
22bfac3
49b8bf0
 
 
 
 
 
 
 
 
d41eec1
 
22bfac3
d41eec1
 
 
49b8bf0
 
 
153ec16
3479f48
 
 
 
c51abf0
3479f48
 
c51abf0
3479f48
 
 
 
 
d41eec1
3479f48
68a79d8
 
796a2f3
 
 
 
 
 
4249eba
796a2f3
 
9dcb348
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
from io import StringIO
from model import DecoderTransformer, Tokenizer
from huggingface_hub import hf_hub_download
import torch
import chess
import chess.svg
import chess.pgn
from svglib.svglib import svg2rlg
from reportlab.graphics import renderPM
from PIL import Image
import os


vocab_size=33
n_embed=384
context_size=256
n_layer=6
n_head=6
dropout=0.2

device = 'cpu'

model_id = "philipp-zettl/chessPT"

model_path = hf_hub_download(repo_id=model_id, filename="chessPT.pkl")
tokenizer_path = hf_hub_download(repo_id=model_id, filename="tokenizer.json")

model = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.to(device)
tokenizer = Tokenizer.from_pretrained(tokenizer_path)

invalid_move_plot = Image.open('./invalid_move.png')

def generate(prompt):
    model_input = torch.tensor(tokenizer.encode(prompt), dtype=torch.long, device=device).view((1, len(prompt)))
    pgn = tokenizer.decode(model.generate(model_input, max_new_tokens=4, context_size=context_size)[0].tolist())
    pgn_str = StringIO(pgn)
    try:
        game = chess.pgn.read_game(pgn_str)
        board = game.board()
        for move in game.mainline_moves():
            board.push(move)
        img = chess.svg.board(board)
    except Exception as e:
        if 'illegal san' in str(e):
            return pgn, invalid_move_plot
    filename = f'./moves-{pgn}'
    with open(filename + '.svg', 'w') as f:
        f.write(img)
    drawing = svg2rlg(filename + '.svg')
    renderPM.drawToFile(drawing, f"{filename}.png", fmt="PNG")
    plot = Image.open(f'{filename}.png')

    os.remove(f'{filename}.png')
    os.remove(f'{filename}.svg')
    return pgn, plot


with gr.Blocks() as demo:
    gr.Markdown("""
    # ChessPT
    Welcome to ChessPT.

    The **C**hess-**P**re-trained-**T**ransformer.

    The rules are simple: provide a PGN string of your current game, the engine will predict the next token!
    """)
    prompt = gr.Text(label="PGN")
    output = gr.Text(label="Next turn", interactive=False)
    img = gr.Image()
    submit = gr.Button("Submit")
    submit.click(generate, [prompt], [output, img])
    
    gr.Examples(
        [
            ["1. e4", ],
            ["1. e4 g6 2."],
        ],
        inputs=[prompt],
        outputs=[output, img],
        fn=generate
    )
demo.launch()