Label Propagation methods

This example shows how to use how to use label propagation methods to perform domain adaptation. This is done by propagating labels from the source domain to the target domain using the OT plan. This was proposed originally in [28] for semi-supervised learning but can be used for DA.

We illustrate the method on a simple regression and classification concept drift dataset. We train a simple Kernel Ridge Regression (KRR) and Logistic Regression on the source domain and evaluate their performance on the source and target domain. We then train the same models with the label propagation method and evaluate their performance on the source and target domain.

# Author: Remi Flamary
#
# License: BSD 3-Clause
# sphinx_gallery_thumbnail_number = 4
import matplotlib.pyplot as plt
import numpy as np
from sklearn.kernel_ridge import KernelRidge
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import mean_squared_error

from skada import OTLabelPropAdapter, make_da_pipeline, source_target_split
from skada.datasets import make_shifted_datasets

Generate concept drift regression dataset and plot it

We generate a simple 2D concept drift dataset.

X, y, sample_domain = make_shifted_datasets(
    n_samples_source=20,
    n_samples_target=20,
    shift="concept_drift",
    noise=0.3,
    label="regression",
    random_state=42,
)

y = (y - y.mean()) / y.std()

Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain)


plt.figure(1, (10, 5))
plt.subplot(1, 2, 1)
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, label="Source")
plt.title("Source data")
ax = plt.axis()

plt.subplot(1, 2, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, label="Target")
plt.title("Target data")
plt.axis(ax)
Source data, Target data
(np.float64(-2.570318895725525), np.float64(4.744989537549497), np.float64(-1.9610814126500358), np.float64(4.45933195939893))

Train a regressor on source data

We train a simple Kernel Ridge Regression (KRR) on the source domain and evaluate its performance on the source and target domain. Performance is much lower on the target domain due to the shift. We also plot the decision boundary learned by the KRR.

clf = KernelRidge(kernel="rbf", alpha=0.5)
clf.fit(Xs, ys)

# Compute accuracy on source and target
ys_pred = clf.predict(Xs)
yt_pred = clf.predict(Xt)

mse_s = mean_squared_error(ys, ys_pred)
mse_t = mean_squared_error(yt, yt_pred)

print(f"MSE on source: {mse_s:.2f}")
print(f"MSE on target: {mse_t:.2f}")

XX, YY = np.meshgrid(np.linspace(ax[0], ax[1], 100), np.linspace(ax[2], ax[3], 100))
Z = clf.predict(np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape)


plt.figure(2, (10, 5))
plt.subplot(1, 2, 1)
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, label="Prediction")
plt.imshow(Z, extent=(ax[0], ax[1], ax[2], ax[3]), origin="lower", alpha=0.5)
plt.title(f"KRR Prediction on source (MSE={mse_s:.2f})")
plt.axis(ax)

plt.subplot(1, 2, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, label="Prediction")
plt.imshow(Z, extent=(ax[0], ax[1], ax[2], ax[3]), origin="lower", alpha=0.5)
plt.title(f"KRR Prediction on target (MSE={mse_t:.2f})")
plt.axis(ax)
KRR Prediction on source (MSE=0.06), KRR Prediction on target (MSE=0.77)
MSE on source: 0.06
MSE on target: 0.77

(np.float64(-2.570318895725525), np.float64(4.744989537549497), np.float64(-1.9610814126500358), np.float64(4.45933195939893))

Train the full Labe Propagation model

clf = make_da_pipeline(OTLabelPropAdapter(), KernelRidge(kernel="rbf", alpha=0.5))
clf.fit(X, y, sample_domain=sample_domain)

ys_pred = clf.predict(Xs)
yt_pred = clf.predict(Xt)

mse_s = mean_squared_error(ys, ys_pred)
mse_t = mean_squared_error(yt, yt_pred)

Zjdot = clf.predict(np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape)

print(f"LabelProp MSE on source: {mse_s:.2f}")
print(f"LabelProp MSE on target: {mse_t:.2f}")

