File size: 5,184 Bytes
71bd54f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0f8eba
71bd54f
 
 
 
 
d0f8eba
71bd54f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import shutil
import gradio as gr
import numpy as np
import wfdb
import torch
from wfdb.plot.plot import plot_wfdb
from wfdb.io.record import Record, rdrecord

from models.CNN import CNN, MMCNN_CAT
from models.RNN import MMRNN
from utils.helper_functions import predict

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from transformers import AutoTokenizer, AutoModel
from langdetect import detect

# edit this before Running
CWD = os.getcwd()
#CKPT paths
MMCNN_CAT_ckpt_path = f"{CWD}/demo_data/model_MMCNN_CAT_epoch_30_acc_84.pt"
MMRNN_ckpt_path = f"{CWD}/demo_data/model_MMRNN_undersampled_augmented_rn_epoch_20_acc_84.pt"

# Define clinical models and tokenizers
en_clin_bert = 'emilyalsentzer/Bio_ClinicalBERT'
ger_clin_bert = 'smanjil/German-MedBERT'

en_tokenizer = AutoTokenizer.from_pretrained(en_clin_bert)
en_model = AutoModel.from_pretrained(en_clin_bert)

g_tokenizer = AutoTokenizer.from_pretrained(ger_clin_bert)
g_model = AutoModel.from_pretrained(ger_clin_bert)

def preprocess(data_file_path):
    data = [wfdb.rdsamp(data_file_path)]
    data = np.array([signal for signal, meta in data])
    return data

def embed(notes):
    if detect(notes) == 'en':
        tokens = en_tokenizer(notes, return_tensors='pt')
        outputs = en_model(**tokens)
    else:
        tokens = g_tokenizer(notes, return_tensors='pt')
        outputs = g_model(**tokens)
    
    embeddings = outputs.last_hidden_state
    embedding = torch.mean(embeddings, dim=1).squeeze(0)
    
    return embedding 
    # return torch.load(f'{"./data/embeddings/"}1.pt')
def plot_ecg(path):
    record100 = rdrecord(path)
    return plot_wfdb(record=record100, title='ECG Signal Graph', figsize=(12,10), return_fig=True)

def infer(model,data, notes):
    embed_notes = embed(notes).unsqueeze(0)
    data= torch.tensor(data)
    if model == "CNN":
        model = MMCNN_CAT()
        checkpoint = torch.load(MMCNN_CAT_ckpt_path, map_location="cpu")
        model.load_state_dict(checkpoint['model_state_dict'])
        data = data.transpose(1,2).float()

    elif model == "RNN":
        model = MMRNN(device='cpu')
        model.load_state_dict(torch.load(MMRNN_ckpt_path, map_location="cpu")['model_state_dict'])
        data = data.float()
    model.eval()
    outputs, predicted = predict(model, data, embed_notes, device='cpu')
    outputs = torch.sigmoid(outputs)[0]
    return {'Conduction Disturbance':round(outputs[0].item(),2), 'Hypertrophy':round(outputs[1].item(),2), 'Myocardial Infarction':round(outputs[2].item(),2), 'Normal ECG':round(outputs[3].item(),2), 'ST/T Change':round(outputs[4].item(),2)}

def run(model_name, header_file, data_file, notes):
    demo_dir = f"{CWD}/demo_data"
    hdr_dirname, hdr_basename = os.path.split(header_file.name)
    data_dirname, data_basename = os.path.split(data_file.name)
    shutil.copyfile(data_file.name, f"{demo_dir}/{data_basename}")
    shutil.copyfile(header_file.name, f"{demo_dir}/{hdr_basename}")
    data = preprocess(f"{demo_dir}/{hdr_basename.split('.')[0]}")
    ECG_graph = plot_ecg(f"{demo_dir}/{hdr_basename.split('.')[0]}")
    os.remove(f"{demo_dir}/{data_basename}")
    os.remove(f"{demo_dir}/{hdr_basename}")
    output = infer(model_name, data, notes)
    return output, ECG_graph

with gr.Blocks() as demo:
    with  gr.Row():
        model = gr.Radio(['CNN', 'RNN'], label= "Select Model")
    with gr.Row():
        with gr.Column(scale=1):
            header_file = gr.File(label = "header_file", file_types=[".hea"])
            data_file = gr.File(label = "data_file", file_types=[".dat"])
            notes = gr.Textbox(label = "Clinical Notes")
        with gr.Column(scale=1):
            output_prob = gr.Label({'Normal ECG':0, 'Myocardial Infarction':0, 'ST/T Change':0, 'Conduction Disturbance':0, 'Hypertrophy':0}, show_label=False)
    with gr.Row():
        ecg_graph = gr.Plot(label = "ECG Signal Visualisation")
    with gr.Row():    
        predict_btn = gr.Button("Predict Class")
        predict_btn.click(fn= run, inputs = [model, header_file, data_file, notes], outputs=[output_prob, ecg_graph])
    with gr.Row():    
        gr.Examples(examples=[[f"{CWD}/demo_data/test/00001_lr.hea", f"{CWD}/demo_data/test/00001_lr.dat", "sinusrhythmus periphere niederspannung"],\
                              [f"{CWD}/demo_data/test/00008_lr.hea", f"{CWD}/demo_data/test/00008_lr.dat", "sinusrhythmus linkstyp qrs(t) abnormal    inferiorer infarkt     alter unbest."], \
                              [f"{CWD}/demo_data/test/00045_lr.hea", f"{CWD}/demo_data/test/00045_lr.dat", "sinusrhythmus unvollstÄndiger rechtsschenkelblock sonst normales ekg"],\
                              [f"{CWD}/demo_data/test/00257_lr.hea", f"{CWD}/demo_data/test/00257_lr.dat", "premature atrial contraction(s). sinus rhythm. left atrial enlargement. qs complexes in v2. st segments are slightly elevated in v2,3. st segments are depressed in i, avl. t waves are low or flat in i, v5,6 and inverted in avl. consistent with ischaemic h"],\
                                ],
                    inputs = [header_file, data_file, notes])

if __name__ == "__main__":
    demo.launch()