skada.deep.DeepJDOT
- skada.deep.DeepJDOT(module, layer_name, reg_dist=1, reg_cl=1, base_criterion=None, target_criterion=None, **kwargs)[source]
DeepJDOT.
See [13].
- Parameters:
- moduletorch module (class or instance)
A PyTorch
Module
.- layer_namestr
The name of the module's layer whose outputs are collected during the training for the adaptation.
- regfloat, default=1
Regularization parameter for DA loss.
- reg_clfloat, default=1
Class distance term regularization parameter.
- base_criteriontorch criterion (class)
The base criterion used to compute the loss with source labels. If None, the default is torch.nn.CrossEntropyLoss.
- target_criteriontorch criterion (class)
The uninitialized criterion (loss) used to compute the DeepJDOT loss. The criterion should support reduction='none'.
References
[13]Bharath Bhushan Damodaran, Benjamin Kellenberger, Remi Flamary, Devis Tuia, and Nicolas Courty. DeepJDOT: Deep Joint Distribution Optimal Transport for Unsupervised Domain Adaptation. In ECCV 2018 15th European Conference on Computer Vision, September 2018. Springer.
Examples using skada.deep.DeepJDOT
Optimal transport domain adaptation methods.