philipp-zettl commited on
Commit
3479f48
1 Parent(s): 712ad44
Files changed (2) hide show
  1. app.py +23 -4
  2. model.py +28 -1
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import spaces
2
  import gradio as gr
3
- from model import DecoderTransformer
4
  from huggingface_hub import hf_hub_download
5
  import torch
6
 
@@ -12,17 +12,36 @@ n_layer=6
12
  n_head=6
13
  dropout=0.2
14
 
 
 
15
  model_id = "philipp-zettl/chessPT"
16
 
17
  model_path = hf_hub_download(repo_id=model_id, filename="chessPT.pkl")
 
18
 
19
  model = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout)
20
  model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
21
-
 
22
 
23
  @spaces.GPU
24
  def greet(prompt):
25
- return model.generate(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
28
  demo.launch()
 
1
  import spaces
2
  import gradio as gr
3
+ from model import DecoderTransformer, Tokenizer
4
  from huggingface_hub import hf_hub_download
5
  import torch
6
 
 
12
  n_head=6
13
  dropout=0.2
14
 
15
+ device = 'cuda'
16
+
17
  model_id = "philipp-zettl/chessPT"
18
 
19
  model_path = hf_hub_download(repo_id=model_id, filename="chessPT.pkl")
20
+ tokenizer_path = hf_hub_download(repo_id=model_id, filename="tokenizer.json")
21
 
22
  model = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout)
23
  model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
24
+ model.to(device)
25
+ tokenizer = Tokenizer.from_pretrained(tokenizer_path)
26
 
27
  @spaces.GPU
28
  def greet(prompt):
29
+ model_input = torch.tensor(tokenizer.encode(prompt), dtype=torch.long, device=device).view((1, len(prompt)))
30
+ return tokenizer.decode(model.generate(model_input, max_new_tokens=4, context_size=context_size)[0].tolist())
31
+
32
+
33
+ with gr.Blocks() as demo:
34
+ gr.Markdown("""
35
+ Welcome to ChessPT.
36
+
37
+ The Chess-Pre-trained-Transformer.
38
+
39
+ The rules are simple: provide a PGN string of your current game, the engine will predict the next token!
40
+ """)
41
+ prompt = gr.Text(label="PGN")
42
+ output = gr.Text(label="Next turn", interactive=False)
43
+
44
+ submit = gr.Button("Submit")
45
+ submit.click(greet, [prompt], [output])
46
 
 
47
  demo.launch()
model.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  import torch.nn as nn
3
  from torch.nn import functional as F
@@ -119,7 +120,11 @@ class DecoderTransformer(nn.Module):
119
  loss = F.cross_entropy(logits, targets)
120
  return logits, loss
121
 
122
- def generate(self, idx, max_new_tokens, context_size):
 
 
 
 
123
  for _ in range(max_new_tokens):
124
  idx_cond = idx[:, -context_size:]
125
  logits, loss = self(idx_cond)
@@ -129,3 +134,25 @@ class DecoderTransformer(nn.Module):
129
  idx = torch.cat([idx, idx_next], dim=1)
130
  return idx
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
  import torch
3
  import torch.nn as nn
4
  from torch.nn import functional as F
 
120
  loss = F.cross_entropy(logits, targets)
121
  return logits, loss
122
 
123
+ def generate(self, idx, max_new_tokens=50, context_size=None):
124
+ if context_size is None:
125
+ context_size = int(self.position_embedding_table.weight.shape[0])
126
+ print(context_size)
127
+
128
  for _ in range(max_new_tokens):
129
  idx_cond = idx[:, -context_size:]
130
  logits, loss = self(idx_cond)
 
134
  idx = torch.cat([idx, idx_next], dim=1)
135
  return idx
136
 
137
+
138
+ class Tokenizer:
139
+ def __init__(self, vocab):
140
+ self.vocab = vocab
141
+ self.stoi = {ch: idx for idx, ch in enumerate(vocab)}
142
+ self.itos = {idx: ch for idx, ch in enumerate(vocab)}
143
+
144
+ def encode(self, s):
145
+ return [self.stoi[c] for c in s]
146
+
147
+ def decode(self, i):
148
+ return ''.join([self.itos[x] for x in i])
149
+
150
+ @classmethod
151
+ def from_pretrained(cls, path):
152
+ with open(path, 'r') as f:
153
+ vocab = json.load(f)
154
+ return cls(vocab)
155
+
156
+ def save_pretrained(self, path):
157
+ with open(path, 'w') as f:
158
+ json.dump(self.vocab, f)