Source code for expert.core.congruence.video_emotions.video_model

from __future__ import annotations

from typing import Tuple

import torch
import torch.nn.functional as F
import torch.nn.init as init
from torch import Tensor, nn
from torchvision import models

from expert.core.functional_tools import get_model_weights


[docs]class DAN(nn.Module): """Distract Your Attention Network implementation. Distract Your Attention Network performs facial expression recognition on tensor face images. """ def __init__( self, num_class: int = 8, num_head: int = 4, pretrained: bool = True, device: torch.device | None = None, ) -> None: """ Args: num_class (int, optional): Number of model output classes. Defaults to 8. num_head (int, optional): Number of heads in multihead classification. Defaults to 4. pretrained (bool, optional): Whether or not to load saved pretrained weights. Defaults to True. device (torch.device | None, optional): Device type on local machine (GPU recommended). Defaults to None. """ super(DAN, self).__init__() resnet = models.resnet18(pretrained=False) self.features = nn.Sequential(*list(resnet.children())[:-2]) self.num_head = num_head for idx in range(num_head): setattr(self, "cat_head{}".format(idx), CrossAttentionHead()) self.sig = nn.Sigmoid() self.fc = nn.Linear(512, num_class) self.bn = nn.BatchNorm1d(num_class) self._device = torch.device("cpu") if pretrained: url = "https://drive.google.com/uc?export=view&id=1uHNADViICyJEjJljv747nfvrGu12kjtu" model_name = "affecnet8_epoch5_acc0.6209.pth" cached_file = get_model_weights(model_name=model_name, url=url) state_dict = torch.load(cached_file, map_location=self._device) self.load_state_dict(state_dict["model_state_dict"], strict=True) if device is not None: self._device = device self.to(self._device) @property def device(self) -> torch.device: """Check the device type. Returns: torch.device: Device type on local machine. """ return self._device
[docs] def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: x = self.features(x) heads = [] for idx in range(self.num_head): heads.append(getattr(self, "cat_head{}".format(idx))(x)) heads = torch.stack(heads).permute([1, 0, 2]) if heads.size(1) > 1: heads = F.log_softmax(heads, dim=1) out = self.fc(heads.sum(dim=1)) out = self.bn(out) return out, x, heads
[docs]class CrossAttentionHead(nn.Module): def __init__(self) -> None: super().__init__() self.sa = SpatialAttention() self.ca = ChannelAttention() self.init_weights() def init_weights(self) -> None: for m in self.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, mode="fan_out") if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): init.normal_(m.weight, std=0.001) if m.bias is not None: init.constant_(m.bias, 0)
[docs] def forward(self, x: Tensor) -> Tensor: sa = self.sa(x) ca = self.ca(sa) return ca
[docs]class SpatialAttention(nn.Module): def __init__(self) -> None: super().__init__() self.conv1x1 = nn.Sequential( nn.Conv2d(512, 256, kernel_size=1), nn.BatchNorm2d(256), ) self.conv_3x3 = nn.Sequential( nn.Conv2d(256, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), ) self.conv_1x3 = nn.Sequential( nn.Conv2d(256, 512, kernel_size=(1, 3), padding=(0, 1)), nn.BatchNorm2d(512), ) self.conv_3x1 = nn.Sequential( nn.Conv2d(256, 512, kernel_size=(3, 1), padding=(1, 0)), nn.BatchNorm2d(512), ) self.relu = nn.ReLU()
[docs] def forward(self, x: Tensor) -> Tensor: y = self.conv1x1(x) y = self.relu(self.conv_3x3(y) + self.conv_1x3(y) + self.conv_3x1(y)) y = y.sum(dim=1, keepdim=True) out = x * y return out
[docs]class ChannelAttention(nn.Module): def __init__(self) -> None: super().__init__() self.gap = nn.AdaptiveAvgPool2d(1) self.attention = nn.Sequential( nn.Linear(512, 32), nn.BatchNorm1d(32), nn.ReLU(inplace=True), nn.Linear(32, 512), nn.Sigmoid(), )
[docs] def forward(self, sa: Tensor) -> Tensor: sa = self.gap(sa) sa = sa.view(sa.size(0), -1) y = self.attention(sa) out = sa * y return out