skada.deep.CAN
- skada.deep.CAN(module, layer_name, reg=1, distance_threshold=0.5, class_threshold=3, sigmas=None, base_criterion=None, callbacks=None, **kwargs)[source]
Contrastive Adaptation Network (CAN) domain adaptation method.
From [33].
- Parameters:
- moduletorch module (class or instance)
A PyTorch
Module
.- layer_namestr
The name of the module's layer whose outputs are collected during the training for the adaptation.
- regfloat, optional (default=1)
Regularization parameter for DA loss.
- distance_thresholdfloat, optional (default=0.5)
Distance threshold for discarding the samples that are to far from the centroids.
- class_thresholdint, optional (default=3)
Minimum number of samples in a class to be considered for the loss.
- sigmasarray like, default=None,
If array, sigmas used for the multi gaussian kernel. If None, uses sigmas proposed in [1]_.
- base_criteriontorch criterion (class)
The base criterion used to compute the loss with source labels. If None, the default is torch.nn.CrossEntropyLoss.
- callbackslist, optional
List of callbacks to be used during training.
References
[33]Kang, G., Jiang, L., Yang, Y., & Hauptmann, A. G. (2019). Contrastive adaptation network for unsupervised domain adaptation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 4893-4902).