skada.deep.MDD
- skada.deep.MDD(module, layer_name, reg=1, gamma=4.0, disc_classifier=None, num_features=None, n_classes=None, base_criterion=None, **kwargs)[source]
Margin Disparity Discrepancy (MDD).
From [35].
- Parameters:
- moduletorch module (class or instance)
A PyTorch
Module
. In general, the uninstantiated class should be passed, although instantiated modules will also work.- layer_namestr
The name of the module's layer whose outputs are collected during the training.
- regfloat, default=1
Regularization parameter for DA loss.
- disc_classifiertorch module, default=None
A PyTorch
Module
used as a discriminator. It should have the same architecture than the classifier used on the source. If None, a domain classifier is created following [1]_.- num_featuresint, default=None
Size of the input of domain classifier, e.g size of the last layer of the feature extractor. If domain_classifier is None, num_features has to be provided.
- n_classesint, default=None
Number of classes. If domain_classifier is None, n_classes has to be provided.
- base_criteriontorch criterion (class)
The base criterion used to compute the loss with source labels. If None, the default is torch.nn.CrossEntropyLoss.
- gammafloat (default=4.0)
Margin parameter following [35].
References