Source code for skada.deep._graph_alignment

# Author: Theo Gnassounou <theo.gnassounou@inria.fr>
#
# License: BSD 3-Clause

import numpy as np
import torch

from skada.deep.base import (
    BaseDALoss,
    DomainAwareCriterion,
    DomainAwareModule,
    DomainAwareNet,
    DomainBalancedDataLoader,
)
from skada.deep.callbacks import CountEpochs, MemoryBank
from skada.deep.losses import gda_loss, nap_loss

from .modules import DomainClassifier


[docs] class SPALoss(BaseDALoss): """Loss SPA. This loss tries to minimize the divergence between features with adversarial method. The weights are updated to make harder to classify domains (i.e., remove domain-specific features). See [36]_ for details. Parameters ---------- max_epochs : int Maximum number of epochs to train the model. target_criterion : torch criterion (class), default=None The initialized criterion (loss) used to compute the adversarial loss. If None, a BCELoss is used. reg_adv : float, default=1 Regularization parameter for adversarial loss. reg_gsa : float, default=1 Regularization parameter for graph alignment loss reg_nap : float, default=1 Regularization parameter for nap loss References ---------- .. [36] Xiao et. al. SPA: A Graph Spectral Alignment Perspective for Domain Adaptation. In Neurips, 2023. """ def __init__( self, max_epochs, domain_criterion=None, memory_features=None, memory_outputs=None, K=5, reg_adv=1, reg_gsa=1, reg_nap=1, ): super().__init__() if domain_criterion is None: self.domain_criterion_ = torch.nn.BCELoss() else: self.domain_criterion_ = domain_criterion self.reg_adv = reg_adv self.reg_gsa = reg_gsa self.reg_nap = reg_nap self.K = K self.memory_features = memory_features self.memory_outputs = memory_outputs self.max_epochs = max_epochs self.n_epochs = 0 def _scheduler_adv(self, high=1.0, low=0.0, alpha=10.0): max_epochs = self.max_epochs n_epochs = self.n_epochs return ( 2.0 * (high - low) / (1.0 + np.exp(-alpha * n_epochs / max_epochs)) - (high - low) + low ) def _scheduler_nap(self): return self.n_epochs / self.max_epochs
[docs] def forward( self, y_pred_t, domain_pred_s, domain_pred_t, features_s, features_t, sample_idx_t, **kwargs, ): """Compute the domain adaptation loss""" domain_label = torch.zeros( (domain_pred_s.size()[0]), device=domain_pred_s.device, ) domain_label_target = torch.ones( (domain_pred_t.size()[0]), device=domain_pred_t.device, ) # update classification function scale = self._scheduler_adv() loss_adv = ( self.reg_adv * scale * ( self.domain_criterion_(domain_pred_s, domain_label) + self.domain_criterion_(domain_pred_t, domain_label_target) ) ) loss_gda = self.reg_gsa * gda_loss(features_s, features_t, metric="gauss") scale = self._scheduler_nap() loss_pl = ( self.reg_nap * scale * nap_loss( features_t=features_t, y_pred_t=y_pred_t, memory_features=self.memory_features, memory_outputs=self.memory_outputs, K=self.K, sample_idx_t=sample_idx_t, ) ) loss = loss_adv + loss_gda + loss_pl return loss
[docs] def SPA( module, layer_name, reg_adv=1, reg_gsa=1, reg_nap=1, domain_classifier=None, num_features=None, base_criterion=None, domain_criterion=None, callbacks=None, max_epochs=100, **kwargs, ): """Domain Adaptation with SPA. From [36]_. Parameters ---------- module : torch module (class or instance) A PyTorch :class:`~torch.nn.Module`. In general, the uninstantiated class should be passed, although instantiated modules will also work. layer_name : str The name of the module's layer whose outputs are collected during the training. reg : float, default=1 Regularization parameter for DA loss. domain_classifier : torch module, default=None A PyTorch :class:`~torch.nn.Module` used to classify the domain. If None, a domain classifier is created following [1]_. num_features : int, default=None Size of the input of domain classifier, e.g size of the last layer of the feature extractor. If domain_classifier is None, num_features has to be provided. base_criterion : torch criterion (class) The base criterion used to compute the loss with source labels. If None, the default is `torch.nn.CrossEntropyLoss`. domain_criterion : torch criterion (class) The criterion (loss) used to compute the adversarial loss. If None, a BCELoss is used. callbacks : list, default=None List of callbacks to use during training. max_epochs : int, default=100 Maximum number of epochs to train the model. References ---------- .. [36] Xiao et. al. SPA: A Graph Spectral Alignment Perspective for Domain Adaptation. In Neurips, 2023. """ if domain_classifier is None: # raise error if num_feature is None if num_features is None: raise ValueError( "If domain_classifier is None, num_features has to be provided" ) domain_classifier = DomainClassifier(num_features=num_features) if callbacks is None: callbacks = [ MemoryBank(), CountEpochs(), ] else: if isinstance(callbacks, list): callbacks.append(MemoryBank()) callbacks.append(CountEpochs()) else: callbacks = [ callbacks, MemoryBank(), CountEpochs(), ] if base_criterion is None: base_criterion = torch.nn.CrossEntropyLoss() net = DomainAwareNet( module=DomainAwareModule, module__base_module=module, module__layer_name=layer_name, module__domain_classifier=domain_classifier, iterator_train=DomainBalancedDataLoader, criterion=DomainAwareCriterion, criterion__base_criterion=base_criterion, criterion__reg=1, criterion__adapt_criterion=SPALoss( domain_criterion=domain_criterion, reg_adv=reg_adv, reg_gsa=reg_gsa, reg_nap=reg_nap, max_epochs=max_epochs, ), callbacks=callbacks, max_epochs=max_epochs, **kwargs, ) return net