Divergence domain adaptation methods.

This example illustrates the DeepCoral method from [1] on a simple image classification task.

# Author: Théo Gnassounou
#
# License: BSD 3-Clause
# sphinx_gallery_thumbnail_number = 4
from skorch import NeuralNetClassifier
from torch import nn

from skada.datasets import load_mnist_usps
from skada.deep import DeepCoral
from skada.deep.modules import MNISTtoUSPSNet

Load the image datasets

dataset = load_mnist_usps(n_classes=2, n_samples=0.5, return_dataset=True)
X, y, sample_domain = dataset.pack(
    as_sources=["mnist"], as_targets=["usps"], mask_target_labels=True
)
X_test, y_test, sample_domain_test = dataset.pack(
    as_sources=[], as_targets=["usps"], mask_target_labels=False
)
  0%|          | 0.00/9.91M [00:00<?, ?B/s]
100%|██████████| 9.91M/9.91M [00:00<00:00, 145MB/s]

  0%|          | 0.00/28.9k [00:00<?, ?B/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 27.9MB/s]

  0%|          | 0.00/1.65M [00:00<?, ?B/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 200MB/s]

  0%|          | 0.00/4.54k [00:00<?, ?B/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 12.0MB/s]
/home/circleci/project/skada/datasets/_mnist_usps.py:72: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  mnist_target = torch.tensor(mnist_dataset.targets)

  0%|          | 0.00/1.83M [00:00<?, ?B/s]
  2%|▏         | 32.8k/1.83M [00:00<00:10, 165kB/s]
  5%|▌         | 98.3k/1.83M [00:00<00:06, 260kB/s]
 11%|█         | 197k/1.83M [00:00<00:04, 365kB/s]
 16%|█▌        | 295k/1.83M [00:00<00:03, 414kB/s]
 20%|█▉        | 360k/1.83M [00:01<00:05, 281kB/s]
 32%|███▏      | 590k/1.83M [00:01<00:02, 517kB/s]
 47%|████▋     | 852k/1.83M [00:01<00:01, 738kB/s]
 54%|█████▎    | 983k/1.83M [00:02<00:01, 555kB/s]
 82%|████████▏ | 1.51M/1.83M [00:02<00:00, 1.05MB/s]
 93%|█████████▎| 1.70M/1.83M [00:02<00:00, 1.03MB/s]
100%|██████████| 1.83M/1.83M [00:02<00:00, 759kB/s]

Train a classic model

model = NeuralNetClassifier(
    MNISTtoUSPSNet(),
    criterion=nn.CrossEntropyLoss(),
    batch_size=128,
    max_epochs=5,
    train_split=False,
    lr=1e-2,
)
model.fit(X[sample_domain > 0], y[sample_domain > 0])
model.score(X_test, y=y_test)
  epoch    train_loss      dur
-------  ------------  -------
      1        1.6299  10.7472
      2        0.3340  11.3004
      3        0.1122  14.0317
      4        0.0668  10.3682
      5        0.0521  11.0014

0.887459807073955

Train a DeepCoral model

model = DeepCoral(
    MNISTtoUSPSNet(),
    layer_name="fc1",
    batch_size=128,
    max_epochs=5,
    train_split=False,
    reg=1,
    lr=1e-2,
)
model.fit(X, y, sample_domain=sample_domain)
model.score(X_test, y_test, sample_domain=sample_domain_test)
  epoch    train_loss      dur
-------  ------------  -------
      1        1.4917  16.0562
      2        0.3264  14.4978
      3        0.1183  12.3996
      4        0.0758  13.8300
      5        0.0648  11.1976

0.9067524115755627

Total running time of the script: (2 minutes 17.279 seconds)

Gallery generated by Sphinx-Gallery