skada.deep.MDD

skada.deep.MDD(module, layer_name, reg=1, gamma=4.0, disc_classifier=None, num_features=None, n_classes=None, base_criterion=None, **kwargs)[source]

Margin Disparity Discrepancy (MDD).

From [35].

Parameters:
moduletorch module (class or instance)

A PyTorch Module. In general, the uninstantiated class should be passed, although instantiated modules will also work.

layer_namestr

The name of the module's layer whose outputs are collected during the training.

regfloat, default=1

Regularization parameter for DA loss.

disc_classifiertorch module, default=None

A PyTorch Module used as a discriminator. It should have the same architecture than the classifier used on the source. If None, a domain classifier is created following [1]_.

num_featuresint, default=None

Size of the input of domain classifier, e.g size of the last layer of the feature extractor. If domain_classifier is None, num_features has to be provided.

n_classesint, default=None

Number of classes. If domain_classifier is None, n_classes has to be provided.

base_criteriontorch criterion (class)

The base criterion used to compute the loss with source labels. If None, the default is torch.nn.CrossEntropyLoss.

gammafloat (default=4.0)

Margin parameter following [35].

References

[35] (1,2)

Yuchen Zhang et. al. Bridging Theory and Algorithm for Domain Adaptation. In International Conference on Machine Learning, 2019.