skada.deep.CANLoss
- class skada.deep.CANLoss(distance_threshold=0.5, class_threshold=3, sigmas=None, target_kmeans=None, eps=1e-07)[source]
Loss for Contrastive Adaptation Network (CAN)
This loss implements the contrastive domain discrepancy (CDD) as described in [33].
- Parameters:
- distance_thresholdfloat, optional (default=0.5)
Distance threshold for discarding the samples that are to far from the centroids.
- class_thresholdint, optional (default=3)
Minimum number of samples in a class to be considered for the loss.
- sigmasarray like, default=None,
If array, sigmas used for the multi gaussian kernel. If None, uses sigmas proposed in [1]_.
- target_kmeanssklearn KMeans instance, default=None,
Pre-computed target KMeans clustering model.
- epsfloat, default=1e-7
Small constant added to median distance calculation for numerical stability.
References
[33]Kang, G., Jiang, L., Yang, Y., & Hauptmann, A. G. (2019). Contrastive adaptation network for unsupervised domain adaptation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 4893-4902).
- forward(y_s, features_s, features_t, **kwargs)[source]
Compute the domain adaptation loss
- Parameters:
- y_s
The true labels for source.
- y_pred_s
Predictions of the source domain.
- y_pred_t
Predictions of the target domain.
- domain_pred_s
Domain predictions of the source domain.
- domain_pred_t
Domain predictions of the source domain.
- features_s
Features of the chosen layer of source domain.
- features_t
Features of the chosen layer of target domain.