# Author: Theo Gnassounou <theo.gnassounou@inria.fr>
#         Remi Flamary <remi.flamary@polytechnique.edu>
#         Oleksii Kachaiev <kachayev@gmail.com>
#         Antoine Collas <contact@antoinecollas.fr>
#
# License: BSD 3-Clause
from abc import abstractmethod
import numpy as np
from ot import da
from ot.gaussian import bures_wasserstein_barycenter, bures_wasserstein_mapping
from sklearn.linear_model import LogisticRegression
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.svm import SVC
from ._pipeline import make_da_pipeline
from ._utils import Y_Type, _estimate_covariance, _find_y_type
from .base import BaseAdapter, clone
from .utils import (
    check_X_domain,
    extract_domains_indices,
    extract_source_indices,
    per_domain_split,
    source_target_merge,
    source_target_split,
    torch_minimize,
)
class BaseOTMappingAdapter(BaseAdapter):
    """Base class for all DA estimators implemented using OT mapping.
    Each implementation has to provide `_create_transport_estimator` callback
    to create OT object using parameters saved in the constructor.
    """
    def fit(self, X, y=None, *, sample_domain=None):
        """Fit adaptation parameters.
        Parameters
        ----------
        X : array-like, shape (n_samples, n_features)
            The source data.
        y : array-like, shape (n_samples,)
            The source labels.
        sample_domain : array-like, shape (n_samples,)
            The domain labels (same as sample_domain).
        Returns
        -------
        self : object
            Returns self.
        """
        X, sample_domain = check_X_domain(X, sample_domain)
        X, X_target, y, y_target = source_target_split(
            X, y, sample_domain=sample_domain
        )
        transport = self._create_transport_estimator()
        self.ot_transport_ = clone(transport)
        self.ot_transport_.fit(Xs=X, ys=y, Xt=X_target, yt=y_target)
        return self
    def fit_transform(self, X, y=None, *, sample_domain=None, **params):
        """Predict adaptation (weights, sample or labels).
        Parameters
        ----------
        X : array-like, shape (n_samples, n_features)
            The source data.
        y : array-like, shape (n_samples,)
            The source labels.
        sample_domain : array-like, shape (n_samples,)
            The domain labels.
        Returns
        -------
        X_t : array-like, shape (n_samples, n_components)
            The data transformed to the target subspace.
        """
        self.fit(X, y, sample_domain=sample_domain)
        return self.transform(X, sample_domain=sample_domain, allow_source=True)
    def transform(
        self, X, y=None, *, sample_domain=None, allow_source=False, **params
    ) -> np.ndarray:
        # xxx(okachaiev): implement auto-infer for sample_domain
        X, sample_domain = check_X_domain(
            X,
            sample_domain,
            allow_source=allow_source,
            allow_multi_source=True,
            allow_multi_target=True,
        )
        X_source, X_target = source_target_split(X, sample_domain=sample_domain)
        # in case of prediction we would get only target samples here,
        # thus there's no need to perform any transformations
        if X_source.shape[0] > 0:
            X_source = self.ot_transport_.transform(Xs=X_source)
        X_adapt, _ = source_target_merge(
            X_source, X_target, sample_domain=sample_domain
        )
        return X_adapt
    @abstractmethod
    def _create_transport_estimator(self):
        pass
[docs]
class OTMappingAdapter(BaseOTMappingAdapter):
    """Domain Adaptation Using Optimal Transport.
    See [6]_ for details.
    Parameters
    ----------
    metric : str, optional (default="sqeuclidean")
        The ground metric for the Wasserstein problem
    norm : {'median', 'max', 'log', 'loglog'} (default=None)
        If given, normalize the ground metric to avoid numerical errors that
        can occur with large metric values.
    max_iter : int, optional (default=100_000)
        The maximum number of iterations before stopping OT algorithm if it
        has not converged.
    Attributes
    ----------
    ot_transport_ : object
        The OT object based on Earth Mover's distance
        fitted on the source and target data.
    References
    ----------
    .. [6] N. Courty, R. Flamary, D. Tuia and A. Rakotomamonjy,
           Optimal Transport for Domain Adaptation, in IEEE
           Transactions on Pattern Analysis and Machine Intelligence
    """
    def __init__(
        self,
        metric="sqeuclidean",
        norm=None,
        max_iter=100_000,
    ):
        super().__init__()
        self.metric = metric
        self.norm = norm
        self.max_iter = max_iter
    def _create_transport_estimator(self):
        return da.EMDTransport(
            metric=self.metric,
            norm=self.norm,
            max_iter=self.max_iter,
        ) 
