import torch import matplotlib.pyplot as plt from tqdm.autonotebook import tqdm import pywt import os def display_eval(epoch, epochs, tlength, global_step, tcorrect, tsamples, t_valid_samples, average_train_loss, average_valid_loss, total_acc_val): tqdm.write( f'Epoch: [{epoch + 1}/{epochs}], Step [{global_step}/{epochs*tlength}] | Train Loss: {average_train_loss: .3f} \ | Train Accuracy: {tcorrect / tsamples: .3f} \ | Val Loss: {average_valid_loss: .3f} \ | Val Accuracy: {total_acc_val / t_valid_samples: .3f}') def save_model(model, optimizer, valid_loss, epoch, path='model.pt'): torch.save({'valid_loss': valid_loss, 'model_state_dict': model.state_dict(), 'epoch': epoch + 1, 'optimizer': optimizer.state_dict() }, path) tqdm.write(f'Model saved to ==> {path}') def save_metrics(train_loss_list, valid_loss_list, global_steps_list, path='metrics.pt'): torch.save({'train_loss_list': train_loss_list, 'valid_loss_list': valid_loss_list, 'global_steps_list': global_steps_list, }, path) def plot_losses(metrics_save_name='metrics', save_dir='./'): path = f'{save_dir}metrics_{metrics_save_name}.pt' state = torch.load(path) train_loss_list = state['train_loss_list'] valid_loss_list = state['valid_loss_list'] global_steps_list = state['global_steps_list'] plt.plot(global_steps_list, train_loss_list, label='Train') plt.plot(global_steps_list, valid_loss_list, label='Valid') plt.xlabel('Global Steps') plt.ylabel('Loss') plt.legend() plt.show() def train_RNN(epochs, train_loader, valid_loader, model, loss_fn, optimizer, eval_every=0.25, best_valid_loss=float("Inf"), device='cuda', model_save_name='', save_dir='./'): model.train() running_loss = 0.0 valid_running_loss = 0.0 global_step = 0 train_loss_list = [] valid_loss_list = [] global_steps_list = [] wavelet = 'db4' level = 3 for epoch in tqdm(range(epochs)): running_loss = 0.0 t_correct = 0 t_samples = 0 for images, labels, notes in train_loader: optimizer.zero_grad() coeffs = pywt.wavedec(images, wavelet, level=level, axis=1) threshold = 0.1 * \ torch.median(torch.abs(torch.from_numpy(coeffs[-1]))) denoised_coeffs = [pywt.threshold( data=c, mode='hard', value=threshold) for c in coeffs] images = pywt.waverec(denoised_coeffs, wavelet, axis=1) images = torch.tensor(images).float().to(device) labels = labels.to(device) notes = notes.to(device) output = model(images, notes) loss = loss_fn(output, labels.float()) running_loss += loss.item()*len(labels) loss.backward() global_step += 1*len(images) optimizer.step() values, indices = torch.max(output, dim=1) t_correct += sum(1 for s, i in enumerate(indices) if labels[s][i] == 1) t_samples += len(indices) if (global_step % (int(eval_every*len(train_loader.dataset)))) < train_loader.batch_size: model.eval() valid_running_loss = 0.0 total_acc_val = 0 with torch.no_grad(): for images, labels, notes in valid_loader: coeffs = pywt.wavedec( images, wavelet, level=level, axis=1) threshold = 0.1 * \ torch.median( torch.abs(torch.from_numpy(coeffs[-1]))) denoised_coeffs = [pywt.threshold( data=c, mode='hard', value=threshold) for c in coeffs] images = pywt.waverec(denoised_coeffs, wavelet, axis=1) images = torch.tensor(images).float().to(device) labels = labels.to(device) notes = notes.to(device) output = model(images, notes) loss = loss_fn(output, labels.float()).item() valid_running_loss += loss*len(images) values, indices = torch.max(output, dim=1) total_acc_val += sum(1 for s, i in enumerate(indices) if labels[s][i] == 1) # evaluation average_train_loss = running_loss / t_samples average_valid_loss = valid_running_loss / \ len(valid_loader.dataset) train_loss_list.append(average_train_loss) valid_loss_list.append(average_valid_loss) global_steps_list.append(global_step) display_eval(epoch, epochs, len(train_loader.dataset), global_step, t_correct, t_samples, len( valid_loader.dataset), average_train_loss, average_valid_loss, total_acc_val) # resetting running values model.train() if best_valid_loss > average_valid_loss: best_valid_loss = average_valid_loss save_model(model, optimizer, best_valid_loss, epoch, path=f'{save_dir}model_{model_save_name}.pt') save_metrics(train_loss_list, valid_loss_list, global_steps_list, path=f'{save_dir}metrics_{model_save_name}.pt') save_metrics(train_loss_list, valid_loss_list, global_steps_list, path=f'{save_dir}metrics_{model_save_name}.pt') print("Training complete.") return model def evaluate_RNN(model, test_loader, device="cuda"): model.eval() y_pred = [] y_true = [] wavelet = 'db4' level = 3 total_acc_test = 0 with torch.no_grad(): for images, labels, notes in test_loader: coeffs = pywt.wavedec(images, wavelet, level=level, axis=1) threshold = 0.1 * \ torch.median(torch.abs(torch.from_numpy(coeffs[-1]))) denoised_coeffs = [pywt.threshold( data=c, mode='hard', value=threshold) for c in coeffs] images = pywt.waverec(denoised_coeffs, wavelet, axis=1) images = torch.tensor(images).float().to(device) labels = labels.to(device) notes = notes.to(device) output = model(images, notes) values, indices = torch.max(output, dim=1) y_pred.extend(indices.tolist()) y_true.extend(labels.tolist()) total_acc_test += sum(1 for s, i in enumerate(indices) if labels[s][i] == 1) test_accuracy = total_acc_test / len(test_loader.dataset) print(f'Test Accuracy: {test_accuracy: .3f}') return test_accuracy def rename_with_acc(save_name, save_dir, acc): acc = round(acc*100) # Rename model new_model_name = f'{save_dir}model_{save_name}_acc_{acc}.pt' new_metrics_name = f'{save_dir}metrics_{save_name}_acc_{acc}.pt' if os.path.isfile(new_model_name): os.remove(new_model_name) if os.path.isfile(new_metrics_name): os.remove(new_metrics_name) os.rename(f'{save_dir}model_{save_name}.pt', f'{save_dir}model_{save_name}_acc_{acc}.pt') # Rename metrics os.rename(f'{save_dir}metrics_{save_name}.pt', f'{save_dir}metrics_{save_name}_acc_{acc}.pt')