Source code for skada._ot

# Author: Remi Flamary <remi.flamary@polytechnique.edu>
#
# License: BSD 3-Clause


import warnings

import numpy as np
import ot
from sklearn.base import clone
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.preprocessing import OneHotEncoder
from sklearn.svm import SVC
from sklearn.utils.metaestimators import available_if
from sklearn.utils.validation import check_is_fitted

from ._pipeline import make_da_pipeline
from ._utils import Y_Type, _find_y_type
from .base import BaseAdapter, DAEstimator
from .utils import check_X_y_domain, per_domain_split, source_target_split


def get_jdot_class_cost_matrix(Ys, Xt, estimator=None, metric="multinomial"):
    """Cost matrix for joint distribution optimal transport classification problem.

    Parameters
    ----------
    Ys : array-like of shape (n_samples,n_classes)
        Source domain labels one hot encoded.
    Xt : array-like of shape (m_samples, n_features)
        Target domain samples.
    estimator : object
        The already fitted estimator to be used for the classification task. This
        estimator should optimize a classification loss corresponding to the
        given metric and provide compatible predict method (decision_function of
        predict_proba). If None, a constant prediction is used.
    metric : str, default='multinomial'
        The metric to use for the cost matrix. Can be 'multinomial' for cross-entropy
        cost/ multinomial logistic regression or 'hinge' for hinge cost (SVM/SVC).

    Returns
    -------
    M : array-like of shape (n_samples, m_samples)
        The cost matrix.

    References
    ----------
    [10] N. Courty, R. Flamary, A. Habrard, A. Rakotomamonjy, Joint Distribution
         Optimal Transportation for Domain Adaptation, Neural Information Processing
         Systems (NIPS), 2017.

    """
    if estimator is None:
        M = np.ones((Ys.shape[0], Xt.shape[0])) * 10
        return M

    if metric == "multinomial":
        if hasattr(estimator, "predict_log_proba"):
            Yt_pred = estimator.predict_log_proba(Xt)
            M = -np.sum(Ys[:, None, :] * Yt_pred[None, :, :], 2)
        elif hasattr(estimator, "predict_proba"):
            Yt_pred = estimator.predict_proba(Xt)
            M = -np.sum(Ys[:, None, :] * np.log(Yt_pred[None, :, :] + 1e-16), 2)
        else:
            raise ValueError(
                "Estimator must have predict_proba or predict_log_proba"
                " method for cce loss"
            )

    elif metric == "hinge":
        Ys = 2 * Ys - 1  # make Y -1/1 for hinge loss

        if hasattr(estimator, "decision_function"):
            Yt_pred = estimator.decision_function(Xt)
            if len(Yt_pred.shape) == 1:
                Yt_pred = np.repeat(Yt_pred.reshape(-1, 1), 2, axis=1)
            M = np.maximum(0, 1 - Ys[:, None, :] * Yt_pred[None, :, :]).sum(2)
        else:
            raise ValueError(
                "Estimator must have decision_function method for hinge loss"
            )
    else:
        raise ValueError("Unknown metric")

    return M


def get_data_jdot_class(Xt, Yth, labels, thr_weights=1e-6):
    """Get data for the joint distribution optimal transport classification problem.

    This function will repeat sample to allow for training on uncertain labels.

    Parameters
    ----------
    Xt : array-like of shape (m_samples, n_features)
        Target domain samples.
    Yth : array-like of shape (n_samples,n_classes)
        Transported source domain labels one hot encoded.
    labels : array-like of shape (n_classes,)
        The labels of the classification problem.
    thr_weights : float, default=1e-6
        The relative threshold for the weights

    Returns
    -------
    Xh : array-like of shape (n_samples, n_features)
        The transported source domain samples.
    yh : array-like of shape (n_samples,)
        The transported source domain labels.
    wh : array-like of shape (n_samples,)
        The transported source domain weights.

    References
    ----------
    [10] N. Courty, R. Flamary, A. Habrard, A. Rakotomamonjy, Joint Distribution
         Optimal Transportation for Domain Adaptation, Neural Information Processing
         Systems (NIPS), 2017.

    """
    thr = thr_weights * np.max(Yth)

    Xh = np.repeat(Xt, Yth.shape[1], axis=0)
    yh = np.tile(labels, Yth.shape[0])
    wh = Yth.flatten()

    # remove samples with low weights
    ind = wh > thr
    Xh = Xh[ind]
    yh = yh[ind]
    wh = wh[ind]

    return Xh, yh, wh


