skada.model_selection.StratifiedDomainShuffleSplit
- class skada.model_selection.StratifiedDomainShuffleSplit(n_splits=10, *, test_size=None, train_size=None, random_state=None)[source]
Stratified-Domain-Shuffle-Split cross-validator.
This cross-validation object returns stratified randomized folds. The folds are made by preserving the percentage of samples for each class and for each sample domain.
- Parameters:
- n_splitsint, default=10
Number of folds. Must be at least 2.
Examples
>>> import numpy as np >>> from skada.model_selection import StratifiedDomainShuffleSplit >>> X = np.ones((10, 2)) >>> y = np.array([-1, 0, 1, -1, 0, 1, -1, 0, 1, -1]) >>> sample_domain = np.array([-2, 1, 1, -2, 1, 1, -2, 1, 1, -2]) >>> da_shufflesplit = StratifiedDomainShuffleSplit(n_splits=2, ... random_state=0, test_size=0.5) >>> da_shufflesplit.get_n_splits(X, y, sample_domain) 2 >>> print(da_shufflesplit) StratifiedDomainShuffleSplit(n_splits=2, random_state=0, test_size=0.5, train_size=None) >>> for i, (train_index, test_index) in enumerate( ... da_shufflesplit.split(X, y, sample_domain) ... ): ... print(f"Fold {i}:") ... print(f" Train: index={train_index}, " ... f'''group={[[b.item(), a.item()] ... for a, b in zip(y[train_index], sample_domain[train_index]) ... ]}''') ... print(f" Test: index={test_index}, " ... f'''group={[[b.item(), a.item()] ... for a, b in zip(y[test_index], sample_domain[test_index]) ... ]}''') Fold 0: Train: index=[0 6 1 8 2], group=[[-2, -1], [-2, -1], [1, 0], [1, 1], [1, 1]] Test: index=[4 9 7 5 3], group=[[1, 0], [-2, -1], [1, 0], [1, 1], [-2, -1]] Fold 1: Train: index=[1 2 8 0 3], group=[[1, 0], [1, 1], [1, 1], [-2, -1], [-2, -1]] Test: index=[7 5 9 4 6], group=[[1, 0], [1, 1], [-2, -1], [1, 0], [-2, -1]]
- set_split_request(*, sample_domain: bool | None | str = '$UNCHANGED$') StratifiedDomainShuffleSplit
Configure whether metadata should be requested to be passed to the
split
method.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True
(seesklearn.set_config()
). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True
: metadata is requested, and passed tosplit
if provided. The request is ignored if metadata is not provided.False
: metadata is not requested and the meta-estimator will not pass it tosplit
.None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- Parameters:
- sample_domainstr, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED
Metadata routing for
sample_domain
parameter insplit
.
- Returns:
- selfobject
The updated object.