Training setup for deep DA method.

This example illustrates the use of deep DA methods in Skada. on a simple image classification task.

# Author: Théo Gnassounou
#
# License: BSD 3-Clause
# sphinx_gallery_thumbnail_number = 4
import torch
from skorch.dataset import Dataset

from skada.datasets import load_mnist_usps
from skada.deep import DeepCoral, DeepCoralLoss
from skada.deep.base import (
    DomainAwareCriterion,
    DomainAwareModule,
    DomainBalancedDataLoader,
)
from skada.deep.modules import MNISTtoUSPSNet

Load the image datasets

dataset = load_mnist_usps(n_classes=2, n_samples=0.5, return_dataset=True)
X, y, sample_domain = dataset.pack_train(as_sources=["mnist"], as_targets=["usps"])
X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["usps"])
/home/circleci/project/skada/datasets/_mnist_usps.py:72: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  mnist_target = torch.tensor(mnist_dataset.targets)

Training parameters

max_epochs = 2
batch_size = 256
lr = 1e-3
reg = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Training with skorch

model = DeepCoral(
    MNISTtoUSPSNet(),
    layer_name="fc1",
    batch_size=batch_size,
    max_epochs=max_epochs,
    train_split=False,
    reg=reg,
    lr=lr,
    device=device,
)
model.fit(X, y, sample_domain=sample_domain)
  epoch    train_loss     dur
-------  ------------  ------
      1        2.3413  8.1887
      2        2.2873  8.0990
<class 'skada.deep.base.DomainAwareNet'>[initialized](
  module_=DomainAwareModule(
    (base_module_): MNISTtoUSPSNet(
      (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
      (relu1): ReLU()
      (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
      (relu2): ReLU()
      (dropout1): Dropout(p=0.25, inplace=False)
      (dropout2): Dropout(p=0.5, inplace=False)
      (fc1): Linear(in_features=9216, out_features=128, bias=True)
      (relu3): ReLU()
      (fc2): Linear(in_features=128, out_features=10, bias=True)
      (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
  ),
)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


Training with skorch with dataset

X_dict = {"X": torch.tensor(X), "sample_domain": torch.tensor(sample_domain)}

# TODO create a dataset also without skorch
dataset = Dataset(X_dict, torch.tensor(y))

model = DeepCoral(
    MNISTtoUSPSNet(),
    layer_name="fc1",
    batch_size=batch_size,
    max_epochs=max_epochs,
    train_split=False,
    reg=reg,
    lr=lr,
    device=device,
)
model.fit(dataset, y=None, sample_domain=None)
  epoch    train_loss      dur
-------  ------------  -------
      1        2.2562  12.7622
      2        2.1860  9.9964
<class 'skada.deep.base.DomainAwareNet'>[initialized](
  module_=DomainAwareModule(
    (base_module_): MNISTtoUSPSNet(
      (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
      (relu1): ReLU()
      (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
      (relu2): ReLU()
      (dropout1): Dropout(p=0.25, inplace=False)
      (dropout2): Dropout(p=0.5, inplace=False)
      (fc1): Linear(in_features=9216, out_features=128, bias=True)
      (relu3): ReLU()
      (fc2): Linear(in_features=128, out_features=10, bias=True)
      (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
  ),
)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


Training with torch

model = DomainAwareModule(MNISTtoUSPSNet(), layer_name="fc1").to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
dataloader = DomainBalancedDataLoader(dataset, batch_size=batch_size)
loss_fn = DomainAwareCriterion(torch.nn.CrossEntropyLoss(), DeepCoralLoss(), reg=reg)

# Training loop
for epoch in range(max_epochs):
    model.train()
    running_loss = 0.0
    iter = 0
    for inputs, labels in dataloader:
        inputs, labels = inputs, labels.to(device)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(**inputs, is_fit=True)
        loss = loss_fn(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        iter += 1
    print("Loss:", running_loss / iter)
Loss: 0.8751300573348999
Loss: 0.18600143492221832

Total running time of the script: (0 minutes 58.587 seconds)

Gallery generated by Sphinx-Gallery