.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/deep/plot_divergence.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_deep_plot_divergence.py: Divergence domain adaptation methods. ========================================== This example illustrates the DeepCoral method from [1] on a simple image classification task. .. [1] Baochen Sun and Kate Saenko. Deep coral: Correlation alignment for deep domain adaptation. In ECCV Workshops, 2016. .. GENERATED FROM PYTHON SOURCE LINES 13-18 .. code-block:: Python # Author: Théo Gnassounou # # License: BSD 3-Clause # sphinx_gallery_thumbnail_number = 4 .. GENERATED FROM PYTHON SOURCE LINES 19-26 .. code-block:: Python from skorch import NeuralNetClassifier from torch import nn from skada.datasets import load_mnist_usps from skada.deep import DeepCoral from skada.deep.modules import MNISTtoUSPSNet .. GENERATED FROM PYTHON SOURCE LINES 27-29 Load the image datasets ---------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 29-34 .. code-block:: Python dataset = load_mnist_usps(n_classes=2, n_samples=0.5, return_dataset=True) X, y, sample_domain = dataset.pack_train(as_sources=["mnist"], as_targets=["usps"]) X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["usps"]) .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Failed to download (trying next): HTTP Error 403: Forbidden Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./datasets/MNIST/raw/train-images-idx3-ubyte.gz 0%| | 0.00/9.91M [00:00 0], y[sample_domain > 0]) model.score(X_test, y=y_test) .. rst-class:: sphx-glr-script-out .. code-block:: none epoch train_loss dur ------- ------------ ------ 1 1.4900 9.5993 2 0.3168 8.1940 3 0.1119 7.3040 4 0.0610 6.7043 5 0.0458 8.6105 0.8906752411575563 .. GENERATED FROM PYTHON SOURCE LINES 49-51 Train a DeepCoral model ---------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 51-62 .. code-block:: Python model = DeepCoral( MNISTtoUSPSNet(), layer_name="fc1", batch_size=128, max_epochs=5, train_split=False, reg=1, lr=1e-2, ) model.fit(X, y, sample_domain=sample_domain) model.score(X_test, y_test, sample_domain=sample_domain_test) .. rst-class:: sphx-glr-script-out .. code-block:: none epoch train_loss dur ------- ------------ ------- 1 1.6827 30.1814 2 0.4499 17.5877 3 0.1586 14.8993 4 0.0891 13.2113 5 0.0644 12.6976 0.8938906752411575 .. rst-class:: sphx-glr-timing **Total running time of the script:** (2 minutes 22.076 seconds) .. _sphx_glr_download_auto_examples_deep_plot_divergence.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_divergence.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_divergence.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_divergence.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_