skada.deep.DeepJDOT

skada.deep.DeepJDOT(module, layer_name, reg_dist=1, reg_cl=1, base_criterion=None, target_criterion=None, **kwargs)[source]

DeepJDOT.

See [13].

Parameters:
moduletorch module (class or instance)

A PyTorch Module.

layer_namestr

The name of the module's layer whose outputs are collected during the training for the adaptation.

regfloat, default=1

Regularization parameter for DA loss.

reg_clfloat, default=1

Class distance term regularization parameter.

base_criteriontorch criterion (class)

The base criterion used to compute the loss with source labels. If None, the default is torch.nn.CrossEntropyLoss.

target_criteriontorch criterion (class)

The uninitialized criterion (loss) used to compute the DeepJDOT loss. The criterion should support reduction='none'.

References

[13]

Bharath Bhushan Damodaran, Benjamin Kellenberger, Remi Flamary, Devis Tuia, and Nicolas Courty. DeepJDOT: Deep Joint Distribution Optimal Transport for Unsupervised Domain Adaptation. In ECCV 2018 15th European Conference on Computer Vision, September 2018. Springer.

Examples using skada.deep.DeepJDOT

Optimal transport domain adaptation methods.

Optimal transport domain adaptation methods.