Coverage for src / cosmic_toolbox / MultiInterp.py: 60%

181 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-31 12:38 +0000

1""" 

2Multi-method interpolation framework. 

3 

4Provides flexible N-dimensional interpolation with multiple backend methods 

5including nearest neighbors, radial basis functions, linear interpolation, 

6and machine learning regressors. 

7 

8author: Tomasz Kacprzak 

9""" 

10 

11import sys 

12import warnings 

13 

14import numpy as np 

15from scipy.interpolate import LinearNDInterpolator, NearestNDInterpolator, Rbf 

16from sklearn.ensemble import RandomForestRegressor 

17from sklearn.neighbors import BallTree, KNeighborsRegressor, RadiusNeighborsRegressor 

18from sklearn.preprocessing import MinMaxScaler 

19 

20from cosmic_toolbox.logger import get_logger 

21 

22warnings.filterwarnings("ignore", category=DeprecationWarning) 

23warnings.filterwarnings("ignore", category=RuntimeWarning) 

24warnings.filterwarnings("once", category=UserWarning) 

25LOGGER = get_logger(__file__) 

26 

27 

28class Rbft: 

29 """ 

30 Radial Basis Function interpolator with bounds checking. 

31 

32 Wraps scipy's Rbf with automatic coordinate scaling and bounds checking. 

33 Points outside the training data bounds return -inf. 

34 

35 :param points: Training points, shape (n_samples, n_dims). 

36 :type points: numpy.ndarray 

37 :param values: Training values, shape (n_samples,). 

38 :type values: numpy.ndarray 

39 :param kw_rbf: Additional keyword arguments passed to scipy.interpolate.Rbf. 

40 """ 

41 

42 def __init__(self, points, values, **kw_rbf): 

43 self.points = points 

44 self.values = values 

45 self.bounds = [np.min(points, axis=0), np.max(points, axis=0)] 

46 self.scaler = MinMaxScaler() 

47 self.scaler.fit(self.points) 

48 self.points = self.scaler.transform(self.points) 

49 self.interp = Rbf(*list(self.points.T), self.values, **kw_rbf) 

50 

51 def __call__(self, points, **kw_rbf): 

52 """ 

53 Evaluate the interpolator at given points. 

54 

55 :param points: Points at which to interpolate, shape (n_points, n_dims). 

56 :type points: numpy.ndarray 

57 :param kw_rbf: Additional keyword arguments (unused). 

58 :return: Interpolated values. Out-of-bounds points return -inf. 

59 :rtype: numpy.ndarray 

60 """ 

61 values_pred = np.zeros(len(points)) 

62 select = self._in_bounds(points) 

63 values_pred[~select] = -np.inf 

64 

65 points = self.scaler.transform(points) 

66 values_pred[select] = self.interp(*list(points[select].T)) 

67 return values_pred 

68 

69 def _in_bounds(self, x): 

70 """ 

71 Check if points are within training data bounds. 

72 

73 :param x: Points to check, shape (n_points, n_dims). 

74 :type x: numpy.ndarray 

75 :return: Boolean array indicating which points are in bounds. 

76 :rtype: numpy.ndarray 

77 """ 

78 return np.all(x > self.bounds[0], axis=1) & np.all(x < self.bounds[1], axis=1) 

79 

80 

81def query_split(X, tree, k, n_proc): 

82 """ 

83 Query BallTree in parallel using multiprocessing. 

84 

85 Splits the query points across multiple processes for parallel execution. 

86 

87 :param X: Query points, shape (n_points, n_dims). 

88 :type X: numpy.ndarray 

89 :param tree: BallTree instance to query. 

90 :type tree: sklearn.neighbors.BallTree 

91 :param k: Number of nearest neighbors to find. 

92 :type k: int 

93 :param n_proc: Number of parallel processes to use. 

94 :type n_proc: int 

95 :return: Tuple of (distances, indices) arrays. 

96 :rtype: tuple 

97 """ 

98 nx = X.shape[0] 

99 n_per_batch = int(np.ceil(nx / n_proc)) 

