File size: 1,757 Bytes
28dc58b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bert_dataset import BERTDataset\n",
    "from torch.utils.data import DataLoader\n",
    "from bert_model import BERT, BERTLM\n",
    "from trainer import BERTTrainer\n",
    "from transformers import BertTokenizer\n",
    "from data import get_data\n",
    "\n",
    "MAX_LEN = 128\n",
    "\n",
    "pairs = get_data('datasets/movie_conversations.txt', \"datasets/movie_lines.txt\")\n",
    "tokenizer = BertTokenizer.from_pretrained(\"bert-it-1/bert-it-vocab.txt\")\n",
    "\n",
    "train_data = BERTDataset()\n",
    "\n",
    "train_loader = DataLoader(\n",
    "   train_data, batch_size=32, shuffle=True, pin_memory=True)\n",
    "\n",
    "bert_model = BERT(\n",
    "  vocab_size=len(tokenizer.vocab),\n",
    "  d_model=768,\n",
    "  n_layers=2,\n",
    "  heads=12,\n",
    "  dropout=0.1\n",
    ")\n",
    "\n",
    "bert_lm = BERTLM(bert=bert_model, vocab_size=len(tokenizer.vocab))\n",
    "bert_trainer = BERTTrainer(bert_lm, train_loader, device='cpu')\n",
    "epochs = 20\n",
    "\n",
    "for epoch in range(epochs):\n",
    "  bert_trainer.train(epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}