skada.deep.DeepJDOT

skada.deep.DeepJDOT(module, layer_name, reg=1, reg_cl=1, target_criterion=None, **kwargs)[source]

DeepJDOT.

See [13].

Parameters:
moduletorch module (class or instance)

A PyTorch Module.

layer_namestr

The name of the module's layer whose outputs are collected during the training for the adaptation.

regfloat, default=1

Regularization parameter.

reg_clfloat, default=1

Class distance term regularization parameter.

target_criteriontorch criterion (class)

The uninitialized criterion (loss) used to compute the DeepJDOT loss. The criterion should support reduction='none'.

References

[13]

Bharath Bhushan Damodaran, Benjamin Kellenberger, Remi Flamary, Devis Tuia, and Nicolas Courty. DeepJDOT: Deep Joint Distribution Optimal Transport for Unsupervised Domain Adaptation. In ECCV 2018 15th European Conference on Computer Vision, September 2018. Springer.

Examples using skada.deep.DeepJDOT

Optimal transport domain adaptation methods.

Optimal transport domain adaptation methods.