Source code for skada.deep._optimal_transport

# Author: Theo Gnassounou <theo.gnassounou@inria.fr>
#
# License: BSD 3-Clause
from torch import nn

from skada.deep.base import (
    BaseDALoss,
    DomainAwareCriterion,
    DomainAwareModule,
    DomainAwareNet,
    DomainBalancedDataLoader,
)

from .losses import deepjdot_loss


[docs] class DeepJDOTLoss(BaseDALoss): """Loss DeepJDOT. This loss reduces the distance between source and target domain through a measure of discrepancy on joint deep representations/labels based on optimal transport. See [13]_. Parameters ---------- reg_cl : float, default=1 Class distance term regularization parameter. target_criterion : torch criterion (class) The uninitialized criterion (loss) used to compute the DeepJDOT loss. The criterion should support reduction='none'. References ---------- .. [13] Bharath Bhushan Damodaran, Benjamin Kellenberger, Remi Flamary, Devis Tuia, and Nicolas Courty. DeepJDOT: Deep Joint Distribution Optimal Transport for Unsupervised Domain Adaptation. In ECCV 2018 15th European Conference on Computer Vision, September 2018. Springer. """ def __init__(self, reg_cl=1, target_criterion=None): super().__init__() self.reg_cl = reg_cl self.criterion_ = target_criterion
[docs] def forward( self, y_s, y_pred_s, y_pred_t, domain_pred_s, domain_pred_t, features_s, features_t, ): """Compute the domain adaptation loss""" loss = deepjdot_loss( y_s, y_pred_t, features_s, features_t, self.reg_cl, criterion=self.criterion_, ) return loss
[docs] def DeepJDOT(module, layer_name, reg=1, reg_cl=1, target_criterion=None, **kwargs): """DeepJDOT. See [13]_. Parameters ---------- module : torch module (class or instance) A PyTorch :class:`~torch.nn.Module`. layer_name : str The name of the module's layer whose outputs are collected during the training for the adaptation. reg : float, default=1 Regularization parameter. reg_cl : float, default=1 Class distance term regularization parameter. target_criterion : torch criterion (class) The uninitialized criterion (loss) used to compute the DeepJDOT loss. The criterion should support reduction='none'. References ---------- .. [13] Bharath Bhushan Damodaran, Benjamin Kellenberger, Remi Flamary, Devis Tuia, and Nicolas Courty. DeepJDOT: Deep Joint Distribution Optimal Transport for Unsupervised Domain Adaptation. In ECCV 2018 15th European Conference on Computer Vision, September 2018. Springer. """ net = DomainAwareNet( module=DomainAwareModule, module__base_module=module, module__layer_name=layer_name, iterator_train=DomainBalancedDataLoader, criterion=DomainAwareCriterion, criterion__criterion=nn.CrossEntropyLoss(), criterion__adapt_criterion=DeepJDOTLoss(reg_cl, target_criterion), criterion__reg=reg, **kwargs, ) return net