.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/methods/plot_gradual_da.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_methods_plot_gradual_da.py: Gradual Domain Adaptation Using Optimal Transport ================================================= This example illustrates the GOAT method from [38] on a simple classification task. However, the CNN is replaced with a MLP. .. [38] Y. He, H. Wang, B. Li, H. Zhao Gradual Domain Adaptation: Theory and Algorithms in Journal of Machine Learning Research, 2024. .. GENERATED FROM PYTHON SOURCE LINES 14-19 .. code-block:: Python # Authors: Félix Lefebvre and Julie Alberge # # License: BSD 3-Clause .. GENERATED FROM PYTHON SOURCE LINES 20-28 .. code-block:: Python import matplotlib.pyplot as plt from sklearn.inspection import DecisionBoundaryDisplay from sklearn.neural_network import MLPClassifier from skada import source_target_split from skada._gradual_da import GradualEstimator from skada.datasets import make_shifted_datasets .. GENERATED FROM PYTHON SOURCE LINES 29-31 Generate conditional shift dataset ---------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 31-41 .. code-block:: Python n, m = 20, 25 # number of source and target samples X, y, sample_domain = make_shifted_datasets( n_samples_source=n, n_samples_target=m, shift="conditional_shift", noise=0.1, random_state=42, ) .. GENERATED FROM PYTHON SOURCE LINES 42-44 Plot source and target datasets ------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 44-65 .. code-block:: Python X_source, X_target, y_source, y_target = source_target_split( X, y, sample_domain=sample_domain ) lims = (min(X[:, 0]) - 0.5, max(X[:, 0]) + 0.5, min(X[:, 1]) - 0.5, max(X[:, 1]) + 0.5) n_tot_source = X_source.shape[0] n_tot_target = X_target.shape[0] plt.figure(1, figsize=(8, 3.5)) plt.subplot(121) plt.scatter(X_source[:, 0], X_source[:, 1], c=y_source, vmax=9, cmap="tab10", alpha=0.7) plt.title("Source domain") plt.axis(lims) plt.subplot(122) plt.scatter(X_target[:, 0], X_target[:, 1], c=y_target, vmax=9, cmap="tab10", alpha=0.7) plt.title("Target domain") plt.axis(lims) .. image-sg:: /auto_examples/methods/images/sphx_glr_plot_gradual_da_001.png :alt: Source domain, Target domain :srcset: /auto_examples/methods/images/sphx_glr_plot_gradual_da_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none (np.float64(-2.3407051687007465), np.float64(4.329230428885397), np.float64(-1.765584745112177), np.float64(4.406935559355198)) .. GENERATED FROM PYTHON SOURCE LINES 66-70 Fit Gradual Domain Adaptation ----------------------------- We use a MLP classifier as the base estimator (default parameters). .. GENERATED FROM PYTHON SOURCE LINES 70-87 .. code-block:: Python base_estimator = MLPClassifier(hidden_layer_sizes=(50, 50)) gradual_adapter = GradualEstimator( n_steps=40, # number of adaptation steps base_estimator=base_estimator, advanced_ot_plan_sampling=True, save_estimators=True, save_intermediate_data=True, ) gradual_adapter.fit( X, y, sample_domain=sample_domain, ) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/circleci/.local/lib/python3.10/site-packages/sklearn/neural_network/_multilayer_perceptron.py:781: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (200) reached and the optimization hasn't converged yet. warnings.warn( .. raw:: html
GradualEstimator(advanced_ot_plan_sampling=True,
                     base_estimator=MLPClassifier(hidden_layer_sizes=(50, 50),
                                                  max_iter=4200),
                     n_steps=40, save_estimators=True, save_intermediate_data=True)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 88-92 Check results ------------- Compute accuracy on source and target with the initial estimator and the final estimator. .. GENERATED FROM PYTHON SOURCE LINES 92-110 .. code-block:: Python clfs = gradual_adapter.get_intermediate_estimators() ACC_source_init = clfs[0].score(X_source, y_source) ACC_target_init = clfs[0].score(X_target, y_target) print(f"Initial accuracy on source domain: {ACC_source_init:.3f}") print(f"Initial accuracy on target domain: {ACC_target_init:.3f}") print("") ACC_source = gradual_adapter.score(X_source, y_source) ACC_target = gradual_adapter.score(X_target, y_target) print(f"Final accuracy on source domain: {ACC_source:.3f}") print(f"Final accuracy on target domain: {ACC_target:.3f}") .. rst-class:: sphx-glr-script-out .. code-block:: none Initial accuracy on source domain: 1.000 Initial accuracy on target domain: 0.500 Final accuracy on source domain: 0.869 Final accuracy on target domain: 0.995 .. GENERATED FROM PYTHON SOURCE LINES 111-115 Inspect intermediate states --------------------------- We can plot the intermediate datasets and decision boundaries. .. GENERATED FROM PYTHON SOURCE LINES 115-145 .. code-block:: Python intermediate_data = gradual_adapter.intermediate_data_ fig, axes = plt.subplots(2, 4, figsize=(12, 6)) axes = axes.ravel() # Define which steps to plot steps_to_plot = [5, 10, 15, 20, 25, 30, 35, 40] for i, step in enumerate(steps_to_plot): ax = axes[i] X_step, y_step = intermediate_data[step - 1] clf = clfs[step - 1] ax.scatter(X_step[:, 0], X_step[:, 1], c=y_step, vmax=9, cmap="tab10", alpha=0.7) DecisionBoundaryDisplay.from_estimator( clf, X, response_method="predict", cmap="gray_r", alpha=0.15, ax=ax, grid_resolution=200, ) ax.set_title(f"t = {step}") ax.axis(lims) plt.tight_layout() .. image-sg:: /auto_examples/methods/images/sphx_glr_plot_gradual_da_002.png :alt: t = 5, t = 10, t = 15, t = 20, t = 25, t = 30, t = 35, t = 40 :srcset: /auto_examples/methods/images/sphx_glr_plot_gradual_da_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 146-151 Plot decision boundaries on source and target datasets ------------------------------------------------------ Now we can see how this gradual domain adaptation has changed the decision boundary between the source and target domains. .. GENERATED FROM PYTHON SOURCE LINES 151-209 .. code-block:: Python figure, axis = plt.subplots(1, 2, figsize=(9, 4)) cm = "gray_r" DecisionBoundaryDisplay.from_estimator( clfs[0], X, response_method="predict", cmap=cm, alpha=0.15, ax=axis[0], grid_resolution=200, ) axis[0].scatter( X_source[:, 0], X_source[:, 1], c=y_source, vmax=9, cmap="tab10", alpha=0.7, ) axis[0].set_title("Source domain") DecisionBoundaryDisplay.from_estimator( clfs[-1], X, response_method="predict", cmap=cm, alpha=0.15, ax=axis[1], grid_resolution=200, ) axis[1].scatter( X_target[:, 0], X_target[:, 1], c=y_target, vmax=9, cmap="tab10", alpha=0.7, ) axis[1].set_title("Target domain") axis[0].text( 0.05, 0.1, f"Accuracy: {clfs[0].score(X_source, y_source):.1%}", transform=axis[0].transAxes, ha="left", bbox={"boxstyle": "round", "facecolor": "white", "alpha": 0.5}, ) axis[1].text( 0.05, 0.1, f"Accuracy: {gradual_adapter.score(X_target, y_target):.1%}", transform=axis[1].transAxes, ha="left", bbox={"boxstyle": "round", "facecolor": "white", "alpha": 0.5}, ) plt.show() .. image-sg:: /auto_examples/methods/images/sphx_glr_plot_gradual_da_003.png :alt: Source domain, Target domain :srcset: /auto_examples/methods/images/sphx_glr_plot_gradual_da_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 9.290 seconds) .. _sphx_glr_download_auto_examples_methods_plot_gradual_da.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_gradual_da.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_gradual_da.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_gradual_da.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_