ECG_Classification / utils /RNN_utils.py
Tej3's picture
Committing App
71bd54f
raw
history blame contribute delete
No virus
7.65 kB
import torch
import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm
import pywt
import os
def display_eval(epoch, epochs, tlength, global_step, tcorrect, tsamples, t_valid_samples, average_train_loss, average_valid_loss, total_acc_val):
tqdm.write(
f'Epoch: [{epoch + 1}/{epochs}], Step [{global_step}/{epochs*tlength}] | Train Loss: {average_train_loss: .3f} \
| Train Accuracy: {tcorrect / tsamples: .3f} \
| Val Loss: {average_valid_loss: .3f} \
| Val Accuracy: {total_acc_val / t_valid_samples: .3f}')
def save_model(model, optimizer, valid_loss, epoch, path='model.pt'):
torch.save({'valid_loss': valid_loss,
'model_state_dict': model.state_dict(),
'epoch': epoch + 1,
'optimizer': optimizer.state_dict()
}, path)
tqdm.write(f'Model saved to ==> {path}')
def save_metrics(train_loss_list, valid_loss_list, global_steps_list, path='metrics.pt'):
torch.save({'train_loss_list': train_loss_list,
'valid_loss_list': valid_loss_list,
'global_steps_list': global_steps_list,
}, path)
def plot_losses(metrics_save_name='metrics', save_dir='./'):
path = f'{save_dir}metrics_{metrics_save_name}.pt'
state = torch.load(path)
train_loss_list = state['train_loss_list']
valid_loss_list = state['valid_loss_list']
global_steps_list = state['global_steps_list']
plt.plot(global_steps_list, train_loss_list, label='Train')
plt.plot(global_steps_list, valid_loss_list, label='Valid')
plt.xlabel('Global Steps')
plt.ylabel('Loss')
plt.legend()
plt.show()
def train_RNN(epochs, train_loader, valid_loader, model, loss_fn, optimizer, eval_every=0.25, best_valid_loss=float("Inf"), device='cuda', model_save_name='', save_dir='./'):
model.train()
running_loss = 0.0
valid_running_loss = 0.0
global_step = 0
train_loss_list = []
valid_loss_list = []
global_steps_list = []
wavelet = 'db4'
level = 3
for epoch in tqdm(range(epochs)):
running_loss = 0.0
t_correct = 0
t_samples = 0
for images, labels, notes in train_loader:
optimizer.zero_grad()
coeffs = pywt.wavedec(images, wavelet, level=level, axis=1)
threshold = 0.1 * \
torch.median(torch.abs(torch.from_numpy(coeffs[-1])))
denoised_coeffs = [pywt.threshold(
data=c, mode='hard', value=threshold) for c in coeffs]
images = pywt.waverec(denoised_coeffs, wavelet, axis=1)
images = torch.tensor(images).float().to(device)
labels = labels.to(device)
notes = notes.to(device)
output = model(images, notes)
loss = loss_fn(output, labels.float())
running_loss += loss.item()*len(labels)
loss.backward()
global_step += 1*len(images)
optimizer.step()
values, indices = torch.max(output, dim=1)
t_correct += sum(1 for s, i in enumerate(indices)
if labels[s][i] == 1)
t_samples += len(indices)
if (global_step % (int(eval_every*len(train_loader.dataset)))) < train_loader.batch_size:
model.eval()
valid_running_loss = 0.0
total_acc_val = 0
with torch.no_grad():
for images, labels, notes in valid_loader:
coeffs = pywt.wavedec(
images, wavelet, level=level, axis=1)
threshold = 0.1 * \
torch.median(
torch.abs(torch.from_numpy(coeffs[-1])))
denoised_coeffs = [pywt.threshold(
data=c, mode='hard', value=threshold) for c in coeffs]
images = pywt.waverec(denoised_coeffs, wavelet, axis=1)
images = torch.tensor(images).float().to(device)
labels = labels.to(device)
notes = notes.to(device)
output = model(images, notes)
loss = loss_fn(output, labels.float()).item()
valid_running_loss += loss*len(images)
values, indices = torch.max(output, dim=1)
total_acc_val += sum(1 for s,
i in enumerate(indices) if labels[s][i] == 1)
# evaluation
average_train_loss = running_loss / t_samples
average_valid_loss = valid_running_loss / \
len(valid_loader.dataset)
train_loss_list.append(average_train_loss)
valid_loss_list.append(average_valid_loss)
global_steps_list.append(global_step)
display_eval(epoch, epochs, len(train_loader.dataset), global_step, t_correct, t_samples, len(
valid_loader.dataset), average_train_loss, average_valid_loss, total_acc_val)
# resetting running values
model.train()
if best_valid_loss > average_valid_loss:
best_valid_loss = average_valid_loss
save_model(model, optimizer, best_valid_loss, epoch,
path=f'{save_dir}model_{model_save_name}.pt')
save_metrics(train_loss_list, valid_loss_list,
global_steps_list, path=f'{save_dir}metrics_{model_save_name}.pt')
save_metrics(train_loss_list, valid_loss_list, global_steps_list,
path=f'{save_dir}metrics_{model_save_name}.pt')
print("Training complete.")
return model
def evaluate_RNN(model, test_loader, device="cuda"):
model.eval()
y_pred = []
y_true = []
wavelet = 'db4'
level = 3
total_acc_test = 0
with torch.no_grad():
for images, labels, notes in test_loader:
coeffs = pywt.wavedec(images, wavelet, level=level, axis=1)
threshold = 0.1 * \
torch.median(torch.abs(torch.from_numpy(coeffs[-1])))
denoised_coeffs = [pywt.threshold(
data=c, mode='hard', value=threshold) for c in coeffs]
images = pywt.waverec(denoised_coeffs, wavelet, axis=1)
images = torch.tensor(images).float().to(device)
labels = labels.to(device)
notes = notes.to(device)
output = model(images, notes)
values, indices = torch.max(output, dim=1)
y_pred.extend(indices.tolist())
y_true.extend(labels.tolist())
total_acc_test += sum(1 for s,
i in enumerate(indices) if labels[s][i] == 1)
test_accuracy = total_acc_test / len(test_loader.dataset)
print(f'Test Accuracy: {test_accuracy: .3f}')
return test_accuracy
def rename_with_acc(save_name, save_dir, acc):
acc = round(acc*100)
# Rename model
new_model_name = f'{save_dir}model_{save_name}_acc_{acc}.pt'
new_metrics_name = f'{save_dir}metrics_{save_name}_acc_{acc}.pt'
if os.path.isfile(new_model_name):
os.remove(new_model_name)
if os.path.isfile(new_metrics_name):
os.remove(new_metrics_name)
os.rename(f'{save_dir}model_{save_name}.pt',
f'{save_dir}model_{save_name}_acc_{acc}.pt')
# Rename metrics
os.rename(f'{save_dir}metrics_{save_name}.pt',
f'{save_dir}metrics_{save_name}_acc_{acc}.pt')