Source code for expert.core.congruence.audio_emotions.audio_model

from __future__ import annotations

import torch
from torch import nn

from expert.core.functional_tools import get_model_weights


[docs]class AudioModel(nn.Module): """Model for emotion classification by audio signal.""" def __init__( self, pretrained: bool = True, device: torch.device | None = None ) -> None: """ Args: pretrained (bool, optional): Whether or not to load saved pretrained weights. Defaults to True. device (torch.device | None, optional): Device type on local machine (GPU recommended). Defaults to None. """ super(AudioModel, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d( in_channels=1, out_channels=512, kernel_size=3, stride=1, padding=2, ), nn.ReLU(), nn.MaxPool2d(kernel_size=2), nn.BatchNorm2d(512), ) self.conv2 = nn.Sequential( nn.Conv2d( in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=2, ), nn.ReLU(), nn.MaxPool2d(kernel_size=2), nn.BatchNorm2d(256), ) self.conv3 = nn.Sequential( nn.Conv2d( in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=2, ), nn.ReLU(), nn.MaxPool2d(kernel_size=2), nn.BatchNorm2d(128), ) self.conv4 = nn.Sequential( nn.Conv2d( in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=2, ), nn.ReLU(), nn.MaxPool2d(kernel_size=2), nn.BatchNorm2d(64), ) self.flatten = nn.Flatten() self.dropout = nn.Dropout(p=0.3) self.linear = nn.Linear(5504, 7) self.softmax = nn.Softmax(dim=1) self._device = torch.device("cpu") if pretrained: url = "https://drive.google.com/uc?export=view&id=1DU5pu9D0BSvXBCj_J3kqzgdallxebcLX" model_name = "audio_model.pth" cached_file = get_model_weights(model_name=model_name, url=url) state_dict = torch.load(cached_file, map_location=self._device) self.load_state_dict(state_dict, strict=True) if device is not None: self._device = device self.to(self._device) @property def device(self) -> torch.device: """Check the device type. Returns: torch.device: Device type on local machine. """ return self._device
[docs] def forward(self, input_data): x = self.conv1(input_data) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) x = self.flatten(x) x = self.dropout(self.flatten(x)) logits = self.linear(x) predictions = self.softmax(logits) return predictions