skada.deep.CANLoss

class skada.deep.CANLoss(distance_threshold=0.5, class_threshold=3, sigmas=None, target_kmeans=None, eps=1e-07)[source]

Loss for Contrastive Adaptation Network (CAN)

This loss implements the contrastive domain discrepancy (CDD) as described in [33].

Parameters:
distance_thresholdfloat, optional (default=0.5)

Distance threshold for discarding the samples that are to far from the centroids.

class_thresholdint, optional (default=3)

Minimum number of samples in a class to be considered for the loss.

sigmasarray like, default=None,

If array, sigmas used for the multi gaussian kernel. If None, uses sigmas proposed in [1]_.

target_kmeanssklearn KMeans instance, default=None,

Pre-computed target KMeans clustering model.

epsfloat, default=1e-7

Small constant added to median distance calculation for numerical stability.

References

[33]

Kang, G., Jiang, L., Yang, Y., & Hauptmann, A. G. (2019). Contrastive adaptation network for unsupervised domain adaptation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 4893-4902).

forward(y_s, features_s, features_t, **kwargs)[source]

Compute the domain adaptation loss

Parameters:
y_s

The true labels for source.

y_pred_s

Predictions of the source domain.

y_pred_t

Predictions of the target domain.

domain_pred_s

Domain predictions of the source domain.

domain_pred_t

Domain predictions of the source domain.

features_s

Features of the chosen layer of source domain.

features_t

Features of the chosen layer of target domain.