Source code for skada.base

# Author: Theo Gnassounou <theo.gnassounou@inria.fr>
#         Remi Flamary <remi.flamary@polytechnique.edu>
#         Oleksii Kachaiev <kachayev@gmail.com>
#
# License: BSD 3-Clause

from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Iterator, Literal, Optional, Tuple

import numpy as np
from sklearn.base import BaseEstimator, clone
from sklearn.exceptions import UnsetMetadataPassedError
from sklearn.utils.metadata_routing import (
    MetadataRouter,
    MethodMapping,
    _MetadataRequester,
    get_routing_for_object,
)
from sklearn.utils.metaestimators import available_if
from sklearn.utils.validation import check_is_fitted

from skada._utils import (
    _apply_domain_masks,
    _merge_domain_outputs,
    _remove_masked,
    _route_params
)
from skada.utils import check_X_domain, check_X_y_domain, extract_source_indices


def _estimator_has(attr, base_attr_name='base_estimator'):
    """Check if we can delegate a method to the underlying estimator.

    First, we check the first fitted classifier if available, otherwise we
    check the unfitted classifier.
    """
    def has_base_estimator(estimator) -> bool:
        return hasattr(estimator, base_attr_name) and hasattr(
            getattr(estimator, base_attr_name),
            attr
        )

    # xxx(okachaiev): there should be a simple way to access selector base estimator
    def has_estimator_selector(estimator) -> bool:
        return hasattr(estimator, "estimators_") and hasattr(
            estimator.estimators_[0],
            attr
        )

    return lambda estimator: (has_base_estimator(estimator) or
                              has_estimator_selector(estimator))


class IncompatibleMetadataError(UnsetMetadataPassedError):
    """The exception is designated to report the situation when the adapter output
    the key, like 'sample_weight', that is not explicitly consumed by the following
    estimator in the pipeline.

    The exception overrides :class:`~sklearn.exceptions.UnsetMetadataPassedError`
    when there is a reason to believe that the original exception was thrown because
    of the adapter output rather than being caused by the input to a specific function.
    """

    def __init__(self, message):
        super().__init__(message=message, unrequested_params={}, routed_params={})


class BaseAdapter(BaseEstimator):

    __metadata_request__fit = {'sample_domain': True}
    __metadata_request__transform = {'sample_domain': True, 'allow_source': True}

    @abstractmethod
    def fit_transform(self, X, y=None, *, sample_domain=None, **params):
        """Fit adapter and transforms samples, labels, and weights to be used
        to fit estimator (i.e. 'adapt' the data).
        """
        pass

    def transform(
        self,
        X,
        y=None,
        *,
        sample_domain=None,
        allow_source=False,
        **params
    ) -> np.ndarray:
        """Transforms (adapts) the data during evaluation. Default implementation
        passes through samples without changing them, as it is a default behavior
        for many adapters.
        """
        check_is_fitted(self)
        X, sample_domain = check_X_domain(
            X,
            sample_domain=sample_domain,
            allow_auto_sample_domain=True,
            allow_source=allow_source,
        )
        return X

    def fit(self, X, y=None, *, sample_domain=None, **params):
        """Fitting of the adapter is supposed to happen in `fit_transform`
        method, though for a convenience reason it might be redefined by
        by the specific implementation.
        """
        raise NotImplementedError('To fit adapter use `fit_transform` method.')


class _DAMetadataRequesterMixin(_MetadataRequester):
    """Mixin class for adding metadata related to the domain adaptation
    functionality. The mixin is primarily designed for the internal API
    and is expected to be rarely, if at all, required by end users.
    """

    __metadata_request__fit = {'sample_domain': True}
    __metadata_request__partial_fit = {'sample_domain': True}
    __metadata_request__predict = {'sample_domain': True, 'allow_source': True}
    __metadata_request__predict_proba = {'sample_domain': True, 'allow_source': True}
    __metadata_request__predict_log_proba = {
        'sample_domain': True,
        'allow_source': True
    }
    __metadata_request__score = {'sample_domain': True, 'allow_source': True}
    __metadata_request__decision_function = {
        'sample_domain': True,
        'allow_source': True
    }


