shuklaji9810 commited on
Commit
a3a6cb5
1 Parent(s): 7bf182f

addition of test file

Browse files
Files changed (2) hide show
  1. .gitignore +0 -1
  2. test_image.py +162 -0
.gitignore CHANGED
@@ -42,5 +42,4 @@ Thumbs.db
42
 
43
 
44
  ___pycache__/
45
- test_image.py
46
  *.pyc
 
42
 
43
 
44
  ___pycache__/
 
45
  *.pyc
test_image.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.utils.data import DataLoader
6
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
7
+ from torch.optim.lr_scheduler import CosineAnnealingLR
8
+ from tqdm import tqdm
9
+ import warnings
10
+ warnings.filterwarnings("ignore")
11
+ import cv2
12
+ import numpy as np
13
+ import matplotlib.pyplot as plt
14
+ import pywt
15
+
16
+ from utils.config import cfg
17
+ from dataset.real_n_fake_dataloader import Extracted_Frames_Dataset
18
+ from utils.data_transforms import get_transforms_train, get_transforms_val
19
+ from net.Multimodalmodel import Image_n_DCT
20
+
21
+
22
+
23
+ import os
24
+ import json
25
+ import torch
26
+ from torchvision import transforms
27
+ from torch.utils.data import DataLoader, Dataset
28
+ from PIL import Image
29
+ import numpy as np
30
+ import pandas as pd
31
+ import cv2
32
+ import argparse
33
+
34
+ class Test_Dataset(Dataset):
35
+ def __init__(self, test_data_path = None, transform = None, image_path = None, multi_modal = "dct"):
36
+ """
37
+ Args:
38
+ returns:
39
+ """
40
+ self.multi_modal = multi_modal
41
+ if test_data_path is None and image_path is not None:
42
+ self.dataset = [[image_path, 2]]
43
+ self.transform = transform
44
+
45
+ else:
46
+ self.transform = transform
47
+
48
+ self.real_data = os.listdir(test_data_path + "/real")
49
+ self.fake_data = os.listdir(test_data_path + "/fake")
50
+ self.dataset = []
51
+ for image in self.real_data:
52
+ self.dataset.append([test_data_path + "/real/" + image, 1])
53
+
54
+ for image in self.fake_data:
55
+ self.dataset.append([test_data_path + "/fake/" + image, 0])
56
+
57
+ def __len__(self):
58
+ return len(self.dataset)
59
+
60
+ def __getitem__(self, idx):
61
+ sample_input = self.get_sample_input(idx)
62
+ return sample_input
63
+
64
+ def get_sample_input(self, idx):
65
+ rgb_image = self.get_rgb_image(idx)
66
+ label = self.get_label(idx)
67
+ if self.multi_modal == "dct":
68
+ dct_image = self.get_dct_image(idx)
69
+ sample_input = {"rgb_image": rgb_image, "dct_image": dct_image, "label": label}
70
+
71
+ # dct_image = self.get_dct_image(idx)
72
+ elif self.multi_modal == "fft":
73
+ fft_image = self.get_fft_image(idx)
74
+ sample_input = {"rgb_image": rgb_image, "dct_image": fft_image, "label": label}
75
+ elif self.multi_modal == "hh":
76
+ hh_image = self.get_hh_image(idx)
77
+ sample_input = {"rgb_image": rgb_image, "dct_image": hh_image, "label": label}
78
+ else:
79
+ AssertionError("multi_modal must be one of (dct:discrete cosine transform, fft: fast forier transform, hh)")
80
+
81
+ return sample_input
82
+
83
+
84
+ def get_fft_image(self, idx):
85
+ gray_image_path = self.dataset[idx][0]
86
+ gray_image = cv2.imread(gray_image_path, cv2.IMREAD_GRAYSCALE)
87
+ fft_image = self.compute_fft(gray_image)
88
+ if self.transform:
89
+ fft_image = self.transform(fft_image)
90
+
91
+ return fft_image
92
+
93
+
94
+ def compute_fft(self, image):
95
+ f = np.fft.fft2(image)
96
+ fshift = np.fft.fftshift(f)
97
+ magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1) # Add 1 to avoid log(0)
98
+ return magnitude_spectrum
99
+
100
+
101
+ def get_hh_image(self, idx):
102
+ gray_image_path = self.dataset[idx][0]
103
+ gray_image = cv2.imread(gray_image_path, cv2.IMREAD_GRAYSCALE)
104
+ hh_image = self.compute_hh(gray_image)
105
+ if self.transform:
106
+ hh_image = self.transform(hh_image)
107
+ return hh_image
108
+
109
+ def compute_hh(self, image):
110
+ coeffs2 = pywt.dwt2(image, 'haar')
111
+ LL, (LH, HL, HH) = coeffs2
112
+ return HH
113
+
114
+ def get_rgb_image(self, idx):
115
+ rgb_image_path = self.dataset[idx][0]
116
+ rgb_image = Image.open(rgb_image_path)
117
+ if self.transform:
118
+ rgb_image = self.transform(rgb_image)
119
+ return rgb_image
120
+
121
+ def get_dct_image(self, idx):
122
+ rgb_image_path = self.dataset[idx][0]
123
+ rgb_image = cv2.imread(rgb_image_path)
124
+ dct_image = self.compute_dct_color(rgb_image)
125
+ if self.transform:
126
+ dct_image = self.transform(dct_image)
127
+
128
+ return dct_image
129
+
130
+ def get_label(self, idx):
131
+ return self.dataset[idx][1]
132
+
133
+
134
+ def compute_dct_color(self, image):
135
+ image_float = np.float32(image)
136
+ dct_image = np.zeros_like(image_float)
137
+ for i in range(3):
138
+ dct_image[:, :, i] = cv2.dct(image_float[:, :, i])
139
+ return dct_image
140
+
141
+
142
+ class Test:
143
+ def __init__(self, model_path, multi_modal = "dct"):
144
+ self.model_path = model_path
145
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
146
+ print(self.device)
147
+ # Load the model
148
+ self.model = Image_n_DCT()
149
+ self.model.load_state_dict(torch.load(self.model_path, map_location = self.device))
150
+ self.model.to(self.device)
151
+ self.model.eval()
152
+ self.multi_modal = multi_modal
153
+
154
+
155
+ def testimage(self, image_path):
156
+ test_dataset = Test_Dataset(transform = get_transforms_val(), image_path = image_path, multi_modal = self.multi_modal)
157
+ inputs = test_dataset[0]
158
+ rgb_image, dct_image = inputs['rgb_image'].to(self.device), inputs['dct_image'].to(self.device)
159
+ output = self.model(rgb_image.unsqueeze(0), dct_image.unsqueeze(0))
160
+ # print(output.shape)
161
+ _, predicted = torch.max(output.data, 1)
162
+ return 'real' if predicted==1 else 'fake'