plt.figure(3, (10, 5))
plt.subplot(1, 2, 1)
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, label="Prediction")
plt.imshow(Zjdot, extent=(ax[0], ax[1], ax[2], ax[3]), origin="lower", alpha=0.5)
plt.title(f"LabelProp Prediction on source (MSE={mse_s:.2f})")
plt.axis(ax)

plt.subplot(1, 2, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, label="Prediction")
plt.imshow(Zjdot, extent=(ax[0], ax[1], ax[2], ax[3]), origin="lower", alpha=0.5)
plt.title(f"LabelProp Prediction on target (MSE={mse_t:.2f})")
plt.axis(ax)
LabelProp Prediction on source (MSE=0.28), LabelProp Prediction on target (MSE=0.50)
/home/circleci/.local/lib/python3.10/site-packages/ot/lp/__init__.py:535: UserWarning: numItermax reached before optimality. Try to increase numItermax.
  result_code_string = check_result(result_code)
LabelProp MSE on source: 0.28
LabelProp MSE on target: 0.50

(np.float64(-2.570318895725525), np.float64(4.744989537549497), np.float64(-1.9610814126500358), np.float64(4.45933195939893))

Illustration of the propagated labels

We illustrate the propagated labels on the target domain. We can see that the labels are propagated from the source domain to the target domain.

lp = OTLabelPropAdapter()

yh = lp.fit_transform(X, y, sample_domain=sample_domain)[1]

yht = yh[sample_domain < 0]  #

plt.figure(1, (10, 5))
plt.subplot(1, 2, 1)
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, label="Source")
plt.scatter(Xt[:, 0], Xt[:, 1], c="gray", label="Target")
plt.legend()

plt.title("Source and Target data")
ax = plt.axis()

plt.subplot(1, 2, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yht, label="Target")


plt.title("Propagated labels data")
plt.axis(ax)
Source and Target data, Propagated labels data
/home/circleci/.local/lib/python3.10/site-packages/ot/lp/__init__.py:535: UserWarning: numItermax reached before optimality. Try to increase numItermax.
  result_code_string = check_result(result_code)

(np.float64(-2.570318895725525), np.float64(4.744989537549497), np.float64(-1.9610814126500358), np.float64(4.45933195939893))

Generate concept drift classification dataset and plot it

We generate a simple 2D concept drift dataset.

X, y, sample_domain = make_shifted_datasets(
    n_samples_source=20,
    n_samples_target=20,
    shift="concept_drift",
    noise=0.2,
    label="multiclass",
    random_state=42,
)


Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain)


plt.figure(5, (10, 5))
plt.subplot(1, 2, 1)
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source")
plt.title("Source data")
ax = plt.axis()