class DAEstimator(BaseEstimator, _DAMetadataRequesterMixin):
    """Generic DA estimator class."""

    @abstractmethod
    def fit(self, X, y=None, sample_domain=None, *, sample_weight=None):
        """Fit adaptation parameters"""
        pass

    @abstractmethod
    def predict(self, X, sample_domain=None, *, sample_weight=None):
        """Predict using the model"""
        pass


# xxx(okachaiev): selectors + container should eventually go either
# into `skada.pipeline` or into `skada.selectors`, literally no reason
# to keep them in base module
# xxx(okachaiev): this one needs good procedure for serialize/deserialize
@dataclass
class MetadataContainer:
    """Container to carry samples, labels, and metadata
    throughout domain adaptation pipeline. Selectors are
    responsible for extracting proper parametrization for
    the underlying estimator(s) as well as merging back
    samples, labels, and/or metadata generated by the call
    to an estimator method.

    Designed for internal use only.
    """

    _features: np.ndarray
    _labels: np.ndarray
    _metadata: dict

    @classmethod
    def from_input(cls, X, y=None) -> 'MetadataContainer':
        return X if isinstance(X, cls) else cls(_features=X, _labels=y, _metadata={})

    # xxx(okachaiev): `merge_in` and `merge_out` method names are too technical
    def merge_in(self, X_container):
        if isinstance(X_container, tuple) and len(X_container) == 2:
            X, params = X_container
            assert isinstance(params, dict)
            self._features = X
            self._metadata.update(params)
        elif isinstance(X_container, tuple) and len(X_container) == 3:
            X, y, params = X_container
            assert isinstance(params, dict)
            self._features = X
            self._labels = y
            self._metadata.update(params)
        elif isinstance(X_container, np.ndarray):
            self._features = X_container
        else:
            raise ValueError("Unsupported container")
        return self

    def merge_out(self, y, **params):
        params.update(self._metadata)
        y_out = self._labels if self._labels is not None else y
        return self._features, y_out, params

    def iter_metadata(self) -> Iterator[Tuple[str, Any]]:
        return self._metadata.items()

    def __len__(self):
        return self._features.shape[0]