[docs]
def OTMapping(base_estimator=None, metric="sqeuclidean", norm=None, max_iter=100000):
    """OTmapping pipeline with adapter and estimator.
    See [6]_ for details.
    Parameters
    ----------
    base_estimator : object, optional (default=None)
        The base estimator to fit on the target dataset.
    metric : str, optional (default="sqeuclidean")
        The ground metric for the Wasserstein problem
    norm : {'median', 'max', 'log', 'loglog'} (default=None)
        If given, normalize the ground metric to avoid numerical errors that
        can occur with large metric values.
    max_iter : int, optional (default=100_000)
        The maximum number of iterations before stopping OT algorithm if it
        has not converged.
    Returns
    -------
    pipeline : Pipeline
        Pipeline containing OTMapping adapter and base estimator.
    References
    ----------
    .. [6] N. Courty, R. Flamary, D. Tuia and A. Rakotomamonjy,
           Optimal Transport for Domain Adaptation, in IEEE
           Transactions on Pattern Analysis and Machine Intelligence
    """
    if base_estimator is None:
        base_estimator = SVC(kernel="rbf")
    return make_da_pipeline(
        OTMappingAdapter(metric=metric, norm=norm, max_iter=max_iter),
        base_estimator,
    ) 
[docs]
class EntropicOTMappingAdapter(BaseOTMappingAdapter):
    """Domain Adaptation Using Optimal Transport.
    See [6]_ for details.
    Parameters
    ----------
    reg_e : float, default=1
        Entropic regularization parameter.
    metric : str, optional (default="sqeuclidean")
        The ground metric for the Wasserstein problem.
    norm : {'median', 'max', 'log', 'loglog'} (default=None)
        If given, normalize the ground metric to avoid numerical errors that
        can occur with large metric values.
    max_iter : int, float, optional (default=1000)
        The minimum number of iteration before stopping the optimization
        of the Sinkhorn algorithm if it has not converged
    tol : float, optional (default=10e-9)
        The precision required to stop the optimization of the Sinkhorn
        algorithm.
    Attributes
    ----------
    ot_transport_ : object
        The OT object based on Sinkhorn Algorithm
        fitted on the source and target data.
    References
    ----------
    .. [6] N. Courty, R. Flamary, D. Tuia and A. Rakotomamonjy,
           Optimal Transport for Domain Adaptation, in IEEE
           Transactions on Pattern Analysis and Machine Intelligence
    """
    def __init__(
        self,
        reg_e=1.0,
        metric="sqeuclidean",
        norm=None,
        max_iter=1000,
        tol=10e-9,
    ):
        super().__init__()
        self.reg_e = reg_e
        self.metric = metric
        self.norm = norm
        self.max_iter = max_iter
        self.tol = tol
    def _create_transport_estimator(self):
        return da.SinkhornTransport(
            reg_e=self.reg_e,
            metric=self.metric,
            norm=self.norm,
            max_iter=self.max_iter,
            tol=self.tol,
        ) 
[docs]
def EntropicOTMapping(
    base_estimator=None,
    metric="sqeuclidean",
    norm=None,
    max_iter=1000,
    reg_e=1.0,
    tol=1e-8,
):
    """EntropicOTMapping pipeline with adapter and estimator.
    see [6]_ for details.
    Parameters
    ----------
    base_estimator : object, optional (default=None)
        The base estimator to fit on the target dataset.
    reg_e : float, default=1
        Entropic regularization parameter.
    metric : str, optional (default="sqeuclidean")
        The ground metric for the Wasserstein problem.
    norm : {'median', 'max', 'log', 'loglog'} (default=None)
        If given, normalize the ground metric to avoid numerical errors that
        can occur with large metric values.
    max_iter : int, float, optional (default=1000)
        The minimum number of iteration before stopping the optimization
        of the Sinkhorn algorithm if it has not converged
    tol : float, optional (default=10e-9)
        The precision required to stop the optimization of the Sinkhorn
        algorithm.
    Returns
    -------
    pipeline : Pipeline
        Pipeline containing EntropicOTMapping adapter and base estimator.
    References
    ----------
    .. [6] N. Courty, R. Flamary, D. Tuia and A. Rakotomamonjy,
           Optimal Transport for Domain Adaptation, in IEEE
           Transactions on Pattern Analysis and Machine Intelligence
    """
    if base_estimator is None:
        base_estimator = SVC(kernel="rbf")
    return make_da_pipeline(
        EntropicOTMappingAdapter(
            metric=metric, norm=norm, max_iter=max_iter, reg_e=reg_e, tol=tol
        ),
        base_estimator,
    ) 
