Coverage for src/edelweiss/compatibility_utils.py: 100%

8 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-07-31 10:21 +0000

1# Copyright (C) 2025 ETH Zurich 

2# Institute for Particle Physics and Astrophysics 

3# Author: Silvan Fischbacher 

4# created: Wed Jul 30 2025 

5 

6import contextlib 

7 

8import numpy as np 

9import sklearn 

10from cosmic_toolbox import logger 

11from packaging import version 

12from sklearn.base import ClassifierMixin 

13from sklearn.base import is_classifier as original_is_classifier 

14 

15LOGGER = logger.get_logger(__file__) 

16 

17 

18class CompatibleCalibratedClassifier: # pragma: no cover 

19 """ 

20 Compatibility wrapper for CalibratedClassifierCV to work with sklearn >= 1.6. 

21 This class bypasses sklearn's stricter validation while maintaining the same 

22 functionality. 

23 """ 

24 

25 def __init__(self, original_pipe): 

26 self.original_pipe = original_pipe 

27 self.classes_ = getattr(original_pipe, "classes_", np.array([0, 1])) 

28 self.n_features_in_ = getattr(original_pipe, "n_features_in_", None) 

29 # Forward the calibrated_classifiers_ attribute for TF model loading 

30 self.calibrated_classifiers_ = getattr( 

31 original_pipe, "calibrated_classifiers_", [] 

32 ) 

33 

34 def predict(self, X): 

35 # Try original first, fallback to manual aggregation 

36 try: 

37 return self.original_pipe.predict(X) 

38 except Exception: 

39 # Manual prediction by averaging calibrated classifiers 

40 predictions = [] 

41 for cal_clf in self.original_pipe.calibrated_classifiers_: 

42 pred = cal_clf.estimator.predict(X) 

43 predictions.append(pred) 

44 # Use majority vote 

45 return np.array( 

46 [1 if np.mean(preds) > 0.5 else 0 for preds in zip(*predictions)] 

47 ) 

48 

49 def predict_proba(self, X): 

50 # Bypass sklearn validation by manually implementing calibrated prediction 

51 try: 

52 # Pre-allocate arrays for better performance 

53 n_samples = X.shape[0] 

54 n_calibrators = len(self.original_pipe.calibrated_classifiers_) 

55 predictions = np.zeros((n_calibrators, n_samples)) 

56 

57 for i, cal_clf in enumerate(self.original_pipe.calibrated_classifiers_): 

58 # Get raw predictions from the estimator pipeline 

59 if hasattr(cal_clf.estimator, "predict_proba"): 

60 estimator_proba = cal_clf.estimator.predict_proba(X) 

61 # Use the positive class probability for calibration 

62 if estimator_proba.shape[1] == 2: 

63 raw_pred = estimator_proba[:, 1] 

64 else: 

65 raw_pred = estimator_proba.flatten() 

66 else: 

67 # Fallback to predict if predict_proba not available 

68 raw_pred = cal_clf.estimator.predict(X).astype(float) 

69 

70 # Apply calibration - try different attribute names for different 

71 # sklearn versions 

72 if hasattr(cal_clf, "calibrator"): 

73 calibrated_pred = cal_clf.calibrator.predict( 

74 raw_pred.reshape(-1, 1) 

75 ).flatten() 

76 elif hasattr(cal_clf, "calibrators"): 

77 # For newer sklearn versions with calibrators (not calibrators_) 

78 calibrated_pred = ( 

79 cal_clf.calibrators[0] 

80 .predict(raw_pred.reshape(-1, 1)) 

81 .flatten() 

82 ) 

83 elif hasattr(cal_clf, "calibrators_"): 

84 # For sklearn versions with calibrators_ 

85 calibrated_pred = ( 

86 cal_clf.calibrators_[0] 

87 .predict(raw_pred.reshape(-1, 1)) 

88 .flatten() 

89 ) 

90 else: 

91 # Fallback: use raw predictions 

92 calibrated_pred = raw_pred 

93 

94 predictions[i] = calibrated_pred 

95 

96 # Average predictions across calibrated classifiers (vectorized) 

97 avg_pred = np.mean(predictions, axis=0) 

98 return np.column_stack([1 - avg_pred, avg_pred]) 

99 

100 except Exception as e: 

101 LOGGER.warning(f"Fallback prediction failed: {e}") 

102 # Ultimate fallback: return uniform probabilities 

103 return np.full((X.shape[0], 2), 0.5) 

104 

105 def set_params(self, **params): 

106 return self.original_pipe.set_params(**params) 

107 

108 def get_params(self, deep=True): 

109 return self.original_pipe.get_params(deep=deep) 

110 

111 

112def fix_calibrated_classifier_compatibility(pipe): # pragma: no cover 

