Source code for skada.deep._adversarial

# Author: Theo Gnassounou <theo.gnassounou@inria.fr>
#         Ambroise Odonnat <ambroiseodonnattechnologie@gmail.com>
#         Antoine Collas <contact@antoinecollas.fr>
#
# License: BSD 3-Clause
import math

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from skada.deep.base import (
    BaseDALoss,
    DomainAwareCriterion,
    DomainAwareModule,
    DomainAwareNet,
    DomainBalancedDataLoader,
)

from .modules import DomainClassifier
from .utils import check_generator


class DANNLoss(BaseDALoss):
    """Loss DANN.

    This loss tries to minimize the divergence between features with
    adversarial method. The weights are updated to make harder
    to classify domains (i.e., remove domain-specific features).

    See [15]_ for details.

    Parameters
    ----------
    domain_criterion : torch criterion (class), default=None
        The initialized criterion (loss) used to compute the
        DANN loss. If None, a BCELoss is used.

    References
    ----------
    .. [15] Yaroslav Ganin et. al. Domain-Adversarial Training
            of Neural Networks. In Journal of Machine Learning
            Research, 2016.
    """

    def __init__(self, domain_criterion=None):
        super().__init__()
        if domain_criterion is None:
            self.domain_criterion_ = torch.nn.BCELoss()
        else:
            self.domain_criterion_ = domain_criterion

    def forward(
        self,
        domain_pred_s,
        domain_pred_t,
        **kwargs,
    ):
        """Compute the domain adaptation loss"""
        domain_label = torch.zeros(
            (domain_pred_s.size()[0]),
            device=domain_pred_s.device,
        )
        domain_label_target = torch.ones(
            (domain_pred_t.size()[0]),
            device=domain_pred_t.device,
        )

        # update classification function
        loss = self.domain_criterion_(
            domain_pred_s, domain_label
        ) + self.domain_criterion_(domain_pred_t, domain_label_target)

        return loss