plt.subplot(1, 2, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Target")
plt.title("Target data")
plt.axis(ax)
Source data, Target data
(np.float64(-2.3676169789533272), np.float64(4.53817259426296), np.float64(-1.7964573229829714), np.float64(4.336863129384494))

Train a classifier on source data

We train a simple SVC classifier on the source domain and evaluate its performance on the source and target domain. Performance is much lower on the target domain due to the shift. We also plot the decision boundary

clf = LogisticRegression()
clf.fit(Xs, ys)

# Compute accuracy on source and target
ys_pred = clf.predict(Xs)
yt_pred = clf.predict(Xt)

acc_s = (ys_pred == ys).mean()
acc_t = (yt_pred == yt).mean()

print(f"Accuracy on source: {acc_s:.2f}")
print(f"Accuracy on target: {acc_t:.2f}")

XX, YY = np.meshgrid(np.linspace(ax[0], ax[1], 100), np.linspace(ax[2], ax[3], 100))
Z = clf.predict(np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape)


plt.figure(6, (10, 5))
plt.subplot(1, 2, 1)
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Prediction")
plt.imshow(
    Z,
    extent=(ax[0], ax[1], ax[2], ax[3]),
    origin="lower",
    alpha=0.5,
    cmap="tab10",
    vmax=9,
)
plt.title(f"LogReg Prediction on source (ACC={acc_s:.2f})")

plt.subplot(1, 2, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Prediction")
plt.imshow(
    Z,
    extent=(ax[0], ax[1], ax[2], ax[3]),
    origin="lower",
    alpha=0.5,
    cmap="tab10",
    vmax=9,
)
plt.title(f"LogReg Prediction on target (ACC={acc_t:.2f})")
plt.axis(ax)
LogReg Prediction on source (ACC=0.99), LogReg Prediction on target (ACC=0.50)
Accuracy on source: 0.99
Accuracy on target: 0.50

(np.float64(-2.3676169789533272), np.float64(4.53817259426296), np.float64(-1.7964573229829714), np.float64(4.336863129384494))

Train with LabelProp + classifier

clf = make_da_pipeline(OTLabelPropAdapter(), LogisticRegression())
clf.fit(X, y, sample_domain=sample_domain)


ys_pred = clf.predict(Xs)
yt_pred = clf.predict(Xt)

acc_s = (ys_pred == ys).mean()
acc_t = (yt_pred == yt).mean()

print(f"LabelProp Accuracy on source: {acc_s:.2f}")
print(f"LabelProp Accuracy on target: {acc_t:.2f}")

XX, YY = np.meshgrid(np.linspace(ax[0], ax[1], 100), np.linspace(ax[2], ax[3], 100))
Z = clf.predict(np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape)


plt.figure(7, (10, 5))
plt.subplot(1, 2, 1)
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Prediction")
plt.imshow(
    Z,
    extent=(ax[0], ax[1], ax[2], ax[3]),
    origin="lower",
    alpha=0.5,
    cmap="tab10",
    vmax=9,
)
plt.title(f"LabelProp reglog on source (ACC={acc_s:.2f})")

plt.subplot(1, 2, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Prediction")
plt.imshow(
    Z,
    extent=(ax[0], ax[1], ax[2], ax[3]),
    origin="lower",
    alpha=0.5,
    cmap="tab10",
    vmax=9,
)
plt.title(f"LabelProp reglog on target (ACC={acc_t:.2f})")
plt.axis(ax)
LabelProp reglog on source (ACC=0.76), LabelProp reglog on target (ACC=0.57)
/home/circleci/.local/lib/python3.10/site-packages/ot/lp/__init__.py:535: UserWarning: numItermax reached before optimality. Try to increase numItermax.
  result_code_string = check_result(result_code)
LabelProp Accuracy on source: 0.76
LabelProp Accuracy on target: 0.57

(np.float64(-2.3676169789533272), np.float64(4.53817259426296), np.float64(-1.7964573229829714), np.float64(4.336863129384494))

Illustration of the propagated labels

We illustrate the propagated labels on the target domain. We can see that the labels are propagated from the source domain to the target domain.

lp = OTLabelPropAdapter()

yh = lp.fit_transform(X, y, sample_domain=sample_domain)[1]

yht = yh[sample_domain < 0]

plt.figure(1, (10, 5))
plt.subplot(1, 2, 1)
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source")
plt.scatter(Xt[:, 0], Xt[:, 1], c="gray", label="Target")
plt.legend()
plt.title("Source and Target data")
ax = plt.axis()

plt.subplot(1, 2, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yht, cmap="tab10", vmax=9, label="Target")


plt.title("Propagated labels data")
plt.axis(ax)
Source and Target data, Propagated labels data
/home/circleci/.local/lib/python3.10/site-packages/ot/lp/__init__.py:535: UserWarning: numItermax reached before optimality. Try to increase numItermax.
  result_code_string = check_result(result_code)

(np.float64(-2.3676169789533272), np.float64(4.53817259426296), np.float64(-1.7964573229829714), np.float64(4.336863129384494))

Total running time of the script: (0 minutes 1.482 seconds)

Gallery generated by Sphinx-Gallery