import torch import torch.nn as nn import torch.nn.functional as F from torchinfo import summary # Not in use yet class Conv1d_layer(nn.Module): def __init__(self, in_channel, out_channel, kernel_size) -> None: super().__init__() self.conv = nn.Conv1d(in_channels=in_channel, out_channels=out_channel, kernel_size=kernel_size) self.batch_norm = torch.nn.BatchNorm1d(out_channel) self.dropout = nn.Dropout1d(p=0.5) def forward(self, x): x= self.conv(x) x = self.batch_norm(x) x = self.dropout(x) return x class CNN(nn.Module): def __init__(self, ecg_channels=12): super(CNN, self).__init__() self.name = "CNN" self.conv1 = nn.Conv1d(ecg_channels, 16, 7) self.pool1 = nn.MaxPool1d(2, 2) self.conv2 = nn.Conv1d(16, 32, 5) self.pool2 = nn.MaxPool1d(2, 2) self.conv3 = nn.Conv1d(32, 48, 3) self.pool3 = nn.MaxPool1d(2, 2) self.fc0 = nn.Linear(5856, 512) self.fc1 = nn.Linear(512, 128) self.fc2 = nn.Linear(128, 5) self.activation = nn.ReLU() def forward(self, x, notes=None): x = self.pool1(self.activation(self.conv1(x))) x = self.pool2(self.activation(self.conv2(x))) x = self.pool3(self.activation(self.conv3(x))) x = x.view(x.size(0),-1) x = self.activation(self.fc0(x)) x = self.activation(self.fc1(x)) x = self.fc2(x) x = x.squeeze(1) return x class MMCNN_SUM(nn.Module): def __init__(self, ecg_channels=12): super(MMCNN_SUM, self).__init__() # ECG processing Layers self.name = "MMCNN_SUM" self.conv1 = Conv1d_layer(ecg_channels, 16, 7) self.pool1 = nn.MaxPool1d(2, 2) self.conv2 = Conv1d_layer(16, 32, 5) self.pool2 = nn.MaxPool1d(2, 2) self.conv3 = Conv1d_layer(32, 48, 3) self.pool3 = nn.MaxPool1d(2, 2) self.fc0 = nn.Linear(5856, 512) self.fc1 = nn.Linear(512, 128) self.fc2 = nn.Linear(128, 5) # Clinical Notes Processing Layers self.fc_emb = nn.Linear(768, 128) self.norm = nn.LayerNorm(128) self.activation = nn.ReLU() def forward(self, x, notes): # ECG Processing x = self.pool1(self.activation(self.conv1(x))) x = self.pool2(self.activation(self.conv2(x))) x = self.pool3(self.activation(self.conv3(x))) x = x.view(x.size(0),-1) x = self.activation(self.fc0(x)) x = self.activation(self.fc1(x)) # Notes Processing notes = notes.view(notes.size(0),-1) notes = self.activation(self.fc_emb(notes)) x = self.fc2(self.norm(x + notes)) x = x.squeeze(1) return x class MMCNN_CAT(nn.Module): def __init__(self, ecg_channels=12): super(MMCNN_CAT, self).__init__() # ECG processing Layers self.name = "MMCNN_CAT" self.conv1 = nn.Conv1d(ecg_channels, 16, 7) self.pool1 = nn.MaxPool1d(2, 2) self.conv2 = nn.Conv1d(16, 32, 5) self.pool2 = nn.MaxPool1d(2, 2) self.conv3 = nn.Conv1d(32, 48, 3) self.pool3 = nn.MaxPool1d(2, 2) self.fc0 = nn.Linear(5856, 512) self.fc1 = nn.Linear(512, 128) self.fc2 = nn.Linear(256, 5) # Clinical Notes Processing Layers self.fc_emb = nn.Linear(768, 128) self.norm = nn.LayerNorm(128) self.activation = nn.ReLU() def forward(self, x, notes): # ECG Processing x = self.pool1(self.activation(self.conv1(x))) x = self.pool2(self.activation(self.conv2(x))) x = self.pool3(self.activation(self.conv3(x))) x = x.view(x.size(0),-1) x = self.activation(self.fc0(x)) x = self.activation(self.fc1(x)) # Notes Processing notes = notes.view(notes.size(0),-1) notes = self.activation(self.fc_emb(notes)) x = self.fc2(torch.cat((x,notes),dim=1)) x = x.squeeze(1) return x class MMCNN_ATT(nn.Module): def __init__(self, ecg_channels=12): super(MMCNN_ATT, self).__init__() # ECG processing Layers self.name = "MMCNN_ATT" self.conv1 = nn.Conv1d(ecg_channels, 16, 7) self.pool1 = nn.MaxPool1d(2, 2) self.conv2 = nn.Conv1d(16, 32, 5) self.pool2 = nn.MaxPool1d(2, 2) self.conv3 = nn.Conv1d(32, 48, 3) self.pool3 = nn.MaxPool1d(2, 2) self.fc0 = nn.Linear(5856, 512) self.fc1 = nn.Linear(512, 128) self.fc2 = nn.Linear(128, 5) # Clinical Notes Processing Layers self.fc_emb = nn.Linear(768, 128) self.norm1 = nn.LayerNorm(128) self.norm2 = nn.LayerNorm(128) self.attention = nn.MultiheadAttention(128, 8, batch_first=True) self.activation = nn.ReLU() def forward(self, x, notes): # ECG Processing x = self.pool1(self.activation(self.conv1(x))) x = self.pool2(self.activation(self.conv2(x))) x = self.pool3(self.activation(self.conv3(x))) x = x.view(x.size(0),-1) x = self.activation(self.fc0(x)) x = self.activation(self.fc1(x)) x = self.norm1(x) # Notes Processing notes = notes.view(notes.size(0),-1) notes = self.activation(self.fc_emb(notes)) notes = self.norm2(notes) notes=notes.unsqueeze(1) x=x.unsqueeze(1) x,_= self.attention(notes, x, x) x = self.fc2(x) x = x.squeeze(1) return x class MMCNN_SUM_ATT(nn.Module): def __init__(self, ecg_channels=12): super(MMCNN_SUM_ATT, self).__init__() # ECG processing Layers self.name = "MMCNN_SUM_ATT" self.conv1 = nn.Conv1d(ecg_channels, 16, 7) self.pool1 = nn.MaxPool1d(2, 2) self.conv2 = nn.Conv1d(16, 32, 5) self.pool2 = nn.MaxPool1d(2, 2) self.conv3 = nn.Conv1d(32, 48, 3) self.pool3 = nn.MaxPool1d(2, 2) self.fc0 = nn.Linear(5856, 512) self.fc1 = nn.Linear(512, 128) self.fc2 = nn.Linear(128, 5) # Clinical Notes Processing Layers self.fc_emb = nn.Linear(768, 128) self.norm = nn.LayerNorm(128) self.attention = nn.MultiheadAttention(128, 8, batch_first=True) self.activation = nn.ReLU() def forward(self, x, notes): # ECG Processing x = self.pool1(self.activation(self.conv1(x))) x = self.pool2(self.activation(self.conv2(x))) x = self.pool3(self.activation(self.conv3(x))) x = x.view(x.size(0),-1) x = self.activation(self.fc0(x)) x = self.activation(self.fc1(x)) # Notes Processing notes = notes.view(notes.size(0),-1) notes = self.activation(self.fc_emb(notes)) x = self.norm(x + notes) x=x.unsqueeze(1) # print(x.shape) x,_= self.attention(x, x, x) x = self.fc2(x) x = x.squeeze(1) return x if __name__ == "__main__": model = CNN() # model = Conv1d_layer(12, 16, 7) summary(model, input_size = (1, 12, 1000))