Coverage for src/edelweiss/reg_utils.py: 100%

44 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-18 17:09 +0000

1# Copyright (C) 2024 ETH Zurich 

2# Institute for Particle Physics and Astrophysics 

3# Author: Silvan Fischbacher 

4# created: Sat Jan 27 2024 

5 

6 

7from catboost import CatBoostRegressor 

8from sklearn.ensemble import RandomForestRegressor 

9from sklearn.gaussian_process import GaussianProcessRegressor 

10from sklearn.gaussian_process.kernels import RBF 

11from sklearn.linear_model import ElasticNet, Lasso, LinearRegression, Ridge 

12from sklearn.neighbors import KNeighborsRegressor 

13from sklearn.neural_network import MLPRegressor 

14from sklearn.pipeline import Pipeline 

15from sklearn.svm import SVR 

16from sklearn.tree import DecisionTreeRegressor 

17from xgboost import XGBRegressor 

18 

19 

20def get_regressor(regressor, scaler, **kwargs): 

21 """ 

22 Returns the regressor object 

23 

24 :param regressor: name of the regressor 

25 :param scaler: scaler object 

26 :param kwargs: additional arguments for the regressor 

27 :return: regressor object (sklearn pipeline) 

28 :raises: ValueError if regressor is not known 

29 """ 

30 if regressor == "linear": 

31 reg = LinearRegression(**kwargs) 

32 elif regressor == "ridge": 

33 reg = Ridge(**kwargs) 

34 elif regressor == "lasso": 

35 reg = Lasso(**kwargs) 

36 elif regressor == "elasticnet": 

37 reg = ElasticNet(**kwargs) 

38 elif regressor == "knn": 

39 reg = KNeighborsRegressor(**kwargs) 

40 elif regressor == "svr": 

41 reg = SVR(**kwargs) 

42 elif regressor == "RandomForest": 

43 reg = RandomForestRegressor(**kwargs) 

44 elif regressor == "XGB": 

45 reg = XGBRegressor(**kwargs) 

46 elif regressor == "CatBoost": 

47 reg = CatBoostRegressor(**kwargs) 

48 elif regressor == "MLP": 

49 reg = MLPRegressor(**kwargs) 

50 elif regressor == "DecisionTree": 

51 reg = DecisionTreeRegressor(**kwargs) 

52 elif regressor == "NeuralNetwork": 

53 from .custom_regs import NeuralNetworkRegressor 

54 

55 reg = NeuralNetworkRegressor(**kwargs) 

56 elif regressor == "GaussianProcess": 

57 if "kernel" not in kwargs: 

58 kernel = RBF() 

59 reg = GaussianProcessRegressor(kernel=kernel, **kwargs) 

60 else: 

61 reg = GaussianProcessRegressor(**kwargs) 

62 else: 

63 raise ValueError(f"{regressor} not known") 

64 return Pipeline([("scaler", scaler), ("reg", reg)])