[docs]
class ClassRegularizerOTMappingAdapter(BaseOTMappingAdapter):
    """Domain Adaptation Using Optimal Transport.
    See [6]_ for details.
    Parameters
    ----------
    reg_e : float, default=1
        Entropic regularization parameter.
    reg_cl : float, default=0.1
        Class regularization parameter.
    norm : str, default="lpl1"
        Norm use for the regularizer of the class labels.
        If "lpl1", use the lp l1 norm.
        If "l1l2", use the l1 l2 norm.
    metric : str, optional (default="sqeuclidean")
        The ground metric for the Wasserstein problem
    max_iter : int, float, optional (default=10)
        The minimum number of iteration before stopping the optimization
        algorithm if it has not converged
    max_inner_iter : int, float, optional (default=200)
        The number of iteration in the inner loop
    tol : float, optional (default=10e-9)
        Stop threshold on error (inner sinkhorn solver) (>0)
    Attributes
    ----------
    ot_transport_ : object
        The OT object based on Sinkhorn Algorithm
        + class regularization fitted on the source
        and target data.
    References
    ----------
    .. [6] N. Courty, R. Flamary, D. Tuia and A. Rakotomamonjy,
           Optimal Transport for Domain Adaptation, in IEEE
           Transactions on Pattern Analysis and Machine Intelligence
    """
    def __init__(
        self,
        reg_e=1.0,
        reg_cl=0.1,
        norm="lpl1",
        metric="sqeuclidean",
        max_iter=10,
        max_inner_iter=200,
        tol=10e-9,
    ):
        super().__init__()
        self.reg_e = reg_e
        self.reg_cl = reg_cl
        self.norm = norm
        self.metric = metric
        self.max_iter = max_iter
        self.max_inner_iter = max_inner_iter
        self.tol = tol
    def _create_transport_estimator(self):
        assert self.norm in ["lpl1", "l1l2"], "Unknown norm"
        if self.norm == "lpl1":
            transport_cls = da.SinkhornLpl1Transport
        elif self.norm == "l1l2":
            transport_cls = da.SinkhornL1l2Transport
        return transport_cls(
            reg_e=self.reg_e,
            reg_cl=self.reg_cl,
            metric=self.metric,
            max_iter=self.max_iter,
            max_inner_iter=self.max_inner_iter,
            tol=self.tol,
        ) 
[docs]
def ClassRegularizerOTMapping(
    base_estimator=SVC(kernel="rbf"),
    metric="sqeuclidean",
    norm="lpl1",
    max_iter=10,
    max_inner_iter=200,
    reg_e=1.0,
    reg_cl=0.1,
    tol=1e-8,
):
    """ClassRegularizedOTMapping pipeline with adapter and estimator.
    see [6]_ for details.
    Parameters
    ----------
    base_estimator : object, optional (default=SVC(kernel="rbf"))
        The base estimator to fit on the target dataset.
    reg_e : float, default=1
        Entropic regularization parameter.
    reg_cl : float, default=0.1
        Class regularization parameter.
    norm : str, default="lpl1"
        Norm use for the regularizer of the class labels.
        If "lpl1", use the lp l1 norm.
        If "l1l2", use the l1 l2 norm.
    metric : str, optional (default="sqeuclidean")
        The ground metric for the Wasserstein problem
    max_iter : int, float, optional (default=10)
        The minimum number of iteration before stopping the optimization
        algorithm if it has not converged
    max_inner_iter : int, float, optional (default=200)
        The number of iteration in the inner loop
    tol : float, optional (default=10e-9)
        Stop threshold on error (inner sinkhorn solver) (>0)
    Returns
    -------
    pipeline : Pipeline
        Pipeline containing ClassRegularizerOTMapping adapter and base estimator.
    References
    ----------
    .. [6] N. Courty, R. Flamary, D. Tuia and A. Rakotomamonjy,
           Optimal Transport for Domain Adaptation, in IEEE
           Transactions on Pattern Analysis and Machine Intelligence
    """
    ot_mapping = make_da_pipeline(
        ClassRegularizerOTMappingAdapter(
            metric=metric,
            norm=norm,
            max_iter=max_iter,
            max_inner_iter=max_inner_iter,
            reg_e=reg_e,
            reg_cl=reg_cl,
            tol=tol,
        ),
        base_estimator,
    )
    return ot_mapping 
[docs]
class LinearOTMappingAdapter(BaseOTMappingAdapter):
    """Domain Adaptation Using Optimal Transport.
    Uses Gaussian Monge mapping to align source and target domains as proposed
    in [7].
    Parameters
    ----------
    reg : float, (default=1e-08)
        regularization added to the diagonals of covariances.
    bias: bool, optional (default=True)
        estimate bias.
    Attributes
    ----------
    ot_transport_ : object
        The OT object based on linear operator between empirical
        distributions fitted on the source
        and target data.
    References
    ----------
    .. [7] Flamary, R., Lounici, K., & Ferrari, A. (2019). Concentration bounds
        for linear monge mapping estimation and optimal transport domain
        adaptation. arXiv preprint arXiv:1905.10155.
    """
    def __init__(self, reg=1e-08, bias=True):
        super().__init__()
        self.reg = reg
        self.bias = bias
    def _create_transport_estimator(self):
        return da.LinearTransport(reg=self.reg, bias=self.bias) 
