Note
Go to the end to download the full example code.
Training setup for deep DA method.
This example illustrates the use of deep DA methods in Skada. on a simple image classification task.
# Author: Théo Gnassounou
#
# License: BSD 3-Clause
# sphinx_gallery_thumbnail_number = 4
import torch
from skorch.dataset import Dataset
from skada.datasets import load_mnist_usps
from skada.deep import DeepCoral, DeepCoralLoss
from skada.deep.base import (
DomainAwareCriterion,
DomainAwareModule,
DomainBalancedDataLoader,
)
from skada.deep.modules import MNISTtoUSPSNet
Load the image datasets
dataset = load_mnist_usps(n_classes=2, n_samples=0.5, return_dataset=True)
X, y, sample_domain = dataset.pack_train(as_sources=["mnist"], as_targets=["usps"])
X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["usps"])
/home/circleci/project/skada/datasets/_mnist_usps.py:72: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
mnist_target = torch.tensor(mnist_dataset.targets)
Training parameters
max_epochs = 2
batch_size = 256
lr = 1e-3
reg = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Training with skorch
model = DeepCoral(
MNISTtoUSPSNet(),
layer_name="fc1",
batch_size=batch_size,
max_epochs=max_epochs,
train_split=False,
reg=reg,
lr=lr,
device=device,
)
model.fit(X, y, sample_domain=sample_domain)
epoch train_loss dur
------- ------------ ------
1 2.2427 9.2629
2 2.1317 13.1939
<class 'skada.deep.base.DomainAwareNet'>[initialized](
module_=DomainAwareModule(
(base_module_): MNISTtoUSPSNet(
(conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
(relu1): ReLU()
(conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
(relu2): ReLU()
(dropout1): Dropout(p=0.25, inplace=False)
(dropout2): Dropout(p=0.5, inplace=False)
(fc1): Linear(in_features=9216, out_features=128, bias=True)
(relu3): ReLU()
(fc2): Linear(in_features=128, out_features=10, bias=True)
(maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
),
)
Training with skorch with dataset
X_dict = {"X": torch.tensor(X), "sample_domain": torch.tensor(sample_domain)}
# TODO create a dataset also without skorch
dataset = Dataset(X_dict, torch.tensor(y))
model = DeepCoral(
MNISTtoUSPSNet(),
layer_name="fc1",
batch_size=batch_size,
max_epochs=max_epochs,
train_split=False,
reg=reg,
lr=lr,
device=device,
)
model.fit(dataset, y=None, sample_domain=None)
epoch train_loss dur
------- ------------ ------
1 2.3003 7.3555
2 2.2037 8.7028
<class 'skada.deep.base.DomainAwareNet'>[initialized](
module_=DomainAwareModule(
(base_module_): MNISTtoUSPSNet(
(conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
(relu1): ReLU()
(conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
(relu2): ReLU()
(dropout1): Dropout(p=0.25, inplace=False)
(dropout2): Dropout(p=0.5, inplace=False)
(fc1): Linear(in_features=9216, out_features=128, bias=True)
(relu3): ReLU()
(fc2): Linear(in_features=128, out_features=10, bias=True)
(maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
),
)
Training with torch
model = DomainAwareModule(MNISTtoUSPSNet(), layer_name="fc1").to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
dataloader = DomainBalancedDataLoader(dataset, batch_size=batch_size)
loss_fn = DomainAwareCriterion(torch.nn.CrossEntropyLoss(), DeepCoralLoss(), reg=reg)
# Training loop
for epoch in range(max_epochs):
model.train()
running_loss = 0.0
iter = 0
for inputs, labels in dataloader:
inputs, labels = inputs, labels.to(device)
# Zero the gradients
optimizer.zero_grad()
# Forward pass
outputs = model(**inputs, is_fit=True)
loss = loss_fn(outputs, labels)
# Backward pass and optimization
loss.backward()
optimizer.step()
running_loss += loss.item()
iter += 1
print("Loss:", running_loss / iter)
Loss: 1.262551188468933
Loss: 0.20536349713802338
Total running time of the script: (1 minutes 3.350 seconds)