skada.deep.losses.deepjdot_loss
- skada.deep.losses.deepjdot_loss(y_s, y_pred_t, features_s, features_t, reg_dist, reg_cl, sample_weights=None, target_sample_weights=None, criterion=None)[source]
Compute the OT loss for DeepJDOT method [13].
- Parameters:
- y_stensor
labels of the source data used to perform the distance matrix.
- y_pred_ttensor
labels of the target data used to perform the distance matrix.
- features_stensor
features of the source data used to perform the distance matrix.
- features_ttensor
features of the target data used to perform the distance matrix.
- reg_distfloat
Divergence term regularization parameter.
- reg_clfloat, default=1
Class distance term regularization parameter.
- sample_weightstensor
Weights of the source samples. If None, create uniform weights.
- target_sample_weightstensor
Weights of the source samples. If None, create uniform weights.
- criteriontorch criterion (class)
The criterion (loss) used to compute the DeepJDOT loss. If None, use the CrossEntropyLoss.
- Returns:
- lossndarray
The loss of the method.
References
[13]Bharath Bhushan Damodaran, Benjamin Kellenberger, Remi Flamary, Devis Tuia, and Nicolas Courty. DeepJDOT: Deep Joint Distribution Optimal Transport for Unsupervised Domain Adaptation. In ECCV 2018 15th European Conference on Computer Vision, September 2018. Springer.