.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/deep/plot_optimal_transport.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_deep_plot_optimal_transport.py: Optimal transport domain adaptation methods. ========================================== This example illustrates the Optimal Transport deep DA method from on a simple image classification task. .. GENERATED FROM PYTHON SOURCE LINES 8-13 .. code-block:: Python # Author: Théo Gnassounou # # License: BSD 3-Clause # sphinx_gallery_thumbnail_number = 4 .. GENERATED FROM PYTHON SOURCE LINES 14-21 .. code-block:: Python from skorch import NeuralNetClassifier from torch import nn from skada.datasets import load_mnist_usps from skada.deep import DeepJDOT from skada.deep.modules import MNISTtoUSPSNet .. GENERATED FROM PYTHON SOURCE LINES 22-24 Load the image datasets ---------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 24-29 .. code-block:: Python dataset = load_mnist_usps(n_classes=2, n_samples=0.5, return_dataset=True) X, y, sample_domain = dataset.pack_train(as_sources=["mnist"], as_targets=["usps"]) X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["usps"]) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/circleci/project/skada/datasets/_mnist_usps.py:72: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). mnist_target = torch.tensor(mnist_dataset.targets) .. GENERATED FROM PYTHON SOURCE LINES 30-32 Train a classic model ---------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 32-43 .. code-block:: Python model = NeuralNetClassifier( MNISTtoUSPSNet(), criterion=nn.CrossEntropyLoss(), batch_size=128, max_epochs=5, train_split=False, lr=1e-2, ) model.fit(X[sample_domain > 0], y[sample_domain > 0]) model.score(X_test, y=y_test) .. rst-class:: sphx-glr-script-out .. code-block:: none epoch train_loss dur ------- ------------ ------ 1 1.4850 2.5033 2 0.2735 2.6029 3 0.0966 3.5964 4 0.0510 2.6985 5 0.0392 2.6949 0.8778135048231511 .. GENERATED FROM PYTHON SOURCE LINES 44-46 Train a DeepJDOT model ---------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 46-58 .. code-block:: Python model = DeepJDOT( MNISTtoUSPSNet(), layer_name="fc1", batch_size=128, max_epochs=5, train_split=False, reg_dist=0.1, reg_cl=0.01, lr=1e-2, ) model.fit(X, y, sample_domain=sample_domain) model.score(X_test, y_test, sample_domain=sample_domain_test) .. rst-class:: sphx-glr-script-out .. code-block:: none epoch train_loss dur ------- ------------ ------ 1 2.0085 6.9701 2 1.0176 6.6018 3 0.6904 8.9007 4 0.5663 7.2969 5 0.4764 8.3985 0.9292604501607717 .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 56.363 seconds) .. _sphx_glr_download_auto_examples_deep_plot_optimal_transport.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_optimal_transport.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_optimal_transport.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_optimal_transport.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_