.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/deep/plot_training_method.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_deep_plot_training_method.py: Training setup for deep DA method. ========================================== This example illustrates the use of deep DA methods in Skada. on a simple image classification task. .. GENERATED FROM PYTHON SOURCE LINES 8-13 .. code-block:: Python # Author: Théo Gnassounou # # License: BSD 3-Clause # sphinx_gallery_thumbnail_number = 4 .. GENERATED FROM PYTHON SOURCE LINES 14-26 .. code-block:: Python import torch from skorch.dataset import Dataset from skada.datasets import load_mnist_usps from skada.deep import DeepCoral, DeepCoralLoss from skada.deep.base import ( DomainAwareCriterion, DomainAwareModule, DomainBalancedDataLoader, ) from skada.deep.modules import MNISTtoUSPSNet .. GENERATED FROM PYTHON SOURCE LINES 27-29 Load the image datasets ---------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 29-38 .. code-block:: Python dataset = load_mnist_usps(n_classes=2, n_samples=0.5, return_dataset=True) X, y, sample_domain = dataset.pack( as_sources=["mnist"], as_targets=["usps"], mask_target_labels=True ) X_test, y_test, sample_domain_test = dataset.pack( as_sources=[], as_targets=["usps"], mask_target_labels=False ) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/circleci/project/skada/datasets/_mnist_usps.py:72: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor). mnist_target = torch.tensor(mnist_dataset.targets) .. GENERATED FROM PYTHON SOURCE LINES 39-41 Training parameters ---------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 41-49 .. code-block:: Python max_epochs = 2 batch_size = 256 lr = 1e-3 reg = 1 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") .. GENERATED FROM PYTHON SOURCE LINES 50-52 Training with skorch ---------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 52-65 .. code-block:: Python model = DeepCoral( MNISTtoUSPSNet(), layer_name="fc1", batch_size=batch_size, max_epochs=max_epochs, train_split=False, reg=reg, lr=lr, device=device, ) model.fit(X, y, sample_domain=sample_domain) .. rst-class:: sphx-glr-script-out .. code-block:: none epoch train_loss dur ------- ------------ ------- 1 2.2379 10.5502 2 2.0952 10.6040 .. raw:: html
<class 'skada.deep.base.DomainAwareNet'>[initialized](
      module_=DomainAwareModule(
        (base_module_): MNISTtoUSPSNet(
          (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
          (relu1): ReLU()
          (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
          (relu2): ReLU()
          (dropout1): Dropout(p=0.25, inplace=False)
          (dropout2): Dropout(p=0.5, inplace=False)
          (fc1): Linear(in_features=9216, out_features=128, bias=True)
          (relu3): ReLU()
          (fc2): Linear(in_features=128, out_features=10, bias=True)
          (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        )
      ),
    )
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 66-68 Training with skorch with dataset ---------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 68-85 .. code-block:: Python X_dict = {"X": torch.tensor(X), "sample_domain": torch.tensor(sample_domain)} # TODO create a dataset also without skorch dataset = Dataset(X_dict, torch.tensor(y)) model = DeepCoral( MNISTtoUSPSNet(), layer_name="fc1", batch_size=batch_size, max_epochs=max_epochs, train_split=False, reg=reg, lr=lr, device=device, ) model.fit(dataset, y=None, sample_domain=None) .. rst-class:: sphx-glr-script-out .. code-block:: none epoch train_loss dur ------- ------------ ------- 1 2.2868 10.3190 2 2.2233 10.2003 .. raw:: html
<class 'skada.deep.base.DomainAwareNet'>[initialized](
      module_=DomainAwareModule(
        (base_module_): MNISTtoUSPSNet(
          (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
          (relu1): ReLU()
          (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
          (relu2): ReLU()
          (dropout1): Dropout(p=0.25, inplace=False)
          (dropout2): Dropout(p=0.5, inplace=False)
          (fc1): Linear(in_features=9216, out_features=128, bias=True)
          (relu3): ReLU()
          (fc2): Linear(in_features=128, out_features=10, bias=True)
          (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        )
      ),
    )
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 86-88 Training with torch ---------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 88-116 .. code-block:: Python model = DomainAwareModule(MNISTtoUSPSNet(), layer_name="fc1").to(device) optimizer = torch.optim.Adam(model.parameters(), lr=lr) dataloader = DomainBalancedDataLoader(dataset, batch_size=batch_size) loss_fn = DomainAwareCriterion(torch.nn.CrossEntropyLoss(), DeepCoralLoss(), reg=reg) # Training loop for epoch in range(max_epochs): model.train() running_loss = 0.0 iter = 0 for inputs, labels in dataloader: inputs, labels = inputs, labels.to(device) # Zero the gradients optimizer.zero_grad() # Forward pass outputs = model(**inputs, is_fit=True) loss = loss_fn(outputs, labels) # Backward pass and optimization loss.backward() optimizer.step() running_loss += loss.item() iter += 1 print("Loss:", running_loss / iter) .. rst-class:: sphx-glr-script-out .. code-block:: none Loss: 0.8678950443863869 Loss: 0.20251541212201118 .. rst-class:: sphx-glr-timing **Total running time of the script:** (1 minutes 4.089 seconds) .. _sphx_glr_download_auto_examples_deep_plot_training_method.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_training_method.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_training_method.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_training_method.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_