def get_tgt_loss_jdot_class(Xh, yh, wh, estimator, metric="multinomial"):
    """Get target loss for joint distribution optimal transport classification problem.

    Parameters
    ----------
    Xh : array-like of shape (n_samples, n_features)
        The transported source domain samples.
    yh : array-like of shape (n_samples,)
        The transported source domain labels.
    wh : array-like of shape (n_samples,)
        The transported source domain weights.
    estimator : object
        The already fitted estimator to be used for the classification task. This
        estimator should optimize a classification loss corresponding to the
        given metric and provide compatible predict method (decision_function of
        predict_proba).
    metric : str, default='multinomial'
        The metric to use for the cost matrix. Can be 'multinomial' for cross-entropy
        cost/ multinomial logistic regression or 'hinge' for hinge cost
        (SVM/SVC).

    Returns
    -------
    loss : float
        The target labels losses.

    References
    ----------
    [10] N. Courty, R. Flamary, A. Habrard, A. Rakotomamonjy, Joint Distribution
         Optimal Transportation for Domain Adaptation, Neural Information Processing
         Systems (NIPS), 2017.

    """
    if metric == "multinomial":
        if hasattr(estimator, "predict_log_proba"):
            Yh_pred = estimator.predict_log_proba(Xh)
            loss = -np.sum(yh * Yh_pred, 1).dot(wh)
        elif hasattr(estimator, "predict_proba"):
            Yh_pred = estimator.predict_proba(Xh)
            loss = -np.sum(yh * np.log(Yh_pred + 1e-16), 1).dot(wh)
        else:
            raise ValueError(
                "Estimator must have predict_proba or predict_log_proba method"
                " for multinomial loss"
            )

    elif metric == "hinge":
        yh = 2 * yh - 1  # make Y -1/1 for hinge loss

        if hasattr(estimator, "decision_function"):
            Yh_pred = estimator.decision_function(Xh)
            if len(Yh_pred.shape) == 1:  # handle binary classification
                Yh_pred = np.repeat(Yh_pred.reshape(-1, 1), 2, axis=1)
            loss = np.sum(np.maximum(0, 1 - yh * Yh_pred), 1).dot(wh)
        else:
            raise ValueError(
                "Estimator must have decision_function method for hinge loss"
            )
    else:
        raise ValueError("Unknown metric")

    return loss