[docs]
def LinearOTMapping(
    base_estimator=None,
    reg=1.0,
    bias=True,
):
    """Returns a the linear OT mapping method with adapter and estimator.
    Uses Gaussian Monge mapping to align source and target domains as proposed
    in [7].
    Parameters
    ----------
    base_estimator : object, optional (default=None)
        The base estimator to fit on the target dataset.
    reg : float, (default=1e-08)
        regularization added to the diagonals of covariances.
    bias: bool, optional (default=True)
        estimate bias.
    Returns
    -------
    pipeline : Pipeline
        Pipeline containing linear OT mapping adapter and base estimator.
    References
    ----------
    .. [6] N. Courty, R. Flamary, D. Tuia and A. Rakotomamonjy,
           Optimal Transport for Domain Adaptation, in IEEE
           Transactions on Pattern Analysis and Machine Intelligence
    """
    if base_estimator is None:
        base_estimator = SVC(kernel="rbf")
    return make_da_pipeline(
        LinearOTMappingAdapter(
            reg=reg,
            bias=bias,
        ),
        base_estimator,
    ) 
def _get_cov_mean(X, w=None, bias=True):
    """Returns covariance and mean
    Parameters
    ----------
    X : array-like, shape (n_samples, n_features)
        The source data.
    w : array-like, shape (n_samples,)
        The weights of the samples.
    bias: bool, optional (default=True)
        estimate bias (mean).
    Returns
    -------
    cov : array-like, shape (n_features, n_features)
        The covariance matrix.
    mean : array-like, shape (n_features,)
        The mean vector.
    """
    if w is None:
        w = np.ones(X.shape[0])
    if bias:
        mean = np.average(X, axis=0, weights=w)
    else:
        mean = np.zeros(X.shape[1])
    X = X - mean
    cov = np.dot(w * X.T, X) / np.sum(w)
    return cov, mean
[docs]
class MultiLinearMongeAlignmentAdapter(BaseAdapter):
    """Aligns multiple domains using Gaussian Monge mapping to a barycenter.
    The method is a simplified extension of [29] using the Bures-Wasserstein
    distance and mapping of [7] to align multiple source domains to a
    barycenter. The sued of barycenter alignment with gaussien assumption was
    proposed in [30].
    Parameters
    ----------
    reg : float, optional (default=1e-08)
        Regularization parameter added to the diagonal of the covariance.
    bias : bool, optional (default=True)
        Estimate bias.
    test_time : bool, optional (default=False)
        If True, the estimator can be updated at test time to map new
        target domains unseen during training
    Attributes
    ----------
    cov_means_sources_ : dict
        Dictionary of covariance and mean for each source domain.
    cov_means_targets_ : dict
        Dictionary of covariance and mean for each target domain.
    barycenter_ : tuple
        Barycenter of the source domains (mean, cov).
    _mappings_ : dict
        Dictionary of mappings for each domain.
    References
    ----------
    .. [29] Montesuma, Eduardo Fernandes, and Fred Maurice Ngole Mboula.
        "Wasserstein barycenter for multi-source domain adaptation." In Proceedings
        of the IEEE/CVF conference on computer vision and pattern recognition, pp.
        16785-16793. 2021.
    .. [7] Flamary, R., Lounici, K., & Ferrari, A. (2019). Concentration bounds
        for linear monge mapping estimation and optimal transport domain
        adaptation. arXiv preprint arXiv:1905.10155.
    .. [30] Gnassounou, Theo, Rémi Flamary, and Alexandre Gramfort. "Convolution
        Monge Mapping Normalization for learning on sleep data." Advances in
        Neural Information Processing Systems 36 (2024).
    """
    def __init__(self, reg=1e-08, bias=True, test_time=False):
        super().__init__()
        self.reg = reg
        self.bias = bias
        self.test_time = test_time
    def fit(self, X, y=None, *, sample_domain=None):
        """Fit adaptation parameters.
        Parameters
        ----------
        X : array-like, shape (n_samples, n_features)
            The source data.
        y : array-like, shape (n_samples,)
            The source labels.
        sample_domain : array-like, shape (n_samples,)
            The domain labels (same as sample_domain).
        Returns
        -------
        self : object
            Returns self.
        """
        X, sample_domain = check_X_domain(X, sample_domain)
        sources, targets = per_domain_split(X, y, None, sample_domain=sample_domain)
        self.cov_means_sources_ = {
            domain: _get_cov_mean(X, w, bias=self.bias)
            for domain, (X, y, w) in sources.items()
        }
        self.cov_means_targets_ = {
            domain: _get_cov_mean(X, w, bias=self.bias)
            for domain, (X, y, w) in targets.items()
        }
        C = np.stack([cov for cov, mean in self.cov_means_sources_.values()])
        m = np.stack([mean for cov, mean in self.cov_means_sources_.values()])
        self.barycenter_ = bures_wasserstein_barycenter(
            m,
            C,
            eps=self.reg,
        )
        self.mappings_ = {
            domain: bures_wasserstein_mapping(
                mean,
                self.barycenter_[0],
                cov,
                self.barycenter_[1],
            )
            for domain, (cov, mean) in self.cov_means_sources_.items()
        }
        mapping_target = {
            domain: bures_wasserstein_mapping(
                mean,
                self.barycenter_[0],
                cov,
                self.barycenter_[1],
            )
            for domain, (cov, mean) in self.cov_means_targets_.items()
        }
        self.mappings_.update(mapping_target)
        return self
    def fit_transform(self, X, y=None, sample_domain=None, **params):
        """Predict adaptation (weights, sample or labels).
        Parameters
        ----------
        X : array-like, shape (n_samples, n_features)
            The source data.
        y : array-like, shape (n_samples,)
            The source labels.
        sample_domain : array-like, shape (n_samples,)
            The domain labels (same as sample_domain).
        Returns
        -------
        X_t : array-like, shape (n_samples, n_components)
            The data (same as X).
        """
        self.fit(X, y, sample_domain=sample_domain)
        return self.transform(X, sample_domain=sample_domain, allow_source=True)
    def transform(
        self, X, y=None, *, sample_domain=None, allow_source=False, **params
    ) -> np.ndarray:
        X, sample_domain = check_X_domain(
            X, sample_domain, allow_multi_source=True, allow_multi_target=True
        )
        idx = extract_domains_indices(sample_domain)
        X_adapt = X.copy()
        for domain, sel in idx.items():
            A, b = self.mappings_[domain]
            X_adapt[sel] = X[sel].dot(A) + b
        return X_adapt 