100 LOGGER.info( 

101 f"querying BallTree with a pool for n_grid={nx} " 

102 f"n_proc={n_proc} n_per_batch={n_per_batch} n_neighbors={k}" 

103 ) 

104 X_chunks = [X[(i * n_per_batch) : (i + 1) * n_per_batch, :] for i in range(n_proc)] 

105 from functools import partial 

106 from multiprocessing import Pool 

107 

108 f = partial(query_batch, tree=tree, k=k, n_per_batch=100000) 

109 with Pool(n_proc) as pool: 

110 list_y = pool.map(f, X_chunks) 

111 

112 distances = np.concatenate([list_y[i][0] for i in range(n_proc)]) 

113 indices = np.concatenate([list_y[i][1] for i in range(n_proc)]) 

114 return distances, indices 

115 

116 

117def query_batch(X, tree, k=100, n_per_batch=10000): 

118 """ 

119 Query BallTree in batches to manage memory usage. 

120 

121 :param X: Query points, shape (n_points, n_dims). 

122 :type X: numpy.ndarray 

123 :param tree: BallTree instance to query. 

124 :type tree: sklearn.neighbors.BallTree 

125 :param k: Number of nearest neighbors to find. Defaults to 100. 

126 :type k: int 

127 :param n_per_batch: Number of points per batch. Defaults to 10000. 

128 :type n_per_batch: int 

129 :return: Tuple of (distances, indices) arrays. 

130 :rtype: tuple 

131 """ 

132 nx = X.shape[0] 

133 n_batches = int(np.ceil(nx / n_per_batch)) 

134 indices = np.zeros([nx, k], dtype=np.int64) 

135 distances = np.zeros([nx, k]) 

136 for i in range(n_batches): 

137 si, ei = i * n_per_batch, (i + 1) * n_per_batch 

138 Xq = X[si:ei, :] 

139 if len(Xq) > 0: 

140 dist, ind = tree.query(Xq, k=k) 

141 indices[si:ei, :] = ind 

142 distances[si:ei, :] = dist 

143 LOGGER.info(f"batch={i:>6}/{n_batches:>6}") 

144 

145 return distances, indices 

146 

147 

148def predict_with_neighbours(y, ind, dist): 

149 """ 

150 Predict values using inverse-distance weighted average of neighbors. 

151 

152 :param y: Training values, shape (n_samples,). 

153 :type y: numpy.ndarray 

154 :param ind: Neighbor indices, shape (n_points, n_neighbors). 

155 :type ind: numpy.ndarray 

156 :param dist: Neighbor distances, shape (n_points, n_neighbors). 

157 :type dist: numpy.ndarray 

158 :return: Predicted values, shape (n_points,). 

159 :rtype: numpy.ndarray 

160 """ 

161 yn = y[ind] 

162 wn = 1.0 / dist 

163 wn[wn == 0] = 1e10 

164 yi = np.average(yn, weights=wn, axis=1) 

165 

166 return yi 

167 

168 

169def predict_knn_balltree(Xi, X, y, n_neighbors, tree): 

170 """ 

171 Predict using k-nearest neighbors with BallTree. 

172 

173 :param Xi: Query points, shape (n_points, n_dims). 

174 :type Xi: numpy.ndarray 

175 :param X: Training points (unused, neighbors from tree). 

176 :type X: numpy.ndarray 

177 :param y: Training values, shape (n_samples,). 

178 :type y: numpy.ndarray 

179 :param n_neighbors: Number of neighbors to use. 

180 :type n_neighbors: int 

181 :param tree: BallTree instance for neighbor queries. 

182 :type tree: sklearn.neighbors.BallTree 

183 :return: Predicted values, shape (n_points,). 

184 :rtype: numpy.ndarray 

185 """ 

186 nx, nd = Xi.shape 

187 dist, ind = tree.query(Xi, k=n_neighbors) 

188 yi = predict_with_neighbours(y, ind, dist) 

189 return yi 

190 

191 

192def predict_knn_linear(Xi, X, y, n_neighbors, tree): 

