.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/deep/plot_adversarial.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_auto_examples_deep_plot_adversarial.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_deep_plot_adversarial.py:


Adversarial domain adaptation methods.
==========================================

This example illustrates the adversarial methods from
on a simple image classification task.

.. GENERATED FROM PYTHON SOURCE LINES 8-13

.. code-block:: Python

    # Author: Théo Gnassounou
    #
    # License: BSD 3-Clause
    # sphinx_gallery_thumbnail_number = 4








.. GENERATED FROM PYTHON SOURCE LINES 14-21

.. code-block:: Python

    from skorch import NeuralNetClassifier
    from torch import nn

    from skada.datasets import load_mnist_usps
    from skada.deep import DANN
    from skada.deep.modules import MNISTtoUSPSNet








.. GENERATED FROM PYTHON SOURCE LINES 22-24

Load the image datasets
----------------------------------------------------------------------------

.. GENERATED FROM PYTHON SOURCE LINES 24-29

.. code-block:: Python


    dataset = load_mnist_usps(n_classes=2, n_samples=0.5, return_dataset=True)
    X, y, sample_domain = dataset.pack_train(as_sources=["mnist"], as_targets=["usps"])
    X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["usps"])





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    /home/circleci/project/skada/datasets/_mnist_usps.py:72: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
      mnist_target = torch.tensor(mnist_dataset.targets)




.. GENERATED FROM PYTHON SOURCE LINES 30-32

Train a classic model
----------------------------------------------------------------------------

.. GENERATED FROM PYTHON SOURCE LINES 32-43

.. code-block:: Python

    model = NeuralNetClassifier(
        MNISTtoUSPSNet(),
        criterion=nn.CrossEntropyLoss(),
        batch_size=128,
        max_epochs=5,
        train_split=False,
        lr=1e-2,
    )
    model.fit(X[sample_domain > 0], y[sample_domain > 0])
    model.score(X_test, y=y_test)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

      epoch    train_loss     dur
    -------  ------------  ------
          1        1.5304  2.7989
          2        0.2732  2.7034
          3        0.0964  2.5977
          4        0.0555  3.0986
          5        0.0359  3.1007

    0.9260450160771704



.. GENERATED FROM PYTHON SOURCE LINES 44-46

Train a DANN model
----------------------------------------------------------------------------

.. GENERATED FROM PYTHON SOURCE LINES 46-58

.. code-block:: Python

    model = DANN(
        MNISTtoUSPSNet(),
        layer_name="fc1",
        batch_size=128,
        max_epochs=5,
        train_split=False,
        reg=0.01,
        num_features=128,
        lr=1e-2,
    )
    model.fit(X, y, sample_domain=sample_domain)
    model.score(X_test, y_test, sample_domain=sample_domain_test)




.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    /home/circleci/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1739: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
      return self._call_impl(*args, **kwargs)
      epoch    train_loss     dur
    -------  ------------  ------
          1        2.4363  6.6828
          2        1.2963  7.3972
          3        1.1106  7.1004
          4        1.0687  8.1048
          5        1.0472  6.6900

    0.8681672025723473




.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 54.228 seconds)


.. _sphx_glr_download_auto_examples_deep_plot_adversarial.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_adversarial.ipynb <plot_adversarial.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_adversarial.py <plot_adversarial.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: plot_adversarial.zip <plot_adversarial.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_