# Author: Oleksii Kachaiev <kachayev@gmail.com>
# Yanis Lalou <yanis.lalou@polytechnique.edu>
#
# License: BSD 3-Clause
import os
import warnings
from functools import reduce
from typing import Dict, Iterable, List, Literal, Mapping, Optional, Tuple, Union
import numpy as np
from sklearn.utils import Bunch, deprecated
try:
import torch
_IS_TORCH_IMPORTED = True
except (ImportError, ModuleNotFoundError):
_IS_TORCH_IMPORTED = False
try:
from skada.deep.base import DeepDADataset as DeepDADataset
_IS_DEEPDADATASET_IMPORTED = True
except (ImportError, ModuleNotFoundError):
_IS_DEEPDADATASET_IMPORTED = False
_DEFAULT_HOME_FOLDER_KEY = "SKADA_DATA_FOLDER"
_DEFAULT_HOME_FOLDER = "~/skada_datasets"
ArrayLike = Union[np.ndarray, "torch.Tensor"]
# xxx(okachaiev): if we use -1 as a detector for targets,
# we should not allow non-labeled dataset or... we need
# to come up with a way to pack them properly
DomainDataType = Union[
# (name, X, y)
Tuple[str, ArrayLike, ArrayLike],
# (X, y)
Tuple[ArrayLike, ArrayLike],
# (X,)
Tuple[ArrayLike,],
]
PackedDatasetType = Union[
Bunch, Tuple[ArrayLike, ArrayLike, ArrayLike], "DeepDADataset"
]
def get_data_home(data_home: Union[str, os.PathLike, None]) -> str:
"""Return the path of the `skada` data folder.
This folder is used by some large dataset loaders to avoid downloading the
data several times.
By default the data directory is set to a folder named 'skada_datasets' in the
user home folder.
Alternatively, it can be set by the 'SKADA_DATA_FOLDER' environment
variable or programmatically by giving an explicit folder path. The '~'
symbol is expanded to the user home folder.
If the folder does not already exist, it is automatically created.
Parameters
----------
data_home : str or path-like, default=None
The path to `skada` data folder. If `None`, the default path
is `~/skada_datasets`.
Returns
-------
data_home: str
The path to `skada` data folder.
"""
if data_home is None:
data_home = os.environ.get(_DEFAULT_HOME_FOLDER_KEY, _DEFAULT_HOME_FOLDER)
data_home = os.path.expanduser(data_home)
os.makedirs(data_home, exist_ok=True)
return data_home
[docs]
class DomainAwareDataset:
"""
Container carrying all dataset domains.
This class allows to store and manipulate datasets from multiple domains,
keeping track of the domain information for each sample.
Parameters
----------
domains : list of tuple or dict of tuple or None, optional
List or dictionary of domains to add at initialization.
Each domain can be a tuple (X, y) or (X, y, name).
Attributes
----------
domains_ : list
List of domains added, each as a tuple (X, y) or (X,).
domain_names_ : dict
Dictionary mapping each domain name to its internal identifier.
"""
def __init__(
self,
# xxx(okachaiev): not sure if dictionary is a good format :thinking:
domains: Union[List[DomainDataType], Dict[str, DomainDataType], None] = None,
):
self.domains_ = []
self.domain_names_ = {}
# xxx(okachaiev): there should be a simpler way for adding those
if domains is not None:
for d in domains:
if len(d) == 2:
X, y = d
domain_name = None
elif len(d) == 3:
X, y, domain_name = d
self.add_domain(X, y=y, domain_name=domain_name)
[docs]
def add_domain(
self, X, y=None, domain_name: Optional[str] = None
) -> "DomainAwareDataset":
"""
Add a new domain to the dataset.
Parameters
----------
X : ArrayLike
Feature matrix for the domain.
y : ArrayLike or None, optional
Labels for the domain. If None, labels are not provided.
domain_name : str, optional
Name of the domain. If None, a unique name is autogenerated.
Returns
-------
self : DomainAwareDataset
The updated dataset.
"""
if domain_name is not None:
# check the name is unique
# xxx(okachaiev): ValueError would be more appropriate
assert domain_name not in self.domain_names_
else:
domain_name = f"_{len(self.domain_names_)+1}"
domain_id = len(self.domains_) + 1
self.domains_.append((X, y) if y is not None else (X,))
self.domain_names_[domain_name] = domain_id
return self
[docs]
def merge(
self, dataset: "DomainAwareDataset", names_mapping: Optional[Mapping] = None
) -> "DomainAwareDataset":
"""
Merge another DomainAwareDataset into this one.
Parameters
----------
dataset : DomainAwareDataset
The dataset to merge.
names_mapping : mapping, optional
Mapping from old domain names to new domain names.
Returns
-------
self : DomainAwareDataset
The updated dataset.
"""
for domain_name in dataset.domain_names_:
# xxx(okachaiev): this needs to be more flexible
# as it should be possible to pass only X with y=None
# i guess best way of doing so is to change 'add_domain' API
X, y = dataset.get_domain(domain_name)
if names_mapping is not None:
domain_name = names_mapping.get(domain_name, domain_name)
self.add_domain(X, y, domain_name)
return self
[docs]
def get_domain(self, domain_name: str) -> Tuple[ArrayLike, Optional[ArrayLike]]:
"""
Retrieve the data and labels for a given domain.
Parameters
----------
domain_name : str
Name of the domain to retrieve.
Returns
-------
domain : tuple
Tuple containing (X, y) or (X,) for the specified domain.
"""
domain_id = self.domain_names_[domain_name]
return self.domains_[domain_id - 1]
[docs]
def select_domain(
self, sample_domain: ArrayLike, domains: Union[str, Iterable[str]]
) -> ArrayLike:
"""
Select samples belonging to one or more domains.
Parameters
----------
sample_domain : ArrayLike
Array of domain labels for each sample.
domains : str or iterable of str
Domain name(s) to select.
Returns
-------
mask : ArrayLike
Boolean mask indicating selected samples.
"""
return select_domain(self.domain_names_, sample_domain, domains)
# xxx(okachaiev): i guess, if we are using names to pack domains into array,
# we should not autogenerate them... otherwise it might be not obvious at all
[docs]
def pack(
self,
as_sources: List[str],
as_targets: List[str],
mask_target_labels: bool,
return_X_y: bool = None,
return_type: Literal[
"auto", "array", "tensor", "DeepDADataset", "Bunch"
] = "auto",
train: Optional[bool] = None,
mask: Union[None, int, float] = None,
) -> PackedDatasetType:
"""Aggregates datasets from all domains into a unified domain-aware
representation, ensuring compatibility with domain adaptation (DA)
estimators.
Parameters
----------
as_sources : list
List of domain names to be used as sources. An empty list
indicates that no source domains are used.
as_targets : list
List of domain names to be used as targets. An empty list
indicates that no target domains are used.
mask_target_labels : bool
This parameter should be set to True for training and False for testing.
When set to True, masks labels for target domains with -1 for classification
tasks of nan for regression tasks, so they are not available at train time.
return_X_y : bool, default=True
[DEPRECATED] When set to True, returns a tuple (X, y, sample_domain).
Otherwise returns :class:`~sklearn.utils.Bunch` object with the structure
described below.
return_type : Literal["auto", "array", "tensor", "DeepDADataset", "Bunch"]
The type of the returned data. If "auto", it will return tensors if
the data is in tensor format, otherwise it will return numpy arrays.
If "array", returns numpy arrays. If "tensor", returns torch tensors.
If "DeepDADataset", returns a :class:`~skada.dataset.base.DeepDADataset`
If "Bunch", returns a :class:`~sklearn.utils.Bunch` object
train: Optional[bool], default=None
[DEPRECATED] Use `mask_target_labels`instead.
mask: int | float (optional), default=None
Value to mask labels at training time.
Returns
-------
data : :class:`~sklearn.utils.Bunch`
Dictionary-like object, with the following attributes.
X: ndarray
Samples from all sources and all targets given.
y : ndarray
Labels from all sources and all targets.
sample_domain : ndarray
The integer label for domain the sample was taken from.
By convention, source domains have non-negative labels,
and target domain label is always < 0.
domain_names : dict
The names of domains and associated domain labels.
(X, y, sample_domain) : tuple of Arraylike if `return_type="array" or "tensor"`
Tuple of (data, target, sample_domain), see the description above.
deep_da_dataset : DeepDADataset
compatible with torch : torch.Dataset extended with the sample_domain
"""
Xs, ys, sample_domains = [], [], []
domain_labels = {}
if return_X_y is not None:
warnings.warn(
"The `return_X_y` parameter is deprecated and will be removed in"
"future versions. Use `return_type` instead.",
DeprecationWarning,
)
return_type = "array" if return_X_y else "Bunch"
if train is not None:
warnings.warn(
"The `train` parameter is deprecated and will be removed in"
"future versions. Use `mask_target_labels` instead.",
DeprecationWarning,
)
mask_target_labels = train
for domain_name in as_sources:
domain_id = self.domain_names_[domain_name]
source = self.get_domain(domain_name)
if len(source) == 1:
(X,) = source
y = -np.ones(X.shape[0], dtype=np.int32)
elif len(source) == 2:
X, y = source
else:
raise ValueError("Invalid definition for domain data")
# xxx(okachaiev): this is horribly inefficient, re-write when API is fixed
Xs.append(X)
ys.append(y)
sample_domains.append(np.ones_like(y, dtype=int) * domain_id)
domain_labels[domain_name] = domain_id
# xxx(okachaiev): code duplication, re-write when API is fixed
dtype = None
for domain_name in as_targets:
domain_id = self.domain_names_[domain_name]
target = self.get_domain(domain_name)
if len(target) == 1:
(X,) = target
# xxx(okachaiev): for what it's worth, we should likely to
# move the decision about dtype to the very end of the list
y = -np.ones(X.shape[0], dtype=np.int32)
elif len(target) == 2:
X, y = target
else:
raise ValueError("Invalid definition for domain data")
if mask_target_labels:
if mask is not None:
y = np.array([mask] * X.shape[0], dtype=dtype)
elif y.dtype in (np.int32, np.int64):
y = -np.ones(X.shape[0], dtype=y.dtype)
# make sure that the mask is reused on the next iteration
mask, dtype = -1, y.dtype
elif y.dtype in (np.float32, np.float64):
y = np.array([np.nan] * X.shape[0], dtype=y.dtype)
# make sure that the mask is reused on the next iteration
mask, dtype = np.nan, y.dtype
# xxx(okachaiev): this is horribly inefficient, rewrite when API is fixed
Xs.append(X)
ys.append(y)
sample_domains.append(-1 * domain_id * np.ones_like(y, dtype=int))
domain_labels[domain_name] = -1 * domain_id
# xxx(okachaiev): so far this only works if source and target has the same size
# Check if torch is available and if the data is in tensor format
if _IS_TORCH_IMPORTED and len(Xs) > 0 and isinstance(Xs[0], torch.Tensor):
is_data_tensor = True
else:
is_data_tensor = False # assume data is in numpy array format
# Determine the return type when not explicitly specified
if return_type == "auto":
return_type = "tensor" if is_data_tensor else "array"
# Note that even if data are numpy arrays, we can return them as tensors
# and vice versa : this is why concatenation is done inside the if-else block
if return_type == "array":
Xs = np.concatenate(Xs)
ys = np.concatenate(ys)
sample_domains = np.concatenate(sample_domains)
return (Xs, ys, sample_domains)
elif return_type == "tensor":
if not _IS_TORCH_IMPORTED:
raise ImportError(
"torch is required to return data as tensors. "
"Please install torch to use this feature."
)
Xs = torch.cat(Xs, dim=0)
ys = torch.cat(ys, dim=0)
sample_domains = torch.cat(
[
torch.as_tensor(sample_domain, device=Xs.device)
for sample_domain in sample_domains
],
dim=0,
)
return (Xs, ys, sample_domains)
elif return_type == "DeepDADataset":
if not _IS_DEEPDADATASET_IMPORTED:
raise ImportError(
"torch and skorch are required to return data as DeepDADataset. "
"Please install them to use this feature."
)
Xs = torch.cat(Xs, dim=0)
ys = torch.cat(ys, dim=0)
sample_domains = torch.cat(
[
torch.as_tensor(sample_domain, device=Xs.device)
for sample_domain in sample_domains
],
dim=0,
)
return DeepDADataset(Xs, ys, sample_domains, device=Xs.device)
elif return_type == "Bunch":
# For now Bunch is associated with numpy arrays
Xs = np.concatenate(Xs)
ys = np.concatenate(ys)
sample_domains = np.concatenate(sample_domains)
return Bunch(
X=Xs,
y=ys,
sample_domain=sample_domains,
domain_names=domain_labels,
)
else:
raise ValueError(
"Invalid return_type. Expected one of 'auto', 'array', "
"'tensor', 'DeepDADataset', or 'Bunch'."
)
[docs]
@deprecated()
def pack_train(
self,
as_sources: List[str],
as_targets: List[str],
return_X_y: bool = True,
mask: Union[None, int, float] = None,
) -> PackedDatasetType:
"""
Aggregate source and target domains for training.
.. warning::
This method is deprecated and will be removed in future versions.
Use :meth:`pack` with ``mask_target_labels=True`` instead.
This method is equivalent to :meth:`pack` with ``train=True``.
It masks the labels for target domains (with -1 or a custom mask value)
so that they are not available during training, as required for
domain adaptation scenarios.
Parameters
----------
as_sources : list of str
List of domain names to be used as sources.
as_targets : list of str
List of domain names to be used as targets.
return_X_y : bool, default=True
If True, returns a tuple (X, y, sample_domain). Otherwise,
returns a :class:`sklearn.utils.Bunch` object.
mask : int or float, optional
Value to mask labels at training time. If None, uses -1 for integers
and np.nan for floats.
Returns
-------
data : :class:`sklearn.utils.Bunch`
Dictionary-like object with attributes X, y, sample_domain, domain_names.
(X, y, sample_domain) : tuple if `return_X_y=True`
Tuple of (data, target, sample_domain).
"""
return self.pack(
as_sources=as_sources,
as_targets=as_targets,
return_X_y=return_X_y,
mask_target_labels=True,
mask=mask,
)
[docs]
@deprecated()
def pack_test(
self,
as_targets: List[str],
return_X_y: bool = True,
) -> PackedDatasetType:
"""
Aggregate target domains for testing.
.. warning::
This method is deprecated and will be removed in future versions.
Use :meth:`pack` with ``mask_target_labels=False`` instead.
This method is equivalent to :meth:`pack` with only target domains
and ``train=False``. Labels are not masked.
Parameters
----------
as_targets : list of str
List of domain names to be used as targets.
return_X_y : bool, default=True
If True, returns a tuple (X, y, sample_domain). Otherwise,
returns a :class:`sklearn.utils.Bunch` object.
Returns
-------
data : :class:`sklearn.utils.Bunch`
Dictionary-like object with attributes X, y, sample_domain, domain_names.
(X, y, sample_domain) : tuple if `return_X_y=True`
Tuple of (data, target, sample_domain).
"""
return self.pack(
as_sources=[],
as_targets=as_targets,
return_X_y=return_X_y,
mask_target_labels=False,
)
[docs]
def pack_lodo(
self,
return_X_y: bool = True,
return_type: Literal[
"auto", "array", "tensor", "DeepDADataset", "Bunch"
] = "auto",
) -> PackedDatasetType:
"""Packages all domains in a format compatible with the Leave-One-Domain-Out
cross-validator (refer to :class:`~skada.model_selection.LeaveOneDomainOut` for
more details). To enable the splitter's dynamic assignment of source and target
domains, data from each domain is included in the output twice — once as a
source and once as a target.
Exercise caution when using this output for purposes other than its intended
use, as this could lead to incorrect results and data leakage.
Parameters
----------
return_X_y : bool, default=True
[DEPRECATED] When set to True, returns a tuple (X, y, sample_domain).
Otherwise returns :class:`~sklearn.utils.Bunch` object with the structure
described below.
return_type : Literal["auto", "array", "tensor", "DeepDADataset", "Bunch"]
The type of the returned data. If "auto", it will return tensors if
the data is in tensor format, otherwise it will return numpy arrays.
If "array", returns numpy arrays. If "tensor", returns torch tensors.
If "DeepDADataset", returns a :class:`~skada.dataset.base.DeepDADataset`
If "Bunch", returns a :class:`~sklearn.utils.Bunch` object
Returns
-------
data : :class:`~sklearn.utils.Bunch`
Dictionary-like object, with the following attributes.
X: ArrayLike
Samples from all sources and all targets given.
y : ArrayLike
Labels from all sources and all targets.
sample_domain : ArrayLike
The integer label for domain the sample was taken from.
By convention, source domains have non-negative labels,
and target domain label is always < 0.
domain_names : dict
The names of domains and associated domain labels.
(X, y, sample_domain) : tuple if `return_X_y=True`
Tuple of (data, target, sample_domain), see the description above.
"""
return self.pack(
as_sources=list(self.domain_names_.keys()),
as_targets=list(self.domain_names_.keys()),
return_X_y=return_X_y,
return_type=return_type,
mask_target_labels=True,
)
def __str__(self) -> str:
return f"DomainAwareDataset(domains={self._get_domain_representation()})"
def __repr__(self) -> str:
head = self.__str__()
body = [f"Number of domains: {len(self.domains_)}"]
body.append(f"Total size: {sum(len(tup[0]) for tup in self.domains_)}")
output = "\n".join([head] + body)
return output
def _get_domain_representation(self, max_domains=5, max_length=50):
domain_names = list(self.domain_names_.keys())
if len(domain_names) <= max_domains:
# If the number of domains is small, include all names
domain_str = str(domain_names)
else:
# If the number of domains is large, truncate the list and add ellipsis
truncated_domains = domain_names[:max_domains]
domain_str = str(truncated_domains)[:-1] + ", ...]"
# Truncate the string representation if it exceeds max_length
if len(domain_str) > max_length:
domain_str = domain_str[: max_length - 3] + "...]"
return domain_str
# xxx(okachaiev): putting `domain_names` first argument
# so it's compatible with `partial`
def select_domain(
domain_names: Dict[str, int],
sample_domain: ArrayLike,
domains: Union[str, Iterable[str]],
) -> ArrayLike:
if isinstance(domains, str):
domains = [domains]
# xxx(okachaiev): this version is not the most efficient
return reduce(
np.logical_or, (sample_domain == domain_names[domain] for domain in domains)
)