193 """ 

194 Predict using local linear interpolation of nearest neighbors. 

195 

196 For each query point, finds k-nearest neighbors and fits a local 

197 linear interpolator using those neighbors. 

198 

199 :param Xi: Query points, shape (n_points, n_dims). 

200 :type Xi: numpy.ndarray 

201 :param X: Training points, shape (n_samples, n_dims). 

202 :type X: numpy.ndarray 

203 :param y: Training values, shape (n_samples,). 

204 :type y: numpy.ndarray 

205 :param n_neighbors: Number of neighbors to use for local interpolation. 

206 :type n_neighbors: int 

207 :param tree: BallTree instance for neighbor queries. 

208 :type tree: sklearn.neighbors.BallTree 

209 :return: Predicted values, shape (n_points,). Out-of-hull points return -inf. 

210 :rtype: numpy.ndarray 

211 """ 

212 nx, nd = Xi.shape 

213 dist, ind = tree.query(Xi, k=n_neighbors) 

214 X_nearest = X[ind, :] 

215 y_nearest = y[ind] 

216 yi = np.full(nx, -np.inf) 

217 n_nan = 0 

218 

219 for i in range(nx): 

220 Xn = X_nearest[i, :] 

221 yn = y_nearest[i, :] 

222 interp = LinearNDInterpolator(Xn, yn) 

223 yi[i] = interp(Xi[i, :]) 

224 if ~np.isfinite(yi[i]): 

225 n_nan += 1 

226 sys.stdout.write(f"\r{i}/{nx} {n_nan}") 

227 

228 return yi 

229 

230 

231class MultiInterp: 

232 """ 

233 Multi-method N-dimensional interpolator. 

234 

235 Provides a unified interface to multiple interpolation methods with 

236 automatic coordinate scaling and bounds checking. 

237 

238 :param X: Training points, shape (n_samples, n_dims). 

239 :type X: numpy.ndarray 

240 :param y: Training values, shape (n_samples,). 

241 :type y: numpy.ndarray 

242 :param method: Interpolation method. Options: 

243 - 'nn': Nearest neighbor (scipy NearestNDInterpolator) 

244 - 'linear': Linear interpolation (scipy LinearNDInterpolator) 

245 - 'rbf': Radial basis function (scipy Rbf) 

246 - 'rbft': Radial basis function with bounds (Rbft class) 

247 - 'knn_regression': K-neighbors regression (sklearn) 

248 - 'rnn_regression': Radius neighbors regression (sklearn) 

249 - 'knn_linear': Local linear interpolation 

250 - 'knn_balltree': K-neighbors with BallTree 

251 - 'random_forest': Random forest regressor (sklearn) 

252 :type method: str 

253 :param kw: Additional keyword arguments passed to the underlying interpolator. 

254 

255 Example 

256 ------- 

257 >>> import numpy as np 

258 >>> from cosmic_toolbox.MultiInterp import MultiInterp 

259 >>> X = np.random.rand(100, 2) 

260 >>> y = np.sin(X[:, 0]) + np.cos(X[:, 1]) 

261 >>> interp = MultiInterp(X, y, method='nn') 

262 >>> Xi = np.random.rand(10, 2) 

263 >>> yi = interp(Xi) 

264 """ 

265 

266 def __init__(self, X, y, method="nn", **kw): 

267 self.X = X.copy() 

268 self.y = y.copy() 

269 self.bounds = [np.min(X, axis=0), np.max(X, axis=0)] 

270 self.scaler = MinMaxScaler() 

271 self.scaler.fit(self.X) 

272 self.X = self.scaler.transform(self.X) 

273 self.interp = None 

274 self.method = method 

275 self.kw = kw 

276 self.init_interp(**kw) 

277 

278 def __call__(self, Xi, **kw): 

279 """ 

280 Evaluate the interpolator at given points. 

281 

282 :param Xi: Points at which to interpolate, shape (n_points, n_dims). 

283 :type Xi: numpy.ndarray 

284 :param kw: Additional keyword arguments passed to underlying interpolator. 

285 :return: Interpolated values. Out-of-bounds points return -inf. 

286 :rtype: numpy.ndarray 

287 """ 

