Coverage for src/edelweiss/regressor.py: 100%

112 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-18 17:09 +0000

1# Copyright (C) 2024 ETH Zurich 

2# Institute for Particle Physics and Astrophysics 

3# Author: Silvan Fischbacher 

4# created: Sat Jan 27 2024 

5 

6 

7import os 

8import pickle 

9 

10import joblib 

11import numpy as np 

12from cosmic_toolbox import arraytools as at 

13from cosmic_toolbox import file_utils, logger 

14from sklearn.utils.class_weight import compute_sample_weight 

15 

16from edelweiss import clf_utils, reg_utils 

17 

18LOGGER = logger.get_logger(__file__) 

19 

20 

21def load_regressor(path, name="regressor"): 

22 """ 

23 Load a regressor from a given path. 

24 

25 :param path: path to the folder containing the regressor 

26 :param name: the name of the regressor 

27 :return: the loaded regressor 

28 """ 

29 with open(os.path.join(path, name + ".pkl"), "rb") as f: 

30 reg = pickle.load(f) 

31 reg.pipe = joblib.load(os.path.join(path, "model.pkl")) 

32 LOGGER.info(f"Regressor loaded from {path}") 

33 return reg 

34 

35 

36class Regressor: 

37 """ 

38 Wrapper class for a several regression models. 

39 

40 :param scaler: the scaler to use for the regressor 

41 :param reg: the regressor to use 

42 :param cv: number of cross validation folds, if 0 no cross validation is performed 

43 :param cv_scoring: the scoring method to use for cross validation 

44 :param input_params: the names of the input parameters 

45 :param output_params: the names of the output parameters 

46 :param reg_kwargs: additional keyword arguments for the regressor 

47 

48 """ 

49 

50 def __init__( 

51 self, 

52 scaler="standard", 

53 reg="linear", 

54 cv=0, 

55 cv_scoring="neg_mean_squared_error", 

56 input_params=None, 

57 output_params=None, 

58 **reg_kwargs, 

59 ): 

60 """ 

61 Initialize the regressor. 

62 """ 

63 self.scaler = scaler 

64 sc = clf_utils.get_scaler(scaler) 

65 self.y_scaler = clf_utils.get_scaler(scaler) 

66 self.reg = reg 

67 self.pipe = reg_utils.get_regressor(reg, sc, **reg_kwargs) 

68 self.cv = cv 

69 self.cv_scoring = cv_scoring 

70 self.reg_kwargs = reg_kwargs 

71 self.input_params = input_params 

72 self.output_params = output_params 

73 self._regressor = None 

74 self._scaler = None 

75 self.mad = None 

76 self.mse = None 

77 self.max_error = None 

78 

79 def train(self, X, y, flat_param=None, **args): 

80 """ 

81 Train the regressor. 

82 

83 :param X: the training data 

84 :param y: the training labels 

85 """ 

86 X, y = self._check_if_recarray(X, y) 

87 

88 if self.input_params is None: 

89 self.input_params = [] 

90 for i in range(X.shape[1]): 

91 self.input_params.append(f"param_{i}") 

92 if self.output_params is None: 

93 self.output_params = [] 

94 for i in range(y.shape[1]): 

95 self.output_params.append(f"param_{i}") 

96 

97 LOGGER.info("Training regressor") 

98 LOGGER.info(f"Input parameters: {self.input_params}") 

99 LOGGER.info(f"Output parameters: {self.output_params}") 

100 LOGGER.info(f"Number of training samples: {X.shape[0]}") 

101 

102 if self.cv > 1: 

103 LOGGER.error("Cross validation not implemented yet") 

104 

105 y = self.y_scaler.fit_transform(y) 

106 

107 sample_weight = None 

108 if flat_param is not None: 

109 flat_param_index = np.where(np.array(self.input_params) == flat_param)[0] 

110 if len(flat_param_index) == 0: 

111 # flat param is an output param 

112 flat_param_index = np.where(np.array(self.output_params) == flat_param)[ 

113 0 

114 ] 