class BaseSelector(BaseEstimator, _DAMetadataRequesterMixin):

    __metadata_request__transform = {'sample_domain': True}

    def __init__(self, base_estimator: BaseEstimator, **kwargs):
        super().__init__()
        self.base_estimator = base_estimator
        self.base_estimator.set_params(**kwargs)
        self._is_final = False
        self._is_transformer = hasattr(base_estimator, 'transform')

    def get_metadata_routing(self):
        return (
            MetadataRouter(owner=self.__class__.__name__)
            .add_self_request(self)
            .add(estimator=self.base_estimator, method_mapping=MethodMapping()
                 .add(callee='fit', caller='fit')
                 .add(callee='partial_fit', caller='partial_fit')
                 .add(callee='transform', caller='transform')
                 .add(callee='predict', caller='predict')
                 .add(callee='predict_proba', caller='predict_proba')
                 .add(callee='predict_log_proba', caller='predict_log_proba')
                 .add(callee='decision_function', caller='decision_function')
                 .add(callee='score', caller='score'))
        )

    @abstractmethod
    def get_estimator(self, *params) -> BaseEstimator:
        """Returns estimator associated with `params`.

        The set of available estimators and access to them has to be provided
        by specific implementations.
        """

    def get_params(self, deep=True):
        """Get parameters for this estimator.

        Returns the parameters of the base estimator provided in the constructor.

        Parameters
        ----------
        deep : bool, default=True
            If True, will return the parameters for this estimator and
            contained sub-objects that are estimators.

        Returns
        -------
        params : mapping of string to any
            Parameter names mapped to their values.
        """
        params = self.base_estimator.get_params(deep=deep)
        params['base_estimator'] = self.base_estimator
        return params

    def set_params(self, base_estimator=None, **kwargs):
        """Set the parameters of this estimator.

        Valid parameter keys can be listed with ``get_params()``. Note that
        you can directly set the parameters of the estimator using `base_estimator`
        attribute.

        Parameters
        ----------
        **kwargs : dict
            Parameters of of the base estimator.

        Returns
        -------
        self : object
            Selector class instance.
        """
        if base_estimator is not None:
            self.base_estimator = base_estimator
        self.base_estimator.set_params(**kwargs)
        return self

    def _mark_as_final(self):
        self._is_final = True
        return self

    def _unmark_as_final(self):
        self._is_final = False
        return self

    @abstractmethod
    def _route_to_estimator(self, method_name, X, y=None, **params) -> np.ndarray:
        """Abstract method for calling method of a base estimator based on
        the input and the routing logic associated with domain labels.
        """

    @available_if(_estimator_has('transform'))
    def transform(self, X, **params):
        return self._route_to_estimator('transform', X, **params)

    @available_if(_estimator_has('predict'))
    def predict(self, X, **params):
        return self._route_to_estimator('predict', X, **params)

    @available_if(_estimator_has('predict_proba'))
    def predict_proba(self, X, **params):
        return self._route_to_estimator('predict_proba', X, **params)

    @available_if(_estimator_has('predict_log_proba'))
    def predict_log_proba(self, X, **params):
        return self._route_to_estimator('predict_log_proba', X, **params)

    @available_if(_estimator_has('decision_function'))
    def decision_function(self, X, **params):
        return self._route_to_estimator('decision_function', X, **params)

    @available_if(_estimator_has('score'))
    def score(self, X, y, **params):
        return self._route_to_estimator('score', X, y=y, **params)

    def _prepare_routing(self, routing_request, metadata_container, params):
        if self._is_final or not self._is_transformer:
            try:
                routed_params = _route_params(routing_request, params, self)
            except UnsetMetadataPassedError as e:
                # check if every parameter given from the metadata container
                # was accepted by the downstream (base) estimator
                # xxx(okachaiev): there's still a way for this to fail,
                # if non-final estimator consumed the input that was
                # generated by one of the adapters
                if isinstance(metadata_container, dict):
                    iter = metadata_container.items()
                else:
                    iter = metadata_container.iter_metadata()
                for k, v in iter:
                    marker = routing_request.requests.get(k)
                    if v is not None and marker is None:
                        method = routing_request.method
                        raise IncompatibleMetadataError(
                            f"The adapter provided '{k}' parameter which is not explicitly set as "  # noqa
                            f"requested or not for '{routing_request.owner}.{method}'.\n"  # noqa
                            f"Make sure that metadata routing is properly setup, e.g. by calling 'set_{method}_request()'. "  # noqa
                            "See documentation at https://scikit-learn.org/stable/metadata_routing.html"  # noqa
                        ) from e
                # re-raise exception if the problem was not caused by the adapter
                raise e
        else:
            routed_params = {k: params[k] for k in routing_request._consumes(params=params)}
        return routed_params

    def _remove_masked(self, X, y, routed_params):
        """Removes masked inputs before passing them to a downstream (base) estimator,
        ensuring their compatibility with the DA pipeline, particularly for estimators
        that do not natively support DA setups, such as those from scikit-learn.

        Masked inputs are removed under the following conditions:
        - Labels `y` are provided (necessary for mask detection).
        - The estimator is not a transformer (i.e., does not define a
          'transform' function).
        - The estimator does not accept a `sample_domain` parameter
          through routing.

        In scenarios not meeting these criteria, masked input samples are retained.

        Note: This API is intended for internal use.

        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            Input data
        y : array-like of shape (n_samples,)
            Labels for the data
        params : dict
            Additional parameters declared in the routing

        Returns
        -------
        X : array-like of shape (n_samples, n_features)
            Input data
        y : array-like of shape (n_samples,)
            Labels for the data
        params : dict
            Additional parameters declared in the routing
        """
        if (y is not None
                and not hasattr(self, 'transform')
                and 'sample_domain' not in routed_params):
            X, y, routed_params = _remove_masked(X, y, routed_params)
        return X, y, routed_params


