.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/methods/plot_monge_alignment_da.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_methods_plot_monge_alignment_da.py: Multi-domain Linear Monge Alignment =================================== This example illustrates the use of the MultiLinearMongeAlignmentAdapter .. GENERATED FROM PYTHON SOURCE LINES 8-14 .. code-block:: Python # Author: Remi Flamary # # License: BSD 3-Clause # sphinx_gallery_thumbnail_number = 4 .. GENERATED FROM PYTHON SOURCE LINES 15-26 .. code-block:: Python import matplotlib.pyplot as plt import numpy as np from sklearn.linear_model import LogisticRegression from skada import ( MultiLinearMongeAlignmentAdapter, make_da_pipeline, source_target_split, ) from skada.datasets import make_shifted_datasets .. GENERATED FROM PYTHON SOURCE LINES 27-31 Generate concept drift classification dataset and plot it ----------------------------------------------------- We generate a simple 2D concept drift dataset. .. GENERATED FROM PYTHON SOURCE LINES 31-56 .. code-block:: Python X, y, sample_domain = make_shifted_datasets( n_samples_source=20, n_samples_target=20, shift="concept_drift", noise=0.2, label="multiclass", random_state=42, ) Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain) plt.figure(5, (10, 5)) plt.subplot(1, 2, 1) plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source") plt.title("Source data") ax = plt.axis() plt.subplot(1, 2, 2) plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Target") plt.title("Target data") plt.axis(ax) .. image-sg:: /auto_examples/methods/images/sphx_glr_plot_monge_alignment_da_001.png :alt: Source data, Target data :srcset: /auto_examples/methods/images/sphx_glr_plot_monge_alignment_da_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none (np.float64(-2.3676169789533272), np.float64(4.53817259426296), np.float64(-1.7964573229829714), np.float64(4.336863129384494)) .. GENERATED FROM PYTHON SOURCE LINES 57-63 Train a classifier on source data -------------------------------- We train a simple SVC classifier on the source domain and evaluate its performance on the source and target domain. Performance is much lower on the target domain due to the shift. We also plot the decision boundary .. GENERATED FROM PYTHON SOURCE LINES 63-107 .. code-block:: Python clf = MultiLinearMongeAlignmentAdapter() clf.fit(X, sample_domain=sample_domain) X_adapt = clf.transform(X, sample_domain=sample_domain, allow_source=True) plt.figure(5, (10, 3)) plt.subplot(1, 3, 1) plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source") plt.title("Source data") ax = plt.axis() plt.subplot(1, 3, 2) plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Target") plt.title("Target data") plt.axis(ax) plt.subplot(1, 3, 3) plt.scatter( X_adapt[sample_domain >= 0, 0], X_adapt[sample_domain >= 0, 1], c=y[sample_domain >= 0], marker="o", cmap="tab10", vmax=9, label="Source", alpha=0.5, ) plt.scatter( X_adapt[sample_domain < 0, 0], X_adapt[sample_domain < 0, 1], c=y[sample_domain < 0], marker="x", cmap="tab10", vmax=9, label="Target", alpha=1, ) plt.legend() plt.title("Adapted data") .. image-sg:: /auto_examples/methods/images/sphx_glr_plot_monge_alignment_da_002.png :alt: Source data, Target data, Adapted data :srcset: /auto_examples/methods/images/sphx_glr_plot_monge_alignment_da_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Text(0.5, 1.0, 'Adapted data') .. GENERATED FROM PYTHON SOURCE LINES 108-110 Train a classifier on adapted data ---------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 110-123 .. code-block:: Python clf = make_da_pipeline( MultiLinearMongeAlignmentAdapter(), LogisticRegression(), ) clf.fit(X, y, sample_domain=sample_domain) print( "Average accuracy on all domains:", clf.score(X, y, sample_domain=sample_domain, allow_source=True), ) .. rst-class:: sphx-glr-script-out .. code-block:: none Average accuracy on all domains: 0.9875 .. GENERATED FROM PYTHON SOURCE LINES 124-202 .. code-block:: Python def get_multidomain_data( n_samples_source=100, n_samples_target=100, noise=0.1, random_state=None, n_sources=3, n_targets=2, ): np.random.seed(random_state) X, y, sample_domain = make_shifted_datasets( n_samples_source=n_samples_source, n_samples_target=n_samples_target, noise=noise, shift="concept_drift", label="multiclass", random_state=random_state, ) for ns in range(n_sources - 1): Xi, yi, sample_domaini = make_shifted_datasets( n_samples_source=n_samples_source, n_samples_target=n_samples_target, noise=noise, shift="concept_drift", label="multiclass", random_state=random_state + ns, mean=np.random.randn(2), sigma=np.random.rand(2) * 0.5 + 0.5, ) Xs, Xt, ys, yt = source_target_split(Xi, yi, sample_domain=sample_domaini) X = np.vstack([X, Xt]) y = np.hstack([y, yt]) sample_domain = np.hstack([sample_domain, np.ones(Xt.shape[0]) * (ns + 2)]) for nt in range(n_targets - 1): Xi, yi, sample_domaini = make_shifted_datasets( n_samples_source=n_samples_source, n_samples_target=n_samples_target, noise=noise, shift="concept_drift", label="multiclass", random_state=random_state + nt + 42, mean=np.random.randn(2), sigma=np.random.rand(2) * 0.5 + 0.5, ) Xs, Xt, ys, yt = source_target_split(Xi, yi, sample_domain=sample_domaini) X = np.vstack([X, Xt]) y = np.hstack([y, yt]) sample_domain = np.hstack([sample_domain, -np.ones(Xt.shape[0]) * (nt + 1)]) return X, y, sample_domain X, y, sample_domain = get_multidomain_data( n_samples_source=50, n_samples_target=50, noise=0.1, random_state=43, n_sources=3, n_targets=2, ) Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain) plt.figure(5, (10, 5)) plt.subplot(1, 2, 1) plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source") plt.title("Source data") ax = plt.axis() plt.subplot(1, 2, 2) plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Target") plt.title("Target domains") plt.axis(ax) .. image-sg:: /auto_examples/methods/images/sphx_glr_plot_monge_alignment_da_003.png :alt: Source data, Target domains :srcset: /auto_examples/methods/images/sphx_glr_plot_monge_alignment_da_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none (np.float64(-2.310098338155625), np.float64(4.756925382279493), np.float64(-2.1443686989830857), np.float64(4.464886123797522)) .. GENERATED FROM PYTHON SOURCE LINES 203-245 .. code-block:: Python clf = MultiLinearMongeAlignmentAdapter() clf.fit(X, sample_domain=sample_domain) X_adapt = clf.transform(X, sample_domain=sample_domain, allow_source=True) plt.figure(5, (10, 3)) plt.subplot(1, 3, 1) plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source") plt.title("Source data") ax = plt.axis() plt.subplot(1, 3, 2) plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Target") plt.title("Target data") plt.axis(ax) plt.subplot(1, 3, 3) plt.scatter( X_adapt[sample_domain >= 0, 0], X_adapt[sample_domain >= 0, 1], c=y[sample_domain >= 0], marker="o", cmap="tab10", vmax=9, label="Source", alpha=0.5, ) plt.scatter( X_adapt[sample_domain < 0, 0], X_adapt[sample_domain < 0, 1], c=y[sample_domain < 0], marker="x", cmap="tab10", vmax=9, label="Target", alpha=1, ) plt.legend() plt.axis(ax) plt.title("Adapted data") .. image-sg:: /auto_examples/methods/images/sphx_glr_plot_monge_alignment_da_004.png :alt: Source data, Target data, Adapted data :srcset: /auto_examples/methods/images/sphx_glr_plot_monge_alignment_da_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Text(0.5, 1.0, 'Adapted data') .. GENERATED FROM PYTHON SOURCE LINES 246-248 Train a classifier on adapted data ---------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 248-260 .. code-block:: Python clf = make_da_pipeline( MultiLinearMongeAlignmentAdapter(), LogisticRegression(), ) clf.fit(X, y, sample_domain=sample_domain) print( "Average accuracy on all domains:", clf.score(X, y, sample_domain=sample_domain, allow_source=True), ) .. rst-class:: sphx-glr-script-out .. code-block:: none Average accuracy on all domains: 1.0 .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.856 seconds) .. _sphx_glr_download_auto_examples_methods_plot_monge_alignment_da.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_monge_alignment_da.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_monge_alignment_da.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_monge_alignment_da.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_