def solve_jdot_regression(
    base_estimator,
    Xs,
    ys,
    Xt,
    alpha=0.5,
    ws=None,
    wt=None,
    n_iter_max=100,
    tol=1e-5,
    verbose=False,
    **kwargs,
):
    """Solve the joint distribution optimal transport regression problem [10]

    .. warning::
        This estimator assumes that the loss function optimized by the base
        estimator is the quadratic loss. For instance, the base estimator should
        optimize and L2 loss (e.g. LinearRegression() or Ridge() or even
        MLPRegressor ()). While any estimator providing the necessary prediction
        functions can be used, the convergence of the fixed point is not guaranteed
        and behavior can be unpredictable.

    Parameters
    ----------
    base_estimator : object
        The base estimator to be used for the regression task. This estimator
        should solve a least squares regression problem (regularized or not)
        to correspond to JDOT theoretical regression problem but other
        approaches can be used with the risk that the fixed point might not converge.
    Xs : array-like of shape (n_samples, n_features)
        Source domain samples.
    ys : array-like of shape (n_samples,)
        Source domain labels.
    Xt : array-like of shape (m_samples, n_features)
        Target domain samples.
    alpha : float, default=0.5
        The trade-off parameter between the feature and label loss in OT metric
    ws : array-like of shape (n_samples,)
        Source domain weights (will ne normalized to sum to 1).
    wt : array-like of shape (m_samples,)
        Target domain weights (will ne normalized to sum to 1).
    n_iter_max: int
        Max number of JDOT alternat optimization iterations.
    tol: float>0
        Tolerance for loss variations (OT and mse) stopping iterations.
    verbose: bool
        Print loss along iterations if True.as_integer_ratio
    kwargs : dict
        Additional parameters to be passed to the base estimator.


    Returns
    -------
    estimator : object
        The fitted estimator.
    lst_loss_ot : list
        The list of OT losses at each iteration.
    lst_loss_tgt_labels : list
        The list of target labels losses at each iteration.
    sol : object
        The solution of the OT problem.

    References
    ----------
    [10] N. Courty, R. Flamary, A. Habrard, A. Rakotomamonjy, Joint Distribution
        Optimal Transportation for Domain Adaptation, Neural Information Processing
        Systems (NIPS), 2017.

    """
    estimator = clone(base_estimator)

    # compute feature distance matrix
    Mf = ot.dist(Xs, Xt)
    Mf = Mf / Mf.mean()

    nt = Xt.shape[0]
    if ws is None:
        a = np.ones((len(ys),)) / len(ys)
    else:
        a = ws / ws.sum()
    if wt is None:
        b = np.ones((nt,)) / nt
    else:
        b = wt / wt.sum()
        kwargs["sample_weight"] = wt  # add it as sample_weight for fit

    lst_loss_ot = []
    lst_loss_tgt_labels = []
    y_pred = 0
    Ml = ot.dist(ys.reshape(-1, 1), np.zeros((nt, 1)))

    for i in range(n_iter_max):
        if i > 0:
            # update the cost matrix
            M = (1 - alpha) * Mf + alpha * Ml
        else:
            M = (1 - alpha) * Mf

        # sole OT problem
        sol = ot.solve(M, a, b)

        T = sol.plan
        loss_ot = sol.value

        if i == 0:
            loss_ot += alpha * np.sum(Ml * T)

        lst_loss_ot.append(loss_ot)

        # compute the transported labels
        yth = ys.T.dot(T) / b

        # fit the estimator
        estimator.fit(Xt, yth, **kwargs)
        y_pred = estimator.predict(Xt)

        Ml = ot.dist(ys.reshape(-1, 1), y_pred.reshape(-1, 1))

        # compute the loss
        loss_tgt_labels = np.mean((yth - y_pred) ** 2)
        lst_loss_tgt_labels.append(loss_tgt_labels)

        if verbose:
            print(f"iter={i}, loss_ot={loss_ot}, loss_tgt_labels={loss_tgt_labels}")

        # break on tol OT loss
        if i > 0 and abs(lst_loss_ot[-1] - lst_loss_ot[-2]) < tol:
            break

        # break on tol target loss
        if i > 0 and abs(lst_loss_tgt_labels[-1] - lst_loss_tgt_labels[-2]) < tol:
            break

        # update the cost matrix
        if i == n_iter_max - 1:
            warnings.warn("Maximum number of iterations reached.")

    return estimator, lst_loss_ot, lst_loss_tgt_labels, sol


