skada.deep.losses.deepjdot_loss

skada.deep.losses.deepjdot_loss(y_s, y_pred_t, features_s, features_t, reg_dist, reg_cl, sample_weights=None, target_sample_weights=None, criterion=None)[source]

Compute the OT loss for DeepJDOT method [13].

Parameters:
y_stensor

labels of the source data used to perform the distance matrix.

y_pred_ttensor

labels of the target data used to perform the distance matrix.

features_stensor

features of the source data used to perform the distance matrix.

features_ttensor

features of the target data used to perform the distance matrix.

reg_distfloat

Divergence term regularization parameter.

reg_clfloat, default=1

Class distance term regularization parameter.

sample_weightstensor

Weights of the source samples. If None, create uniform weights.

target_sample_weightstensor

Weights of the source samples. If None, create uniform weights.

criteriontorch criterion (class)

The criterion (loss) used to compute the DeepJDOT loss. If None, use the CrossEntropyLoss.

Returns:
lossndarray

The loss of the method.

References

[13]

Bharath Bhushan Damodaran, Benjamin Kellenberger, Remi Flamary, Devis Tuia, and Nicolas Courty. DeepJDOT: Deep Joint Distribution Optimal Transport for Unsupervised Domain Adaptation. In ECCV 2018 15th European Conference on Computer Vision, September 2018. Springer.