115 sample_weight = compute_sample_weight( 

116 "balanced", y[:, flat_param_index] 

117 ) 

118 else: 

119 # flat param is an input param 

120 flat_param_index = flat_param_index[0] 

121 sample_weight = compute_sample_weight( 

122 "balanced", X[:, flat_param_index] 

123 ) 

124 self.pipe.fit(X, y, reg__sample_weight=sample_weight, **args) 

125 else: 

126 self.pipe.fit(X, y, **args) 

127 

128 fit = train 

129 

130 def _predict(self, X): 

131 """ 

132 Predict the output from the input. 

133 

134 :param X: the input data 

135 :return: the predicted output as an array 

136 """ 

137 X, _ = self._check_if_recarray(X, None) 

138 y = np.atleast_2d(self.pipe.predict(X)) 

139 return self.y_scaler.inverse_transform(y) 

140 

141 def predict(self, X): 

142 """ 

143 Predict the output from the input. 

144 

145 :param X: the input data 

146 :return: the predicted output as a recarray 

147 """ 

148 y = self._predict(X) 

149 return at.arr2rec(y, names=self.output_params) 

150 

151 __call__ = predict 

152 

153 def test(self, X, y): 

154 """ 

155 Test the regressor. 

156 

157 :param X: the test data 

158 :param y: the test labels 

159 """ 

160 X, y = self._check_if_recarray(X, y) 

161 LOGGER.info("Testing regressor") 

162 LOGGER.info(f"Number of test samples: {X.shape[0]}") 

163 y_pred = self._predict(X) 

164 mad = np.mean(np.abs(y - y_pred), axis=0) 

165 mse = np.mean((y - y_pred) ** 2, axis=0) 

166 max_error = np.max(np.abs(y - y_pred), axis=0) 

167 self.mad = at.arr2rec(mad, names=self.output_params) 

168 self.mse = at.arr2rec(mse, names=self.output_params) 

169 self.max_error = at.arr2rec(max_error, names=self.output_params) 

170 LOGGER.info(f"max MAD: {max(mad)}") 

171 LOGGER.info(f"max MSE: {max(mse)}") 

172 LOGGER.info(f"max Max error: {max(max_error)}") 

173 

174 relative_mad = mad / np.mean(np.abs(y), axis=0) 

175 relative_mse = mse / np.mean(y**2, axis=0) 

176 relative_max_error = max_error / np.max(np.abs(y), axis=0) 

177 self.relative_mad = at.arr2rec(relative_mad, names=self.output_params) 

178 self.relative_mse = at.arr2rec(relative_mse, names=self.output_params) 

179 self.relative_max_error = at.arr2rec( 

180 relative_max_error, names=self.output_params 

181 ) 

182 LOGGER.info(f"max relative MAD: {max(relative_mad)}") 

183 LOGGER.info(f"max relative MSE: {max(relative_mse)}") 

184 LOGGER.info(f"max relative Max error: {max(relative_max_error)}") 

185 

186 def save(self, path, name="regressor"): 

187 """ 

188 Save the regressor to a given path. 

189 

190 :param path: the path where to save the regressor 

191 :param name: the name of the regressor 

192 """ 

193 file_utils.robust_makedirs(path) 

194 joblib.dump(self.pipe, os.path.join(path, "model.pkl")) 

195 with open(os.path.join(path, name + ".pkl"), "wb") as f: 

196 pickle.dump(self, f) 

197 LOGGER.info(f"Regressor saved to {path}") 

198 

199 def _check_if_recarray(self, X, y=None): 

200 try: 

201 X, x_names = at.rec2arr(X, return_names=True) 

202 if self.input_params is None: 

203 self.input_params = x_names 

204 if y is not None: 

205 y, y_names = at.rec2arr(y, return_names=True) 

206 if self.output_params is None: 

207 self.output_params = y_names 

208 except Exception: 

209 # is already a normal array 

210 return X, y 

211 assert self.input_params == x_names, "Input parameters do not match" 

212 if y is not None: 

213 assert self.output_params == y_names, "Output parameters do not match" 

214 return X, y