.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/deep/plot_divergence.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_deep_plot_divergence.py: Divergence domain adaptation methods. ========================================== This example illustrates the DeepCoral method from [1] on a simple image classification task. .. [1] Baochen Sun and Kate Saenko. Deep coral: Correlation alignment for deep domain adaptation. In ECCV Workshops, 2016. .. GENERATED FROM PYTHON SOURCE LINES 13-18 .. code-block:: Python # Author: Théo Gnassounou # # License: BSD 3-Clause # sphinx_gallery_thumbnail_number = 4 .. GENERATED FROM PYTHON SOURCE LINES 19-26 .. code-block:: Python from skorch import NeuralNetClassifier from torch import nn from skada.datasets import load_mnist_usps from skada.deep import DeepCoral from skada.deep.modules import MNISTtoUSPSNet .. GENERATED FROM PYTHON SOURCE LINES 27-29 Load the image datasets ---------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 29-34 .. code-block:: Python dataset = load_mnist_usps(n_classes=2, n_samples=0.5, return_dataset=True) X, y, sample_domain = dataset.pack_train(as_sources=["mnist"], as_targets=["usps"]) X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["usps"]) .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0.00/9.91M [00:00 0], y[sample_domain > 0]) model.score(X_test, y=y_test) .. rst-class:: sphx-glr-script-out .. code-block:: none epoch train_loss dur ------- ------------ ------ 1 1.6348 3.6962 2 0.3176 3.4033 3 0.1082 3.2970 4 0.0643 3.3030 5 0.0456 3.4945 0.8842443729903537 .. GENERATED FROM PYTHON SOURCE LINES 49-51 Train a DeepCoral model ---------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 51-62 .. code-block:: Python model = DeepCoral( MNISTtoUSPSNet(), layer_name="fc1", batch_size=128, max_epochs=5, train_split=False, reg=1, lr=1e-2, ) model.fit(X, y, sample_domain=sample_domain) model.score(X_test, y_test, sample_domain=sample_domain_test) .. rst-class:: sphx-glr-script-out .. code-block:: none epoch train_loss dur ------- ------------ ------ 1 1.4456 7.0768 2 0.2629 6.5027 3 0.1114 6.7957 4 0.0737 6.6992 5 0.0648 7.1999 0.9517684887459807 .. rst-class:: sphx-glr-timing **Total running time of the script:** (1 minutes 0.276 seconds) .. _sphx_glr_download_auto_examples_deep_plot_divergence.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_divergence.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_divergence.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_divergence.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_