import torch import torch.nn as nn class RNN(nn.Module): def __init__(self, input_dim=12, hidden_dim=64, num_layers=2, num_classes=5, cuda=True, device='cuda'): super(RNN, self).__init__() self.hidden_dim = hidden_dim self.num_layers = num_layers self.device = device self.lstm = nn.LSTM(input_size=input_dim, hidden_size=self.hidden_dim, num_layers=self.num_layers, batch_first=True) self.fc1 = nn.Linear(self.hidden_dim, self.hidden_dim) self.fc2 = nn.Linear(self.hidden_dim, num_classes) self.relu = nn.ReLU() def forward(self, x, notes): h = torch.zeros(self.num_layers, x.size(0), self.hidden_dim) c = torch.zeros(self.num_layers, x.size(0), self.hidden_dim) nn.init.xavier_normal_(h) nn.init.xavier_normal_(c) h = h.to(self.device) c = c.to(self.device) x = x.to(self.device) output, _ = self.lstm(x, (h, c)) out = self.fc2(self.relu(self.fc1(output[:, -1, :]))) return out class MMRNN(nn.ModuleList): def __init__(self, input_dim=12, hidden_dim=64, num_layers=2, num_classes=5, embed_size=768, device="cuda"): super(MMRNN, self).__init__() self.hidden_dim = hidden_dim self.num_layers = num_layers self.device = device self.lstm = nn.LSTM(input_size=input_dim, hidden_size=self.hidden_dim, num_layers=self.num_layers, batch_first=True) self.fc1 = nn.Linear(self.hidden_dim, embed_size) self.fc2 = nn.Linear(embed_size, num_classes) self.lnorm_out = nn.LayerNorm(embed_size) self.lnorm_embed = nn.LayerNorm(embed_size) def forward(self, x, note): h = torch.zeros(self.num_layers, x.size(0), self.hidden_dim) c = torch.zeros(self.num_layers, x.size(0), self.hidden_dim) nn.init.xavier_normal_(h) nn.init.xavier_normal_(c) h = h.to(self.device) c = c.to(self.device) x = x.to(self.device) note = note.to(self.device) output, _ = self.lstm(x, (h, c)) # Take last hidden state out = self.fc1(output[:, -1, :]) note = self.lnorm_embed(note) out = self.lnorm_out(out) out = note + out out = self.fc2(out) return out.squeeze(1)