skada.deep.CAN

skada.deep.CAN(module, layer_name, reg=1, distance_threshold=0.5, class_threshold=3, sigmas=None, base_criterion=None, callbacks=None, **kwargs)[source]

Contrastive Adaptation Network (CAN) domain adaptation method.

From [33].

Parameters:
moduletorch module (class or instance)

A PyTorch Module.

layer_namestr

The name of the module's layer whose outputs are collected during the training for the adaptation.

regfloat, optional (default=1)

Regularization parameter for DA loss.

distance_thresholdfloat, optional (default=0.5)

Distance threshold for discarding the samples that are to far from the centroids.

class_thresholdint, optional (default=3)

Minimum number of samples in a class to be considered for the loss.

sigmasarray like, default=None,

If array, sigmas used for the multi gaussian kernel. If None, uses sigmas proposed in [1]_.

base_criteriontorch criterion (class)

The base criterion used to compute the loss with source labels. If None, the default is torch.nn.CrossEntropyLoss.

callbackslist, optional

List of callbacks to be used during training.

References

[33]

Kang, G., Jiang, L., Yang, Y., & Hauptmann, A. G. (2019). Contrastive adaptation network for unsupervised domain adaptation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 4893-4902).