bgpt / model.py
bibekyess's picture
Upload 9 files
42e3a78
raw
history blame
No virus
670 Bytes
import torch.nn as nn
class NeuralNet(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super().__init__()
self.l1 = nn.Linear(input_size, hidden_size)
self.l2 = nn.Linear(hidden_size, hidden_size)
self.l3 = nn.Linear(hidden_size, num_classes)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.5)
def forward(self, x):
out = self.l1(x)
out = self.relu(out)
out = self.dropout(out)
out = self.l2(out)
out = self.relu(out)
out = self.dropout(out)
out = self.l3(out)
# no activation and no softmax at the end
return out