skada.deep.MDDLoss

class skada.deep.MDDLoss(gamma=4.0)[source]

Loss MDD.

This loss tries to minimize the disparity discrepancy between the source and target domains. The discrepancy is estimated through adversarial training of three networks: an encoder, a task network and a discriminator.

See [35] for details.

Parameters:
gammafloat (default=4.0)

Margin parameter following [35]

References

[35] (1,2)

Yuchen Zhang et. al. Bridging Theory and Algorithm for Domain Adaptation. In International Conference on Machine Learning, 2019.

forward(y_pred_s, y_pred_t, domain_pred_s, domain_pred_t, **kwargs)[source]

Compute the domain adaptation loss