mbar0075's picture
Testing Commit
c9baa67
raw
history blame
No virus
4.83 kB
from collections import OrderedDict
import importlib
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import model_zoo
from .modules import FeatureExtractor, Finalizer, DeepGazeIIIMixture, MixtureModel
from .layers import (
Conv2dMultiInput,
LayerNorm,
LayerNormMultiInput,
Bias,
)
BACKBONES = [
{
'type': 'deepgaze_pytorch.features.shapenet.RGBShapeNetC',
'used_features': [
'1.module.layer3.0.conv2',
'1.module.layer3.3.conv2',
'1.module.layer3.5.conv1',
'1.module.layer3.5.conv2',
'1.module.layer4.1.conv2',
'1.module.layer4.2.conv2',
],
'channels': 2048,
},
{
'type': 'deepgaze_pytorch.features.efficientnet.RGBEfficientNetB5',
'used_features': [
'1._blocks.24._depthwise_conv',
'1._blocks.26._depthwise_conv',
'1._blocks.35._project_conv',
],
'channels': 2416,
},
{
'type': 'deepgaze_pytorch.features.densenet.RGBDenseNet201',
'used_features': [
'1.features.denseblock4.denselayer32.norm1',
'1.features.denseblock4.denselayer32.conv1',
'1.features.denseblock4.denselayer31.conv2',
],
'channels': 2048,
},
{
'type': 'deepgaze_pytorch.features.resnext.RGBResNext50',
'used_features': [
'1.layer3.5.conv1',
'1.layer3.5.conv2',
'1.layer3.4.conv2',
'1.layer4.2.conv2',
],
'channels': 2560,
},
]
def build_saliency_network(input_channels):
return nn.Sequential(OrderedDict([
('layernorm0', LayerNorm(input_channels)),
('conv0', nn.Conv2d(input_channels, 8, (1, 1), bias=False)),
('bias0', Bias(8)),
('softplus0', nn.Softplus()),
('layernorm1', LayerNorm(8)),
('conv1', nn.Conv2d(8, 16, (1, 1), bias=False)),
('bias1', Bias(16)),
('softplus1', nn.Softplus()),
('layernorm2', LayerNorm(16)),
('conv2', nn.Conv2d(16, 1, (1, 1), bias=False)),
('bias2', Bias(1)),
('softplus3', nn.Softplus()),
]))
def build_fixation_selection_network():
return nn.Sequential(OrderedDict([
('layernorm0', LayerNormMultiInput([1, 0])),
('conv0', Conv2dMultiInput([1, 0], 128, (1, 1), bias=False)),
('bias0', Bias(128)),
('softplus0', nn.Softplus()),
('layernorm1', LayerNorm(128)),
('conv1', nn.Conv2d(128, 16, (1, 1), bias=False)),
('bias1', Bias(16)),
('softplus1', nn.Softplus()),
('conv2', nn.Conv2d(16, 1, (1, 1), bias=False)),
]))
def build_deepgaze_mixture(backbone_config, components=10):
feature_class = import_class(backbone_config['type'])
features = feature_class()
feature_extractor = FeatureExtractor(features, backbone_config['used_features'])
saliency_networks = []
scanpath_networks = []
fixation_selection_networks = []
finalizers = []
for component in range(components):
saliency_network = build_saliency_network(backbone_config['channels'])
fixation_selection_network = build_fixation_selection_network()
saliency_networks.append(saliency_network)
scanpath_networks.append(None)
fixation_selection_networks.append(fixation_selection_network)
finalizers.append(Finalizer(sigma=8.0, learn_sigma=True, saliency_map_factor=2))
return DeepGazeIIIMixture(
features=feature_extractor,
saliency_networks=saliency_networks,
scanpath_networks=scanpath_networks,
fixation_selection_networks=fixation_selection_networks,
finalizers=finalizers,
downsample=2,
readout_factor=16,
saliency_map_factor=2,
included_fixations=[],
)
class DeepGazeIIE(MixtureModel):
"""DeepGazeIIE model
:note
See Linardos, A., Kümmerer, M., Press, O., & Bethge, M. (2021). Calibrated prediction in and out-of-domain for state-of-the-art saliency modeling. ArXiv:2105.12441 [Cs], http://arxiv.org/abs/2105.12441
"""
def __init__(self, pretrained=True):
# we average over 3 instances per backbone, each instance has 10 crossvalidation folds
backbone_models = [build_deepgaze_mixture(backbone_config, components=3 * 10) for backbone_config in BACKBONES]
super().__init__(backbone_models)
if pretrained:
self.load_state_dict(model_zoo.load_url('https://github.com/matthias-k/DeepGaze/releases/download/v1.0.0/deepgaze2e.pth', map_location=torch.device('cpu')))
def import_class(name):
module_name, class_name = name.rsplit('.', 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)