File size: 4,066 Bytes
5e014de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# We will use this file to create a dataloader for the real and fake dataset
import os
import json
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np
import pandas as pd
import cv2

import cv2
import numpy as np
import matplotlib.pyplot as plt
import pywt

class Extracted_Frames_Dataset(Dataset):
    def __init__(self, root_dir, split = "train", transform = None, extend = 'None', multi_modal = "dct"):
        """
        Args:   
        returns:
            """
        AssertionError(split in ["train", "val", "test"]), "Split must be one of (train, val, test)"
        self.multi_modal = multi_modal
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        if extend == 'faceswap':
            self.dataset = pd.read_csv(os.path.join(root_dir, f"faceswap_extended_{self.split}.csv"))
        elif extend == 'fsgan':
            self.dataset = pd.read_csv(os.path.join(root_dir, f"fsgan_extended_{self.split}.csv"))
        else:
            self.dataset = pd.read_csv(os.path.join(root_dir, f"{self.split}.csv"))
                

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample_input = self.get_sample_input(idx)
        return sample_input
    

    def get_sample_input(self, idx):
        rgb_image = self.get_rgb_image(idx)
        label = self.get_label(idx) 
        if self.multi_modal == "dct":
            dct_image = self.get_dct_image(idx)
            sample_input = {"rgb_image": rgb_image, "dct_image": dct_image, "label": label}

        # dct_image = self.get_dct_image(idx)
        elif self.multi_modal == "fft":
            fft_image = self.get_fft_image(idx)
            sample_input = {"rgb_image": rgb_image, "dct_image": fft_image, "label": label}
        elif self.multi_modal == "hh":
            hh_image = self.get_hh_image(idx)
            sample_input = {"rgb_image": rgb_image, "dct_image": hh_image, "label": label}
        else:
            AssertionError("multi_modal must be one of (dct:discrete cosine transform, fft: fast forier transform, hh)")

        return sample_input

    
    def get_fft_image(self, idx):
        gray_image_path = self.dataset.iloc[idx, 0]
        gray_image = cv2.imread(gray_image_path, cv2.IMREAD_GRAYSCALE)
        fft_image = self.compute_fft(gray_image)
        if self.transform:
            fft_image = self.transform(fft_image)
        
        return fft_image

    
    def compute_fft(self, image):
        f = np.fft.fft2(image)
        fshift = np.fft.fftshift(f)
        magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1)  # Add 1 to avoid log(0)
        return magnitude_spectrum


    def get_hh_image(self, idx):
        gray_image_path = self.dataset.iloc[idx, 0]
        gray_image = cv2.imread(gray_image_path, cv2.IMREAD_GRAYSCALE)
        hh_image = self.compute_hh(gray_image)
        if self.transform:
            hh_image = self.transform(hh_image)
        return hh_image
    
    def compute_hh(self, image):
        coeffs2 = pywt.dwt2(image, 'haar')
        LL, (LH, HL, HH) = coeffs2
        return HH
        
    def get_rgb_image(self, idx):
        rgb_image_path = self.dataset.iloc[idx, 0]
        rgb_image = Image.open(rgb_image_path)
        if self.transform:
            rgb_image = self.transform(rgb_image)
        return rgb_image
    
    def get_dct_image(self, idx):
        rgb_image_path = self.dataset.iloc[idx, 0]
        rgb_image = cv2.imread(rgb_image_path)
        dct_image = self.compute_dct_color(rgb_image)
        if self.transform:
            dct_image = self.transform(dct_image)
        
        return dct_image
    
    def get_label(self, idx):
        return self.dataset.iloc[idx, 1]
    

    def compute_dct_color(self, image):
        image_float = np.float32(image)
        dct_image = np.zeros_like(image_float)
        for i in range(3):  
            dct_image[:, :, i] = cv2.dct(image_float[:, :, i])
        return dct_image