Plot dataset source domain and shifted target domain

This illustrates the make_shifted_dataset() dataset generator. Each method consists of generating source data and shifted target data. We illustrate here: covariate shift, target shift, concept drift, and sample bias. See detailed description of each shift in [1].

import matplotlib.pyplot as plt

from skada import source_target_split
from skada.datasets import make_shifted_datasets
def plot_shifted_dataset(shift, random_state=42):
    """Plot source and shifted target data for a given type of shift.

    The possible shifts are 'covariate_shift', 'target_shift',
    'concept_drift', or 'subspace'.

    We use here the same random seed for multiple calls to
    ensure same distributions.
    """
    X, y, sample_domain = make_shifted_datasets(
        n_samples_source=20,
        n_samples_target=20,
        shift=shift,
        noise=0.3,
        label="multiclass",
        random_state=random_state,
    )

    X_source, X_target, y_source, y_target = source_target_split(
        X, y, sample_domain=sample_domain
    )

    fig, (ax1, ax2) = plt.subplots(1, 2, sharex="row", sharey="row", figsize=(8, 4))
    fig.suptitle(shift.replace("_", " ").title(), fontsize=14)
    plt.subplots_adjust(bottom=0.15)
    ax1.scatter(
        X_source[:, 0],
        X_source[:, 1],
        c=y_source,
        cmap="tab10",
        vmax=10,
        alpha=0.5,
    )
    ax1.set_title("Source data")
    ax1.set_xlabel("Feature 1")
    ax1.set_ylabel("Feature 2")

    ax2.scatter(
        X_source[:, 0],
        X_source[:, 1],
        c=y_source,
        cmap="tab10",
        vmax=10,
        alpha=0.1,
    )
    ax2.scatter(
        X_target[:, 0],
        X_target[:, 1],
        c=y_target,
        cmap="tab10",
        vmax=10,
        alpha=0.5,
    )
    ax2.set_title("Target data")
    ax2.set_xlabel("Feature 1")
    ax2.set_ylabel("Feature 2")

    plt.show()
for shift in [
    "covariate_shift",
    "target_shift",
    "concept_drift",
    "subspace",
]:
    plot_shifted_dataset(shift)
  • Covariate Shift, Source data, Target data
  • Target Shift, Source data, Target data
  • Concept Drift, Source data, Target data
  • Subspace, Source data, Target data

Total running time of the script: (0 minutes 0.477 seconds)

Gallery generated by Sphinx-Gallery