skada.JDOTClassifier
- skada.JDOTClassifier(base_estimator=None, alpha=0.5, metric='multinomial', n_iter_max=100, tol=1e-05, verbose=False, thr_weights=1e-06, **kwargs)[source]
Joint Distribution Optimal Transport Classifier proposed in [10]
Warning
This estimator assumes that the loss function optimized by the base estimator is compatible with the given metric. For instance, if the metric is 'multinomial', the base estimator should optimize a cross-entropy loss (e.g. LogisticRegression with multi_class='multinomial') or a hinge loss (e.g. SVC with kernel='linear' and one versus rest) if the metric is 'hinge'. While any estimator providing the necessary prediction functions can be used, the convergence of the fixed point is not guaranteed and behavior can be unpredictable.
- Parameters:
- base_estimatorobject
The base estimator to be used for the classification task. This estimator should solve a classification problem to correspond to JDOT theoretical classification problem but other approaches can be used with the risk that the fixed point might not converge. default value is LogisticRegression() from scikit-learn.
- alphafloat, default=0.5
The trade-off parameter between the feature and label loss in OT metric
- metricstr, default='multinomial'
The metric to use for the cost matrix. Can be 'multinomial' for cross-entropy cost/ multinomial logistic regression or 'hinge' for hinge cost (SVM/SVC).
- n_iter_max: int
Max number of JDOT alternat optimization iterations.
- tol: float>0
Tolerance for loss variations (OT and mse) stopping iterations.
- verbose: bool
Print loss along iterations if True.as_integer_ratio
- thr_weightsfloat, default=1e-6
The relative threshold for the weights
- Attributes:
- estimator_object
The fitted estimator.
- lst_loss_ot_list
The list of OT losses at each iteration.
- lst_loss_tgt_labels_list
The list of target labels losses at each iteration.
- sol_object
The solution of the OT problem.
References
- [10] N. Courty, R. Flamary, A. Habrard, A. Rakotomamonjy, Joint Distribution
Optimal Transportation for Domain Adaptation, Neural Information Processing Systems (NIPS), 2017.