skada.deep.DeepJDOTLoss

class skada.deep.DeepJDOTLoss(reg_dist=1, reg_cl=1, target_criterion=None)[source]

Loss DeepJDOT.

This loss reduces the distance between source and target domain through a measure of discrepancy on joint deep representations/labels based on optimal transport. See [13].

Parameters:
reg_distfloat, default=1

Divergence regularization parameter.

reg_clfloat, default=1

Class distance term regularization parameter.

target_criteriontorch criterion (class)

The uninitialized criterion (loss) used to compute the DeepJDOT loss. The criterion should support reduction='none'.

References

[13]

Bharath Bhushan Damodaran, Benjamin Kellenberger, Remi Flamary, Devis Tuia, and Nicolas Courty. DeepJDOT: Deep Joint Distribution Optimal Transport for Unsupervised Domain Adaptation. In ECCV 2018 15th European Conference on Computer Vision, September 2018. Springer.

forward(y_s, y_pred_t, features_s, features_t, **kwargs)[source]

Compute the domain adaptation loss