File size: 5,552 Bytes
a3a6cb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
import cv2
import numpy as np
import matplotlib.pyplot as plt
import pywt

from utils.config import cfg
from dataset.real_n_fake_dataloader import Extracted_Frames_Dataset
from utils.data_transforms import get_transforms_train, get_transforms_val
from net.Multimodalmodel import Image_n_DCT



import os
import json
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np
import pandas as pd
import cv2
import argparse

class Test_Dataset(Dataset):
    def __init__(self, test_data_path = None, transform = None, image_path = None, multi_modal = "dct"):
        """
        Args:   
        returns:
            """
        self.multi_modal = multi_modal
        if test_data_path is None and image_path is not None:
            self.dataset = [[image_path, 2]]
            self.transform = transform

        else:
            self.transform = transform
            
            self.real_data = os.listdir(test_data_path + "/real")
            self.fake_data = os.listdir(test_data_path + "/fake")
            self.dataset = []
            for image in self.real_data:
                self.dataset.append([test_data_path + "/real/" + image, 1])

            for image in self.fake_data:
                self.dataset.append([test_data_path + "/fake/" + image, 0])
                
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample_input = self.get_sample_input(idx)
        return sample_input
    
    def get_sample_input(self, idx):
        rgb_image = self.get_rgb_image(idx)
        label = self.get_label(idx) 
        if self.multi_modal == "dct":
            dct_image = self.get_dct_image(idx)
            sample_input = {"rgb_image": rgb_image, "dct_image": dct_image, "label": label}

        # dct_image = self.get_dct_image(idx)
        elif self.multi_modal == "fft":
            fft_image = self.get_fft_image(idx)
            sample_input = {"rgb_image": rgb_image, "dct_image": fft_image, "label": label}
        elif self.multi_modal == "hh":
            hh_image = self.get_hh_image(idx)
            sample_input = {"rgb_image": rgb_image, "dct_image": hh_image, "label": label}
        else:
            AssertionError("multi_modal must be one of (dct:discrete cosine transform, fft: fast forier transform, hh)")

        return sample_input

    
    def get_fft_image(self, idx):
        gray_image_path = self.dataset[idx][0]
        gray_image = cv2.imread(gray_image_path, cv2.IMREAD_GRAYSCALE)
        fft_image = self.compute_fft(gray_image)
        if self.transform:
            fft_image = self.transform(fft_image)
        
        return fft_image

    
    def compute_fft(self, image):
        f = np.fft.fft2(image)
        fshift = np.fft.fftshift(f)
        magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1)  # Add 1 to avoid log(0)
        return magnitude_spectrum


    def get_hh_image(self, idx):
        gray_image_path = self.dataset[idx][0]
        gray_image = cv2.imread(gray_image_path, cv2.IMREAD_GRAYSCALE)
        hh_image = self.compute_hh(gray_image)
        if self.transform:
            hh_image = self.transform(hh_image)
        return hh_image
    
    def compute_hh(self, image):
        coeffs2 = pywt.dwt2(image, 'haar')
        LL, (LH, HL, HH) = coeffs2
        return HH
        
    def get_rgb_image(self, idx):
        rgb_image_path = self.dataset[idx][0]
        rgb_image = Image.open(rgb_image_path)
        if self.transform:
            rgb_image = self.transform(rgb_image)
        return rgb_image
    
    def get_dct_image(self, idx):
        rgb_image_path = self.dataset[idx][0]
        rgb_image = cv2.imread(rgb_image_path)
        dct_image = self.compute_dct_color(rgb_image)
        if self.transform:
            dct_image = self.transform(dct_image)
        
        return dct_image
    
    def get_label(self, idx):
        return self.dataset[idx][1]
    

    def compute_dct_color(self, image):
        image_float = np.float32(image)
        dct_image = np.zeros_like(image_float)
        for i in range(3):  
            dct_image[:, :, i] = cv2.dct(image_float[:, :, i])
        return dct_image

    
class Test:
    def __init__(self, model_path, multi_modal = "dct"):
        self.model_path = model_path
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(self.device)
        # Load the model
        self.model = Image_n_DCT()
        self.model.load_state_dict(torch.load(self.model_path, map_location = self.device))
        self.model.to(self.device)
        self.model.eval()
        self.multi_modal = multi_modal


    def testimage(self, image_path):
        test_dataset = Test_Dataset(transform = get_transforms_val(), image_path = image_path, multi_modal = self.multi_modal)
        inputs = test_dataset[0]
        rgb_image, dct_image = inputs['rgb_image'].to(self.device), inputs['dct_image'].to(self.device)
        output = self.model(rgb_image.unsqueeze(0), dct_image.unsqueeze(0))
        # print(output.shape)
        _, predicted = torch.max(output.data, 1)
        return 'real' if predicted==1 else 'fake'