skada.deep.SPA
- skada.deep.SPA(module, layer_name, reg_adv=1, reg_gsa=1, reg_nap=1, domain_classifier=None, num_features=None, base_criterion=None, domain_criterion=None, callbacks=None, max_epochs=100, **kwargs)[source]
Domain Adaptation with SPA.
From [36].
- Parameters:
- moduletorch module (class or instance)
A PyTorch
Module
. In general, the uninstantiated class should be passed, although instantiated modules will also work.- layer_namestr
The name of the module's layer whose outputs are collected during the training.
- regfloat, default=1
Regularization parameter for DA loss.
- domain_classifiertorch module, default=None
A PyTorch
Module
used to classify the domain. If None, a domain classifier is created following [1]_.- num_featuresint, default=None
Size of the input of domain classifier, e.g size of the last layer of the feature extractor. If domain_classifier is None, num_features has to be provided.
- base_criteriontorch criterion (class)
The base criterion used to compute the loss with source labels. If None, the default is torch.nn.CrossEntropyLoss.
- domain_criteriontorch criterion (class)
The criterion (loss) used to compute the adversarial loss. If None, a BCELoss is used.
- callbackslist, default=None
List of callbacks to use during training.
- max_epochsint, default=100
Maximum number of epochs to train the model.
References
[36]Xiao et. al. SPA: A Graph Spectral Alignment Perspective for Domain Adaptation. In Neurips, 2023.