# Author: Theo Gnassounou <theo.gnassounou@inria.fr>
#         Remi Flamary <remi.flamary@polytechnique.edu>
#         Yanis Lalou <yanis.lalou@polytechnique.edu>
#         Antoine Collas <contact@antoinecollas.fr>
#
# License: BSD 3-Clause
import torch
from skada.deep.base import (
    BaseDALoss,
    DomainAwareCriterion,
    DomainAwareModule,
    DomainAwareNet,
    DomainBalancedDataLoader,
)
from .callbacks import ComputeSourceCentroids
from .losses import cdd_loss, dan_loss, deepcoral_loss
[docs]
class DeepCoralLoss(BaseDALoss):
    """Loss DeepCORAL
    This loss reduces the distance between covariances
    of the source features and the target features.
    See [12]_.
    Parameters
    ----------
    assume_centered: bool, default=False
        If True, data are not centered before computation.
    References
    ----------
    .. [12]  Baochen Sun and Kate Saenko. Deep coral:
            Correlation alignment for deep domain
            adaptation. In ECCV Workshops, 2016.
    """
    def __init__(
        self,
        assume_centered=False,
    ):
        super().__init__()
        self.assume_centered = assume_centered
[docs]
    def forward(
        self,
        features_s,
        features_t,
        **kwargs,
    ):
        """Compute the domain adaptation loss"""
        loss = deepcoral_loss(features_s, features_t, self.assume_centered)
        return loss 
 
[docs]
def DeepCoral(
    module, layer_name, reg=1, assume_centered=False, base_criterion=None, **kwargs
):
    """DeepCORAL domain adaptation method.
    From [12]_.
    Parameters
    ----------
    module : torch module (class or instance)
        A PyTorch :class:`~torch.nn.Module`.
    layer_name : str
        The name of the module's layer whose outputs are
        collected during the training for the adaptation.
    reg : float, optional (default=1)
        Regularization parameter for DA loss.
    assume_centered: bool, default=False
        If True, data are not centered before computation.
    base_criterion : torch criterion (class)
        The base criterion used to compute the loss with source
        labels. If None, the default is `torch.nn.CrossEntropyLoss`.
    References
    ----------
    .. [12]  Baochen Sun and Kate Saenko. Deep coral:
            Correlation alignment for deep domain
            adaptation. In ECCV Workshops, 2016.
    """
    if base_criterion is None:
        base_criterion = torch.nn.CrossEntropyLoss()
    net = DomainAwareNet(
        module=DomainAwareModule,
        module__base_module=module,
        module__layer_name=layer_name,
        iterator_train=DomainBalancedDataLoader,
        criterion=DomainAwareCriterion,
        criterion__base_criterion=base_criterion,
        criterion__reg=reg,
        criterion__adapt_criterion=DeepCoralLoss(assume_centered=assume_centered),
        **kwargs,
    )
    return net 
[docs]
class DANLoss(BaseDALoss):
    """Loss DAN
    This loss reduces the MMD distance between
    source features and target features.
    From [14]_.
    Parameters
    ----------
    sigmas : array-like, optional (default=None)
        The sigmas for the Gaussian kernel.
    eps : float, default=1e-7
        Small constant added to median distance calculation for numerical stability.
    References
    ----------
    .. [14]  Mingsheng Long et. al. Learning Transferable
            Features with Deep Adaptation Networks.
            In ICML, 2015.
    """
    def __init__(self, sigmas=None, eps=1e-7):
        super().__init__()
        self.sigmas = sigmas
        self.eps = eps
[docs]
    def forward(
        self,
        features_s,
        features_t,
        **kwargs,
    ):
        """Compute the domain adaptation loss"""
        loss = dan_loss(features_s, features_t, sigmas=self.sigmas, eps=self.eps)
        return loss 
 
[docs]
def DAN(module, layer_name, reg=1, sigmas=None, base_criterion=None, **kwargs):
    """DAN domain adaptation method.
    See [14]_.
    Parameters
    ----------
    module : torch module (class or instance)
        A PyTorch :class:`~torch.nn.Module`.
    layer_name : str
        The name of the module's layer whose outputs are
        collected during the training for the adaptation.
    reg : float, optional (default=1)
        Regularization parameter for DA loss.
    sigmas : array-like, optional (default=None)
        The sigmas for the Gaussian kernel.
    base_criterion : torch criterion (class)
        The base criterion used to compute the loss with source
        labels. If None, the default is `torch.nn.CrossEntropyLoss`.
    References
    ----------
    .. [14]  Mingsheng Long et. al. Learning Transferable
            Features with Deep Adaptation Networks.
            In ICML, 2015.
    """
    if base_criterion is None:
        base_criterion = torch.nn.CrossEntropyLoss()
    net = DomainAwareNet(
        module=DomainAwareModule,
        module__base_module=module,
        module__layer_name=layer_name,
        iterator_train=DomainBalancedDataLoader,
        criterion=DomainAwareCriterion,
        criterion__base_criterion=base_criterion,
        criterion__reg=reg,
        criterion__adapt_criterion=DANLoss(sigmas=sigmas),
        **kwargs,
    )
    return net 
