Coverage for src/edelweiss/reg_utils.py: 100%
44 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-18 17:09 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-18 17:09 +0000
1# Copyright (C) 2024 ETH Zurich
2# Institute for Particle Physics and Astrophysics
3# Author: Silvan Fischbacher
4# created: Sat Jan 27 2024
7from catboost import CatBoostRegressor
8from sklearn.ensemble import RandomForestRegressor
9from sklearn.gaussian_process import GaussianProcessRegressor
10from sklearn.gaussian_process.kernels import RBF
11from sklearn.linear_model import ElasticNet, Lasso, LinearRegression, Ridge
12from sklearn.neighbors import KNeighborsRegressor
13from sklearn.neural_network import MLPRegressor
14from sklearn.pipeline import Pipeline
15from sklearn.svm import SVR
16from sklearn.tree import DecisionTreeRegressor
17from xgboost import XGBRegressor
20def get_regressor(regressor, scaler, **kwargs):
21 """
22 Returns the regressor object
24 :param regressor: name of the regressor
25 :param scaler: scaler object
26 :param kwargs: additional arguments for the regressor
27 :return: regressor object (sklearn pipeline)
28 :raises: ValueError if regressor is not known
29 """
30 if regressor == "linear":
31 reg = LinearRegression(**kwargs)
32 elif regressor == "ridge":
33 reg = Ridge(**kwargs)
34 elif regressor == "lasso":
35 reg = Lasso(**kwargs)
36 elif regressor == "elasticnet":
37 reg = ElasticNet(**kwargs)
38 elif regressor == "knn":
39 reg = KNeighborsRegressor(**kwargs)
40 elif regressor == "svr":
41 reg = SVR(**kwargs)
42 elif regressor == "RandomForest":
43 reg = RandomForestRegressor(**kwargs)
44 elif regressor == "XGB":
45 reg = XGBRegressor(**kwargs)
46 elif regressor == "CatBoost":
47 reg = CatBoostRegressor(**kwargs)
48 elif regressor == "MLP":
49 reg = MLPRegressor(**kwargs)
50 elif regressor == "DecisionTree":
51 reg = DecisionTreeRegressor(**kwargs)
52 elif regressor == "NeuralNetwork":
53 from .custom_regs import NeuralNetworkRegressor
55 reg = NeuralNetworkRegressor(**kwargs)
56 elif regressor == "GaussianProcess":
57 if "kernel" not in kwargs:
58 kernel = RBF()
59 reg = GaussianProcessRegressor(kernel=kernel, **kwargs)
60 else:
61 reg = GaussianProcessRegressor(**kwargs)
62 else:
63 raise ValueError(f"{regressor} not known")
64 return Pipeline([("scaler", scaler), ("reg", reg)])