Note
Go to the end to download the full example code.
Divergence domain adaptation methods.
This example illustrates the DeepCoral method from [1] on a simple image classification task.
# Author: Théo Gnassounou
#
# License: BSD 3-Clause
# sphinx_gallery_thumbnail_number = 4
from skorch import NeuralNetClassifier
from torch import nn
from skada.datasets import load_mnist_usps
from skada.deep import DeepCoral
from skada.deep.modules import MNISTtoUSPSNet
Load the image datasets
dataset = load_mnist_usps(n_classes=2, n_samples=0.5, return_dataset=True)
X, y, sample_domain = dataset.pack_train(as_sources=["mnist"], as_targets=["usps"])
X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["usps"])
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./datasets/MNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/9912422 [00:00<?, ?it/s]
100%|██████████| 9912422/9912422 [00:00<00:00, 127952036.87it/s]
Extracting ./datasets/MNIST/raw/train-images-idx3-ubyte.gz to ./datasets/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./datasets/MNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/28881 [00:00<?, ?it/s]
100%|██████████| 28881/28881 [00:00<00:00, 25427307.69it/s]
Extracting ./datasets/MNIST/raw/train-labels-idx1-ubyte.gz to ./datasets/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./datasets/MNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/1648877 [00:00<?, ?it/s]
100%|██████████| 1648877/1648877 [00:00<00:00, 45494496.61it/s]
Extracting ./datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to ./datasets/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/4542 [00:00<?, ?it/s]
100%|██████████| 4542/4542 [00:00<00:00, 10696534.96it/s]
Extracting ./datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./datasets/MNIST/raw
/home/circleci/project/skada/datasets/_mnist_usps.py:72: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
mnist_target = torch.tensor(mnist_dataset.targets)
Downloading https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2 to ./datasets/usps.t.bz2
0%| | 0/1831726 [00:00<?, ?it/s]
2%|▏ | 32768/1831726 [00:00<00:12, 140521.72it/s]
4%|▎ | 65536/1831726 [00:00<00:12, 140692.05it/s]
5%|▌ | 98304/1831726 [00:00<00:12, 140725.66it/s]
13%|█▎ | 229376/1831726 [00:00<00:05, 306663.85it/s]
23%|██▎ | 425984/1831726 [00:01<00:02, 500723.24it/s]
48%|████▊ | 884736/1831726 [00:01<00:00, 999408.96it/s]
95%|█████████▍| 1736704/1831726 [00:01<00:00, 1867355.98it/s]
100%|██████████| 1831726/1831726 [00:01<00:00, 1121755.63it/s]
Train a classic model
model = NeuralNetClassifier(
MNISTtoUSPSNet(),
criterion=nn.CrossEntropyLoss(),
batch_size=128,
max_epochs=5,
train_split=False,
lr=1e-2,
)
model.fit(X[sample_domain > 0], y[sample_domain > 0])
model.score(X_test, y=y_test)
epoch train_loss dur
------- ------------ ------
1 1.5518 4.4002
2 0.3568 3.6063
3 0.1240 3.0916
4 0.0682 4.9007
5 0.0498 3.3028
0.9421221864951769
Train a DeepCoral model
model = DeepCoral(
MNISTtoUSPSNet(),
layer_name="fc1",
batch_size=128,
max_epochs=5,
train_split=False,
reg=1,
lr=1e-2,
)
model.fit(X, y, sample_domain=sample_domain)
model.score(X_test, y_test, sample_domain=sample_domain_test)
epoch train_loss dur
------- ------------ ------
1 1.6425 8.6921
2 0.4032 10.9031
3 0.1406 10.1052
4 0.0876 11.7941
5 0.0639 10.4980
0.9292604501607717
Total running time of the script: (1 minutes 20.827 seconds)