def solve_jdot_classification(
    base_estimator,
    Xs,
    ys,
    Xt,
    alpha=0.5,
    ws=None,
    wt=None,
    metric="multinomial",
    n_iter_max=100,
    tol=1e-5,
    verbose=False,
    thr_weights=1e-6,
    **kwargs,
):
    """Solve the joint distribution optimal transport classification problem [10]

    .. warning::
        This estimator assumes that the loss function optimized by the base
        estimator is compatible with the given metric. For instance, if the
        metric is 'multinomial', the base estimator should optimize a
        cross-entropy loss (e.g. LogisticRegression with multi_class='multinomial')
        or a hinge loss (e.g. SVC with kernel='linear' and one versus rest) if the
        metric is 'hinge'. While any estimator providing the necessary prediction
        functions can be used, the convergence of the fixed point is not guaranteed
        and behavior can be unpredictable.


    Parameters
    ----------
    base_estimator : object
        The base estimator to be used for the classification task. This estimator
        should solve a classification problem to correspond to JDOT theoretical
        classification problem but other approaches can be used with the risk
        that the fixed point might not converge.
    Xs : array-like of shape (n_samples, n_features)
        Source domain samples.
    ys : array-like of shape (n_samples,)
        Source domain labels.
    Xt : array-like of shape (m_samples, n_features)
        Target domain samples.
    alpha : float, default=0.5
        The trade-off parameter between the feature and label loss in OT metric
    ws : array-like of shape (n_samples,)
        Source domain weights (will ne normalized to sum to 1).
    wt : array-like of shape (m_samples,)
        Target domain weights (will ne normalized to sum to 1).
    metric : str, default='multinomial'
        The metric to use for the cost matrix. Can be 'multinomial' for
        cross-entropy cost/ multinomial logistic regression or 'hinge' for
        hinge cost (SVM/SVC).
    n_iter_max: int
        Max number of JDOT alternate optimization iterations.
    tol: float>0
        Tolerance for loss variations (OT and mse) stopping iterations.
    verbose: bool
        Print loss along iterations if True.as_integer_ratio
    thr_weights : float, default=1e-6
        The relative threshold for the weights
    kwargs : dict
        Additional parameters to be passed to the base estimator.

    Returns
    -------
    estimator : object
        The fitted estimator.
    lst_loss_ot : list
        The list of OT losses at each iteration.
    lst_loss_tgt_labels : list
        The list of target labels losses at each iteration.
    sol : object
        The solution of the OT problem.

    References
    ----------
    [10] N. Courty, R. Flamary, A. Habrard, A. Rakotomamonjy, Joint Distribution
         Optimal Transportation for Domain Adaptation, Neural Information Processing
         Systems (NIPS), 2017.

    """
    estimator = clone(base_estimator)

    # compute feature distance matrix
    Mf = ot.dist(Xs, Xt)
    Mf = Mf / Mf.mean()

    nt = Xt.shape[0]
    if ws is None:
        a = np.ones((len(ys),)) / len(ys)
    else:
        a = ws / ws.sum()
    if wt is None:
        b = np.ones((nt,)) / nt
    else:
        b = wt / wt.sum()

    encoder = OneHotEncoder(sparse_output=False)
    Ys = encoder.fit_transform(ys.reshape(-1, 1))
    labels = encoder.categories_[0]

    lst_loss_ot = []
    lst_loss_tgt_labels = []
    Ml = get_jdot_class_cost_matrix(ys, Xt, None, metric=metric)

    for i in range(n_iter_max):
        if i > 0:
            # update the cost matrix
            M = (1 - alpha) * Mf + alpha * Ml
        else:
            M = (1 - alpha) * Mf

        # sole OT problem
        sol = ot.solve(M, a, b)

        T = sol.plan
        loss_ot = sol.value

        if i == 0:
            loss_ot += alpha * np.sum(Ml * T)

        lst_loss_ot.append(loss_ot)

        # compute the transported labels
        Yth = T.T.dot(Ys) * nt  # not normalized because weights used in fit

        # create reweighted taregt data for classification
        Xh, yh, wh = get_data_jdot_class(Xt, Yth, labels, thr_weights=thr_weights)

        # fit the estimator
        estimator.fit(Xh, yh, sample_weight=wh, **kwargs)

        Ml = get_jdot_class_cost_matrix(Ys, Xt, estimator, metric=metric)

        # compute the losses
        loss_tgt_labels = (
            get_tgt_loss_jdot_class(
                Xh, encoder.transform(yh[:, None]), wh, estimator, metric=metric
            )
            / nt
        )
        lst_loss_tgt_labels.append(loss_tgt_labels)

        if verbose:
            print(f"iter={i}, loss_ot={loss_ot}, loss_tgt_labels={loss_tgt_labels}")

        # break on tol OT loss
        if i > 0 and abs(lst_loss_ot[-1] - lst_loss_ot[-2]) < tol:
            break

        # break on tol target loss
        if i > 0 and abs(lst_loss_tgt_labels[-1] - lst_loss_tgt_labels[-2]) < tol:
            break

        # update the cost matrix
        if i == n_iter_max - 1:
            warnings.warn("Maximum number of iterations reached.")

    return estimator, lst_loss_ot, lst_loss_tgt_labels, sol


