Note
Go to the end to download the full example code.
Adversarial domain adaptation methods.
This example illustrates the adversarial methods from on a simple image classification task.
# Author: Théo Gnassounou
#
# License: BSD 3-Clause
# sphinx_gallery_thumbnail_number = 4
from skorch import NeuralNetClassifier
from torch import nn
from skada.datasets import load_mnist_usps
from skada.deep import DANN
from skada.deep.modules import MNISTtoUSPSNet
Load the image datasets
dataset = load_mnist_usps(n_classes=2, n_samples=0.5, return_dataset=True)
X, y, sample_domain = dataset.pack(
as_sources=["mnist"], as_targets=["usps"], mask_target_labels=True
)
X_test, y_test, sample_domain_test = dataset.pack(
as_sources=[], as_targets=["usps"], mask_target_labels=False
)
/home/circleci/project/skada/datasets/_mnist_usps.py:72: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
mnist_target = torch.tensor(mnist_dataset.targets)
Train a classic model
model = NeuralNetClassifier(
MNISTtoUSPSNet(),
criterion=nn.CrossEntropyLoss(),
batch_size=128,
max_epochs=5,
train_split=False,
lr=1e-2,
)
model.fit(X[sample_domain > 0], y[sample_domain > 0])
model.score(X_test, y=y_test)
epoch train_loss dur
------- ------------ ------
1 1.5035 5.0989
2 0.2546 4.8037
3 0.0831 5.1934
4 0.0490 4.5016
5 0.0356 4.8983
0.9003215434083601
Train a DANN model
model = DANN(
MNISTtoUSPSNet(),
layer_name="fc1",
batch_size=128,
max_epochs=5,
train_split=False,
reg=0.01,
num_features=128,
lr=1e-2,
)
model.fit(X, y, sample_domain=sample_domain)
model.score(X_test, y_test, sample_domain=sample_domain_test)
/home/circleci/.local/lib/python3.10/site-packages/sklearn/utils/deprecation.py:132: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.
warnings.warn(
/home/circleci/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1773: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
return self._call_impl(*args, **kwargs)
epoch train_loss dur
------- ------------ -------
1 2.4881 10.0910
2 1.2773 10.7986
3 1.0984 10.3991
4 1.0555 10.8971
5 1.0409 10.6012
/home/circleci/.local/lib/python3.10/site-packages/sklearn/utils/deprecation.py:132: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.
warnings.warn(
/home/circleci/.local/lib/python3.10/site-packages/sklearn/utils/deprecation.py:132: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.
warnings.warn(
/home/circleci/.local/lib/python3.10/site-packages/sklearn/utils/deprecation.py:132: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.
warnings.warn(
0.9035369774919614
Total running time of the script: (1 minutes 21.695 seconds)