[docs]
class CANLoss(BaseDALoss):
    """Loss for Contrastive Adaptation Network (CAN)
    This loss implements the contrastive domain discrepancy (CDD)
    as described in [33].
    Parameters
    ----------
    distance_threshold : float, optional (default=0.5)
        Distance threshold for discarding the samples that are
        to far from the centroids.
    class_threshold : int, optional (default=3)
        Minimum number of samples in a class to be considered for the loss.
    sigmas : array like, default=None,
        If array, sigmas used for the multi gaussian kernel.
        If None, uses sigmas proposed  in [1]_.
    target_kmeans : sklearn KMeans instance, default=None,
        Pre-computed target KMeans clustering model.
    eps : float, default=1e-7
        Small constant added to median distance calculation for numerical stability.
    References
    ----------
    .. [33] Kang, G., Jiang, L., Yang, Y., & Hauptmann, A. G. (2019).
           Contrastive adaptation network for unsupervised domain adaptation.
           In Proceedings of the IEEE/CVF Conference on Computer Vision
           and Pattern Recognition (pp. 4893-4902).
    """
    def __init__(
        self,
        distance_threshold=0.5,
        class_threshold=3,
        sigmas=None,
        target_kmeans=None,
        eps=1e-7,
    ):
        super().__init__()
        self.distance_threshold = distance_threshold
        self.class_threshold = class_threshold
        self.sigmas = sigmas
        self.target_kmeans = target_kmeans
        self.eps = eps
[docs]
    def forward(
        self,
        y_s,
        features_s,
        features_t,
        **kwargs,
    ):
        loss = cdd_loss(
            y_s,
            features_s,
            features_t,
            sigmas=self.sigmas,
            target_kmeans=self.target_kmeans,
            distance_threshold=self.distance_threshold,
            class_threshold=self.class_threshold,
            eps=self.eps,
        )
        return loss 
 
[docs]
def CAN(
    module,
    layer_name,
    reg=1,
    distance_threshold=0.5,
    class_threshold=3,
    sigmas=None,
    base_criterion=None,
    callbacks=None,
    **kwargs,
):
    """Contrastive Adaptation Network (CAN) domain adaptation method.
    From [33].
    Parameters
    ----------
    module : torch module (class or instance)
        A PyTorch :class:`~torch.nn.Module`.
    layer_name : str
        The name of the module's layer whose outputs are
        collected during the training for the adaptation.
    reg : float, optional (default=1)
        Regularization parameter for DA loss.
    distance_threshold : float, optional (default=0.5)
        Distance threshold for discarding the samples that are
        to far from the centroids.
    class_threshold : int, optional (default=3)
        Minimum number of samples in a class to be considered for the loss.
    sigmas : array like, default=None,
        If array, sigmas used for the multi gaussian kernel.
        If None, uses sigmas proposed  in [1]_.
    base_criterion : torch criterion (class)
        The base criterion used to compute the loss with source
        labels. If None, the default is `torch.nn.CrossEntropyLoss`.
    callbacks : list, optional
        List of callbacks to be used during training.
    References
    ----------
    .. [33] Kang, G., Jiang, L., Yang, Y., & Hauptmann, A. G. (2019).
           Contrastive adaptation network for unsupervised domain adaptation.
           In Proceedings of the IEEE/CVF Conference on Computer Vision
           and Pattern Recognition (pp. 4893-4902).
    """
    if base_criterion is None:
        base_criterion = torch.nn.CrossEntropyLoss()
    if callbacks is None:
        callbacks = [ComputeSourceCentroids()]
    else:
        if isinstance(callbacks, list):
            callbacks.append(ComputeSourceCentroids())
        else:
            callbacks = [callbacks, ComputeSourceCentroids()]
    net = DomainAwareNet(
        module=DomainAwareModule,
        module__base_module=module,
        module__layer_name=layer_name,
        iterator_train=DomainBalancedDataLoader,
        criterion=DomainAwareCriterion,
        criterion__base_criterion=base_criterion,
        criterion__reg=reg,
        criterion__adapt_criterion=CANLoss(
            distance_threshold=distance_threshold,
            class_threshold=class_threshold,
            sigmas=sigmas,
        ),
        callbacks=callbacks,
        **kwargs,
    )
    return net