Tej3's picture
Committing App
71bd54f
raw
history blame
No virus
7.13 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary
# Not in use yet
class Conv1d_layer(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size) -> None:
super().__init__()
self.conv = nn.Conv1d(in_channels=in_channel, out_channels=out_channel, kernel_size=kernel_size)
self.batch_norm = torch.nn.BatchNorm1d(out_channel)
self.dropout = nn.Dropout1d(p=0.5)
def forward(self, x):
x= self.conv(x)
x = self.batch_norm(x)
x = self.dropout(x)
return x
class CNN(nn.Module):
def __init__(self, ecg_channels=12):
super(CNN, self).__init__()
self.name = "CNN"
self.conv1 = nn.Conv1d(ecg_channels, 16, 7)
self.pool1 = nn.MaxPool1d(2, 2)
self.conv2 = nn.Conv1d(16, 32, 5)
self.pool2 = nn.MaxPool1d(2, 2)
self.conv3 = nn.Conv1d(32, 48, 3)
self.pool3 = nn.MaxPool1d(2, 2)
self.fc0 = nn.Linear(5856, 512)
self.fc1 = nn.Linear(512, 128)
self.fc2 = nn.Linear(128, 5)
self.activation = nn.ReLU()
def forward(self, x, notes=None):
x = self.pool1(self.activation(self.conv1(x)))
x = self.pool2(self.activation(self.conv2(x)))
x = self.pool3(self.activation(self.conv3(x)))
x = x.view(x.size(0),-1)
x = self.activation(self.fc0(x))
x = self.activation(self.fc1(x))
x = self.fc2(x)
x = x.squeeze(1)
return x
class MMCNN_SUM(nn.Module):
def __init__(self, ecg_channels=12):
super(MMCNN_SUM, self).__init__()
# ECG processing Layers
self.name = "MMCNN_SUM"
self.conv1 = Conv1d_layer(ecg_channels, 16, 7)
self.pool1 = nn.MaxPool1d(2, 2)
self.conv2 = Conv1d_layer(16, 32, 5)
self.pool2 = nn.MaxPool1d(2, 2)
self.conv3 = Conv1d_layer(32, 48, 3)
self.pool3 = nn.MaxPool1d(2, 2)
self.fc0 = nn.Linear(5856, 512)
self.fc1 = nn.Linear(512, 128)
self.fc2 = nn.Linear(128, 5)
# Clinical Notes Processing Layers
self.fc_emb = nn.Linear(768, 128)
self.norm = nn.LayerNorm(128)
self.activation = nn.ReLU()
def forward(self, x, notes):
# ECG Processing
x = self.pool1(self.activation(self.conv1(x)))
x = self.pool2(self.activation(self.conv2(x)))
x = self.pool3(self.activation(self.conv3(x)))
x = x.view(x.size(0),-1)
x = self.activation(self.fc0(x))
x = self.activation(self.fc1(x))
# Notes Processing
notes = notes.view(notes.size(0),-1)
notes = self.activation(self.fc_emb(notes))
x = self.fc2(self.norm(x + notes))
x = x.squeeze(1)
return x
class MMCNN_CAT(nn.Module):
def __init__(self, ecg_channels=12):
super(MMCNN_CAT, self).__init__()
# ECG processing Layers
self.name = "MMCNN_CAT"
self.conv1 = nn.Conv1d(ecg_channels, 16, 7)
self.pool1 = nn.MaxPool1d(2, 2)
self.conv2 = nn.Conv1d(16, 32, 5)
self.pool2 = nn.MaxPool1d(2, 2)
self.conv3 = nn.Conv1d(32, 48, 3)
self.pool3 = nn.MaxPool1d(2, 2)
self.fc0 = nn.Linear(5856, 512)
self.fc1 = nn.Linear(512, 128)
self.fc2 = nn.Linear(256, 5)
# Clinical Notes Processing Layers
self.fc_emb = nn.Linear(768, 128)
self.norm = nn.LayerNorm(128)
self.activation = nn.ReLU()
def forward(self, x, notes):
# ECG Processing
x = self.pool1(self.activation(self.conv1(x)))
x = self.pool2(self.activation(self.conv2(x)))
x = self.pool3(self.activation(self.conv3(x)))
x = x.view(x.size(0),-1)
x = self.activation(self.fc0(x))
x = self.activation(self.fc1(x))
# Notes Processing
notes = notes.view(notes.size(0),-1)
notes = self.activation(self.fc_emb(notes))
x = self.fc2(torch.cat((x,notes),dim=1))
x = x.squeeze(1)
return x
class MMCNN_ATT(nn.Module):
def __init__(self, ecg_channels=12):
super(MMCNN_ATT, self).__init__()
# ECG processing Layers
self.name = "MMCNN_ATT"
self.conv1 = nn.Conv1d(ecg_channels, 16, 7)
self.pool1 = nn.MaxPool1d(2, 2)
self.conv2 = nn.Conv1d(16, 32, 5)
self.pool2 = nn.MaxPool1d(2, 2)
self.conv3 = nn.Conv1d(32, 48, 3)
self.pool3 = nn.MaxPool1d(2, 2)
self.fc0 = nn.Linear(5856, 512)
self.fc1 = nn.Linear(512, 128)
self.fc2 = nn.Linear(128, 5)
# Clinical Notes Processing Layers
self.fc_emb = nn.Linear(768, 128)
self.norm1 = nn.LayerNorm(128)
self.norm2 = nn.LayerNorm(128)
self.attention = nn.MultiheadAttention(128, 8, batch_first=True)
self.activation = nn.ReLU()
def forward(self, x, notes):
# ECG Processing
x = self.pool1(self.activation(self.conv1(x)))
x = self.pool2(self.activation(self.conv2(x)))
x = self.pool3(self.activation(self.conv3(x)))
x = x.view(x.size(0),-1)
x = self.activation(self.fc0(x))
x = self.activation(self.fc1(x))
x = self.norm1(x)
# Notes Processing
notes = notes.view(notes.size(0),-1)
notes = self.activation(self.fc_emb(notes))
notes = self.norm2(notes)
notes=notes.unsqueeze(1)
x=x.unsqueeze(1)
x,_= self.attention(notes, x, x)
x = self.fc2(x)
x = x.squeeze(1)
return x
class MMCNN_SUM_ATT(nn.Module):
def __init__(self, ecg_channels=12):
super(MMCNN_SUM_ATT, self).__init__()
# ECG processing Layers
self.name = "MMCNN_SUM_ATT"
self.conv1 = nn.Conv1d(ecg_channels, 16, 7)
self.pool1 = nn.MaxPool1d(2, 2)
self.conv2 = nn.Conv1d(16, 32, 5)
self.pool2 = nn.MaxPool1d(2, 2)
self.conv3 = nn.Conv1d(32, 48, 3)
self.pool3 = nn.MaxPool1d(2, 2)
self.fc0 = nn.Linear(5856, 512)
self.fc1 = nn.Linear(512, 128)
self.fc2 = nn.Linear(128, 5)
# Clinical Notes Processing Layers
self.fc_emb = nn.Linear(768, 128)
self.norm = nn.LayerNorm(128)
self.attention = nn.MultiheadAttention(128, 8, batch_first=True)
self.activation = nn.ReLU()
def forward(self, x, notes):
# ECG Processing
x = self.pool1(self.activation(self.conv1(x)))
x = self.pool2(self.activation(self.conv2(x)))
x = self.pool3(self.activation(self.conv3(x)))
x = x.view(x.size(0),-1)
x = self.activation(self.fc0(x))
x = self.activation(self.fc1(x))
# Notes Processing
notes = notes.view(notes.size(0),-1)
notes = self.activation(self.fc_emb(notes))
x = self.norm(x + notes)
x=x.unsqueeze(1)
# print(x.shape)
x,_= self.attention(x, x, x)
x = self.fc2(x)
x = x.squeeze(1)
return x
if __name__ == "__main__":
model = CNN()
# model = Conv1d_layer(12, 16, 7)
summary(model, input_size = (1, 12, 1000))