.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/deep/plot_optimal_transport.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_deep_plot_optimal_transport.py: Optimal transport domain adaptation methods. ========================================== This example illustrates the Optimal Transport deep DA method from on a simple image classification task. .. GENERATED FROM PYTHON SOURCE LINES 8-13 .. code-block:: Python # Author: Théo Gnassounou # # License: BSD 3-Clause # sphinx_gallery_thumbnail_number = 4 .. GENERATED FROM PYTHON SOURCE LINES 14-21 .. code-block:: Python from skorch import NeuralNetClassifier from torch import nn from skada.datasets import load_mnist_usps from skada.deep import DeepJDOT from skada.deep.modules import MNISTtoUSPSNet .. GENERATED FROM PYTHON SOURCE LINES 22-24 Load the image datasets ---------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 24-29 .. code-block:: Python dataset = load_mnist_usps(n_classes=2, n_samples=0.5, return_dataset=True) X, y, sample_domain = dataset.pack_train(as_sources=["mnist"], as_targets=["usps"]) X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["usps"]) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/circleci/project/skada/datasets/_mnist_usps.py:72: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). mnist_target = torch.tensor(mnist_dataset.targets) .. GENERATED FROM PYTHON SOURCE LINES 30-32 Train a classic model ---------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 32-43 .. code-block:: Python model = NeuralNetClassifier( MNISTtoUSPSNet(), criterion=nn.CrossEntropyLoss(), batch_size=128, max_epochs=5, train_split=False, lr=1e-2, ) model.fit(X[sample_domain > 0], y[sample_domain > 0]) model.score(X_test, y=y_test) .. rst-class:: sphx-glr-script-out .. code-block:: none epoch train_loss dur ------- ------------ ------ 1 1.6460 6.1998 2 0.4327 6.4993 3 0.1498 6.6010 4 0.0746 12.6988 5 0.0511 13.5012 0.8938906752411575 .. GENERATED FROM PYTHON SOURCE LINES 44-46 Train a DeepJDOT model ---------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 46-58 .. code-block:: Python model = DeepJDOT( MNISTtoUSPSNet(), layer_name="fc1", batch_size=128, max_epochs=5, train_split=False, reg_dist=0.1, reg_cl=0.01, lr=1e-2, ) model.fit(X, y, sample_domain=sample_domain) model.score(X_test, y_test, sample_domain=sample_domain_test) .. rst-class:: sphx-glr-script-out .. code-block:: none epoch train_loss dur ------- ------------ ------- 1 2.1794 38.5859 2 1.3292 13.6001 3 0.8222 10.7061 4 0.6452 10.4921 5 0.5480 10.5015 0.9389067524115756 .. rst-class:: sphx-glr-timing **Total running time of the script:** (2 minutes 14.978 seconds) .. _sphx_glr_download_auto_examples_deep_plot_optimal_transport.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_optimal_transport.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_optimal_transport.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_optimal_transport.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_