Training setup for deep DA method.

This example illustrates the use of deep DA methods in Skada. on a simple image classification task.

# Author: Théo Gnassounou
#
# License: BSD 3-Clause
# sphinx_gallery_thumbnail_number = 4
import torch
from skorch.dataset import Dataset

from skada.datasets import load_mnist_usps
from skada.deep import DeepCoral, DeepCoralLoss
from skada.deep.base import (
    DomainAwareCriterion,
    DomainAwareModule,
    DomainBalancedDataLoader,
)
from skada.deep.modules import MNISTtoUSPSNet

Load the image datasets

dataset = load_mnist_usps(n_classes=2, n_samples=0.5, return_dataset=True)
X, y, sample_domain = dataset.pack_train(as_sources=["mnist"], as_targets=["usps"])
X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["usps"])
/home/circleci/project/skada/datasets/_mnist_usps.py:72: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  mnist_target = torch.tensor(mnist_dataset.targets)

Training parameters

max_epochs = 2
batch_size = 256
lr = 1e-3
reg = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Training with skorch

model = DeepCoral(
    MNISTtoUSPSNet(),
    layer_name="fc1",
    batch_size=batch_size,
    max_epochs=max_epochs,
    train_split=False,
    reg=reg,
    lr=lr,
    device=device,
)
model.fit(X, y, sample_domain=sample_domain)
  epoch    train_loss      dur
-------  ------------  -------
      1        2.2764  10.0830
      2        2.1936  9.4029

<class 'skada.deep.base.DomainAwareNet'>[initialized](
  module_=DomainAwareModule(
    (base_module_): MNISTtoUSPSNet(
      (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
      (relu1): ReLU()
      (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
      (relu2): ReLU()
      (dropout1): Dropout(p=0.25, inplace=False)
      (dropout2): Dropout(p=0.5, inplace=False)
      (fc1): Linear(in_features=9216, out_features=128, bias=True)
      (relu3): ReLU()
      (fc2): Linear(in_features=128, out_features=10, bias=True)
      (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
  ),
)

Training with skorch with dataset

X_dict = {"X": torch.tensor(X), "sample_domain": torch.tensor(sample_domain)}

# TODO create a dataset also without skorch
dataset = Dataset(X_dict, torch.tensor(y))

model = DeepCoral(
    MNISTtoUSPSNet(),
    layer_name="fc1",
    batch_size=batch_size,
    max_epochs=max_epochs,
    train_split=False,
    reg=reg,
    lr=lr,
    device=device,
)
model.fit(dataset, y=None, sample_domain=None)
  epoch    train_loss     dur
-------  ------------  ------
      1        2.2715  8.8677
      2        2.1959  9.2985

<class 'skada.deep.base.DomainAwareNet'>[initialized](
  module_=DomainAwareModule(
    (base_module_): MNISTtoUSPSNet(
      (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
      (relu1): ReLU()
      (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
      (relu2): ReLU()
      (dropout1): Dropout(p=0.25, inplace=False)
      (dropout2): Dropout(p=0.5, inplace=False)
      (fc1): Linear(in_features=9216, out_features=128, bias=True)
      (relu3): ReLU()
      (fc2): Linear(in_features=128, out_features=10, bias=True)
      (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
  ),
)

Training with torch

model = DomainAwareModule(MNISTtoUSPSNet(), layer_name="fc1").to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
dataloader = DomainBalancedDataLoader(dataset, batch_size=batch_size)
loss_fn = DomainAwareCriterion(torch.nn.CrossEntropyLoss(), DeepCoralLoss(), reg=reg)

# Training loop
for epoch in range(max_epochs):
    model.train()
    running_loss = 0.0
    iter = 0
    for inputs, labels in dataloader:
        inputs, labels = inputs, labels.to(device)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(**inputs, is_fit=True)
        loss = loss_fn(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        iter += 1
    print("Loss:", running_loss / iter)
Loss: 0.8874167464673519
Loss: 0.18708933144807816

Total running time of the script: (1 minutes 1.895 seconds)

Gallery generated by Sphinx-Gallery