Gomoku-GPT2 / ai.py
snoop2head's picture
initial commit
0de5849
raw
history blame
No virus
1.06 kB
import numpy as np
import torch
from transformers import GPT2LMHeadModel
def load_model(model_name: str = "snoop2head/Gomoku-GPT2") -> GPT2LMHeadModel:
gpt2 = GPT2LMHeadModel.from_pretrained(model_name)
return gpt2
BOS_TOKEN_ID = 401
PAD_TOKEN_ID = 402
EOS_TOKEN_ID = 403
def generate_gpt2(model: GPT2LMHeadModel, input_ids: torch.LongTensor) -> list:
"""
input_ids: [batch_size, seq_len] torch.LongTensor
output_ids: [seq_len] list
"""
output_ids = model.generate(
input_ids,
max_length=128,
num_beams=5,
temperature=0.7,
pad_token_id=PAD_TOKEN_ID,
eos_token_id=EOS_TOKEN_ID,
)
return output_ids.squeeze().tolist()
def change_to_1d_coordinate(board: np.ndarray, x: int, y: int) -> int:
"""change 2d coordinate to 1d coordinate"""
return x * board.shape[1] + y
def change_to_2d_coordinate(board: np.ndarray, coordinate: int) -> tuple:
"""change 1d coordinate to 2d coordinate"""
return (coordinate // board.shape[1], coordinate % board.shape[1])