Optimal transport domain adaptation methods.

This example illustrates the Optimal Transport deep DA method from on a simple image classification task.

# Author: Théo Gnassounou
#
# License: BSD 3-Clause
# sphinx_gallery_thumbnail_number = 4
from skorch import NeuralNetClassifier
from torch import nn

from skada.datasets import load_mnist_usps
from skada.deep import DeepJDOT
from skada.deep.modules import MNISTtoUSPSNet

Load the image datasets

dataset = load_mnist_usps(n_classes=2, n_samples=0.5, return_dataset=True)
X, y, sample_domain = dataset.pack(
    as_sources=["mnist"], as_targets=["usps"], mask_target_labels=True
)
X_test, y_test, sample_domain_test = dataset.pack(
    as_sources=[], as_targets=["usps"], mask_target_labels=False
)
/home/circleci/project/skada/datasets/_mnist_usps.py:72: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  mnist_target = torch.tensor(mnist_dataset.targets)

Train a classic model

model = NeuralNetClassifier(
    MNISTtoUSPSNet(),
    criterion=nn.CrossEntropyLoss(),
    batch_size=128,
    max_epochs=5,
    train_split=False,
    lr=1e-2,
)
model.fit(X[sample_domain > 0], y[sample_domain > 0])
model.score(X_test, y=y_test)
  epoch    train_loss     dur
-------  ------------  ------
      1        1.4460  5.3023
      2        0.1985  5.7021
      3        0.0777  5.3999
      4        0.0422  6.9946
      5        0.0334  7.5007

0.8906752411575563

Train a DeepJDOT model

model = DeepJDOT(
    MNISTtoUSPSNet(),
    layer_name="fc1",
    batch_size=128,
    max_epochs=5,
    train_split=False,
    reg_dist=0.1,
    reg_cl=0.01,
    lr=1e-2,
)
model.fit(X, y, sample_domain=sample_domain)
model.score(X_test, y_test, sample_domain=sample_domain_test)
/home/circleci/.local/lib/python3.10/site-packages/sklearn/utils/deprecation.py:132: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.
  warnings.warn(
  epoch    train_loss      dur
-------  ------------  -------
      1        2.0113  14.6628
      2        1.1483  14.0983
      3        0.7808  12.7022
      4        0.6364  11.3017
      5        0.5516  13.8933
/home/circleci/.local/lib/python3.10/site-packages/sklearn/utils/deprecation.py:132: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.
  warnings.warn(
/home/circleci/.local/lib/python3.10/site-packages/sklearn/utils/deprecation.py:132: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.
  warnings.warn(
/home/circleci/.local/lib/python3.10/site-packages/sklearn/utils/deprecation.py:132: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.
  warnings.warn(

0.8585209003215434

Total running time of the script: (1 minutes 42.555 seconds)

Gallery generated by Sphinx-Gallery