Coverage for src/edelweiss/compatibility_utils.py: 100%
8 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-07-31 10:21 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2025-07-31 10:21 +0000
1# Copyright (C) 2025 ETH Zurich
2# Institute for Particle Physics and Astrophysics
3# Author: Silvan Fischbacher
4# created: Wed Jul 30 2025
6import contextlib
8import numpy as np
9import sklearn
10from cosmic_toolbox import logger
11from packaging import version
12from sklearn.base import ClassifierMixin
13from sklearn.base import is_classifier as original_is_classifier
15LOGGER = logger.get_logger(__file__)
18class CompatibleCalibratedClassifier: # pragma: no cover
19 """
20 Compatibility wrapper for CalibratedClassifierCV to work with sklearn >= 1.6.
21 This class bypasses sklearn's stricter validation while maintaining the same
22 functionality.
23 """
25 def __init__(self, original_pipe):
26 self.original_pipe = original_pipe
27 self.classes_ = getattr(original_pipe, "classes_", np.array([0, 1]))
28 self.n_features_in_ = getattr(original_pipe, "n_features_in_", None)
29 # Forward the calibrated_classifiers_ attribute for TF model loading
30 self.calibrated_classifiers_ = getattr(
31 original_pipe, "calibrated_classifiers_", []
32 )
34 def predict(self, X):
35 # Try original first, fallback to manual aggregation
36 try:
37 return self.original_pipe.predict(X)
38 except Exception:
39 # Manual prediction by averaging calibrated classifiers
40 predictions = []
41 for cal_clf in self.original_pipe.calibrated_classifiers_:
42 pred = cal_clf.estimator.predict(X)
43 predictions.append(pred)
44 # Use majority vote
45 return np.array(
46 [1 if np.mean(preds) > 0.5 else 0 for preds in zip(*predictions)]
47 )
49 def predict_proba(self, X):
50 # Bypass sklearn validation by manually implementing calibrated prediction
51 try:
52 # Pre-allocate arrays for better performance
53 n_samples = X.shape[0]
54 n_calibrators = len(self.original_pipe.calibrated_classifiers_)
55 predictions = np.zeros((n_calibrators, n_samples))
57 for i, cal_clf in enumerate(self.original_pipe.calibrated_classifiers_):
58 # Get raw predictions from the estimator pipeline
59 if hasattr(cal_clf.estimator, "predict_proba"):
60 estimator_proba = cal_clf.estimator.predict_proba(X)
61 # Use the positive class probability for calibration
62 if estimator_proba.shape[1] == 2:
63 raw_pred = estimator_proba[:, 1]
64 else:
65 raw_pred = estimator_proba.flatten()
66 else:
67 # Fallback to predict if predict_proba not available
68 raw_pred = cal_clf.estimator.predict(X).astype(float)
70 # Apply calibration - try different attribute names for different
71 # sklearn versions
72 if hasattr(cal_clf, "calibrator"):
73 calibrated_pred = cal_clf.calibrator.predict(
74 raw_pred.reshape(-1, 1)
75 ).flatten()
76 elif hasattr(cal_clf, "calibrators"):
77 # For newer sklearn versions with calibrators (not calibrators_)
78 calibrated_pred = (
79 cal_clf.calibrators[0]
80 .predict(raw_pred.reshape(-1, 1))
81 .flatten()
82 )
83 elif hasattr(cal_clf, "calibrators_"):
84 # For sklearn versions with calibrators_
85 calibrated_pred = (
86 cal_clf.calibrators_[0]
87 .predict(raw_pred.reshape(-1, 1))
88 .flatten()
89 )
90 else:
91 # Fallback: use raw predictions
92 calibrated_pred = raw_pred
94 predictions[i] = calibrated_pred
96 # Average predictions across calibrated classifiers (vectorized)
97 avg_pred = np.mean(predictions, axis=0)
98 return np.column_stack([1 - avg_pred, avg_pred])
100 except Exception as e:
101 LOGGER.warning(f"Fallback prediction failed: {e}")
102 # Ultimate fallback: return uniform probabilities
103 return np.full((X.shape[0], 2), 0.5)
105 def set_params(self, **params):
106 return self.original_pipe.set_params(**params)
108 def get_params(self, deep=True):
109 return self.original_pipe.get_params(deep=deep)
112def fix_calibrated_classifier_compatibility(pipe): # pragma: no cover
113 """
114 Fix compatibility issues with CalibratedClassifierCV for scikit-learn >= 1.6.
116 This function creates a compatibility wrapper that bypasses sklearn's stricter
117 validation while maintaining the same functionality.
119 :param pipe: The pipeline to fix
120 :return: The original pipeline or a compatibility wrapper if needed
121 """
122 if not hasattr(pipe, "calibrated_classifiers_"):
123 return pipe
125 try:
126 # For sklearn >= 1.6, only create the wrapper if we actually encounter the error
127 if version.parse(sklearn.__version__) >= version.parse("1.6"):
128 # First, try to use the original pipeline to see if it works
129 try:
130 # Determine number of features from the pipeline
131 n_features = getattr(pipe, "n_features_in_", None)
133 n_feature_none = n_features is None
134 has_calibrated_classifiers = (
135 hasattr(pipe, "calibrated_classifiers_")
136 and len(pipe.calibrated_classifiers_) > 0
137 )
138 if n_feature_none & has_calibrated_classifiers:
139 first_estimator = pipe.calibrated_classifiers_[0].estimator
140 if hasattr(first_estimator, "named_steps"):
141 for _, step in first_estimator.named_steps.items():
142 if hasattr(step, "n_features_in_"):
143 n_features = step.n_features_in_
144 break
146 # Test with a small dummy array to check if validation fails
147 test_X = np.random.random((1, n_features))
148 _ = pipe.predict_proba(test_X)
149 # If we reach here, the original pipeline works fine
150 LOGGER.debug(
151 "Original pipeline works with sklearn >= 1.6, no wrapper needed"
152 )
153 return pipe
154 except ValueError as e:
155 if "Pipeline should either be a classifier" in str(e):
156 LOGGER.info(
157 "Creating compatibility wrapper for sklearn >= 1.6 due to "
158 "validation error"
159 )
160 return CompatibleCalibratedClassifier(pipe)
161 else:
162 # Different error, let it propagate
163 raise
164 except Exception:
165 # Other exceptions during testing, fall back to wrapper
166 LOGGER.info(
167 "Creating compatibility wrapper for sklearn >= 1.6 due to test "
168 "failure"
169 )
170 return CompatibleCalibratedClassifier(pipe)
172 return pipe
174 except Exception as e:
175 LOGGER.warning(f"Could not create compatibility wrapper: {e}")
176 return pipe
179def patched_is_classifier(estimator): # pragma: no cover
180 """
181 Enhanced is_classifier that works with custom classifiers in sklearn 1.6+.
183 In sklearn 1.6+, is_classifier relies on the new tagging system which may not
184 recognize custom classifiers properly. This function provides backward
185 compatibility by falling back to ClassifierMixin detection.
186 """
187 # First try the original function
188 result = original_is_classifier(estimator)
189 if result:
190 return True
192 # If original fails, check for ClassifierMixin as fallback
193 if isinstance(estimator, ClassifierMixin):
194 return True
196 # For pipelines, check if the final step is a classifier
197 if hasattr(estimator, "steps") and estimator.steps:
198 final_step = estimator.steps[-1][1]
199 if isinstance(final_step, ClassifierMixin):
200 return True
202 # For GridSearchCV, check the base estimator
203 if hasattr(estimator, "estimator"):
204 return patched_is_classifier(estimator.estimator)
206 return False
209def apply_sklearn_compatibility_patches(): # pragma: no cover
210 """
211 Apply compatibility patches for scikit-learn version differences.
213 This should be called early in the import process to ensure compatibility
214 across different sklearn versions.
215 """
216 sklearn_version = version.parse(sklearn.__version__)
218 # Apply patches for sklearn 1.6+ where is_classifier behavior changed
219 if sklearn_version >= version.parse("1.6.0"):
220 LOGGER.info("Applying sklearn 1.6+ compatibility patches for is_classifier")
222 # Patch the is_classifier function in relevant modules
223 sklearn.base.is_classifier = patched_is_classifier
225 # Try to patch sklearn.utils._response if it exists and is accessible
226 try:
227 sklearn.utils._response.is_classifier = patched_is_classifier
228 except (ImportError, AttributeError):
229 LOGGER.debug("sklearn.utils._response not available for patching")
231 # Also patch in calibration module if it exists
232 with contextlib.suppress(ImportError, AttributeError):
233 sklearn.calibration.is_classifier = patched_is_classifier