mbar0075's picture
Testing Commit
c9baa67
raw
history blame
No virus
3.78 kB
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import model_zoo
from .features.densenet import RGBDenseNet201
from .modules import FeatureExtractor, Finalizer, DeepGazeIIIMixture
from .layers import FlexibleScanpathHistoryEncoding
from .layers import (
Conv2dMultiInput,
LayerNorm,
LayerNormMultiInput,
Bias,
)
def build_saliency_network(input_channels):
return nn.Sequential(OrderedDict([
('layernorm0', LayerNorm(input_channels)),
('conv0', nn.Conv2d(input_channels, 8, (1, 1), bias=False)),
('bias0', Bias(8)),
('softplus0', nn.Softplus()),
('layernorm1', LayerNorm(8)),
('conv1', nn.Conv2d(8, 16, (1, 1), bias=False)),
('bias1', Bias(16)),
('softplus1', nn.Softplus()),
('layernorm2', LayerNorm(16)),
('conv2', nn.Conv2d(16, 1, (1, 1), bias=False)),
('bias2', Bias(1)),
('softplus2', nn.Softplus()),
]))
def build_scanpath_network():
return nn.Sequential(OrderedDict([
('encoding0', FlexibleScanpathHistoryEncoding(in_fixations=4, channels_per_fixation=3, out_channels=128, kernel_size=[1, 1], bias=True)),
('softplus0', nn.Softplus()),
('layernorm1', LayerNorm(128)),
('conv1', nn.Conv2d(128, 16, (1, 1), bias=False)),
('bias1', Bias(16)),
('softplus1', nn.Softplus()),
]))
def build_fixation_selection_network():
return nn.Sequential(OrderedDict([
('layernorm0', LayerNormMultiInput([1, 16])),
('conv0', Conv2dMultiInput([1, 16], 128, (1, 1), bias=False)),
('bias0', Bias(128)),
('softplus0', nn.Softplus()),
('layernorm1', LayerNorm(128)),
('conv1', nn.Conv2d(128, 16, (1, 1), bias=False)),
('bias1', Bias(16)),
('softplus1', nn.Softplus()),
('conv2', nn.Conv2d(16, 1, (1, 1), bias=False)),
]))
class DeepGazeIII(DeepGazeIIIMixture):
"""DeepGazeIII model
:note
See Kümmerer, M., Bethge, M., & Wallis, T.S.A. (2022). DeepGaze III: Modeling free-viewing human scanpaths with deep learning. Journal of Vision 2022, https://doi.org/10.1167/jov.22.5.7
"""
def __init__(self, pretrained=True):
features = RGBDenseNet201()
feature_extractor = FeatureExtractor(features, [
'1.features.denseblock4.denselayer32.norm1',
'1.features.denseblock4.denselayer32.conv1',
'1.features.denseblock4.denselayer31.conv2',
])
saliency_networks = []
scanpath_networks = []
fixation_selection_networks = []
finalizers = []
for component in range(10):
saliency_network = build_saliency_network(2048)
scanpath_network = build_scanpath_network()
fixation_selection_network = build_fixation_selection_network()
saliency_networks.append(saliency_network)
scanpath_networks.append(scanpath_network)
fixation_selection_networks.append(fixation_selection_network)
finalizers.append(Finalizer(sigma=8.0, learn_sigma=True, saliency_map_factor=4))
super().__init__(
features=feature_extractor,
saliency_networks=saliency_networks,
scanpath_networks=scanpath_networks,
fixation_selection_networks=fixation_selection_networks,
finalizers=finalizers,
downsample=2,
readout_factor=4,
saliency_map_factor=4,
included_fixations=[-1, -2, -3, -4]
)
if pretrained:
self.load_state_dict(model_zoo.load_url('https://github.com/matthias-k/DeepGaze/releases/download/v1.1.0/deepgaze3.pth', map_location=torch.device('cpu')))