[docs]
def MultiLinearMongeAlignment(
    base_estimator=None, reg=1e-08, bias=True, test_time=False
):
    """MultiLinearMongeAlignment pipeline with adapter and estimator.
    The method is a simplified extension of [29] using the Bures-Wasserstein
    distance and mapping of [7] to align multiple source domains to a
    barycenter. The sued of barycenter alignment with gaussien assumption was
    proposed in [30].
    Parameters
    ----------
    base_estimator : object, optional (default=None)
        The base estimator to fit on the target dataset.
    reg : float, optional (default=1e-08)
        Regularization parameter added to the diagonal of the covariance.
    bias : bool, optional (default=True)
        Estimate bias.
    test_time : bool, optional (default=False)
        If True, the estimator can be updated at test time to map new
        target domains unseen during training
    Returns
    -------
    pipeline : Pipeline
        Pipeline containing MultiLinearMongeAlignment adapter and base
        estimator.
    References
    ----------
    .. [29] Montesuma, Eduardo Fernandes, and Fred Maurice Ngole Mboula.
        "Wasserstein barycenter for multi-source domain adaptation." In Proceedings
        of the IEEE/CVF conference on computer vision and pattern recognition, pp.
        16785-16793. 2021.
    .. [7] Flamary, R., Lounici, K., & Ferrari, A. (2019). Concentration bounds
        for linear monge mapping estimation and optimal transport domain
        adaptation. arXiv preprint arXiv:1905.10155.
    .. [30] Gnassounou, Theo, Rémi Flamary, and Alexandre Gramfort. "Convolution
        Monge Mapping Normalization for learning on sleep data." Advances in
        Neural Information Processing Systems 36 (2024).
    """
    if base_estimator is None:
        base_estimator = LogisticRegression()
    return make_da_pipeline(
        MultiLinearMongeAlignmentAdapter(reg=reg, bias=bias, test_time=test_time),
        base_estimator,
    ) 
def _sqrtm(C):
    r"""Square root of SPD matrices.
    The matrix square root of a SPD matrix C is defined by:
    .. math::
        \mathbf{D} =
        \mathbf{V} \left( \mathbf{\Lambda} \right)^{1/2} \mathbf{V}^\top
    where :math:`\mathbf{\Lambda}` is the diagonal matrix of eigenvalues
    and :math:`\mathbf{V}` the eigenvectors of :math:`\mathbf{C}`.
    Parameters
    ----------
    C : ndarray, shape (n, n)
        SPD matrix.
    Returns
    -------
    D : ndarray, shape (n, n)
        Matrix inverse square root of C.
    """
    eigvals, eigvecs = np.linalg.eigh(C)
    return (eigvecs * np.sqrt(eigvals)) @ eigvecs.T
