Note
Go to the end to download the full example code.
Optimal transport domain adaptation methods.
This example illustrates the Optimal Transport deep DA method from on a simple image classification task.
# Author: Théo Gnassounou
#
# License: BSD 3-Clause
# sphinx_gallery_thumbnail_number = 4
from skorch import NeuralNetClassifier
from torch import nn
from skada.datasets import load_mnist_usps
from skada.deep import DeepJDOT
from skada.deep.modules import MNISTtoUSPSNet
Load the image datasets
dataset = load_mnist_usps(n_classes=2, n_samples=0.5, return_dataset=True)
X, y, sample_domain = dataset.pack_train(as_sources=["mnist"], as_targets=["usps"])
X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["usps"])
/home/circleci/project/skada/datasets/_mnist_usps.py:72: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
mnist_target = torch.tensor(mnist_dataset.targets)
Train a classic model
model = NeuralNetClassifier(
MNISTtoUSPSNet(),
criterion=nn.CrossEntropyLoss(),
batch_size=128,
max_epochs=5,
train_split=False,
lr=1e-2,
)
model.fit(X[sample_domain > 0], y[sample_domain > 0])
model.score(X_test, y=y_test)
epoch train_loss dur
------- ------------ ------
1 1.5280 2.6382
2 0.2267 2.7048
3 0.0864 2.8013
4 0.0485 2.5987
5 0.0394 2.4990
0.8906752411575563
Train a DeepJDOT model
model = DeepJDOT(
MNISTtoUSPSNet(),
layer_name="fc1",
batch_size=128,
max_epochs=5,
train_split=False,
reg=0.1,
reg_cl=0.01,
lr=1e-2,
)
model.fit(X, y, sample_domain=sample_domain)
model.score(X_test, y_test, sample_domain=sample_domain_test)
epoch train_loss dur
------- ------------ -------
1 1.9599 17.4901
2 0.9898 7.3927
3 0.6751 6.0022
4 0.5494 17.7027
5 0.4724 19.3976
0.9067524115755627
Total running time of the script: (1 minutes 24.711 seconds)