113 """ 

114 Fix compatibility issues with CalibratedClassifierCV for scikit-learn >= 1.6. 

115 

116 This function creates a compatibility wrapper that bypasses sklearn's stricter 

117 validation while maintaining the same functionality. 

118 

119 :param pipe: The pipeline to fix 

120 :return: The original pipeline or a compatibility wrapper if needed 

121 """ 

122 if not hasattr(pipe, "calibrated_classifiers_"): 

123 return pipe 

124 

125 try: 

126 # For sklearn >= 1.6, only create the wrapper if we actually encounter the error 

127 if version.parse(sklearn.__version__) >= version.parse("1.6"): 

128 # First, try to use the original pipeline to see if it works 

129 try: 

130 # Determine number of features from the pipeline 

131 n_features = getattr(pipe, "n_features_in_", None) 

132 

133 n_feature_none = n_features is None 

134 has_calibrated_classifiers = ( 

135 hasattr(pipe, "calibrated_classifiers_") 

136 and len(pipe.calibrated_classifiers_) > 0 

137 ) 

138 if n_feature_none & has_calibrated_classifiers: 

139 first_estimator = pipe.calibrated_classifiers_[0].estimator 

140 if hasattr(first_estimator, "named_steps"): 

141 for _, step in first_estimator.named_steps.items(): 

142 if hasattr(step, "n_features_in_"): 

143 n_features = step.n_features_in_ 

144 break 

145 

146 # Test with a small dummy array to check if validation fails 

147 test_X = np.random.random((1, n_features)) 

148 _ = pipe.predict_proba(test_X) 

149 # If we reach here, the original pipeline works fine 

150 LOGGER.debug( 

151 "Original pipeline works with sklearn >= 1.6, no wrapper needed" 

152 ) 

153 return pipe 

154 except ValueError as e: 

155 if "Pipeline should either be a classifier" in str(e): 

156 LOGGER.info( 

157 "Creating compatibility wrapper for sklearn >= 1.6 due to " 

158 "validation error" 

159 ) 

160 return CompatibleCalibratedClassifier(pipe) 

161 else: 

162 # Different error, let it propagate 

163 raise 

164 except Exception: 

165 # Other exceptions during testing, fall back to wrapper 

166 LOGGER.info( 

167 "Creating compatibility wrapper for sklearn >= 1.6 due to test " 

168 "failure" 

169 ) 

170 return CompatibleCalibratedClassifier(pipe) 

171 

172 return pipe 

173 

174 except Exception as e: 

175 LOGGER.warning(f"Could not create compatibility wrapper: {e}") 

176 return pipe 

177 

178 

179def patched_is_classifier(estimator): # pragma: no cover 

180 """ 

181 Enhanced is_classifier that works with custom classifiers in sklearn 1.6+. 

182 

183 In sklearn 1.6+, is_classifier relies on the new tagging system which may not 

184 recognize custom classifiers properly. This function provides backward 

185 compatibility by falling back to ClassifierMixin detection. 

186 """ 

187 # First try the original function 

188 result = original_is_classifier(estimator) 

189 if result: 

190 return True 

191 

192 # If original fails, check for ClassifierMixin as fallback 

193 if isinstance(estimator, ClassifierMixin): 

194 return True 

195 

196 # For pipelines, check if the final step is a classifier 

197 if hasattr(estimator, "steps") and estimator.steps: 

198 final_step = estimator.steps[-1][1] 

199 if isinstance(final_step, ClassifierMixin): 

200 return True 

201 

202 # For GridSearchCV, check the base estimator 

203 if hasattr(estimator, "estimator"): 

204 return patched_is_classifier(estimator.estimator) 

205 

206 return False 

207 

208 

209def apply_sklearn_compatibility_patches(): # pragma: no cover 

210 """ 

211 Apply compatibility patches for scikit-learn version differences. 

212 

213 This should be called early in the import process to ensure compatibility 

214 across different sklearn versions. 

215 """ 

216 sklearn_version = version.parse(sklearn.__version__) 

217 

218 # Apply patches for sklearn 1.6+ where is_classifier behavior changed 

219 if sklearn_version >= version.parse("1.6.0"): 

220 LOGGER.info("Applying sklearn 1.6+ compatibility patches for is_classifier") 

221 

222 # Patch the is_classifier function in relevant modules 

223 sklearn.base.is_classifier = patched_is_classifier 

224 

225 # Try to patch sklearn.utils._response if it exists and is accessible 

226 try: 

227 sklearn.utils._response.is_classifier = patched_is_classifier 

228 except (ImportError, AttributeError): 

229 LOGGER.debug("sklearn.utils._response not available for patching") 

230 

231 # Also patch in calibration module if it exists 

232 with contextlib.suppress(ImportError, AttributeError): 

233 sklearn.calibration.is_classifier = patched_is_classifier