.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/deep/plot_training_method.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_deep_plot_training_method.py: Training setup for deep DA method. ========================================== This example illustrates the use of deep DA methods in Skada. on a simple image classification task. .. GENERATED FROM PYTHON SOURCE LINES 8-13 .. code-block:: Python # Author: Théo Gnassounou # # License: BSD 3-Clause # sphinx_gallery_thumbnail_number = 4 .. GENERATED FROM PYTHON SOURCE LINES 14-26 .. code-block:: Python import torch from skorch.dataset import Dataset from skada.datasets import load_mnist_usps from skada.deep import DeepCoral, DeepCoralLoss from skada.deep.base import ( DomainAwareCriterion, DomainAwareModule, DomainBalancedDataLoader, ) from skada.deep.modules import MNISTtoUSPSNet .. GENERATED FROM PYTHON SOURCE LINES 27-29 Load the image datasets ---------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 29-34 .. code-block:: Python dataset = load_mnist_usps(n_classes=2, n_samples=0.5, return_dataset=True) X, y, sample_domain = dataset.pack_train(as_sources=["mnist"], as_targets=["usps"]) X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["usps"]) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/circleci/project/skada/datasets/_mnist_usps.py:72: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). mnist_target = torch.tensor(mnist_dataset.targets) .. GENERATED FROM PYTHON SOURCE LINES 35-37 Training parameters ---------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 37-45 .. code-block:: Python max_epochs = 2 batch_size = 256 lr = 1e-3 reg = 1 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") .. GENERATED FROM PYTHON SOURCE LINES 46-48 Training with skorch ---------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 48-61 .. code-block:: Python model = DeepCoral( MNISTtoUSPSNet(), layer_name="fc1", batch_size=batch_size, max_epochs=max_epochs, train_split=False, reg=reg, lr=lr, device=device, ) model.fit(X, y, sample_domain=sample_domain) .. rst-class:: sphx-glr-script-out .. code-block:: none epoch train_loss dur ------- ------------ ------ 1 2.3413 8.1887 2 2.2873 8.0990 .. raw:: html
<class 'skada.deep.base.DomainAwareNet'>[initialized](
      module_=DomainAwareModule(
        (base_module_): MNISTtoUSPSNet(
          (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
          (relu1): ReLU()
          (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
          (relu2): ReLU()
          (dropout1): Dropout(p=0.25, inplace=False)
          (dropout2): Dropout(p=0.5, inplace=False)
          (fc1): Linear(in_features=9216, out_features=128, bias=True)
          (relu3): ReLU()
          (fc2): Linear(in_features=128, out_features=10, bias=True)
          (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        )
      ),
    )
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 62-64 Training with skorch with dataset ---------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 64-81 .. code-block:: Python X_dict = {"X": torch.tensor(X), "sample_domain": torch.tensor(sample_domain)} # TODO create a dataset also without skorch dataset = Dataset(X_dict, torch.tensor(y)) model = DeepCoral( MNISTtoUSPSNet(), layer_name="fc1", batch_size=batch_size, max_epochs=max_epochs, train_split=False, reg=reg, lr=lr, device=device, ) model.fit(dataset, y=None, sample_domain=None) .. rst-class:: sphx-glr-script-out .. code-block:: none epoch train_loss dur ------- ------------ ------- 1 2.2562 12.7622 2 2.1860 9.9964 .. raw:: html
<class 'skada.deep.base.DomainAwareNet'>[initialized](
      module_=DomainAwareModule(
        (base_module_): MNISTtoUSPSNet(
          (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
          (relu1): ReLU()
          (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
          (relu2): ReLU()
          (dropout1): Dropout(p=0.25, inplace=False)
          (dropout2): Dropout(p=0.5, inplace=False)
          (fc1): Linear(in_features=9216, out_features=128, bias=True)
          (relu3): ReLU()
          (fc2): Linear(in_features=128, out_features=10, bias=True)
          (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        )
      ),
    )
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 82-84 Training with torch ---------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 84-112 .. code-block:: Python model = DomainAwareModule(MNISTtoUSPSNet(), layer_name="fc1").to(device) optimizer = torch.optim.Adam(model.parameters(), lr=lr) dataloader = DomainBalancedDataLoader(dataset, batch_size=batch_size) loss_fn = DomainAwareCriterion(torch.nn.CrossEntropyLoss(), DeepCoralLoss(), reg=reg) # Training loop for epoch in range(max_epochs): model.train() running_loss = 0.0 iter = 0 for inputs, labels in dataloader: inputs, labels = inputs, labels.to(device) # Zero the gradients optimizer.zero_grad() # Forward pass outputs = model(**inputs, is_fit=True) loss = loss_fn(outputs, labels) # Backward pass and optimization loss.backward() optimizer.step() running_loss += loss.item() iter += 1 print("Loss:", running_loss / iter) .. rst-class:: sphx-glr-script-out .. code-block:: none Loss: 0.8751300573348999 Loss: 0.18600143492221832 .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 58.587 seconds) .. _sphx_glr_download_auto_examples_deep_plot_training_method.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_training_method.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_training_method.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_training_method.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_