.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/methods/plot_label_prop_da.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_methods_plot_label_prop_da.py: Label Propagation methods ====================================== This example shows how to use how to use label propagation methods to perform domain adaptation. This is done by propagating labels from the source domain to the target domain using the OT plan. This was proposed originally in [28]_ for semi-supervised learning but can be used for DA. We illustrate the method on a simple regression and classification concept drift dataset. We train a simple Kernel Ridge Regression (KRR) and Logistic Regression on the source domain and evaluate their performance on the source and target domain. We then train the same models with the label propagation method and evaluate their performance on the source and target domain. .. [28] Solomon, J., Rustamov, R., Guibas, L., & Butscher, A. (2014, January). Wasserstein propagation for semi-supervised learning. In International Conference on Machine Learning (pp. 306-314). PMLR. .. GENERATED FROM PYTHON SOURCE LINES 22-28 .. code-block:: Python # Author: Remi Flamary # # License: BSD 3-Clause # sphinx_gallery_thumbnail_number = 4 .. GENERATED FROM PYTHON SOURCE LINES 29-44 .. code-block:: Python import matplotlib.pyplot as plt import numpy as np from sklearn.kernel_ridge import KernelRidge from sklearn.linear_model import LogisticRegression from sklearn.metrics import mean_squared_error from sklearn.svm import SVC from skada import ( JCPOTLabelPropAdapter, OTLabelPropAdapter, make_da_pipeline, source_target_split, ) from skada.datasets import make_shifted_datasets .. GENERATED FROM PYTHON SOURCE LINES 45-49 Generate concept drift regression dataset and plot it ----------------------------------------------------- We generate a simple 2D concept drift dataset. .. GENERATED FROM PYTHON SOURCE LINES 49-75 .. code-block:: Python X, y, sample_domain = make_shifted_datasets( n_samples_source=20, n_samples_target=20, shift="concept_drift", noise=0.3, label="regression", random_state=42, ) y = (y - y.mean()) / y.std() Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain) plt.figure(1, (10, 5)) plt.subplot(1, 2, 1) plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, label="Source") plt.title("Source data") ax = plt.axis() plt.subplot(1, 2, 2) plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, label="Target") plt.title("Target data") plt.axis(ax) .. image-sg:: /auto_examples/methods/images/sphx_glr_plot_label_prop_da_001.png :alt: Source data, Target data :srcset: /auto_examples/methods/images/sphx_glr_plot_label_prop_da_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none (np.float64(-2.570318895725525), np.float64(4.744989537549497), np.float64(-1.9610814126500358), np.float64(4.45933195939893)) .. GENERATED FROM PYTHON SOURCE LINES 76-83 Train a regressor on source data -------------------------------- We train a simple Kernel Ridge Regression (KRR) on the source domain and evaluate its performance on the source and target domain. Performance is much lower on the target domain due to the shift. We also plot the decision boundary learned by the KRR. .. GENERATED FROM PYTHON SOURCE LINES 83-116 .. code-block:: Python clf = KernelRidge(kernel="rbf", alpha=0.5) clf.fit(Xs, ys) # Compute accuracy on source and target ys_pred = clf.predict(Xs) yt_pred = clf.predict(Xt) mse_s = mean_squared_error(ys, ys_pred) mse_t = mean_squared_error(yt, yt_pred) print(f"MSE on source: {mse_s:.2f}") print(f"MSE on target: {mse_t:.2f}") XX, YY = np.meshgrid(np.linspace(ax[0], ax[1], 100), np.linspace(ax[2], ax[3], 100)) Z = clf.predict(np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape) plt.figure(2, (10, 5)) plt.subplot(1, 2, 1) plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, label="Prediction") plt.imshow(Z, extent=(ax[0], ax[1], ax[2], ax[3]), origin="lower", alpha=0.5) plt.title(f"KRR Prediction on source (MSE={mse_s:.2f})") plt.axis(ax) plt.subplot(1, 2, 2) plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, label="Prediction") plt.imshow(Z, extent=(ax[0], ax[1], ax[2], ax[3]), origin="lower", alpha=0.5) plt.title(f"KRR Prediction on target (MSE={mse_t:.2f})") plt.axis(ax) .. image-sg:: /auto_examples/methods/images/sphx_glr_plot_label_prop_da_002.png :alt: KRR Prediction on source (MSE=0.06), KRR Prediction on target (MSE=0.77) :srcset: /auto_examples/methods/images/sphx_glr_plot_label_prop_da_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none MSE on source: 0.06 MSE on target: 0.77 (np.float64(-2.570318895725525), np.float64(4.744989537549497), np.float64(-1.9610814126500358), np.float64(4.45933195939893)) .. GENERATED FROM PYTHON SOURCE LINES 117-120 Train the full Labe Propagation model -------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 120-148 .. code-block:: Python clf = make_da_pipeline(OTLabelPropAdapter(), KernelRidge(kernel="rbf", alpha=0.5)) clf.fit(X, y, sample_domain=sample_domain) ys_pred = clf.predict(Xs) yt_pred = clf.predict(Xt) mse_s = mean_squared_error(ys, ys_pred) mse_t = mean_squared_error(yt, yt_pred) Zjdot = clf.predict(np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape) print(f"LabelProp MSE on source: {mse_s:.2f}") print(f"LabelProp MSE on target: {mse_t:.2f}") plt.figure(3, (10, 5)) plt.subplot(1, 2, 1) plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, label="Prediction") plt.imshow(Zjdot, extent=(ax[0], ax[1], ax[2], ax[3]), origin="lower", alpha=0.5) plt.title(f"LabelProp Prediction on source (MSE={mse_s:.2f})") plt.axis(ax) plt.subplot(1, 2, 2) plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, label="Prediction") plt.imshow(Zjdot, extent=(ax[0], ax[1], ax[2], ax[3]), origin="lower", alpha=0.5) plt.title(f"LabelProp Prediction on target (MSE={mse_t:.2f})") plt.axis(ax) .. image-sg:: /auto_examples/methods/images/sphx_glr_plot_label_prop_da_003.png :alt: LabelProp Prediction on source (MSE=0.28), LabelProp Prediction on target (MSE=0.50) :srcset: /auto_examples/methods/images/sphx_glr_plot_label_prop_da_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/circleci/.local/lib/python3.10/site-packages/ot/lp/__init__.py:576: UserWarning: numItermax reached before optimality. Try to increase numItermax. result_code_string = check_result(result_code) LabelProp MSE on source: 0.28 LabelProp MSE on target: 0.50 (np.float64(-2.570318895725525), np.float64(4.744989537549497), np.float64(-1.9610814126500358), np.float64(4.45933195939893)) .. GENERATED FROM PYTHON SOURCE LINES 149-154 Illustration of the propagated labels --------------------------- We illustrate the propagated labels on the target domain. We can see that the labels are propagated from the source domain to the target domain. .. GENERATED FROM PYTHON SOURCE LINES 154-179 .. code-block:: Python lp = OTLabelPropAdapter() yh = lp.fit_transform(X, y, sample_domain=sample_domain)[1] yht = yh[sample_domain < 0] # plt.figure(1, (10, 5)) plt.subplot(1, 2, 1) plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, label="Source") plt.scatter(Xt[:, 0], Xt[:, 1], c="gray", label="Target") plt.legend() plt.title("Source and Target data") ax = plt.axis() plt.subplot(1, 2, 2) plt.scatter(Xt[:, 0], Xt[:, 1], c=yht, label="Target") plt.title("Propagated labels data") plt.axis(ax) .. image-sg:: /auto_examples/methods/images/sphx_glr_plot_label_prop_da_004.png :alt: Source and Target data, Propagated labels data :srcset: /auto_examples/methods/images/sphx_glr_plot_label_prop_da_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/circleci/.local/lib/python3.10/site-packages/ot/lp/__init__.py:576: UserWarning: numItermax reached before optimality. Try to increase numItermax. result_code_string = check_result(result_code) (np.float64(-2.570318895725525), np.float64(4.744989537549497), np.float64(-1.9610814126500358), np.float64(4.45933195939893)) .. GENERATED FROM PYTHON SOURCE LINES 180-184 Generate concept drift classification dataset and plot it ----------------------------------------------------- We generate a simple 2D concept drift dataset. .. GENERATED FROM PYTHON SOURCE LINES 184-209 .. code-block:: Python X, y, sample_domain = make_shifted_datasets( n_samples_source=20, n_samples_target=20, shift="concept_drift", noise=0.2, label="multiclass", random_state=42, ) Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain) plt.figure(5, (10, 5)) plt.subplot(1, 2, 1) plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source") plt.title("Source data") ax = plt.axis() plt.subplot(1, 2, 2) plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Target") plt.title("Target data") plt.axis(ax) .. image-sg:: /auto_examples/methods/images/sphx_glr_plot_label_prop_da_005.png :alt: Source data, Target data :srcset: /auto_examples/methods/images/sphx_glr_plot_label_prop_da_005.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none (np.float64(-2.3676169789533272), np.float64(4.53817259426296), np.float64(-1.7964573229829714), np.float64(4.336863129384494)) .. GENERATED FROM PYTHON SOURCE LINES 210-216 Train a classifier on source data -------------------------------- We train a simple SVC classifier on the source domain and evaluate its performance on the source and target domain. Performance is much lower on the target domain due to the shift. We also plot the decision boundary .. GENERATED FROM PYTHON SOURCE LINES 216-262 .. code-block:: Python clf = LogisticRegression() clf.fit(Xs, ys) # Compute accuracy on source and target ys_pred = clf.predict(Xs) yt_pred = clf.predict(Xt) acc_s = (ys_pred == ys).mean() acc_t = (yt_pred == yt).mean() print(f"Accuracy on source: {acc_s:.2f}") print(f"Accuracy on target: {acc_t:.2f}") XX, YY = np.meshgrid(np.linspace(ax[0], ax[1], 100), np.linspace(ax[2], ax[3], 100)) Z = clf.predict(np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape) plt.figure(6, (10, 5)) plt.subplot(1, 2, 1) plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Prediction") plt.imshow( Z, extent=(ax[0], ax[1], ax[2], ax[3]), origin="lower", alpha=0.5, cmap="tab10", vmax=9, ) plt.title(f"LogReg Prediction on source (ACC={acc_s:.2f})") plt.subplot(1, 2, 2) plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Prediction") plt.imshow( Z, extent=(ax[0], ax[1], ax[2], ax[3]), origin="lower", alpha=0.5, cmap="tab10", vmax=9, ) plt.title(f"LogReg Prediction on target (ACC={acc_t:.2f})") plt.axis(ax) .. image-sg:: /auto_examples/methods/images/sphx_glr_plot_label_prop_da_006.png :alt: LogReg Prediction on source (ACC=0.99), LogReg Prediction on target (ACC=0.50) :srcset: /auto_examples/methods/images/sphx_glr_plot_label_prop_da_006.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Accuracy on source: 0.99 Accuracy on target: 0.50 (np.float64(-2.3676169789533272), np.float64(4.53817259426296), np.float64(-1.7964573229829714), np.float64(4.336863129384494)) .. GENERATED FROM PYTHON SOURCE LINES 263-266 Train with LabelProp + classifier ------------------------- .. GENERATED FROM PYTHON SOURCE LINES 266-311 .. code-block:: Python clf = make_da_pipeline(OTLabelPropAdapter(), LogisticRegression()) clf.fit(X, y, sample_domain=sample_domain) ys_pred = clf.predict(Xs) yt_pred = clf.predict(Xt) acc_s = (ys_pred == ys).mean() acc_t = (yt_pred == yt).mean() print(f"LabelProp Accuracy on source: {acc_s:.2f}") print(f"LabelProp Accuracy on target: {acc_t:.2f}") XX, YY = np.meshgrid(np.linspace(ax[0], ax[1], 100), np.linspace(ax[2], ax[3], 100)) Z = clf.predict(np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape) plt.figure(7, (10, 5)) plt.subplot(1, 2, 1) plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Prediction") plt.imshow( Z, extent=(ax[0], ax[1], ax[2], ax[3]), origin="lower", alpha=0.5, cmap="tab10", vmax=9, ) plt.title(f"LabelProp reglog on source (ACC={acc_s:.2f})") plt.subplot(1, 2, 2) plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Prediction") plt.imshow( Z, extent=(ax[0], ax[1], ax[2], ax[3]), origin="lower", alpha=0.5, cmap="tab10", vmax=9, ) plt.title(f"LabelProp reglog on target (ACC={acc_t:.2f})") plt.axis(ax) .. image-sg:: /auto_examples/methods/images/sphx_glr_plot_label_prop_da_007.png :alt: LabelProp reglog on source (ACC=0.76), LabelProp reglog on target (ACC=0.57) :srcset: /auto_examples/methods/images/sphx_glr_plot_label_prop_da_007.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/circleci/.local/lib/python3.10/site-packages/ot/lp/__init__.py:576: UserWarning: numItermax reached before optimality. Try to increase numItermax. result_code_string = check_result(result_code) LabelProp Accuracy on source: 0.76 LabelProp Accuracy on target: 0.57 (np.float64(-2.3676169789533272), np.float64(4.53817259426296), np.float64(-1.7964573229829714), np.float64(4.336863129384494)) .. GENERATED FROM PYTHON SOURCE LINES 312-317 Illustration of the propagated labels --------------------------- We illustrate the propagated labels on the target domain. We can see that the labels are propagated from the source domain to the target domain. .. GENERATED FROM PYTHON SOURCE LINES 317-341 .. code-block:: Python lp = OTLabelPropAdapter() yh = lp.fit_transform(X, y, sample_domain=sample_domain)[1] yht = yh[sample_domain < 0] plt.figure(1, (10, 5)) plt.subplot(1, 2, 1) plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source") plt.scatter(Xt[:, 0], Xt[:, 1], c="gray", label="Target") plt.legend() plt.title("Source and Target data") ax = plt.axis() plt.subplot(1, 2, 2) plt.scatter(Xt[:, 0], Xt[:, 1], c=yht, cmap="tab10", vmax=9, label="Target") plt.title("Propagated labels data") plt.axis(ax) .. image-sg:: /auto_examples/methods/images/sphx_glr_plot_label_prop_da_008.png :alt: Source and Target data, Propagated labels data :srcset: /auto_examples/methods/images/sphx_glr_plot_label_prop_da_008.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/circleci/.local/lib/python3.10/site-packages/ot/lp/__init__.py:576: UserWarning: numItermax reached before optimality. Try to increase numItermax. result_code_string = check_result(result_code) (np.float64(-2.3676169789533272), np.float64(4.53817259426296), np.float64(-1.7964573229829714), np.float64(4.336863129384494)) .. GENERATED FROM PYTHON SOURCE LINES 342-346 Generate classification classification dataset and plot it ----------------------------------------------------- We generate a simple 2D target shift dataset. .. GENERATED FROM PYTHON SOURCE LINES 346-371 .. code-block:: Python X, y, sample_domain = make_shifted_datasets( n_samples_source=20, n_samples_target=20, shift="target_shift", noise=0.2, random_state=42, ) Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain) plt.figure(5, (10, 5)) plt.subplot(1, 2, 1) plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source") plt.title("Source data") ax = plt.axis() plt.subplot(1, 2, 2) plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Target") plt.title("Target data") plt.axis(ax) .. image-sg:: /auto_examples/methods/images/sphx_glr_plot_label_prop_da_009.png :alt: Source data, Target data :srcset: /auto_examples/methods/images/sphx_glr_plot_label_prop_da_009.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none (np.float64(-2.4269298070320042), np.float64(4.352173719537829), np.float64(-1.5585726101363702), np.float64(4.367467726151268)) .. GENERATED FROM PYTHON SOURCE LINES 372-379 Train with LabelProp and JCPOT + classifier ------------------------- On this target shift dataset, we can see that the label propagation method does not work well because it finds correspondences between the source and target samples with different classes. In this case JCPOT is more robust to this kind of shift because it estimates the class proportions in the target. .. GENERATED FROM PYTHON SOURCE LINES 379-432 .. code-block:: Python clf = make_da_pipeline(OTLabelPropAdapter(), SVC()) clf.fit(X, y, sample_domain=sample_domain) clf_jcpot = make_da_pipeline(JCPOTLabelPropAdapter(reg=0.1), SVC()) clf_jcpot.fit(X, y, sample_domain=sample_domain) yt_pred = clf.predict(Xt) acc_t = (yt_pred == yt).mean() print(f"LabelProp Accuracy on target: {acc_t:.2f}") yt_pred = clf_jcpot.predict(Xt) acc_s_jcpot = (yt_pred == yt).mean() print(f"JCPOT Accuracy on target: {acc_s_jcpot:.2f}") XX, YY = np.meshgrid(np.linspace(ax[0], ax[1], 100), np.linspace(ax[2], ax[3], 100)) Z = clf.predict(np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape) Z_jcpot = clf_jcpot.predict(np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape) plt.figure(7, (10, 5)) plt.subplot(1, 2, 1) plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Prediction") plt.imshow( Z, extent=(ax[0], ax[1], ax[2], ax[3]), origin="lower", alpha=0.5, cmap="tab10", vmax=9, ) plt.title(f"LabelProp reglog on target (ACC={acc_t:.2f})") plt.axis(ax) plt.subplot(1, 2, 2) plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Prediction") plt.imshow( Z_jcpot, extent=(ax[0], ax[1], ax[2], ax[3]), origin="lower", alpha=0.5, cmap="tab10", vmax=9, ) plt.title(f"JCPOT reglog on target (ACC={acc_s_jcpot:.2f})") plt.axis(ax) .. image-sg:: /auto_examples/methods/images/sphx_glr_plot_label_prop_da_010.png :alt: LabelProp reglog on target (ACC=0.87), JCPOT reglog on target (ACC=0.99) :srcset: /auto_examples/methods/images/sphx_glr_plot_label_prop_da_010.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/circleci/.local/lib/python3.10/site-packages/ot/lp/__init__.py:576: UserWarning: numItermax reached before optimality. Try to increase numItermax. result_code_string = check_result(result_code) /home/circleci/.local/lib/python3.10/site-packages/ot/bregman/_barycenter.py:1047: UserWarning: Algorithm did not converge. You might want to increase the number of iterations `numItermax` or the regularization parameter `reg`. warnings.warn( LabelProp Accuracy on target: 0.87 JCPOT Accuracy on target: 0.99 (np.float64(-2.4269298070320042), np.float64(4.352173719537829), np.float64(-1.5585726101363702), np.float64(4.367467726151268)) .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 1.907 seconds) .. _sphx_glr_download_auto_examples_methods_plot_label_prop_da.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_label_prop_da.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_label_prop_da.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_label_prop_da.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_