Source code for mexca.video.anfl

"""Action unit (AU) relationship-aware node feature learning (ANFL).

Implementation of the ANFL module from the paper:

    Luo, C., Song, S., Xie, W., Shen, L., Gunes, H. (2022). Learning multi-dimentionsal edge
    feature-based AU relation graph for facial action unit recognition. *arXiv*.
    `<https://arxiv.org/pdf/2205.01782.pdf>`_

Code adapted from the `OpenGraphAU <https://github.com/lingjivoo/OpenGraphAU/tree/main>`_ code base
(licensed under Apache 2.0).

"""

# pylint: disable=invalid-name

import math

import torch
from torch import nn

from mexca.video.helper_classes import AUPredictor, LinearBlock


[docs] class GNN(nn.Module): """Apply a graph neural network (GNN) layer. Transform action unit (AU) features using digraph connectivity. Inputs and outputs correspond to AU features. Parameters ---------- in_features: int Size of each input sample. n_nodes: int Number of nodes in the digraph. n_neighbors: int, default=4 Number of top K similar neighbors for computing graph connectivity. Notes ----- See eq. 1 in the corresponding `paper <https://arxiv.org/abs/2205.01782>`_. Functions :math:`{g, r}` are linear and the nonlinear activation function :math:`\\sigma` is ReLU. Linear layer weights are initialized with :math:`N(0, \\sqrt{\\frac{2}{out\\_features}})`. Batch norm weights are initialized as 1 and biases as 0. """ def __init__(self, in_features: int, n_nodes: int, n_neighbors: int = 4): super().__init__() self.in_features = in_features self.n_nodes = n_nodes self.n_neighbors = n_neighbors # Layers self.linear_u = nn.Linear(self.in_features, self.in_features) self.linear_v = nn.Linear(self.in_features, self.in_features) self.bnv = nn.BatchNorm1d(n_nodes) self.relu = nn.ReLU() # Param init self.linear_u.weight.data.normal_(0, math.sqrt(2.0 / self.in_features)) self.linear_v.weight.data.normal_(0, math.sqrt(2.0 / self.in_features)) self.bnv.weight.data.fill_(1) self.bnv.bias.data.zero_() @staticmethod def _calc_adj_mat(x: torch.Tensor, k: int) -> torch.Tensor: # Calculate adjacency matrix between nodes as thresholded dot product similarity b, n, _ = x.shape sim = x.detach() # Calc dot product sim = torch.einsum("b i j, b j k -> b i k", sim, sim.transpose(1, 2)) # Get top k similar nodes threshold = sim.topk(k=k, dim=-1, largest=True)[0][:, :, -1].view( b, n, 1 ) adj_mat = (sim >= threshold).float() return adj_mat @staticmethod def _normalize_digraph(adj_mat: torch.Tensor) -> torch.Tensor: # Normalize adjacency matrix to 0 and 1 by sqrt(degree) b, n, _ = adj_mat.shape node_degrees = adj_mat.detach().sum(dim=-1) degs_inv_sqrt = node_degrees**-0.5 norm_degs_matrix = torch.eye(n) dev = adj_mat.get_device() if dev >= 0: norm_degs_matrix = norm_degs_matrix.to(dev) norm_degs_matrix = norm_degs_matrix.view(1, n, n) * degs_inv_sqrt.view( b, n, 1 ) norm_adj_mat = torch.bmm( torch.bmm(norm_degs_matrix, adj_mat), norm_degs_matrix ) return norm_adj_mat def forward(self, x: torch.Tensor) -> torch.Tensor: # Calc adjacency matrix (0, 1) adj_mat = self._calc_adj_mat(x, self.n_neighbors) # Calc connectivity matrix con_mat = self._normalize_digraph(adj_mat) # eq. 1 aggregate = torch.einsum( "b i j, b j k -> b i k", con_mat, self.linear_v(x) ) x = self.relu(x + self.bnv(aggregate + self.linear_u(x))) return x
[docs] class AUFeatureGenerator(nn.Module): """Generate action unit (AU) features. Inputs correspond to face representations (embeddings) and outputs to AU features. Parameters ---------- in_features: int Size of each input sample. out_features: int, default=27 Size of each output sample. Notes ----- AU specific features are generated by individual linear and global average pooling transformations. """ def __init__(self, in_features: int, out_features: int = 27): super().__init__() self.in_features = in_features self.out_features = out_features # FC layers self.main_node_linear_layers = nn.ModuleList( [ LinearBlock(self.in_features, self.in_features) for _ in range(self.out_features) ] ) def forward(self, x: torch.Tensor) -> torch.Tensor: # Linear FC layers x = [layer(x).unsqueeze(1) for layer in self.main_node_linear_layers] x = torch.cat(x, dim=1) # Global average pooling x = x.mean(dim=-2) return x
[docs] class FacialGraphGenerator(AUPredictor): """Generate action unit (AU) activations from AU features using a facial graph. Inputs correspond to AU features and outputs to AU activations. Main plus sub nodes represent facial AUs. Sub nodes represent left and right activations of AUs 1, 2, 4, 6, 10, 12, and 14. Parameters ---------- in_features: int Size of each input sample. n_main_nodes: int, default=27 Number of main nodes in the facial graph. n_sub_nodes: int, default=14 Number of sub nodes in the facial graph. n_neighbors: int, default=4 Number of top K similar neighbors for computing graph connectivity. Notes ----- First applies a graph neural network (:func:`GNN`) transformation to AU features. Transformed features are fed into similarity calculating (SC) layers for main and sub nodes as in eq. 2 of the corresponding `paper <https://arxiv.org/abs/2205.01782>`_. Sub node activations are calulated based on matching main node features. SC layer weights are initialized using Glorot initialization (see :func:`torch.nn.init.xavier_uniform`). """ def __init__( self, in_features: int, n_main_nodes: int = 27, n_sub_nodes: int = 14, n_neighbors: int = 4, ): super().__init__(in_features, n_main_nodes, n_sub_nodes) # Layers self.gnn = GNN( self.in_features, self.n_main_nodes, n_neighbors=n_neighbors ) def forward(self, x: torch.Tensor) -> torch.Tensor: f_v = self.gnn(x) # Predict action unit activations return super().forward(f_v)
[docs] class ANFL(nn.Module): """Apply AU relationship-aware node feature learning (ANFL). Transform face representations into facial action unit (AU) activations. Inputs correspond to facial representations (embeddings) and outputs to AU activations. Parameters ---------- in_features: int Size of each input sample. n_main_aus: int, default=27 Number of main AUs. n_sub_aus: int, default=14 Number of sub AUs. n_neighbors: int, default=4 Number of top K similar neighbors for computing graph connectivity. Notes ----- First generates AU features from face representations (see :func:`AUFeatureGenerator`) and then transforms them into activations using a facial graph (see :func:`FacialGraphGenerator`). """ def __init__( self, in_features: int, n_main_aus: int = 27, n_sub_aus: int = 14, n_neighbors: int = 4, ): super().__init__() self.in_features = in_features self.n_main_aus = n_main_aus self.n_sub_aus = n_sub_aus self.n_neighbors = n_neighbors # Modules self.afg = AUFeatureGenerator(self.in_features, self.n_main_aus) self.fgg = FacialGraphGenerator( self.in_features, self.n_main_aus, self.n_sub_aus, self.n_neighbors ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.afg(x) x = self.fgg(x) return x