[docs] class JDOTRegressor(DAEstimator): """Joint Distribution Optimal Transport Regressor proposed in [10] .. warning:: This estimator assumes that the loss function optimized by the base estimator is the quadratic loss. For instance, the base estimator should optimize and L2 loss (e.g. LinearRegression() or Ridge() or even MLPRegressor ()). While any estimator providing the necessary prediction functions can be used, the convergence of the fixed point is not guaranteed and behavior can be unpredictable. Parameters ---------- base_estimator : object The base estimator to be used for the regression task. This estimator should solve a least squares regression problem (regularized or not) to correspond to JDOT theoretical regression problem but other approaches can be used with the risk that the fixed point might not converge. default value is LinearRegression() from scikit-learn. alpha : float, default=0.5 The trade-off parameter between the feature and label loss in OT metric n_iter_max: int Max number of JDOT alternat optimization iterations. tol: float>0 Tolerance for loss variations (OT and mse) stopping iterations. verbose: bool Print loss along iterations if True.as_integer_ratio Attributes ---------- estimator_ : object The fitted estimator. lst_loss_ot_ : list The list of OT losses at each iteration. lst_loss_tgt_labels_ : list The list of target labels losses at each iteration. sol_ : object The solution of the OT problem. References ---------- [10] N. Courty, R. Flamary, A. Habrard, A. Rakotomamonjy, Joint Distribution Optimal Transportation for Domain Adaptation, Neural Information Processing Systems (NIPS), 2017. """ def __init__( self, base_estimator=None, alpha=0.5, n_iter_max=100, tol=1e-5, verbose=False, **kwargs, ): if base_estimator is None: base_estimator = LinearRegression() else: if not hasattr(base_estimator, "fit") or not hasattr( base_estimator, "predict" ): raise ValueError( "base_estimator must be a regressor with" " fit and predict methods" ) self.base_estimator = base_estimator self.kwargs = kwargs self.alpha = alpha self.n_iter_max = n_iter_max self.tol = tol self.verbose = verbose def fit(self, X, y=None, sample_domain=None, *, sample_weight=None): """Fit adaptation parameters""" Xs, Xt, ys, yt, ws, wt = source_target_split( X, y, sample_weight, sample_domain=sample_domain ) res = solve_jdot_regression( self.base_estimator, Xs, ys, Xt, ws=ws, wt=wt, alpha=self.alpha, n_iter_max=self.n_iter_max, tol=self.tol, verbose=self.verbose, **self.kwargs, ) self.estimator_, self.lst_loss_ot_, self.lst_loss_tgt_labels_, self.sol_ = res def predict(self, X, sample_domain=None, *, sample_weight=None): """Predict using the model""" check_is_fitted(self) if sample_domain is not None and np.any(sample_domain >= 0): warnings.warn( "Source domain detected. Predictor is trained on target" "and prediction might be biased." ) return self.estimator_.predict(X) def score(self, X, y, sample_domain=None, *, sample_weight=None): """Return the coefficient of determination R^2 of the prediction""" check_is_fitted(self) if sample_domain is not None and np.any(sample_domain >= 0): warnings.warn( "Source domain detected. Predictor is trained on target" "and score might be biased." ) return self.estimator_.score(X, y, sample_weight=sample_weight)
class JDOTClassifier(DAEstimator): """Joint Distribution Optimal Transport Classifier proposed in [10] .. warning:: This estimator assumes that the loss function optimized by the base estimator is compatible with the given metric. For instance, if the metric is 'multinomial', the base estimator should optimize a cross-entropy loss (e.g. LogisticRegression with multi_class='multinomial') or a hinge loss (e.g. SVC with kernel='linear' and one versus rest) if the metric is 'hinge'. While any estimator providing the necessary prediction functions can be used, the convergence of the fixed point is not guaranteed and behavior can be unpredictable. Parameters ---------- base_estimator : object The base estimator to be used for the classification task. This estimator should solve a classification problem to correspond to JDOT theoretical classification problem but other approaches can be used with the risk that the fixed point might not converge. default value is LogisticRegression() from scikit-learn. alpha : float, default=0.5 The trade-off parameter between the feature and label loss in OT metric metric : str, default='multinomial' The metric to use for the cost matrix. Can be 'multinomial' for cross-entropy cost/ multinomial logistic regression or 'hinge' for hinge cost (SVM/SVC). n_iter_max: int Max number of JDOT alternat optimization iterations. tol: float>0 Tolerance for loss variations (OT and mse) stopping iterations. verbose: bool Print loss along iterations if True.as_integer_ratio thr_weights : float, default=1e-6 The relative threshold for the weights Attributes ---------- estimator_ : object The fitted estimator. lst_loss_ot_ : list The list of OT losses at each iteration. lst_loss_tgt_labels_ : list The list of target labels losses at each iteration. sol_ : object The solution of the OT problem. References ---------- [10] N. Courty, R. Flamary, A. Habrard, A. Rakotomamonjy, Joint Distribution Optimal Transportation for Domain Adaptation, Neural Information Processing Systems (NIPS), 2017. """ def __init__( self, base_estimator=None, alpha=0.5, metric="multinomial", n_iter_max=100, tol=1e-5, verbose=False, thr_weights=1e-6, **kwargs, ): if base_estimator is None: base_estimator = LogisticRegression(multi_class="multinomial") else: if not hasattr(base_estimator, "fit") or not hasattr( base_estimator, "predict" ): raise ValueError( "base_estimator must be a regressor with" " fit and predict methods" ) self.base_estimator = base_estimator self.kwargs = kwargs self.alpha = alpha self.metric = metric self.n_iter_max = n_iter_max self.tol = tol self.verbose = verbose self.thr_weights = thr_weights def fit(self, X, y=None, sample_domain=None, *, sample_weight=None): """Fit adaptation parameters""" Xs, Xt, ys, yt, ws, wt = source_target_split( X, y, sample_weight, sample_domain=sample_domain ) res = solve_jdot_classification( self.base_estimator, Xs, ys, Xt, ws=ws, wt=wt, alpha=self.alpha, metric=self.metric, n_iter_max=self.n_iter_max, tol=self.tol, verbose=self.verbose, thr_weights=self.thr_weights, **self.kwargs, ) self.estimator_, self.lst_loss_ot_, self.lst_loss_tgt_labels_, self.sol_ = res def predict(self, X, sample_domain=None, *, sample_weight=None, allow_source=False): """Predict using the model""" check_is_fitted(self) if sample_domain is not None and np.any(sample_domain >= 0): warnings.warn( "Source domain detected. Predictor is trained on target" "and prediction might be biased." ) return self.estimator_.predict(X) def _check_proba(self): if hasattr(self.base_estimator, "predict_proba"): return True else: raise AttributeError( "The base estimator does not have a predict_proba method" ) @available_if(_check_proba) def predict_proba( self, X, sample_domain=None, *, sample_weight=None, allow_source=False ): """Predict using the model""" check_is_fitted(self) if sample_domain is not None and np.any(sample_domain >= 0): warnings.warn( "Source domain detected. Predictor is trained on target" "and prediction might be biased." ) return self.estimator_.predict_proba(X) def score(self, X, y, sample_domain=None, *, sample_weight=None, **kwargs): """Return the scores of the prediction""" check_is_fitted(self) if sample_domain is not None and np.any(sample_domain >= 0): warnings.warn( "Source domain detected. Predictor is trained on target" "and score might be biased." ) return self.estimator_.score(X, y, sample_weight=sample_weight) class OTLabelPropAdapter(BaseAdapter): """Label propagation using optimal transport plan. This adapter uses the optimal transport plan to propagate labels from source to target domain. This was proposed originally in [28] for semi-supervised learning and can be used for domain adaptation. Parameters ---------- metric : str, default='sqeuclidean' The metric to use for the cost matrix. Can be 'sqeuclidean' for squared euclidean distance, 'euclidean' for euclidean distance, reg : float, default=None The entropic regularization parameter for the optimal transport problem. If None, the exact OT is solved, else it is used to weight the entropy regularizationof the coupling matrix. n_iter_max: int Maximum number of iterations for the OT solver. Attributes ---------- G_ : array-like of shape (n_samples, m_samples) The optimal transport plan. Xt_ : array-like of shape (m_samples, n_features) The target domain samples. yht_ : array-like of shape (m_samples,) The transported source domain labels. References ---------- [28] Solomon, J., Rustamov, R., Guibas, L., & Butscher, A. (2014, January). Wasserstein propagation for semi-supervised learning. In International Conference on Machine Learning (pp. 306-314). PMLR. """ __metadata_request__fit = {"sample_weight": True} __metadata_request__fit_transform = {"sample_weight": True} def __init__(self, metric="sqeuclidean", reg=None, n_iter_max=200): super().__init__() self.metric = metric self.reg = reg self.n_iter_max = n_iter_max def fit_transform(self, X, y, sample_domain=None, *, sample_weight=None): """Fit adaptation parameters""" X, y, sample_domain = check_X_y_domain(X, y, sample_domain) if sample_weight is not None: Xs, Xt, ys, yt, ws, wt = source_target_split( X, y, sample_weight, sample_domain=sample_domain ) ws = ws / ws.sum() wt = wt / wt.sum() else: Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain) ws = ot.unif(Xs.shape[0]) wt = ot.unif(Xt.shape[0]) M = ot.dist(Xs, Xt, metric=self.metric) G = ot.solve(M, ws, wt, reg=self.reg, max_iter=self.n_iter_max).plan self.discrete_ = discrete = _find_y_type(ys) == Y_Type.DISCRETE if discrete: self.classes_ = classes = np.unique(ys) Y = np.zeros((Xs.shape[0], len(classes))) for i, c in enumerate(classes): Y[:, i] = (ys == c).astype(int) yht = G.T.dot(Y) self.yht_continuous_ = yht yht = np.argmax(yht, axis=1) yht = classes[yht] yout = -np.ones_like(y) else: Y = ys yht = G.T.dot(Y) / wt self.yht_continuous_ = yht yout = np.ones_like(y) * np.nan self.G_ = G self.Xt_ = Xt self.yht_ = yht # set estimated labels yout[sample_domain < 0] = yht # return sample weight only if it was provided dico = dict() if sample_weight is not None: dico["sample_weight"] = sample_weight return X, yout, dico def OTLabelProp(base_estimator=None, reg=0, metric="sqeuclidean", n_iter_max=200): """Label propagation using optimal transport plan. This adapter uses the optimal transport plan to propagate labels from source to target domain. This was proposed originally in [28] for semi-supervised learning and can be used for domain adaptation. Parameters ---------- base_estimator : object The base estimator to be used for the classification task. This estimator should optimize a classification loss corresponding to the given metric and provide compatible predict method (decision_function of predict_proba). reg : float, default=0 The entropic regularization parameter for the optimal transport problem. If None, the exact OT is solved, else it is used to weight the entropy regularizationof the coupling matrix. metric : str, default='sqeuclidean' The metric to use for the cost matrix. Can be 'sqeuclidean' for squared euclidean distance, 'euclidean' for euclidean distance, n_iter_max: int Maximum number of iterations for the OT solver. Returns ------- adapter : OTLabelPropAdapter The optimal transport label propagation adapter. References ---------- [28] Solomon, J., Rustamov, R., Guibas, L., & Butscher, A. (2014, January). Wasserstein propagation for semi-supervised learning. In International Conference on Machine Learning (pp. 306-314). PMLR. """ if base_estimator is None: base_estimator = SVC(kernel="rbf").set_fit_request(sample_weight=True) return make_da_pipeline( OTLabelPropAdapter(reg=reg, metric=metric, n_iter_max=n_iter_max), base_estimator, ) class JCPOTLabelPropAdapter(BaseAdapter): """JCPOT Label Propagation Adapter for multi source target shift This adapter uses the optimal transport plan to propagate labels from sources to target domain with target shift (change in proportion of classes). This was proposed in [31]. Parameters ---------- metric : str, default='sqeuclidean' The metric to use for the cost matrix. Can be 'sqeuclidean' for squared euclidean distance, 'euclidean' for euclidean distance, reg : float, default=1 The entropic regularization parameter for the optimal transport problem. max_iter : int, default=10 Maximum number of iterations for the JCPOT solver. tol : float, default=1e-9 Tolerance for loss variations (OT and mse) stopping iterations. verbose : bool, default=False Print loss along iterations if True. References ---------- [31] Redko, Ievgen, Nicolas Courty, Rémi Flamary, and Devis Tuia. "Optimal transport for multi-source domain adaptation under target shift." In The 22nd International Conference on artificial intelligence and statistics, pp. 849-858. PMLR, 2019. """ def __init__( self, metric="sqeuclidean", reg=1, max_iter=10, tol=1e-9, verbose=False ): super().__init__() self.metric = metric self.reg = reg self.max_iter = max_iter self.tol = tol self.verbose = verbose def fit_transform(self, X, y, sample_domain=None, *, sample_weight=None): X, y, sample_domain = check_X_y_domain(X, y, sample_domain) sources, targets = per_domain_split(X, y, sample_domain=sample_domain) Xs = [X for X, y in sources.values()] ys = [y for X, y in sources.values()] if len(ys) == 1: Xs = Xs * 2 ys = ys * 2 Xt = [X for X, y in targets.values()] Xt = np.concatenate(Xt, axis=0) self.ot_adapter_ = ot.da.JCPOTTransport( reg_e=self.reg, metric=self.metric, max_iter=self.max_iter, tol=self.tol, log=True, ) self.ot_adapter_.fit(Xs=Xs, ys=ys, Xt=Xt) yh = self.ot_adapter_.transform_labels(ys) self.yh_continuous_ = yh yh = np.argmax(yh, axis=1) yout = -np.ones_like(y) yout[sample_domain < 0] = yh return X, yout, {} def JCPOTLabelProp( base_estimator=None, reg=1, metric="sqeuclidean", max_iter=10, tol=1e-9, verbose=False, ): """JCPOT Label Propagation Adapter for multi source target shift This adapter uses the optimal transport plan to propagate labels from sources to target domain with target shift (change in proportion of classes). This was proposed in [31]. Parameters ---------- base_estimator : object, default=LinearRegression() The base estimator to be used for the classification task. This estimator should optimize a classification loss corresponding to the given metric and provide compatible predict method (decision_function of predict_proba). reg : float, default=1 The entropic regularization parameter for the optimal transport problem. metric : str, default='sqeuclidean' The metric to use for the cost matrix. Can be 'sqeuclidean' for squared euclidean distance, 'euclidean' for euclidean distance, max_iter : int, default=10 Maximum number of iterations for the JCPOT solver. tol : float, default=1e-9 Tolerance for loss variations (OT and mse) stopping iterations. verbose : bool, default=False Print loss along iterations if True. Returns ------- adapter : JCPOTLabelPropAdapter The optimal transport label propagation adapter. References ---------- [31] Redko, Ievgen, Nicolas Courty, Rémi Flamary, and Devis Tuia. "Optimal transport for multi-source domain adaptation under target shift." In The 22nd International Conference on artificial intelligence and statistics, pp. 849-858. PMLR, 2019. """ if base_estimator is None: base_estimator = LogisticRegression() return make_da_pipeline( JCPOTLabelPropAdapter( reg=reg, metric=metric, max_iter=max_iter, tol=tol, verbose=verbose ), base_estimator, )