[docs] class Shared(BaseSelector):
[docs] def get_estimator(self) -> BaseEstimator: """Provides access to the fitted estimator.""" check_is_fitted(self) return self.base_estimator_
def fit(self, X, y=None, **params): X_container = MetadataContainer.from_input(X) self._fit('fit', X_container, y, **params) return self @available_if(_estimator_has('transform')) def fit_transform(self, X, y=None, **params): X_container = MetadataContainer.from_input(X) if hasattr(self.base_estimator, "fit_transform"): output = self._fit('fit_transform', X_container, y, **params) else: self._fit('fit', X_container, y, **params) X, y, method_params = X_container.merge_out(y, **params) transform_params = _route_params(self.routing_.transform, method_params, self) output = self.transform(X, **transform_params) return X_container.merge_in(output) # xxx(okachaiev): solve the problem with parameter renaming def _fit(self, routing_method, X_container, y=None, **params): X, y, params = X_container.merge_out(y, **params) routing = get_routing_for_object(self.base_estimator) routing_request = getattr(routing, routing_method) routed_params = self._prepare_routing(routing_request, X_container, params) X, y, routed_params = self._remove_masked(X, y, routed_params) estimator = clone(self.base_estimator) output = getattr(estimator, routing_method)(X, y, **routed_params) self.base_estimator_ = estimator self.routing_ = routing return output # xxx(okachaiev): fail if unknown domain is given def _route_to_estimator(self, method_name, X, y=None, **params): check_is_fitted(self) request = getattr(self.routing_, method_name) routed_params = self._prepare_routing(request, {}, params) X, y, routed_params = self._remove_masked(X, y, routed_params) method = getattr(self.base_estimator_, method_name) output = method(X, **routed_params) if y is None else method( X, y, **routed_params ) return output
[docs] class PerDomain(BaseSelector):
[docs] def get_estimator(self, domain_label: int) -> BaseEstimator: """Provides access to the fitted estimator based on the domain label.""" check_is_fitted(self) return self.estimators_[domain_label]
def fit(self, X, y, **params): X_container = MetadataContainer.from_input(X) self._fit('fit', X_container, y, **params) return self def _fit(self, method_name, X_container, y, **params): X, y, params = X_container.merge_out(y, **params) sample_domain = params['sample_domain'] routing = get_routing_for_object(self.base_estimator) routing_request = getattr(routing, method_name) routed_params = self._prepare_routing(routing_request, X_container, params) X, y, routed_params = self._remove_masked(X, y, routed_params) estimators, outputs = {}, {} for domain_label in np.unique(sample_domain): idx, = np.where(sample_domain == domain_label) estimator = clone(self.base_estimator) outputs[domain_label] = (idx, getattr(estimator, method_name)( X[idx], y[idx] if y is not None else None, **{k: v[idx] for k, v in routed_params.items()} )) estimators[domain_label] = estimator self.estimators_ = estimators self.routing_ = routing return outputs @available_if(_estimator_has('transform')) def fit_transform(self, X, y=None, **params): X_container = MetadataContainer.from_input(X) if hasattr(self.base_estimator, "fit_transform"): domain_outputs = self._fit('fit_transform', X_container, y=y, **params) output = _merge_domain_outputs(len(X_container), domain_outputs, allow_containers=True) else: self._fit(X_container, y, **params) X, y, method_params = X_container.merge_out(y, **params) transform_params = _route_params(self.routing_.transform, method_params, self) # the output of the transform call is already merged into a single ndarray output = self.transform(X, **transform_params) return X_container.merge_in(output) def _route_to_estimator(self, method_name, X, y=None, **params): check_is_fitted(self) request = getattr(self.routing_, method_name) routed_params = self._prepare_routing(request, {}, params) # xxx(okachaiev): use check_*_domain to derive default domain labels sample_domain = params['sample_domain'] domain_outputs = {} # test if default target domain and unique target during fit and replace label for domain_label in np.unique(sample_domain): # xxx(okachaiev): fail if unknown domain is given try: method = getattr(self.estimators_[domain_label], method_name) except KeyError: raise ValueError( f"Domain label {domain_label} is not present in the " "fitted estimators." ) idx, = np.where(sample_domain == domain_label) X_domain = X[idx] y_domain = y[idx] if y is not None else None domain_params = {k: v[idx] for k, v in routed_params.items()} if y is None: domain_output = method(X_domain, **domain_params) else: domain_output = method(X_domain, y_domain, **domain_params) domain_outputs[domain_label] = (idx, domain_output) return _merge_domain_outputs(X.shape[0], domain_outputs)
class _BaseSelectDomain(Shared): """Abstract class to implement selectors that are intended for picking specific subset of samples from the input for fitting the base estimator, e.g. only source, only targets, specific domain by its index or name, and more. Specific functionality is given by providing implementation for the `_select_indices` abstract method that receives an array with domain indices (i.e. `sample_domain`) and returns masks that should be used to filter out necessary samples from the input. """ @abstractmethod def _select_indices(self, sample_domain): """Calculates masks for input samples. Parameters ---------- sample_domain : array-like, shape (n_samples,) The domain labels. Returns ------- filter_masks : array-like, shape (n_samples,) Array of boolean masks. """ def _pre_filter(self, method_name, X_container, y=None, **params): X, y, params = X_container.merge_out(y, **params) filter_masks = self._select_indices(params.get('sample_domain')) X_input, y, params = _apply_domain_masks(X, y, params, masks=filter_masks) return getattr(super(), method_name)(X_container.merge_in((X_input, y, params)), y, **params) def fit(self, X, y=None, **params): X_container = MetadataContainer.from_input(X) return self._pre_filter('fit', X_container, y=y, **params) @available_if(_estimator_has('transform')) def fit_transform(self, X, y=None, **params): X_container = MetadataContainer.from_input(X) return self._pre_filter('fit_transform', X_container, y=y, **params) class SelectSource(_BaseSelectDomain): """Selects only source domains for fitting base estimator.""" def _select_indices(self, sample_domain): return extract_source_indices(sample_domain) class SelectTarget(_BaseSelectDomain): """Selects only target domains for fitting base estimator.""" def _select_indices(self, sample_domain): return ~extract_source_indices(sample_domain) class SelectSourceTarget(BaseSelector): def __init__(self, source_estimator: BaseEstimator, target_estimator: Optional[BaseEstimator] = None): if target_estimator is not None \ and hasattr(source_estimator, 'transform') \ and not hasattr(target_estimator, 'transform'): raise TypeError("The provided source and target estimators must " "both be transformers, or neither should be.") self.source_estimator = source_estimator self.target_estimator = target_estimator # xxx(okachaiev): the fact that we need to put those variables # here means that the choice of the base class is suboptimal self._is_final = False self._is_transformer = hasattr(source_estimator, 'transform') def get_metadata_routing(self): routing = MetadataRouter(owner=self.__class__.__name__).add_self_request(self) routing.add(estimator=self.source_estimator, method_mapping=MethodMapping() .add(callee='fit', caller='fit') .add(callee='partial_fit', caller='partial_fit') .add(callee='transform', caller='transform') .add(callee='predict', caller='predict') .add(callee='predict_proba', caller='predict_proba') .add(callee='predict_log_proba', caller='predict_log_proba') .add(callee='decision_function', caller='decision_function') .add(callee='score', caller='score')) if self.target_estimator is not None: routing.add(estimator=self.target_estimator, method_mapping=MethodMapping() .add(callee='fit', caller='fit') .add(callee='partial_fit', caller='partial_fit') .add(callee='transform', caller='transform') .add(callee='predict', caller='predict') .add(callee='predict_proba', caller='predict_proba') .add(callee='predict_log_proba', caller='predict_log_proba') .add(callee='decision_function', caller='decision_function') .add(callee='score', caller='score')) return routing def get_params(self, deep=True): """Get parameters for this estimator. Returns the parameters of the base estimator provided in the constructor. Parameters ---------- deep : bool, default=True If True, will return the parameters for this estimator and contained sub-objects that are estimators. Returns ------- params : mapping of string to any Parameter names mapped to their values. """ return super(BaseSelector, self).get_params(deep=deep) def set_params(self, **kwargs): """Set the parameters of this estimator. Valid parameter keys can be listed with ``get_params()``. Note that you can directly set the parameters of the estimator using `base_estimator` attribute. Parameters ---------- **kwargs : dict Parameters of of the base estimator. Returns ------- self : object Selector class instance. """ super(BaseSelector, self).set_params(**kwargs) return self def get_estimator(self, domain: Literal['source', 'target']) -> BaseEstimator: """Provides access to the fitted estimator based on the domain type.""" assert domain in ('source', 'target') check_is_fitted(self) return self.estimators_[domain] def fit(self, X, y=None, **params): X_container = MetadataContainer.from_input(X) self._fit('fit', X_container, y=y, **params) return self def _fit(self, method_name, X_container, y=None, **params): X, y, params = X_container.merge_out(y, **params) if y is not None: X, y, sample_domain = check_X_y_domain(X, y, sample_domain=params.get('sample_domain')) else: X, sample_domain = check_X_domain(X, sample_domain=params.get('sample_domain')) params['sample_domain'] = sample_domain source_masks = extract_source_indices(sample_domain) estimators, outputs = {}, {} target_estimator = self.target_estimator if self.target_estimator is not None else self.source_estimator for domain_type, base_estimator, domain_masks in [ ('source', self.source_estimator, source_masks), ('target', target_estimator, ~source_masks) ]: if domain_masks.sum() == 0: # if we don't have either source or target, we should conclude that fitting # was not successful, otherwise prediction might be not possible raise ValueError( "`SelectSourceTarget` requires both source and target samples for fitting. " f"'{domain_type}' samples are missing in the input provided." ) X_masked, y_masked, params_masked = _apply_domain_masks(X, y, params, masks=domain_masks) routing = getattr(get_routing_for_object(base_estimator), method_name) routed_params = self._prepare_routing(routing, X_masked, params_masked) X_masked, y_masked, routed_params = self._remove_masked(X_masked, y_masked, routed_params) estimator = clone(base_estimator) estimator_method = getattr(estimator, method_name) domain_output = estimator_method(X_masked, y_masked, **routed_params) outputs[domain_type] = (domain_masks, domain_output) estimators[domain_type] = estimator self.estimators_ = estimators return outputs @available_if(_estimator_has('transform', base_attr_name='source_estimator')) def fit_transform(self, X, y=None, **params): X_container = MetadataContainer.from_input(X) if hasattr(self.source_estimator, "fit_transform"): domain_outputs = self._fit('fit_transform', X_container, y=y, **params) output = _merge_domain_outputs(len(X_container), domain_outputs, allow_containers=True) else: self.fit(X_container, y, **params) X, y, method_params = X_container.merge_out(y, **params) transform_params = _route_params(self.routing_.transform, method_params, self) output = self.transform(X, **transform_params) return X_container.merge_in(output) def _route_to_estimator(self, method_name, X, y=None, **params): check_is_fitted(self) if y is not None: X, y, sample_domain = check_X_y_domain(X, y, sample_domain=params.get('sample_domain')) else: X, sample_domain = check_X_domain(X, sample_domain=params.get('sample_domain')) params['sample_domain'] = sample_domain source_masks = extract_source_indices(sample_domain) outputs = {} for domain_label, domain_masks in [('source', source_masks), ('target', ~source_masks)]: domain_estimator = self.estimators_[domain_label] if domain_masks.sum() == 0: # if domain type is not present, just skip continue X_domain, y_domain, params_domain = _apply_domain_masks(X, y, params, masks=domain_masks) request = getattr(get_routing_for_object(domain_estimator), method_name) routed_params = self._prepare_routing(request, {}, params_domain) method = getattr(domain_estimator, method_name) if y_domain is None or method_name == 'transform': domain_output = method(X_domain, **routed_params) else: domain_output = method(X_domain, y_domain, **routed_params) assert isinstance(domain_output, np.ndarray) outputs[domain_label] = (domain_masks, domain_output) return _merge_domain_outputs(X.shape[0], outputs) @available_if(_estimator_has('transform', base_attr_name='source_estimator')) def transform(self, X, **params): return self._route_to_estimator('transform', X, **params) # xxx(okachaiev): i guess this should return 'True' only if both # estimators have the same method. though practical advantage is, # surely, questionable @available_if(_estimator_has('predict_proba', base_attr_name='source_estimator')) def predict_proba(self, X, **params): return self._route_to_estimator('predict_proba', X, **params) @available_if(_estimator_has('predict_log_proba', base_attr_name='source_estimator')) def predict_log_proba(self, X, **params): return self._route_to_estimator('predict_log_proba', X, **params) @available_if(_estimator_has('decision_function', base_attr_name='source_estimator')) def decision_function(self, X, **params): return self._route_to_estimator('decision_function', X, **params) @available_if(_estimator_has('score', base_attr_name='source_estimator')) def score(self, X, y, **params): return self._route_to_estimator('score', X, y=y, **params)