[docs] def DANN( module, layer_name, reg=1, domain_classifier=None, num_features=None, base_criterion=None, domain_criterion=None, **kwargs, ): """Domain-Adversarial Training of Neural Networks (DANN). From [15]_. Parameters ---------- module : torch module (class or instance) A PyTorch :class:`~torch.nn.Module`. In general, the uninstantiated class should be passed, although instantiated modules will also work. layer_name : str The name of the module's layer whose outputs are collected during the training. reg : float, default=1 Regularization parameter for DA loss. domain_classifier : torch module, default=None A PyTorch :class:`~torch.nn.Module` used to classify the domain. If None, a domain classifier is created following [1]_. num_features : int, default=None Size of the input of domain classifier, e.g size of the last layer of the feature extractor. If domain_classifier is None, num_features has to be provided. base_criterion : torch criterion (class) The base criterion used to compute the loss with source labels. If None, the default is `torch.nn.CrossEntropyLoss`. domain_criterion : torch criterion (class) The criterion (loss) used to compute the DANN loss. If None, a BCELoss is used. References ---------- .. [15] Yaroslav Ganin et. al. Domain-Adversarial Training of Neural Networks. In Journal of Machine Learning Research, 2016. """ if domain_classifier is None: # raise error if num_feature is None if num_features is None: raise ValueError( "If domain_classifier is None, num_features has to be provided" ) domain_classifier = DomainClassifier(num_features=num_features) if base_criterion is None: base_criterion = torch.nn.CrossEntropyLoss() net = DomainAwareNet( module=DomainAwareModule, module__base_module=module, module__layer_name=layer_name, module__domain_classifier=domain_classifier, iterator_train=DomainBalancedDataLoader, criterion=DomainAwareCriterion, criterion__base_criterion=base_criterion, criterion__reg=reg, criterion__adapt_criterion=DANNLoss(domain_criterion=domain_criterion), **kwargs, ) return net
[docs] class CDANLoss(BaseDALoss): """Conditional Domain Adversarial Networks (CDAN) loss. This loss tries to minimize the divergence between features with adversarial method. The weights are updated to make harder to classify domains (i.e., remove domain-specific features) via multilinear conditioning that captures the crosscovariance between feature representations and classifier predictions From [16]_. Parameters ---------- domain_criterion : torch criterion (class), default=None The initialized criterion (loss) used to compute the CDAN loss. If None, a BCELoss is used. References ---------- .. [16] Mingsheng Long et. al. Conditional Adversarial Domain Adaptation In NeurIPS, 2016. """ def __init__(self, domain_criterion=None): super().__init__() if domain_criterion is None: self.domain_criterion_ = torch.nn.BCELoss() else: self.domain_criterion_ = domain_criterion
[docs] def forward( self, domain_pred_s, domain_pred_t, features_s, features_t, **kwargs, ): """Compute the domain adaptation loss""" dtype = torch.float32 # create domain label domain_label = torch.zeros( (features_s.size()[0]), device=features_s.device, dtype=dtype ) domain_label_target = torch.ones( (features_t.size()[0]), device=features_s.device, dtype=dtype ) # update classification function loss = self.domain_criterion_( domain_pred_s, domain_label ) + self.domain_criterion_(domain_pred_t, domain_label_target) return loss
class CDANModule(DomainAwareModule): """Conditional Domain Adversarial Networks (CDAN) module. From [16]_. Parameters ---------- module : torch module (class or instance) A PyTorch :class:`~torch.nn.Module`. layer_name : str The name of the module's layer whose outputs are collected during the training for adaptation. domain_classifier : torch module A PyTorch :class:`~torch.nn.Module` used to classify the domain. max_features : int, default=4096 Maximum size of the input for the domain classifier. 4096 is the largest number of units in typical deep network according to [1]_. References ---------- .. [16] Mingsheng Long et. al. Conditional Adversarial Domain Adaptation In NeurIPS, 2016. """ def __init__( self, base_module, layer_name, domain_classifier, max_features=4096, random_state=42, ): super().__init__(base_module, layer_name, domain_classifier) self.max_features = max_features self.random_state = random_state def forward( self, X, sample_domain=None, sample_idx=None, is_fit=False, return_features=False, ): if is_fit: # predict y_pred = self.base_module_(X) features = self.intermediate_layers[self.layer_name] n_classes = y_pred.shape[1] n_features = features.shape[1] if n_features * n_classes > self.max_features: random_layer = _RandomLayer( self.random_state, input_dims=[n_features, n_classes], output_dim=self.max_features, ) else: random_layer = None # Compute the input for the domain classifier if random_layer is None: multilinear_map = torch.bmm(y_pred.unsqueeze(2), features.unsqueeze(1)) multilinear_map = multilinear_map.view(-1, n_features * n_classes) else: multilinear_map = random_layer.forward([features, y_pred]) domain_pred = self.domain_classifier_(multilinear_map) return ( y_pred, domain_pred, features, sample_domain, sample_idx, ) else: if return_features: return self.base_module_(X), self.intermediate_layers[self.layer_name] else: return self.base_module_(X)
[docs] def CDAN( module, layer_name, reg=1, max_features=4096, domain_classifier=None, num_features=None, n_classes=None, base_criterion=None, domain_criterion=None, **kwargs, ): """Conditional Domain Adversarial Networks (CDAN). From [16]_. Parameters ---------- module : torch module (class or instance) A PyTorch :class:`~torch.nn.Module`. In general, the uninstantiated class should be passed, although instantiated modules will also work. layer_name : str The name of the module's layer whose outputs are collected during the training. reg : float, default=1 Regularization parameter for DA loss. max_features : int, default=4096 Maximum size of the input for the domain classifier. 4096 is the largest number of units in typical deep network according to [1]_. domain_classifier : torch module, default=None A PyTorch :class:`~torch.nn.Module` used to classify the domain. If None, a domain classifier is created following [1]_. num_features : int, default=None Size of the embedding space e.g. the size of the output of layer_name. If domain_classifier is None, num_features has to be provided. n_classes : int, default None Number of output classes. If domain_classifier is None, n_classes has to be provided. base_criterion : torch criterion (class) The base criterion used to compute the loss with source labels. If None, the default is `torch.nn.CrossEntropyLoss`. domain_criterion : torch criterion (class) The criterion (loss) used to compute the CDAN loss. If None, a BCELoss is used. References ---------- .. [16] Mingsheng Long et. al. Conditional Adversarial Domain Adaptation In NeurIPS, 2016. """ if domain_classifier is None: if num_features is None: raise ValueError( "If domain_classifier is None, num_features has to be provided" ) if n_classes is None: raise ValueError( "If domain_classifier is None, n_classes has to be provided" ) num_features = np.min([num_features * n_classes, max_features]) domain_classifier = DomainClassifier(num_features=num_features) if base_criterion is None: base_criterion = torch.nn.CrossEntropyLoss() net = DomainAwareNet( module=CDANModule, module__base_module=module, module__layer_name=layer_name, module__domain_classifier=domain_classifier, module__max_features=max_features, iterator_train=DomainBalancedDataLoader, criterion=DomainAwareCriterion, criterion__base_criterion=base_criterion, criterion__reg=reg, criterion__adapt_criterion=CDANLoss(domain_criterion=domain_criterion), **kwargs, ) return net
class ModifiedCrossEntropyLoss(torch.nn.Module): """Modified CrossEntropyLoss as described in (29) from [35]_ with label smoothing. Parameters ---------- smoothing : float, default=0.1 Smoothing factor for label smoothing. """ def __init__(self, smoothing=0.1): super().__init__() self.smoothing = smoothing def forward(self, input, target): """ Compute the modified CrossEntropyLoss with label smoothing applied to predictions. Parameters ---------- input : torch.Tensor Predictions from the model. target : torch.Tensor Target labels. Returns ------- loss : torch.Tensor The modified CrossEntropyLoss with label smoothing applied to predictions. """ # Compute probabilities prob = F.softmax(input, dim=-1) # Apply label smoothing to predictions num_classes = input.shape[-1] smooth_prob = (1 - self.smoothing) * prob + self.smoothing / num_classes # Gather probabilities of the target classes target_prob = smooth_prob[..., target] # Compute log(1 - probability) for the target class log_one_minus_prob = torch.log(1 - target_prob) # Return mean loss return torch.mean(log_one_minus_prob)
[docs] class MDDLoss(BaseDALoss): """Loss MDD. This loss tries to minimize the disparity discrepancy between the source and target domains. The discrepancy is estimated through adversarial training of three networks: an encoder, a task network and a discriminator. See [35]_ for details. Parameters ---------- gamma : float (default=4.0) Margin parameter following [35]_ References ---------- .. [35] Yuchen Zhang et. al. Bridging Theory and Algorithm for Domain Adaptation. In International Conference on Machine Learning, 2019. """ def __init__(self, gamma=4.0): super().__init__() self.gamma = gamma self.disc_criterion_s = torch.nn.CrossEntropyLoss() self.disc_criterion_t = ModifiedCrossEntropyLoss()
[docs] def forward( self, y_pred_s, y_pred_t, domain_pred_s, domain_pred_t, **kwargs, ): """Compute the domain adaptation loss""" # TODO: handle binary classification # Multiclass classification pseudo_label_s = torch.argmax(y_pred_s, axis=-1) pseudo_label_t = torch.argmax(y_pred_t, axis=-1) disc_loss_s = self.disc_criterion_s(domain_pred_s, pseudo_label_s) disc_loss_t = self.disc_criterion_t(domain_pred_t, pseudo_label_t) # Compute the MDD loss value disc_loss = self.gamma * disc_loss_s - disc_loss_t return disc_loss
[docs] def MDD( module, layer_name, reg=1, gamma=4.0, disc_classifier=None, num_features=None, n_classes=None, base_criterion=None, **kwargs, ): """Margin Disparity Discrepancy (MDD). From [35]_. Parameters ---------- module : torch module (class or instance) A PyTorch :class:`~torch.nn.Module`. In general, the uninstantiated class should be passed, although instantiated modules will also work. layer_name : str The name of the module's layer whose outputs are collected during the training. reg : float, default=1 Regularization parameter for DA loss. disc_classifier : torch module, default=None A PyTorch :class:`~torch.nn.Module` used as a discriminator. It should have the same architecture than the classifier used on the source. If None, a domain classifier is created following [1]_. num_features : int, default=None Size of the input of domain classifier, e.g size of the last layer of the feature extractor. If domain_classifier is None, num_features has to be provided. n_classes : int, default=None Number of classes. If domain_classifier is None, n_classes has to be provided. base_criterion : torch criterion (class) The base criterion used to compute the loss with source labels. If None, the default is `torch.nn.CrossEntropyLoss`. gamma : float (default=4.0) Margin parameter following [35]_. References ---------- .. [35] Yuchen Zhang et. al. Bridging Theory and Algorithm for Domain Adaptation. In International Conference on Machine Learning, 2019. """ if disc_classifier is None: # raise error if num_feature is None if num_features is None: raise ValueError( "If disc_classifier is None, num_features has to be provided" ) disc_classifier = DomainClassifier( num_features=num_features, n_classes=n_classes ) if base_criterion is None: base_criterion = torch.nn.CrossEntropyLoss() net = DomainAwareNet( module=DomainAwareModule, module__base_module=module, module__layer_name=layer_name, module__domain_classifier=disc_classifier, iterator_train=DomainBalancedDataLoader, criterion=DomainAwareCriterion, criterion__base_criterion=base_criterion, criterion__reg=reg, criterion__adapt_criterion=MDDLoss( gamma=gamma, ), **kwargs, ) return net
class _RandomLayer(nn.Module): """Randomized multilinear map layer. Parameters ---------- random_state : int, Generator instance or None Determines random number generation for random layer creation input_dims : list of int List of input dimensions. output_dims : int Output dimension wanted. """ def __init__(self, random_state, input_dims, output_dim=4096): super().__init__() gen = check_generator(random_state) self.output_dim = output_dim self.random_matrix = [ torch.randn(size=(input_dims[i], output_dim), generator=gen) for i in range(len(input_dims)) ] def forward(self, input_list): device = input_list[0].device return_list = [ torch.mm(input_list[i], self.random_matrix[i].to(device)) for i in range(len(input_list)) ] return_tensor = return_list[0] / math.pow( float(self.output_dim), 1.0 / len(input_list) ) for single in return_list[1:]: return_tensor = torch.mul(return_tensor, single) return return_tensor