.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/deep/plot_training_method.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_deep_plot_training_method.py: Training setup for deep DA method. ========================================== This example illustrates the use of deep DA methods in Skada. on a simple image classification task. .. GENERATED FROM PYTHON SOURCE LINES 8-13 .. code-block:: Python # Author: Théo Gnassounou # # License: BSD 3-Clause # sphinx_gallery_thumbnail_number = 4 .. GENERATED FROM PYTHON SOURCE LINES 14-26 .. code-block:: Python import torch from skorch.dataset import Dataset from skada.datasets import load_mnist_usps from skada.deep import DeepCoral, DeepCoralLoss from skada.deep.base import ( DomainAwareCriterion, DomainAwareModule, DomainBalancedDataLoader, ) from skada.deep.modules import MNISTtoUSPSNet .. GENERATED FROM PYTHON SOURCE LINES 27-29 Load the image datasets ---------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 29-34 .. code-block:: Python dataset = load_mnist_usps(n_classes=2, n_samples=0.5, return_dataset=True) X, y, sample_domain = dataset.pack_train(as_sources=["mnist"], as_targets=["usps"]) X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["usps"]) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/circleci/project/skada/datasets/_mnist_usps.py:72: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). mnist_target = torch.tensor(mnist_dataset.targets) .. GENERATED FROM PYTHON SOURCE LINES 35-37 Training parameters ---------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 37-45 .. code-block:: Python max_epochs = 2 batch_size = 256 lr = 1e-3 reg = 1 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") .. GENERATED FROM PYTHON SOURCE LINES 46-48 Training with skorch ---------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 48-61 .. code-block:: Python model = DeepCoral( MNISTtoUSPSNet(), layer_name="fc1", batch_size=batch_size, max_epochs=max_epochs, train_split=False, reg=reg, lr=lr, device=device, ) model.fit(X, y, sample_domain=sample_domain) .. rst-class:: sphx-glr-script-out .. code-block:: none epoch train_loss dur ------- ------------ ------- 1 2.2764 10.0830 2 2.1936 9.4029 [initialized]( module_=DomainAwareModule( (base_module_): MNISTtoUSPSNet( (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1)) (relu1): ReLU() (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1)) (relu2): ReLU() (dropout1): Dropout(p=0.25, inplace=False) (dropout2): Dropout(p=0.5, inplace=False) (fc1): Linear(in_features=9216, out_features=128, bias=True) (relu3): ReLU() (fc2): Linear(in_features=128, out_features=10, bias=True) (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) ), ) .. GENERATED FROM PYTHON SOURCE LINES 62-64 Training with skorch with dataset ---------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 64-81 .. code-block:: Python X_dict = {"X": torch.tensor(X), "sample_domain": torch.tensor(sample_domain)} # TODO create a dataset also without skorch dataset = Dataset(X_dict, torch.tensor(y)) model = DeepCoral( MNISTtoUSPSNet(), layer_name="fc1", batch_size=batch_size, max_epochs=max_epochs, train_split=False, reg=reg, lr=lr, device=device, ) model.fit(dataset, y=None, sample_domain=None) .. rst-class:: sphx-glr-script-out .. code-block:: none epoch train_loss dur ------- ------------ ------ 1 2.2715 8.8677 2 2.1959 9.2985 [initialized]( module_=DomainAwareModule( (base_module_): MNISTtoUSPSNet( (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1)) (relu1): ReLU() (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1)) (relu2): ReLU() (dropout1): Dropout(p=0.25, inplace=False) (dropout2): Dropout(p=0.5, inplace=False) (fc1): Linear(in_features=9216, out_features=128, bias=True) (relu3): ReLU() (fc2): Linear(in_features=128, out_features=10, bias=True) (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) ), ) .. GENERATED FROM PYTHON SOURCE LINES 82-84 Training with torch ---------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 84-112 .. code-block:: Python model = DomainAwareModule(MNISTtoUSPSNet(), layer_name="fc1").to(device) optimizer = torch.optim.Adam(model.parameters(), lr=lr) dataloader = DomainBalancedDataLoader(dataset, batch_size=batch_size) loss_fn = DomainAwareCriterion(torch.nn.CrossEntropyLoss(), DeepCoralLoss(), reg=reg) # Training loop for epoch in range(max_epochs): model.train() running_loss = 0.0 iter = 0 for inputs, labels in dataloader: inputs, labels = inputs, labels.to(device) # Zero the gradients optimizer.zero_grad() # Forward pass outputs = model(**inputs, is_fit=True) loss = loss_fn(outputs, labels) # Backward pass and optimization loss.backward() optimizer.step() running_loss += loss.item() iter += 1 print("Loss:", running_loss / iter) .. rst-class:: sphx-glr-script-out .. code-block:: none Loss: 0.8874167464673519 Loss: 0.18708933144807816 .. rst-class:: sphx-glr-timing **Total running time of the script:** (1 minutes 1.895 seconds) .. _sphx_glr_download_auto_examples_deep_plot_training_method.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_training_method.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_training_method.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_training_method.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_