288 yi = np.zeros(len(Xi)) 

289 select = self._in_bounds(Xi) 

290 yi[~select] = -np.inf 

291 

292 Xi = self.scaler.transform(Xi.copy()) 

293 

294 if self.method.lower() == "rbf": 

295 yi[select] = self.interp(*list(Xi[select, :].T), **kw) 

296 

297 elif self.method.lower() == "rbft": 

298 yi[select] = self.interp(Xi[select, :], **kw) 

299 

300 elif self.method.lower() in ["knn_regression", "rnn_regression"]: 

301 with warnings.catch_warnings(): 

302 warnings.simplefilter("ignore") 

303 yi[select] = self.interp.predict(Xi[select, :], **kw) 

304 

305 elif self.method.lower() in ["nn", "linear"]: 

306 yi[select] = self.interp(Xi[select, :], **kw) 

307 

308 elif self.method.lower() == "knn_linear": 

309 yi[select] = predict_knn_linear( 

310 Xi[select, :], 

311 self.X, 

312 self.y, 

313 self.n_neighbors, 

314 self.interp, 

315 **kw, 

316 ) 

317 

318 elif self.method.lower() == "random_forest": 

319 yi[select] = self.interp.predict(Xi[select, :], **kw) 

320 

321 elif self.method.lower() == "knn_balltree": 

322 yi[select] = predict_knn_balltree( 

323 Xi[select, :], 

324 self.X, 

325 self.y, 

326 self.n_neighbors, 

327 self.interp, 

328 **kw, 

329 ) 

330 

331 else: 

332 raise Exception(f"unknown interp method {self.method}") 

333 

334 return yi 

335 

336 def _in_bounds(self, X): 

337 """ 

338 Check if points are within training data bounds. 

339 

340 :param X: Points to check, shape (n_points, n_dims). 

341 :type X: numpy.ndarray 

342 :return: Boolean array indicating which points are in bounds. 

343 :rtype: numpy.ndarray 

344 """ 

345 return np.all(self.bounds[0] < X, axis=1) & np.all(self.bounds[1] > X, axis=1) 

346 

347 def init_interp(self, **kw): 

348 """ 

349 Initialize the underlying interpolator based on method. 

350 

351 :param kw: Keyword arguments passed to the underlying interpolator. 

352 :raises Exception: If method is unknown. 

353 """ 

354 if self.method.lower() == "rbf": 

355 self.interp = Rbf(*list(self.X.T), self.y, **kw) 

356 

357 elif self.method.lower() == "rbft": 

358 self.interp = Rbft(points=self.X, values=self.y, **kw) 

359 

360 elif self.method.lower() == "knn" or self.method.lower() == "knn_regression": 

361 self.interp = KNeighborsRegressor(**kw) 

362 self.interp.fit(self.X, self.y) 

363 

364 elif self.method.lower() == "rnn_regression": 

365 kw.setdefault("radius", 0.1) 

366 self.interp = RadiusNeighborsRegressor(**kw) 

367 self.interp.fit(self.X, self.y) 

368 

369 elif self.method.lower() == "nn": 

370 self.interp = NearestNDInterpolator(self.X, self.y, **kw) 

371 

372 elif self.method.lower() == "linear": 

373 self.interp = LinearNDInterpolator(self.X, self.y, **kw) 

374 

375 elif self.method.lower() == "knn_linear": 

376 # self.slice_linear_upsampling(n_repeat=1) 

377 n_neighbors = kw.pop("n_neighbors", self.X.shape[1] * 3) 

378 self.interp = BallTree(self.X, **kw) 

379 self.n_neighbors = n_neighbors 

380 

381 elif self.method.lower() == "random_forest": 

382 self.interp = RandomForestRegressor(**kw) 

383 self.interp.fit(self.X, self.y) 

384 

385 elif self.method.lower() == "knn_balltree": 

386 n_neighbors = kw.pop("n_neighbors", self.X.shape[1] * 3) 

