# Author: Theo Gnassounou <theo.gnassounou@inria.fr>
# Remi Flamary <remi.flamary@polytechnique.edu>
#
# License: BSD 3-Clause
import torch
from skada.deep.base import (
BaseDALoss,
DomainAwareCriterion,
DomainAwareModule,
DomainAwareNet,
DomainBalancedDataLoader,
)
from .losses import dan_loss, deepcoral_loss
[docs]
class DeepCoralLoss(BaseDALoss):
"""Loss DeepCORAL
This loss reduces the distance between covariances
of the source features and the target features.
See [12]_.
Parameters
----------
assume_centered: bool, default=False
If True, data are not centered before computation.
References
----------
.. [12] Baochen Sun and Kate Saenko. Deep coral:
Correlation alignment for deep domain
adaptation. In ECCV Workshops, 2016.
"""
def __init__(
self,
assume_centered=False,
):
super().__init__()
self.assume_centered = assume_centered
[docs]
def forward(
self,
y_s,
y_pred_s,
y_pred_t,
domain_pred_s,
domain_pred_t,
features_s,
features_t,
):
"""Compute the domain adaptation loss"""
loss = deepcoral_loss(features_s, features_t, self.assume_centered)
return loss
[docs]
def DeepCoral(
module, layer_name, reg=1, assume_centered=False, base_criterion=None, **kwargs
):
"""DeepCORAL domain adaptation method.
From [12]_.
Parameters
----------
module : torch module (class or instance)
A PyTorch :class:`~torch.nn.Module`.
layer_name : str
The name of the module's layer whose outputs are
collected during the training for the adaptation.
reg : float, optional (default=1)
Regularization parameter for DA loss.
assume_centered: bool, default=False
If True, data are not centered before computation.
base_criterion : torch criterion (class)
The base criterion used to compute the loss with source
labels. If None, the default is `torch.nn.CrossEntropyLoss`.
References
----------
.. [12] Baochen Sun and Kate Saenko. Deep coral:
Correlation alignment for deep domain
adaptation. In ECCV Workshops, 2016.
"""
if base_criterion is None:
base_criterion = torch.nn.CrossEntropyLoss()
net = DomainAwareNet(
module=DomainAwareModule,
module__base_module=module,
module__layer_name=layer_name,
iterator_train=DomainBalancedDataLoader,
criterion=DomainAwareCriterion,
criterion__base_criterion=base_criterion,
criterion__reg=reg,
criterion__adapt_criterion=DeepCoralLoss(assume_centered=assume_centered),
**kwargs,
)
return net
[docs]
class DANLoss(BaseDALoss):
"""Loss DAN
This loss reduces the MMD distance between
source features and target features.
From [14]_.
Parameters
----------
sigmas : array-like, optional (default=None)
The sigmas for the Gaussian kernel.
References
----------
.. [14] Mingsheng Long et. al. Learning Transferable
Features with Deep Adaptation Networks.
In ICML, 2015.
"""
def __init__(self, sigmas=None):
super().__init__()
self.sigmas = sigmas
[docs]
def forward(
self,
y_s,
y_pred_s,
y_pred_t,
domain_pred_s,
domain_pred_t,
features_s,
features_t,
):
"""Compute the domain adaptation loss"""
loss = dan_loss(features_s, features_t, sigmas=self.sigmas)
return loss
def DAN(module, layer_name, reg=1, sigmas=None, base_criterion=None, **kwargs):
"""DAN domain adaptation method.
See [14]_.
Parameters
----------
module : torch module (class or instance)
A PyTorch :class:`~torch.nn.Module`.
layer_name : str
The name of the module's layer whose outputs are
collected during the training for the adaptation.
reg : float, optional (default=1)
Regularization parameter for DA loss.
sigmas : array-like, optional (default=None)
The sigmas for the Gaussian kernel.
base_criterion : torch criterion (class)
The base criterion used to compute the loss with source
labels. If None, the default is `torch.nn.CrossEntropyLoss`.
References
----------
.. [14] Mingsheng Long et. al. Learning Transferable
Features with Deep Adaptation Networks.
In ICML, 2015.
"""
if base_criterion is None:
base_criterion = torch.nn.CrossEntropyLoss()
net = DomainAwareNet(
module=DomainAwareModule,
module__base_module=module,
module__layer_name=layer_name,
iterator_train=DomainBalancedDataLoader,
criterion=DomainAwareCriterion,
criterion__base_criterion=base_criterion,
criterion__reg=reg,
criterion__adapt_criterion=DANLoss(sigmas=sigmas),
**kwargs,
)
return net