Source code for skada.deep._divergence

# Author: Theo Gnassounou <theo.gnassounou@inria.fr>
#         Remi Flamary <remi.flamary@polytechnique.edu>
#         Yanis Lalou <yanis.lalou@polytechnique.edu>
#         Antoine Collas <contact@antoinecollas.fr>
#
# License: BSD 3-Clause
import torch

from skada.deep.base import (
    BaseDALoss,
    DomainAwareCriterion,
    DomainAwareModule,
    DomainAwareNet,
    DomainBalancedDataLoader,
)

from .callbacks import ComputeSourceCentroids
from .losses import cdd_loss, dan_loss, deepcoral_loss


[docs] class DeepCoralLoss(BaseDALoss): """Loss DeepCORAL This loss reduces the distance between covariances of the source features and the target features. See [12]_. Parameters ---------- assume_centered: bool, default=False If True, data are not centered before computation. References ---------- .. [12] Baochen Sun and Kate Saenko. Deep coral: Correlation alignment for deep domain adaptation. In ECCV Workshops, 2016. """ def __init__( self, assume_centered=False, ): super().__init__() self.assume_centered = assume_centered
[docs] def forward( self, features_s, features_t, **kwargs, ): """Compute the domain adaptation loss""" loss = deepcoral_loss(features_s, features_t, self.assume_centered) return loss
[docs] def DeepCoral( module, layer_name, reg=1, assume_centered=False, base_criterion=None, **kwargs ): """DeepCORAL domain adaptation method. From [12]_. Parameters ---------- module : torch module (class or instance) A PyTorch :class:`~torch.nn.Module`. layer_name : str The name of the module's layer whose outputs are collected during the training for the adaptation. reg : float, optional (default=1) Regularization parameter for DA loss. assume_centered: bool, default=False If True, data are not centered before computation. base_criterion : torch criterion (class) The base criterion used to compute the loss with source labels. If None, the default is `torch.nn.CrossEntropyLoss`. References ---------- .. [12] Baochen Sun and Kate Saenko. Deep coral: Correlation alignment for deep domain adaptation. In ECCV Workshops, 2016. """ if base_criterion is None: base_criterion = torch.nn.CrossEntropyLoss() net = DomainAwareNet( module=DomainAwareModule, module__base_module=module, module__layer_name=layer_name, iterator_train=DomainBalancedDataLoader, criterion=DomainAwareCriterion, criterion__base_criterion=base_criterion, criterion__reg=reg, criterion__adapt_criterion=DeepCoralLoss(assume_centered=assume_centered), **kwargs, ) return net
[docs] class DANLoss(BaseDALoss): """Loss DAN This loss reduces the MMD distance between source features and target features. From [14]_. Parameters ---------- sigmas : array-like, optional (default=None) The sigmas for the Gaussian kernel. eps : float, default=1e-7 Small constant added to median distance calculation for numerical stability. References ---------- .. [14] Mingsheng Long et. al. Learning Transferable Features with Deep Adaptation Networks. In ICML, 2015. """ def __init__(self, sigmas=None, eps=1e-7): super().__init__() self.sigmas = sigmas self.eps = eps
[docs] def forward( self, features_s, features_t, **kwargs, ): """Compute the domain adaptation loss""" loss = dan_loss(features_s, features_t, sigmas=self.sigmas, eps=self.eps) return loss
def DAN(module, layer_name, reg=1, sigmas=None, base_criterion=None, **kwargs): """DAN domain adaptation method. See [14]_. Parameters ---------- module : torch module (class or instance) A PyTorch :class:`~torch.nn.Module`. layer_name : str The name of the module's layer whose outputs are collected during the training for the adaptation. reg : float, optional (default=1) Regularization parameter for DA loss. sigmas : array-like, optional (default=None) The sigmas for the Gaussian kernel. base_criterion : torch criterion (class) The base criterion used to compute the loss with source labels. If None, the default is `torch.nn.CrossEntropyLoss`. References ---------- .. [14] Mingsheng Long et. al. Learning Transferable Features with Deep Adaptation Networks. In ICML, 2015. """ if base_criterion is None: base_criterion = torch.nn.CrossEntropyLoss() net = DomainAwareNet( module=DomainAwareModule, module__base_module=module, module__layer_name=layer_name, iterator_train=DomainBalancedDataLoader, criterion=DomainAwareCriterion, criterion__base_criterion=base_criterion, criterion__reg=reg, criterion__adapt_criterion=DANLoss(sigmas=sigmas), **kwargs, ) return net
[docs] class CANLoss(BaseDALoss): """Loss for Contrastive Adaptation Network (CAN) This loss implements the contrastive domain discrepancy (CDD) as described in [33]. Parameters ---------- distance_threshold : float, optional (default=0.5) Distance threshold for discarding the samples that are to far from the centroids. class_threshold : int, optional (default=3) Minimum number of samples in a class to be considered for the loss. sigmas : array like, default=None, If array, sigmas used for the multi gaussian kernel. If None, uses sigmas proposed in [1]_. target_kmeans : sklearn KMeans instance, default=None, Pre-computed target KMeans clustering model. eps : float, default=1e-7 Small constant added to median distance calculation for numerical stability. References ---------- .. [33] Kang, G., Jiang, L., Yang, Y., & Hauptmann, A. G. (2019). Contrastive adaptation network for unsupervised domain adaptation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 4893-4902). """ def __init__( self, distance_threshold=0.5, class_threshold=3, sigmas=None, target_kmeans=None, eps=1e-7, ): super().__init__() self.distance_threshold = distance_threshold self.class_threshold = class_threshold self.sigmas = sigmas self.target_kmeans = target_kmeans self.eps = eps
[docs] def forward( self, y_s, features_s, features_t, **kwargs, ): loss = cdd_loss( y_s, features_s, features_t, sigmas=self.sigmas, target_kmeans=self.target_kmeans, distance_threshold=self.distance_threshold, class_threshold=self.class_threshold, eps=self.eps, ) return loss
[docs] def CAN( module, layer_name, reg=1, distance_threshold=0.5, class_threshold=3, sigmas=None, base_criterion=None, callbacks=None, **kwargs, ): """Contrastive Adaptation Network (CAN) domain adaptation method. From [33]. Parameters ---------- module : torch module (class or instance) A PyTorch :class:`~torch.nn.Module`. layer_name : str The name of the module's layer whose outputs are collected during the training for the adaptation. reg : float, optional (default=1) Regularization parameter for DA loss. distance_threshold : float, optional (default=0.5) Distance threshold for discarding the samples that are to far from the centroids. class_threshold : int, optional (default=3) Minimum number of samples in a class to be considered for the loss. sigmas : array like, default=None, If array, sigmas used for the multi gaussian kernel. If None, uses sigmas proposed in [1]_. base_criterion : torch criterion (class) The base criterion used to compute the loss with source labels. If None, the default is `torch.nn.CrossEntropyLoss`. callbacks : list, optional List of callbacks to be used during training. References ---------- .. [33] Kang, G., Jiang, L., Yang, Y., & Hauptmann, A. G. (2019). Contrastive adaptation network for unsupervised domain adaptation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 4893-4902). """ if base_criterion is None: base_criterion = torch.nn.CrossEntropyLoss() if callbacks is None: callbacks = [ComputeSourceCentroids()] else: if isinstance(callbacks, list): callbacks.append(ComputeSourceCentroids()) else: callbacks = [callbacks, ComputeSourceCentroids()] net = DomainAwareNet( module=DomainAwareModule, module__base_module=module, module__layer_name=layer_name, iterator_train=DomainBalancedDataLoader, criterion=DomainAwareCriterion, criterion__base_criterion=base_criterion, criterion__reg=reg, criterion__adapt_criterion=CANLoss( distance_threshold=distance_threshold, class_threshold=class_threshold, sigmas=sigmas, ), callbacks=callbacks, **kwargs, ) return net