def _invsqrtm(C):
    r"""Inverse square root of SPD matrices.
    The matrix inverse square root of a SPD matrix C is defined by:
    .. math::
        \mathbf{D} =
        \mathbf{V} \left( \mathbf{\Lambda} \right)^{-1/2} \mathbf{V}^\top
    where :math:`\mathbf{\Lambda}` is the diagonal matrix of eigenvalues
    and :math:`\mathbf{V}` the eigenvectors of :math:`\mathbf{C}`.
    Parameters
    ----------
    C : ndarray, shape (n, n)
        SPD matrix.
    Returns
    -------
    D : ndarray, shape (n, n)
        Matrix inverse square root of C.
    """
    eigvals, eigvecs = np.linalg.eigh(C)
    return (eigvecs * 1.0 / np.sqrt(eigvals)) @ eigvecs.T
[docs]
class CORALAdapter(BaseAdapter):
    """Estimator based on Correlation Alignment [1]_.
    See [5]_ for details.
    Parameters
    ----------
    reg : 'auto' or float, default="auto"
        The regularization parameter of the covariance estimator.
        Possible values:
          - None: no shrinkage).
          - 'auto': automatic shrinkage using the Ledoit-Wolf lemma.
          - float between 0 and 1: fixed shrinkage parameter.
    assume_centered: bool, default=False
        If True, data are not centered before computation.
    Attributes
    ----------
    mean_source_: array, shape (n_features,)
    mean_target_: array, shape (n_features,)
    cov_source_inv_sqrt_: array, shape (n_features, n_features)
        Inverse of the square root of covariance of the source data with regularization.
    cov_target_sqrt_: array, shape (n_features, n_features)
        Square root of covariance of the target data with regularization.
    References
    ----------
    .. [5] Baochen Sun, Jiashi Feng, and Kate Saenko.
           Correlation Alignment for Unsupervised Domain Adaptation.
           In Advances in Computer Vision and Pattern Recognition, 2017.
    """
    def __init__(self, reg="auto", assume_centered=False):
        super().__init__()
        self.reg = reg
        self.assume_centered = assume_centered
    def fit(self, X, y=None, sample_domain=None):
        """Fit adaptation parameters.
        Parameters
        ----------
        X : array-like, shape (n_samples, n_features)
            The source data.
        y : array-like, shape (n_samples,)
            The source labels.
        sample_domain : array-like, shape (n_samples,)
            The domain labels (same as sample_domain).
        Returns
        -------
        self : object
            Returns self.
        """
        X, sample_domain = check_X_domain(
            X, sample_domain, allow_multi_source=True, allow_multi_target=True
        )
        X_source, X_target = source_target_split(X, sample_domain=sample_domain)
        self.mean_source_ = np.mean(X_source, axis=0)
        self.mean_target_ = np.mean(X_target, axis=0)
        cov_source_ = _estimate_covariance(
            X_source, shrinkage=self.reg, assume_centered=self.assume_centered
        )
        cov_target_ = _estimate_covariance(
            X_target, shrinkage=self.reg, assume_centered=self.assume_centered
        )
        self.cov_source_inv_sqrt_ = _invsqrtm(cov_source_)
        self.cov_target_sqrt_ = _sqrtm(cov_target_)
        return self
    def fit_transform(self, X, y=None, *, sample_domain=None, **params):
        """Predict adaptation (weights, sample or labels).
        Parameters
        ----------
        X : array-like, shape (n_samples, n_features)
            The source data.
        y : array-like, shape (n_samples,)
            The source labels.
        sample_domain : array-like, shape (n_samples,)
            The domain labels (same as sample_domain).
        Returns
        -------
        X_t : array-like, shape (n_samples, n_features)
            The data transformed to the target space.
        """
        self.fit(X, y, sample_domain=sample_domain)
        return self.transform(X, sample_domain=sample_domain, allow_source=True)
    def transform(
        self, X, y=None, *, sample_domain=None, allow_source=False, **params
    ) -> np.ndarray:
        X, sample_domain = check_X_domain(
            X,
            sample_domain,
            allow_source=allow_source,
            allow_multi_source=True,
            allow_multi_target=True,
        )
        X_source_adapt, X_target_adapt = source_target_split(
            X, sample_domain=sample_domain
        )
        # Adapt the source data
        if X_source_adapt.shape[0] > 0:
            # Center data
            if not self.assume_centered:
                X_source_adapt = X_source_adapt - self.mean_source_
            # Whitening and coloring source data
            X_source_adapt = np.dot(X_source_adapt, self.cov_source_inv_sqrt_)
            X_source_adapt = np.dot(X_source_adapt, self.cov_target_sqrt_)
        # Adapt the target data
        if X_target_adapt.shape[0] > 0 and not self.assume_centered:
            X_target_adapt = X_target_adapt - self.mean_target_
        X_adapt, _ = source_target_merge(
            X_source_adapt, X_target_adapt, sample_domain=sample_domain
        )
        return X_adapt 
