Training setup for deep DA method.

This example illustrates the use of deep DA methods in Skada. on a simple image classification task.

# Author: Théo Gnassounou
#
# License: BSD 3-Clause
# sphinx_gallery_thumbnail_number = 4
import torch
from skorch.dataset import Dataset

from skada.datasets import load_mnist_usps
from skada.deep import DeepCoral, DeepCoralLoss
from skada.deep.base import (
    DomainAwareCriterion,
    DomainAwareModule,
    DomainBalancedDataLoader,
)
from skada.deep.modules import MNISTtoUSPSNet

Load the image datasets

dataset = load_mnist_usps(n_classes=2, n_samples=0.5, return_dataset=True)
X, y, sample_domain = dataset.pack(
    as_sources=["mnist"], as_targets=["usps"], mask_target_labels=True
)
X_test, y_test, sample_domain_test = dataset.pack(
    as_sources=[], as_targets=["usps"], mask_target_labels=False
)
/home/circleci/project/skada/datasets/_mnist_usps.py:72: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  mnist_target = torch.tensor(mnist_dataset.targets)

Training parameters

max_epochs = 2
batch_size = 256
lr = 1e-3
reg = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Training with skorch

model = DeepCoral(
    MNISTtoUSPSNet(),
    layer_name="fc1",
    batch_size=batch_size,
    max_epochs=max_epochs,
    train_split=False,
    reg=reg,
    lr=lr,
    device=device,
)
model.fit(X, y, sample_domain=sample_domain)
  epoch    train_loss      dur
-------  ------------  -------
      1        2.2747  10.9690
      2        2.2173  12.5042
<class 'skada.deep.base.DomainAwareNet'>[initialized](
  module_=DomainAwareModule(
    (base_module_): MNISTtoUSPSNet(
      (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
      (relu1): ReLU()
      (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
      (relu2): ReLU()
      (dropout1): Dropout(p=0.25, inplace=False)
      (dropout2): Dropout(p=0.5, inplace=False)
      (fc1): Linear(in_features=9216, out_features=128, bias=True)
      (relu3): ReLU()
      (fc2): Linear(in_features=128, out_features=10, bias=True)
      (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
  ),
)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


Training with skorch with dataset

X_dict = {"X": torch.tensor(X), "sample_domain": torch.tensor(sample_domain)}

# TODO create a dataset also without skorch
dataset = Dataset(X_dict, torch.tensor(y))

model = DeepCoral(
    MNISTtoUSPSNet(),
    layer_name="fc1",
    batch_size=batch_size,
    max_epochs=max_epochs,
    train_split=False,
    reg=reg,
    lr=lr,
    device=device,
)
model.fit(dataset, y=None, sample_domain=None)
  epoch    train_loss      dur
-------  ------------  -------
      1        2.2298  13.2825
      2        2.1103  13.1990
<class 'skada.deep.base.DomainAwareNet'>[initialized](
  module_=DomainAwareModule(
    (base_module_): MNISTtoUSPSNet(
      (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
      (relu1): ReLU()
      (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
      (relu2): ReLU()
      (dropout1): Dropout(p=0.25, inplace=False)
      (dropout2): Dropout(p=0.5, inplace=False)
      (fc1): Linear(in_features=9216, out_features=128, bias=True)
      (relu3): ReLU()
      (fc2): Linear(in_features=128, out_features=10, bias=True)
      (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
  ),
)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


Training with torch

model = DomainAwareModule(MNISTtoUSPSNet(), layer_name="fc1").to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
dataloader = DomainBalancedDataLoader(dataset, batch_size=batch_size)
loss_fn = DomainAwareCriterion(torch.nn.CrossEntropyLoss(), DeepCoralLoss(), reg=reg)

# Training loop
for epoch in range(max_epochs):
    model.train()
    running_loss = 0.0
    iter = 0
    for inputs, labels in dataloader:
        inputs, labels = inputs, labels.to(device)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(**inputs, is_fit=True)
        loss = loss_fn(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        iter += 1
    print("Loss:", running_loss / iter)
Loss: 0.9618009179830551
Loss: 0.21983920969069004

Total running time of the script: (1 minutes 17.445 seconds)

Gallery generated by Sphinx-Gallery