Tej3's picture
Committing App
71bd54f
raw
history blame
No virus
2.36 kB
import torch
import torch.nn as nn
class RNN(nn.Module):
def __init__(self, input_dim=12, hidden_dim=64, num_layers=2, num_classes=5, cuda=True, device='cuda'):
super(RNN, self).__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.device = device
self.lstm = nn.LSTM(input_size=input_dim, hidden_size=self.hidden_dim,
num_layers=self.num_layers, batch_first=True)
self.fc1 = nn.Linear(self.hidden_dim, self.hidden_dim)
self.fc2 = nn.Linear(self.hidden_dim, num_classes)
self.relu = nn.ReLU()
def forward(self, x, notes):
h = torch.zeros(self.num_layers, x.size(0), self.hidden_dim)
c = torch.zeros(self.num_layers, x.size(0), self.hidden_dim)
nn.init.xavier_normal_(h)
nn.init.xavier_normal_(c)
h = h.to(self.device)
c = c.to(self.device)
x = x.to(self.device)
output, _ = self.lstm(x, (h, c))
out = self.fc2(self.relu(self.fc1(output[:, -1, :])))
return out
class MMRNN(nn.ModuleList):
def __init__(self, input_dim=12, hidden_dim=64, num_layers=2, num_classes=5, embed_size=768, device="cuda"):
super(MMRNN, self).__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.device = device
self.lstm = nn.LSTM(input_size=input_dim, hidden_size=self.hidden_dim,
num_layers=self.num_layers, batch_first=True)
self.fc1 = nn.Linear(self.hidden_dim, embed_size)
self.fc2 = nn.Linear(embed_size, num_classes)
self.lnorm_out = nn.LayerNorm(embed_size)
self.lnorm_embed = nn.LayerNorm(embed_size)
def forward(self, x, note):
h = torch.zeros(self.num_layers, x.size(0), self.hidden_dim)
c = torch.zeros(self.num_layers, x.size(0), self.hidden_dim)
nn.init.xavier_normal_(h)
nn.init.xavier_normal_(c)
h = h.to(self.device)
c = c.to(self.device)
x = x.to(self.device)
note = note.to(self.device)
output, _ = self.lstm(x, (h, c))
# Take last hidden state
out = self.fc1(output[:, -1, :])
note = self.lnorm_embed(note)
out = self.lnorm_out(out)
out = note + out
out = self.fc2(out)
return out.squeeze(1)