[docs]
def CORAL(
    base_estimator=None,
    reg="auto",
    assume_centered=False,
):
    """CORAL pipeline with adapter and estimator.
    See [5]_ for details.
    Parameters
    ----------
    base_estimator : object, optional (default=None)
        The base estimator to fit on the target dataset.
    reg : 'auto' or float, default="auto"
        The regularization parameter of the covariance estimator.
        Possible values:
          - None: no shrinkage).
          - 'auto': automatic shrinkage using the Ledoit-Wolf lemma.
          - float between 0 and 1: fixed shrinkage parameter.
    assume_centered: bool, default=False
        If True, data are not centered before computation.
    Returns
    -------
    pipeline : Pipeline
        Pipeline containing CORAL adapter and base estimator.
    References
    ----------
    .. [5] Baochen Sun, Jiashi Feng, and Kate Saenko.
           Correlation Alignment for Unsupervised Domain Adaptation.
           In Advances in Computer Vision and Pattern Recognition, 2017.
    """
    if base_estimator is None:
        base_estimator = SVC(kernel="rbf")
    return make_da_pipeline(
        CORALAdapter(reg=reg, assume_centered=assume_centered),
        base_estimator,
    ) 
# xxx(okachaiev): we should move this to 'skada.deep.*' I guess
# to avoid defining things that won't work anyways
[docs]
class MMDLSConSMappingAdapter(BaseAdapter):
    r"""Location-Scale mapping minimizing the MMD with a Gaussian kernel.
    MMDLSConSMapping finds a linear transformation that minimizes the Maximum Mean
    Discrepancy (MMD) between the source and target domains, such that
    $X^t = W(y^s) \\odot X^s + B(y^s)$, where $W(y^s)$ and $B(y^s)$ are the scaling
    and bias of the linear transformation, respectively.
    See Section 4 of [21]_ for details.
    Parameters
    ----------
    gamma : float
        Parameter for the Gaussian kernel.
    reg_k : float, default=1e-10
        Regularization parameter for the labels kernel matrix.
    reg_m : float, default=1e-10
        Regularization parameter for the mapping parameters.
    tol : float, default=1e-5
        Tolerance for the stopping criterion in the optimization.
    max_iter : int, default=100
        Number of maximum iteration before stopping the optimization.
    Attributes
    ----------
    `W_` : array-like, shape (n_samples, n_features)
        The scaling matrix.
    `B_` : array-like, shape (n_samples, n_features)
        The bias matrix.
    `G_` : array-like, shape (n_classes, n_features) or (n_samples, n_features)
        The learned kernel scaling matrix.
    `H_` : array-like, shape (n_classes, n_features) or (n_samples, n_features)
        The learned kernel bias matrix.
    `X_source_` : array-like, shape (n_samples, n_features)
        The source data.
    References
    ----------
    .. [21] Kun Zhang et. al. Domain Adaptation under Target and Conditional Shift
           In ICML, 2013.
    """
    def __init__(self, gamma, reg_k=1e-10, reg_m=1e-10, tol=1e-5, max_iter=100):
        super().__init__()
        self.gamma = gamma
        self.reg_k = reg_k
        self.reg_m = reg_m
        self.tol = tol
        self.max_iter = max_iter
        self.W_ = None
        self.B_ = None
    def _mapping_optimization(self, X_source, X_target, y_source):
        """Mapping optimization"""
        try:
            import torch
        except ImportError:
            raise ImportError(
                "MMDLSConSMappingAdapter requires pytorch to be installed."
            )
        # check y is discrete or continuous
        self.discrete_ = discrete = _find_y_type(y_source) == Y_Type.DISCRETE
        # convert to pytorch tensors
        X_source = torch.tensor(X_source, dtype=torch.float64)
        X_target = torch.tensor(X_target, dtype=torch.float64)
        y_source = torch.tensor(
            y_source, dtype=torch.int64 if discrete else torch.float64
        )
        # get shapes
        m, n = X_source.shape[0], X_target.shape[0]
        d = X_source.shape[1]
        # compute omega
        L = torch.exp(-self.gamma * torch.cdist(X_source, X_source, p=2))
        omega = L @ torch.linalg.inv(L + self.reg_k * torch.eye(m))
        # compute R
        if discrete:
            self.classes_ = classes = torch.unique(y_source).numpy()
            R = torch.zeros((m, len(classes)), dtype=torch.float64)
            for i, c in enumerate(classes):
                R[:, i] = (y_source == c).int()
        else:
            self.classes_ = None
            R = L @ torch.linalg.inv(L + self.reg_k * torch.eye(m))
        # solve the optimization problem
        # min_{G, H} MMD(W \odot X^s + B, X^t)
        # s.t. W = RG, B = RH
        k = R.shape[1]
        def func(G, H):
            W = R @ G
            B = R @ H
            X_new = W * X_source + B
            K = torch.exp(-self.gamma * torch.cdist(X_new, X_new, p=2))
            K_cross = torch.exp(-self.gamma * torch.cdist(X_target, X_new, p=2))
            J_cons = (1 / (m**2)) * torch.sum(omega @ K @ omega.T)
            J_cons -= (2 / (m * n)) * torch.sum(K_cross @ omega.T)
            J_reg = (1 / m) * (torch.sum((W - 1) ** 2) + torch.sum(B**2))
            return J_cons + self.reg_m * J_reg
        # optimize using torch solver
        G = torch.ones((k, d), dtype=torch.float64, requires_grad=True)
        H = torch.zeros((k, d), dtype=torch.float64, requires_grad=True)
        (G, H), _ = torch_minimize(func, (G, H), tol=self.tol, max_iter=self.max_iter)
        R = R.detach().numpy()
        W = R @ G
        B = R @ H
        return W, B, G, H
    def fit(self, X, y, sample_domain=None):
        """Fit adaptation parameters.
        Parameters
        ----------
        X : array-like, shape (n_samples, n_features)
            The source data.
        y : array-like, shape (n_samples,)
            The source labels.
        sample_domain : array-like, shape (n_samples,)
            The domain labels (same as sample_domain).
        Returns
        -------
        self : object
            Returns self.
        """
        # xxx(okachaiev): we can't test X_y here because y might
        # have NaNs, thought it might be better to keep this as an
        # argument of a checker
        X, sample_domain = check_X_domain(X, sample_domain)
        X_source, X_target, y_source, _ = source_target_split(
            X, y, sample_domain=sample_domain
        )
        self.X_source_ = X_source
        self.W_, self.B_, self.G_, self.H_ = self._mapping_optimization(
            X_source, X_target, y_source
        )
        return self
    def fit_transform(self, X, y=None, sample_domain=None, **params):
        """Predict adaptation (weights, sample or labels).
        Parameters
        ----------
        X : array-like, shape (n_samples, n_features)
            The source data.
        y : array-like, shape (n_samples,)
            The source labels.
        sample_domain : array-like, shape (n_samples,)
            The domain labels (same as sample_domain).
        Returns
        -------
        X_t : array-like, shape (n_samples, n_components)
            The data (same as X).
        """
        self.fit(X, y, sample_domain=sample_domain)
        return self.transform(X, sample_domain=sample_domain, allow_source=True)
    def transform(
        self, X, y=None, *, sample_domain=None, allow_source=False, **params
    ) -> np.ndarray:
        X, sample_domain = check_X_domain(X, sample_domain, allow_source=allow_source)
        source_idx = extract_source_indices(sample_domain)
        X_source, X_target = X[source_idx], X[~source_idx]
        if X_source.shape[0] == 0:
            X_source_adapt = X_source
        else:
            if np.array_equal(self.X_source_, X[source_idx]):
                W, B = self.W_, self.B_
            else:
                if self.discrete_ and y is not None:
                    # recompute the mapping
                    X, sample_domain = check_X_domain(X, sample_domain)
                    source_idx = extract_source_indices(sample_domain)
                    y_source = y[source_idx]
                    classes = self.classes_
                    R = np.zeros((source_idx.sum(), len(classes)))
                    for i, c in enumerate(classes):
                        R[:, i] = (y_source == c).astype(int)
                    W, B = R @ self.G_, R @ self.H_
                else:
                    # assign the nearest neighbor's mapping to the source samples
                    C = pairwise_distances(X[source_idx], self.X_source_)
                    idx = np.argmin(C, axis=1)
                    W, B = self.W_[idx], self.B_[idx]
            X_source_adapt = W * X_source + B
        X_adapt, _ = source_target_merge(
            X_source_adapt, X_target, sample_domain=sample_domain
        )
        return X_adapt 