387 self.interp = BallTree(self.X, **kw) 

388 self.n_neighbors = n_neighbors 

389 

390 else: 

391 raise Exception(f"unknown interp method {self.method}") 

392 

393 def precompute_grid_neighbors(self, Xn, n_neighbors=100, n_proc=1): 

394 """ 

395 Precompute neighbors for a grid of query points. 

396 

397 Useful for repeated interpolation on the same grid with different 

398 training values (e.g., in MCMC sampling). 

399 

400 :param Xn: Grid points, shape (n_grid, n_dims). 

401 :type Xn: numpy.ndarray 

402 :param n_neighbors: Number of neighbors to precompute. Defaults to 100. 

403 :type n_neighbors: int 

404 :param n_proc: Number of parallel processes. Defaults to 1. 

405 :type n_proc: int 

406 :raises AssertionError: If method is not 'knn_balltree'. 

407 """ 

408 assert self.method == "knn_balltree" 

409 

410 Xn = self.scaler.transform(Xn.copy()) 

411 dist, ind = query_split(tree=self.interp, X=Xn, k=n_neighbors, n_proc=n_proc) 

412 

413 self.neighbors_ind = ind 

414 self.neighbors_dist = dist.astype(np.float32) 

415 self.neighbors_Xn = Xn.astype(np.float32) 

416 

417 def interpolate_grid_neighbours(self, y, n_neighbors=None): 

418 """ 

419 Interpolate using precomputed grid neighbors. 

420 

421 Requires prior call to precompute_grid_neighbors. 

422 

423 :param y: Training values, shape (n_samples,). 

424 :type y: numpy.ndarray 

425 :param n_neighbors: Number of neighbors to use. Defaults to self.n_neighbors. 

426 :type n_neighbors: int or None 

427 :return: Interpolated values at grid points. 

428 :rtype: numpy.ndarray 

429 :raises Exception: If requested neighbors exceeds precomputed neighbors. 

430 """ 

431 assert len(y) == len(self.X) 

432 if n_neighbors is None: 

433 n_neighbors = self.n_neighbors 

434 

435 nn = self.neighbors_ind.shape[1] 

436 if n_neighbors > nn: 

437 raise Exception(f"number of available neighbors {nn}") 

438 yi = predict_with_neighbours( 

439 y, 

440 self.neighbors_ind[:, :n_neighbors], 

441 self.neighbors_dist[:, :n_neighbors], 

442 ) 

443 

444 return yi 

445 

446 def slice_linear_upsampling(self, n_repeat=1, n_neighbors=1): 

447 """ 

448 Upsample training data by adding midpoints between neighbors. 

449 

450 Experimental method to densify training data for better interpolation. 

451 

452 :param n_repeat: Number of upsampling iterations. Defaults to 1. 

453 :type n_repeat: int 

454 :param n_neighbors: Number of neighbors to use for midpoint generation. 

455 Defaults to 1. 

456 :type n_neighbors: int 

457 """ 

458 for _ in range(n_repeat): 

459 bt = BallTree(self.X) 

460 dist, ids = bt.query(self.X, k=n_neighbors + 1) 

461 list_Xn = [self.X] 

462 list_yn = [self.y] 

463 for j in range(1, n_neighbors + 1): 

464 Xn = (self.X + self.X[ids[:, j], :]) / 2.0 

465 yn = (self.y + self.y[ids[:, j]]) / 2.0 

466 list_Xn += [Xn] 

467 list_yn += [yn] 

468 

469 # Xp = self.X[ids,:] 

470 # yp = self.y[ids] 

471 # Xn = Xp.mean(axis=1) 

472 # yn = yp.mean(axis=1) 

473 # Xc = np.concatenate([self.X, Xn], axis=0) 

474 # yc = np.concatenate([self.y, yn]) 

475 Xn = np.concatenate(list_Xn, axis=0) 

476 yn = np.concatenate(list_yn, axis=0) 

477 self.X = Xn 

478 self.y = yn 

479 print(Xn.shape, yn.shape)