shuklaji9810 commited on
Commit
5e014de
1 Parent(s): c329af9

first commit

Browse files
.gitignore ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ working.ipynb
2
+ training.py
3
+
4
+ # Compiled source #
5
+ ###################
6
+ *.com
7
+ *.class
8
+ *.dll
9
+ *.exe
10
+ *.o
11
+ *.so
12
+
13
+ # Packages #
14
+ ############
15
+ # it's better to unpack these files and commit the raw source because
16
+ # git has its own built in compression methods
17
+ *.7z
18
+ *.dmg
19
+ *.gz
20
+ *.iso
21
+ *.jar
22
+ *.rar
23
+ *.tar
24
+ *.zip
25
+
26
+ # Logs and databases #
27
+ ######################
28
+ *.log
29
+ *.sql
30
+ *.sqlite
31
+
32
+ # OS generated files #
33
+ ######################
34
+ .DS_Store
35
+ .DS_Store?
36
+ ._*
37
+ .Spotlight-V100
38
+ .Trashes
39
+ ehthumbs.db
40
+ Thumbs.db
41
+
42
+
43
+
44
+ ___pycache__/
45
+ test_image.py
46
+ *.pyc
Examples/DeepFakes_10.png ADDED
Examples/DeepFakes_2.png ADDED
Examples/DeepFakes_4.png ADDED
Examples/DeepFakes_8.png ADDED
Examples/DeepFakes_9.png ADDED
Examples/SimSwap_8.png ADDED
Examples/StyleGAN_7.png ADDED
Examples/o_11.jpg ADDED
Examples/o_3.jpg ADDED
Examples/o_5.jpg ADDED
Examples/o_6.jpg ADDED
Examples/o_7.jpg ADDED
app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import numpy as np
4
+ import os
5
+ from face_cropper import detect_and_label_faces
6
+ # Define a custom function to convert an image to grayscale
7
+ def to_grayscale(input_image):
8
+ grayscale_image = Image.fromarray(np.array(input_image).mean(axis=-1).astype(np.uint8))
9
+ return grayscale_image
10
+
11
+
12
+ description_markdown = """
13
+ # Fake Face Detection tool from TrustWorthy BiometraVision Lab IISER Bhopal
14
+
15
+ ## Usage
16
+ This tool expects a face image as input. Upon submission, it will process the image and provide an output with bounding boxes drawn on the face. Alongside the visual markers, the tool will give a detection result indicating whether the face is fake or real.
17
+
18
+ ## Disclaimer
19
+ Please note that this tool is for research purposes only and may not always be 100% accurate. Users are advised to exercise discretion and supervise the tool's usage accordingly.
20
+
21
+ ## Licensing and Permissions
22
+ This tool has been developed solely for research and demonstrative purposes. Any commercial utilization of this tool is strictly prohibited unless explicit permission has been obtained from the developers.
23
+
24
+ ## Developer Contact
25
+ For further inquiries or permissions, you can reach out to the developer through the following social media accounts:
26
+ - [LAB Webpage](https://sites.google.com/iiitd.ac.in/agarwalakshay/labiiserb?authuser=0)
27
+ - [LinkedIn](https://www.linkedin.com/in/shivam-shukla-0a50ab1a2/)
28
+ - [GitHub](https://github.com/SaShukla090)
29
+ """
30
+
31
+
32
+
33
+
34
+ # Create the Gradio app
35
+ app = gr.Interface(
36
+ fn=detect_and_label_faces,
37
+ inputs=gr.Image(type="pil"),
38
+ outputs="image",
39
+ # examples=[
40
+ # "path_to_example_image_1.jpg",
41
+ # "path_to_example_image_2.jpg"
42
+ # ]
43
+ examples=[
44
+ os.path.join("Examples", image_name) for image_name in os.listdir("Examples")
45
+ ],
46
+ title="Fake Face Detection",
47
+ description=description_markdown,
48
+ )
49
+
50
+ # Run the app
51
+ app.launch()
52
+
53
+
54
+
55
+
56
+
57
+
58
+
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+
67
+
68
+
69
+
70
+
71
+
72
+
73
+
74
+
75
+
76
+
77
+
78
+
79
+
80
+
81
+
82
+
83
+
84
+
85
+ # import torch.nn.functional as F
86
+ # import torch
87
+ # import torch.nn as nn
88
+ # import torch.optim as optim
89
+ # from torch.utils.data import DataLoader
90
+ # from sklearn.metrics import accuracy_score, precision_recall_fscore_support
91
+ # from torch.optim.lr_scheduler import CosineAnnealingLR
92
+ # from tqdm import tqdm
93
+ # import warnings
94
+ # warnings.filterwarnings("ignore")
95
+
96
+ # from utils.config import cfg
97
+ # from dataset.real_n_fake_dataloader import Extracted_Frames_Dataset
98
+ # from utils.data_transforms import get_transforms_train, get_transforms_val
99
+ # from net.Multimodalmodel import Image_n_DCT
100
+ # import gradio as gr
101
+
102
+
103
+
104
+
105
+ # import os
106
+ # import json
107
+ # import torch
108
+ # from torchvision import transforms
109
+ # from torch.utils.data import DataLoader, Dataset
110
+ # from PIL import Image
111
+ # import numpy as np
112
+ # import pandas as pd
113
+ # import cv2
114
+ # import argparse
115
+
116
+
117
+
118
+
119
+
120
+
121
+ # from sklearn.metrics import classification_report, confusion_matrix
122
+ # import matplotlib.pyplot as plt
123
+ # import seaborn as sns
124
+
125
+
126
+
127
+
128
+
129
+ # class Test_Dataset(Dataset):
130
+ # def __init__(self, test_data_path = None, transform = None, image = None):
131
+ # """
132
+ # Args:
133
+ # returns:
134
+ # """
135
+
136
+ # if test_data_path is None and image is not None:
137
+ # self.dataset = [(image, 2)]
138
+ # self.transform = transform
139
+
140
+ # def __len__(self):
141
+ # return len(self.dataset)
142
+
143
+ # def __getitem__(self, idx):
144
+ # sample_input = self.get_sample_input(idx)
145
+ # return sample_input
146
+
147
+
148
+ # def get_sample_input(self, idx):
149
+ # rgb_image = self.get_rgb_image(self.dataset[idx][0])
150
+ # dct_image = self.compute_dct_color(self.dataset[idx][0])
151
+ # # label = self.get_label(idx)
152
+ # sample_input = {"rgb_image": rgb_image, "dct_image": dct_image}
153
+
154
+ # return sample_input
155
+
156
+
157
+ # def get_rgb_image(self, rgb_image):
158
+ # # rgb_image_path = self.dataset[idx][0]
159
+ # # rgb_image = Image.open(rgb_image_path)
160
+ # if self.transform:
161
+ # rgb_image = self.transform(rgb_image)
162
+ # return rgb_image
163
+
164
+ # def get_dct_image(self, idx):
165
+ # rgb_image_path = self.dataset[idx][0]
166
+ # rgb_image = cv2.imread(rgb_image_path)
167
+ # dct_image = self.compute_dct_color(rgb_image)
168
+ # if self.transform:
169
+ # dct_image = self.transform(dct_image)
170
+
171
+ # return dct_image
172
+
173
+ # def get_label(self, idx):
174
+ # return self.dataset[idx][1]
175
+
176
+
177
+ # def compute_dct_color(self, image):
178
+ # image_float = np.float32(image)
179
+ # dct_image = np.zeros_like(image_float)
180
+ # for i in range(3):
181
+ # dct_image[:, :, i] = cv2.dct(image_float[:, :, i])
182
+ # if self.transform:
183
+ # dct_image = self.transform(dct_image)
184
+ # return dct_image
185
+
186
+
187
+ # device = torch.device("cpu")
188
+ # # print(device)
189
+ # model = Image_n_DCT()
190
+ # model.load_state_dict(torch.load('weights/best_model.pth', map_location = device))
191
+ # model.to(device)
192
+ # model.eval()
193
+
194
+
195
+ # def classify(image):
196
+ # test_dataset = Test_Dataset(transform = get_transforms_val(), image = image)
197
+ # inputs = test_dataset[0]
198
+ # rgb_image, dct_image = inputs['rgb_image'].to(device), inputs['dct_image'].to(device)
199
+ # output = model(rgb_image.unsqueeze(0), dct_image.unsqueeze(0))
200
+ # # _, predicted = torch.max(output.data, 1)
201
+ # # print(f"the face is {'real' if predicted==1 else 'fake'}")
202
+ # return {'Fake': output[0][0], 'Real': output[0][1]}
203
+
204
+ # iface = gr.Interface(fn=classify, inputs="image", outputs="label")
205
+ # if __name__ == "__main__":
206
+ # iface.launch()
dataset/real_n_fake_dataloader.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # We will use this file to create a dataloader for the real and fake dataset
2
+ import os
3
+ import json
4
+ import torch
5
+ from torchvision import transforms
6
+ from torch.utils.data import DataLoader, Dataset
7
+ from PIL import Image
8
+ import numpy as np
9
+ import pandas as pd
10
+ import cv2
11
+
12
+ import cv2
13
+ import numpy as np
14
+ import matplotlib.pyplot as plt
15
+ import pywt
16
+
17
+ class Extracted_Frames_Dataset(Dataset):
18
+ def __init__(self, root_dir, split = "train", transform = None, extend = 'None', multi_modal = "dct"):
19
+ """
20
+ Args:
21
+ returns:
22
+ """
23
+ AssertionError(split in ["train", "val", "test"]), "Split must be one of (train, val, test)"
24
+ self.multi_modal = multi_modal
25
+ self.root_dir = root_dir
26
+ self.split = split
27
+ self.transform = transform
28
+ if extend == 'faceswap':
29
+ self.dataset = pd.read_csv(os.path.join(root_dir, f"faceswap_extended_{self.split}.csv"))
30
+ elif extend == 'fsgan':
31
+ self.dataset = pd.read_csv(os.path.join(root_dir, f"fsgan_extended_{self.split}.csv"))
32
+ else:
33
+ self.dataset = pd.read_csv(os.path.join(root_dir, f"{self.split}.csv"))
34
+
35
+
36
+ def __len__(self):
37
+ return len(self.dataset)
38
+
39
+ def __getitem__(self, idx):
40
+ sample_input = self.get_sample_input(idx)
41
+ return sample_input
42
+
43
+
44
+ def get_sample_input(self, idx):
45
+ rgb_image = self.get_rgb_image(idx)
46
+ label = self.get_label(idx)
47
+ if self.multi_modal == "dct":
48
+ dct_image = self.get_dct_image(idx)
49
+ sample_input = {"rgb_image": rgb_image, "dct_image": dct_image, "label": label}
50
+
51
+ # dct_image = self.get_dct_image(idx)
52
+ elif self.multi_modal == "fft":
53
+ fft_image = self.get_fft_image(idx)
54
+ sample_input = {"rgb_image": rgb_image, "dct_image": fft_image, "label": label}
55
+ elif self.multi_modal == "hh":
56
+ hh_image = self.get_hh_image(idx)
57
+ sample_input = {"rgb_image": rgb_image, "dct_image": hh_image, "label": label}
58
+ else:
59
+ AssertionError("multi_modal must be one of (dct:discrete cosine transform, fft: fast forier transform, hh)")
60
+
61
+ return sample_input
62
+
63
+
64
+ def get_fft_image(self, idx):
65
+ gray_image_path = self.dataset.iloc[idx, 0]
66
+ gray_image = cv2.imread(gray_image_path, cv2.IMREAD_GRAYSCALE)
67
+ fft_image = self.compute_fft(gray_image)
68
+ if self.transform:
69
+ fft_image = self.transform(fft_image)
70
+
71
+ return fft_image
72
+
73
+
74
+ def compute_fft(self, image):
75
+ f = np.fft.fft2(image)
76
+ fshift = np.fft.fftshift(f)
77
+ magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1) # Add 1 to avoid log(0)
78
+ return magnitude_spectrum
79
+
80
+
81
+ def get_hh_image(self, idx):
82
+ gray_image_path = self.dataset.iloc[idx, 0]
83
+ gray_image = cv2.imread(gray_image_path, cv2.IMREAD_GRAYSCALE)
84
+ hh_image = self.compute_hh(gray_image)
85
+ if self.transform:
86
+ hh_image = self.transform(hh_image)
87
+ return hh_image
88
+
89
+ def compute_hh(self, image):
90
+ coeffs2 = pywt.dwt2(image, 'haar')
91
+ LL, (LH, HL, HH) = coeffs2
92
+ return HH
93
+
94
+ def get_rgb_image(self, idx):
95
+ rgb_image_path = self.dataset.iloc[idx, 0]
96
+ rgb_image = Image.open(rgb_image_path)
97
+ if self.transform:
98
+ rgb_image = self.transform(rgb_image)
99
+ return rgb_image
100
+
101
+ def get_dct_image(self, idx):
102
+ rgb_image_path = self.dataset.iloc[idx, 0]
103
+ rgb_image = cv2.imread(rgb_image_path)
104
+ dct_image = self.compute_dct_color(rgb_image)
105
+ if self.transform:
106
+ dct_image = self.transform(dct_image)
107
+
108
+ return dct_image
109
+
110
+ def get_label(self, idx):
111
+ return self.dataset.iloc[idx, 1]
112
+
113
+
114
+ def compute_dct_color(self, image):
115
+ image_float = np.float32(image)
116
+ dct_image = np.zeros_like(image_float)
117
+ for i in range(3):
118
+ dct_image[:, :, i] = cv2.dct(image_float[:, :, i])
119
+ return dct_image
face_cropper.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import mediapipe as mp
3
+ import os
4
+ from gradio_client import Client
5
+ from test_image_fusion import Test
6
+ from test_image_fusion import Test
7
+ import numpy as np
8
+
9
+
10
+
11
+ from PIL import Image
12
+ import numpy as np
13
+ import cv2
14
+
15
+ # client = Client("https://tbvl-real-and-fake-face-detection.hf.space/--replicas/40d41jxhhx/")
16
+
17
+ data = 'faceswap'
18
+ dct = 'fft'
19
+
20
+
21
+ testet = Test(model_paths = [f"weights/{data}-hh-best_model.pth",
22
+ f"weights/{data}-fft-best_model.pth"],
23
+ multi_modal = ['hh', 'fft'])
24
+
25
+ # Initialize MediaPipe Face Detection
26
+ mp_face_detection = mp.solutions.face_detection
27
+ mp_drawing = mp.solutions.drawing_utils
28
+ face_detection = mp_face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.35)
29
+
30
+ # Create a directory to save the cropped face images if it does not exist
31
+ save_dir = "cropped_faces"
32
+ os.makedirs(save_dir, exist_ok=True)
33
+
34
+ # def detect_and_label_faces(image_path):
35
+
36
+
37
+ # Function to crop faces from a video and save them as images
38
+ # def crop_faces_from_video(video_path):
39
+ # # Read the video
40
+ # cap = cv2.VideoCapture(video_path)
41
+ # frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
42
+ # frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
43
+ # fps = int(cap.get(cv2.CAP_PROP_FPS))
44
+ # total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
45
+
46
+ # # Define the codec and create VideoWriter object
47
+ # out = cv2.VideoWriter(f'output_{real}_{data}_fusion.avi', cv2.VideoWriter_fourcc('M','J','P','G'), fps, (frame_width, frame_height))
48
+
49
+ # if not cap.isOpened():
50
+ # print("Error: Could not open video.")
51
+ # return
52
+ # Convert PIL Image to NumPy array for OpenCV
53
+ def pil_to_opencv(pil_image):
54
+ open_cv_image = np.array(pil_image)
55
+ # Convert RGB to BGR for OpenCV
56
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
57
+ return open_cv_image
58
+
59
+ # Convert OpenCV NumPy array to PIL Image
60
+ def opencv_to_pil(opencv_image):
61
+ # Convert BGR to RGB
62
+ pil_image = Image.fromarray(opencv_image[:, :, ::-1])
63
+ return pil_image
64
+
65
+
66
+
67
+
68
+ def detect_and_label_faces(frame):
69
+ frame = pil_to_opencv(frame)
70
+
71
+
72
+ print(type(frame))
73
+ # Convert the frame to RGB
74
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
75
+ # Perform face detection
76
+ results = face_detection.process(frame_rgb)
77
+
78
+ # If faces are detected, crop and save each face as an image
79
+ if results.detections:
80
+ for face_count,detection in enumerate(results.detections):
81
+ bboxC = detection.location_data.relative_bounding_box
82
+ ih, iw, _ = frame.shape
83
+ x, y, w, h = int(bboxC.xmin * iw), int(bboxC.ymin * ih), int(bboxC.width * iw), int(bboxC.height * ih)
84
+ # Crop the face region and make sure the bounding box is within the frame dimensions
85
+ crop_img = frame[max(0, y):min(ih, y+h), max(0, x):min(iw, x+w)]
86
+ if crop_img.size > 0:
87
+ face_filename = os.path.join(save_dir, f'face_{face_count}.jpg')
88
+ cv2.imwrite(face_filename, crop_img)
89
+
90
+ label = testet.testimage(face_filename)
91
+
92
+ if os.path.exists(face_filename):
93
+ os.remove(face_filename)
94
+
95
+ color = (0, 0, 255) if label == 'fake' else (0, 255, 0)
96
+ cv2.rectangle(frame, (x, y), (x + w, y + h), color, 2)
97
+ cv2.putText(frame, label, (x, y + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
98
+ return opencv_to_pil(frame)
99
+
net/Multimodalmodel.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from utils.config import cfg
5
+ from utils.basicblocks import BasicBlock
6
+ from utils.feature_fusion_block import DCT_Attention_Fusion_Conv
7
+ from utils.classifier import ClassifierModel
8
+
9
+ class Image_n_DCT(nn.Module):
10
+ def __init__(self,):
11
+ super(Image_n_DCT, self).__init__()
12
+ self.Img_Block = nn.ModuleList()
13
+ self.DCT_Block = nn.ModuleList()
14
+ self.RGB_n_DCT_Fusion = nn.ModuleList()
15
+ self.num_classes = len(cfg.CLASSES)
16
+
17
+
18
+
19
+ for i in range(len(cfg.MULTIMODAL_FUSION.IMG_CHANNELS) - 1):
20
+ self.Img_Block.append(BasicBlock(cfg.MULTIMODAL_FUSION.IMG_CHANNELS[i], cfg.MULTIMODAL_FUSION.IMG_CHANNELS[i+1], stride=1))
21
+ self.DCT_Block.append(BasicBlock(cfg.MULTIMODAL_FUSION.DCT_CHANNELS[i], cfg.MULTIMODAL_FUSION.IMG_CHANNELS[i+1], stride=1))
22
+ self.RGB_n_DCT_Fusion.append(DCT_Attention_Fusion_Conv(cfg.MULTIMODAL_FUSION.IMG_CHANNELS[i+1]))
23
+
24
+
25
+ self.classifier = ClassifierModel(self.num_classes)
26
+
27
+
28
+
29
+ def forward(self, rgb_image, dct_image):
30
+ image = [rgb_image]
31
+ dct_image = [dct_image]
32
+
33
+ for i in range(len(self.Img_Block)):
34
+ image.append(self.Img_Block[i](image[-1]))
35
+ dct_image.append(self.DCT_Block[i](dct_image[-1]))
36
+ image[-1] = self.RGB_n_DCT_Fusion[i](image[-1], dct_image[-1])
37
+ dct_image[-1] = image[-1]
38
+ out = self.classifier(image[-1])
39
+
40
+ return out
41
+
test_image_fusion.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.utils.data import DataLoader
6
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
7
+ from torch.optim.lr_scheduler import CosineAnnealingLR
8
+ from tqdm import tqdm
9
+ import warnings
10
+ warnings.filterwarnings("ignore")
11
+ import cv2
12
+ import numpy as np
13
+ import matplotlib.pyplot as plt
14
+ import pywt
15
+
16
+ from utils.config import cfg
17
+ from dataset.real_n_fake_dataloader import Extracted_Frames_Dataset
18
+ from utils.data_transforms import get_transforms_train, get_transforms_val
19
+ from net.Multimodalmodel import Image_n_DCT
20
+
21
+
22
+
23
+ import os
24
+ import json
25
+ import torch
26
+ from torchvision import transforms
27
+ from torch.utils.data import DataLoader, Dataset
28
+ from PIL import Image
29
+ import numpy as np
30
+ import pandas as pd
31
+ import cv2
32
+ import argparse
33
+
34
+ class Test_Dataset(Dataset):
35
+ def __init__(self, test_data_path = None, transform = None, image_path = None, multi_modal = "dct"):
36
+ """
37
+ Args:
38
+ returns:
39
+ """
40
+ self.multi_modal = multi_modal
41
+ if test_data_path is None and image_path is not None:
42
+ self.dataset = [[image_path, 2]]
43
+ self.transform = transform
44
+
45
+ else:
46
+ self.transform = transform
47
+
48
+ self.real_data = os.listdir(test_data_path + "/real")
49
+ self.fake_data = os.listdir(test_data_path + "/fake")
50
+ self.dataset = []
51
+ for image in self.real_data:
52
+ self.dataset.append([test_data_path + "/real/" + image, 1])
53
+
54
+ for image in self.fake_data:
55
+ self.dataset.append([test_data_path + "/fake/" + image, 0])
56
+
57
+ def __len__(self):
58
+ return len(self.dataset)
59
+
60
+ def __getitem__(self, idx):
61
+ sample_input = self.get_sample_input(idx)
62
+ return sample_input
63
+
64
+ def get_sample_input(self, idx):
65
+ rgb_image = self.get_rgb_image(idx)
66
+ label = self.get_label(idx)
67
+ if self.multi_modal == "dct":
68
+ dct_image = self.get_dct_image(idx)
69
+ sample_input = {"rgb_image": rgb_image, "dct_image": dct_image, "label": label}
70
+
71
+ # dct_image = self.get_dct_image(idx)
72
+ elif self.multi_modal == "fft":
73
+ fft_image = self.get_fft_image(idx)
74
+ sample_input = {"rgb_image": rgb_image, "dct_image": fft_image, "label": label}
75
+ elif self.multi_modal == "hh":
76
+ hh_image = self.get_hh_image(idx)
77
+ sample_input = {"rgb_image": rgb_image, "dct_image": hh_image, "label": label}
78
+ else:
79
+ AssertionError("multi_modal must be one of (dct:discrete cosine transform, fft: fast forier transform, hh)")
80
+
81
+ return sample_input
82
+
83
+
84
+ def get_fft_image(self, idx):
85
+ gray_image_path = self.dataset[idx][0]
86
+ gray_image = cv2.imread(gray_image_path, cv2.IMREAD_GRAYSCALE)
87
+ fft_image = self.compute_fft(gray_image)
88
+ if self.transform:
89
+ fft_image = self.transform(fft_image)
90
+
91
+ return fft_image
92
+
93
+
94
+ def compute_fft(self, image):
95
+ f = np.fft.fft2(image)
96
+ fshift = np.fft.fftshift(f)
97
+ magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1) # Add 1 to avoid log(0)
98
+ return magnitude_spectrum
99
+
100
+
101
+ def get_hh_image(self, idx):
102
+ gray_image_path = self.dataset[idx][0]
103
+ gray_image = cv2.imread(gray_image_path, cv2.IMREAD_GRAYSCALE)
104
+ hh_image = self.compute_hh(gray_image)
105
+ if self.transform:
106
+ hh_image = self.transform(hh_image)
107
+ return hh_image
108
+
109
+ def compute_hh(self, image):
110
+ coeffs2 = pywt.dwt2(image, 'haar')
111
+ LL, (LH, HL, HH) = coeffs2
112
+ return HH
113
+
114
+ def get_rgb_image(self, idx):
115
+ rgb_image_path = self.dataset[idx][0]
116
+ rgb_image = Image.open(rgb_image_path)
117
+ if self.transform:
118
+ rgb_image = self.transform(rgb_image)
119
+ return rgb_image
120
+
121
+ def get_dct_image(self, idx):
122
+ rgb_image_path = self.dataset[idx][0]
123
+ rgb_image = cv2.imread(rgb_image_path)
124
+ dct_image = self.compute_dct_color(rgb_image)
125
+ if self.transform:
126
+ dct_image = self.transform(dct_image)
127
+
128
+ return dct_image
129
+
130
+ def get_label(self, idx):
131
+ return self.dataset[idx][1]
132
+
133
+
134
+ def compute_dct_color(self, image):
135
+ image_float = np.float32(image)
136
+ dct_image = np.zeros_like(image_float)
137
+ for i in range(3):
138
+ dct_image[:, :, i] = cv2.dct(image_float[:, :, i])
139
+ return dct_image
140
+
141
+
142
+ class Test:
143
+ def __init__(self, model_paths = [ 'weights/faceswap-hh-best_model.pth',
144
+ 'weights/faceswap-fft-best_model.pth',
145
+ ],
146
+ multi_modal = ["hh","fct"]):
147
+ self.model_path = model_paths
148
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
149
+ print(self.device)
150
+ # Load the model
151
+ self.model1 = Image_n_DCT()
152
+ self.model1.load_state_dict(torch.load(self.model_path[0], map_location = self.device))
153
+ self.model1.to(self.device)
154
+ self.model1.eval()
155
+
156
+ self.model2 = Image_n_DCT()
157
+ self.model2.load_state_dict(torch.load(self.model_path[1], map_location = self.device))
158
+ self.model2.to(self.device)
159
+ self.model2.eval()
160
+
161
+
162
+ self.multi_modal = multi_modal
163
+
164
+
165
+ def testimage(self, image_path):
166
+ test_dataset1 = Test_Dataset(transform = get_transforms_val(), image_path = image_path, multi_modal = self.multi_modal[0])
167
+ test_dataset2 = Test_Dataset(transform = get_transforms_val(), image_path = image_path, multi_modal = self.multi_modal[1])
168
+
169
+ inputs1 = test_dataset1[0]
170
+ rgb_image1, dct_image1 = inputs1['rgb_image'].to(self.device), inputs1['dct_image'].to(self.device)
171
+
172
+ inputs2 = test_dataset2[0]
173
+ rgb_image2, dct_image2 = inputs2['rgb_image'].to(self.device), inputs2['dct_image'].to(self.device)
174
+
175
+ output1 = self.model1(rgb_image1.unsqueeze(0), dct_image1.unsqueeze(0))
176
+
177
+ output2 = self.model2(rgb_image2.unsqueeze(0), dct_image2.unsqueeze(0))
178
+
179
+ output = (output1 + output2)/2
180
+ # print(output.shape)
181
+ _, predicted = torch.max(output.data, 1)
182
+ return 'real' if predicted==1 else 'fake'
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ import os
utils/basicblocks.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ BatchNorm2d = nn.BatchNorm2d
7
+
8
+ def conv3x3(in_planes, out_planes, stride = 1):
9
+ """3x3 convolution with padding"""
10
+ return nn.Conv2d(in_planes, out_planes, kernel_size = 3, stride = stride,
11
+ padding = 1, bias = False)
12
+
13
+ def conv1x1(in_planes, out_planes, stride = 1):
14
+ """3x3 convolution with padding"""
15
+ return nn.Conv2d(in_planes, out_planes, kernel_size = 1, stride = stride,
16
+ padding = 0, bias = False)
17
+
18
+ class BasicBlock(nn.Module):
19
+ def __init__(self, inplanes, outplanes, stride = 1):
20
+ super(BasicBlock, self).__init__()
21
+ self.conv1 = conv3x3(inplanes, outplanes, stride)
22
+ self.bn1 = BatchNorm2d(outplanes)
23
+ self.relu = nn.ReLU(inplace = True)
24
+ self.conv2 = conv3x3(outplanes, outplanes, 2*stride)
25
+
26
+ def forward(self, x):
27
+ out = self.conv1(x)
28
+ out = self.bn1(out)
29
+ out = self.relu(out)
30
+ out = self.conv2(out)
31
+
32
+ return out
utils/classifier.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class ClassifierModel(nn.Module):
6
+ def __init__(self, num_classes):
7
+ super(ClassifierModel, self).__init__()
8
+ # Apply adaptive average pooling to convert (512, 14, 14) to (512)
9
+ self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
10
+
11
+ # Define multiple fully connected layers
12
+ self.fc1 = nn.Linear(512, 256) # First FC layer, reducing to 256 features
13
+ self.fc2 = nn.Linear(256, 128) # Second FC layer, reducing to 128 features
14
+ self.fc3 = nn.Linear(128, num_classes) # Final FC layer, outputting num_classes for classification
15
+
16
+ #dropout for regularization
17
+ self.dropout = nn.Dropout(0.2)
18
+
19
+ def forward(self, x):
20
+ # Flatten the output from the adaptive pooling
21
+ x = self.adaptive_pool(x)
22
+ x = torch.flatten(x, 1)
23
+
24
+ # Pass through the fully connected layers with ReLU activations and dropout
25
+ x = F.relu(self.fc1(x))
26
+ x = self.dropout(x)
27
+ x = F.relu(self.fc2(x))
28
+ x = self.dropout(x)
29
+ x = self.fc3(x) # No activation, raw scores
30
+ x = F.softmax(x, dim=1)
31
+
32
+ return x
utils/config.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+ import numpy as np
3
+
4
+ __C = edict()
5
+ cfg = __C
6
+
7
+ # 0. basic config
8
+ __C.TAG = 'default'
9
+ __C.CLASSES = ['Real', 'Fake']
10
+
11
+
12
+ # config of network input
13
+ __C.MULTIMODAL_FUSION = edict()
14
+ __C.MULTIMODAL_FUSION.IMG_CHANNELS = [3, 64, 128, 256, 512]
15
+ __C.MULTIMODAL_FUSION.DCT_CHANNELS = [1, 64, 128, 256, 512]
16
+
17
+
18
+ __C.NUM_EPOCHS = 100
19
+
20
+ __C.BATCH_SIZE = 64
21
+
22
+ __C.NUM_WORKERS = 4
23
+
24
+ __C.LEARNING_RATE = 0.0001
25
+
26
+ __C.PRETRAINED = False
27
+
28
+ __C.PRETRAINED_PATH = "/home/user/Documents/Real_and_DeepFake/src/best_model.pth"
29
+
30
+
31
+
32
+
33
+ __C.TEST_BATCH_SIZE = 512
34
+
35
+ __C.TEST_CSV = "/home/user/Documents/Real_and_DeepFake/src/dataset/extended_val.csv"
36
+
37
+ __C.MODEL_PATH = "/home/user/Documents/Real_and_DeepFake/src/best_model.pth"
38
+
utils/data_transforms.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+
3
+
4
+
5
+ def get_transforms_train():
6
+ # Define the dataset object
7
+ transform = transform = transforms.Compose([
8
+ transforms.ToTensor(),
9
+ transforms.Lambda(lambda x: x.float()) ,
10
+ transforms.Resize((224, 224)),
11
+ transforms.RandomHorizontalFlip(),
12
+ transforms.RandomRotation(10),
13
+ transforms.Normalize(mean=[(0.485+0.456+0.406)/3], std=[(0.229+ 0.224+ 0.225)/3]),
14
+ ])
15
+
16
+ return transform
17
+
18
+
19
+
20
+
21
+ def get_transforms_val():
22
+ transform = transform = transforms.Compose([
23
+ transforms.ToTensor(),
24
+ transforms.Lambda(lambda x: x.float()) ,
25
+ transforms.Resize((224, 224)),
26
+ # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
27
+ transforms.Normalize(mean=[(0.485+0.456+0.406)/3], std=[(0.229+ 0.224+ 0.225)/3]),
28
+
29
+
30
+ ])
31
+
32
+
33
+ return transform
utils/feature_fusion_block.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ class SpatialAttention(nn.Module):
6
+ def __init__(self, in_channels):
7
+ super(SpatialAttention, self).__init__()
8
+ self.conv1 = nn.Conv2d(in_channels, 1, kernel_size=1, stride=1, padding=0)
9
+
10
+ def forward(self, x):
11
+ # Calculate attention scores
12
+ attention_scores = self.conv1(x)
13
+ attention_scores = F.softmax(attention_scores, dim=2)
14
+
15
+ # Apply attention to input features
16
+ attended_features = x * attention_scores
17
+
18
+ return attended_features
19
+
20
+ class DCT_Attention_Fusion_Conv(nn.Module):
21
+ def __init__(self, channels):
22
+ super(DCT_Attention_Fusion_Conv, self).__init__()
23
+ self.rgb_attention = SpatialAttention(channels)
24
+ self.depth_attention = SpatialAttention(channels)
25
+ self.rgb_pooling = nn.AdaptiveAvgPool2d(1)
26
+ self.depth_pooling = nn.AdaptiveAvgPool2d(1)
27
+
28
+ def forward(self, rgb_features, DCT_features):
29
+ # Spatial attention for both modalities
30
+ rgb_attended_features = self.rgb_attention(rgb_features)
31
+ depth_attended_features = self.depth_attention(DCT_features)
32
+
33
+ # Adaptive pooling for both modalities
34
+ rgb_pooled = self.rgb_pooling(rgb_attended_features)
35
+ depth_pooled = self.depth_pooling(depth_attended_features)
36
+
37
+ # Upsample attended and pooled features to the original size
38
+ rgb_upsampled = F.interpolate(rgb_pooled, size=rgb_features.size()[2:], mode='bilinear', align_corners=False)
39
+ depth_upsampled = F.interpolate(depth_pooled, size=DCT_features.size()[2:], mode='bilinear', align_corners=False)
40
+
41
+ # Concatenate the upsampled features
42
+ fused_features = F.relu(rgb_upsampled+depth_upsampled)
43
+ # fused_features = fused_features.sum(dim=1)
44
+
45
+ return fused_features
46
+
weights/faceswap-fft-best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c42f82049bed6db4edb5e933ffe4ce6e3612e7fbf351c29327d9cfe81f8c5ff
3
+ size 38189260
weights/faceswap-hh-best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15272d1439ef629566cf43b3d4d1bc4f2091f3db1c0d0430038b56880c7ef385
3
+ size 38189178