[docs]
def MMDLSConSMapping(
    base_estimator=None, gamma=1.0, reg_k=1e-10, reg_m=1e-10, tol=1e-5, max_iter=100
):
    """MMDLSConSMapping pipeline with adapter and estimator.
    See [21]_ for details.
    Parameters
    ----------
    base_estimator : object, optional (default=None)
        The base estimator to fit on the target dataset.
    gamma : float
        Parameter for the Gaussian kernel.
    reg_k : float, default=1e-10
        Regularization parameter for the labels kernel matrix.
    reg_m : float, default=1e-10
        Regularization parameter for the mapping parameters.
    tol : float, default=1e-5
        Tolerance for the stopping criterion in the optimization.
    max_iter : int, default=100
        Number of maximum iteration before stopping the optimization.
    Returns
    -------
    pipeline : Pipeline
        Pipeline containing CORAL adapter and base estimator.
    References
    ----------
    .. [21] Kun Zhang et. al. Domain Adaptation under Target and Conditional Shift
            In ICML, 2013.
    """
    if base_estimator is None:
        base_estimator = SVC(kernel="rbf")
    return make_da_pipeline(
        MMDLSConSMappingAdapter(
            gamma=gamma, reg_k=reg_k, reg_m=reg_m, tol=tol, max_iter=max_iter
        ),
        base_estimator,
    )