Coverage for src / cosmic_toolbox / MultiInterp.py: 60%
181 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-31 12:38 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-31 12:38 +0000
1"""
2Multi-method interpolation framework.
4Provides flexible N-dimensional interpolation with multiple backend methods
5including nearest neighbors, radial basis functions, linear interpolation,
6and machine learning regressors.
8author: Tomasz Kacprzak
9"""
11import sys
12import warnings
14import numpy as np
15from scipy.interpolate import LinearNDInterpolator, NearestNDInterpolator, Rbf
16from sklearn.ensemble import RandomForestRegressor
17from sklearn.neighbors import BallTree, KNeighborsRegressor, RadiusNeighborsRegressor
18from sklearn.preprocessing import MinMaxScaler
20from cosmic_toolbox.logger import get_logger
22warnings.filterwarnings("ignore", category=DeprecationWarning)
23warnings.filterwarnings("ignore", category=RuntimeWarning)
24warnings.filterwarnings("once", category=UserWarning)
25LOGGER = get_logger(__file__)
28class Rbft:
29 """
30 Radial Basis Function interpolator with bounds checking.
32 Wraps scipy's Rbf with automatic coordinate scaling and bounds checking.
33 Points outside the training data bounds return -inf.
35 :param points: Training points, shape (n_samples, n_dims).
36 :type points: numpy.ndarray
37 :param values: Training values, shape (n_samples,).
38 :type values: numpy.ndarray
39 :param kw_rbf: Additional keyword arguments passed to scipy.interpolate.Rbf.
40 """
42 def __init__(self, points, values, **kw_rbf):
43 self.points = points
44 self.values = values
45 self.bounds = [np.min(points, axis=0), np.max(points, axis=0)]
46 self.scaler = MinMaxScaler()
47 self.scaler.fit(self.points)
48 self.points = self.scaler.transform(self.points)
49 self.interp = Rbf(*list(self.points.T), self.values, **kw_rbf)
51 def __call__(self, points, **kw_rbf):
52 """
53 Evaluate the interpolator at given points.
55 :param points: Points at which to interpolate, shape (n_points, n_dims).
56 :type points: numpy.ndarray
57 :param kw_rbf: Additional keyword arguments (unused).
58 :return: Interpolated values. Out-of-bounds points return -inf.
59 :rtype: numpy.ndarray
60 """
61 values_pred = np.zeros(len(points))
62 select = self._in_bounds(points)
63 values_pred[~select] = -np.inf
65 points = self.scaler.transform(points)
66 values_pred[select] = self.interp(*list(points[select].T))
67 return values_pred
69 def _in_bounds(self, x):
70 """
71 Check if points are within training data bounds.
73 :param x: Points to check, shape (n_points, n_dims).
74 :type x: numpy.ndarray
75 :return: Boolean array indicating which points are in bounds.
76 :rtype: numpy.ndarray
77 """
78 return np.all(x > self.bounds[0], axis=1) & np.all(x < self.bounds[1], axis=1)
81def query_split(X, tree, k, n_proc):
82 """
83 Query BallTree in parallel using multiprocessing.
85 Splits the query points across multiple processes for parallel execution.
87 :param X: Query points, shape (n_points, n_dims).
88 :type X: numpy.ndarray
89 :param tree: BallTree instance to query.
90 :type tree: sklearn.neighbors.BallTree
91 :param k: Number of nearest neighbors to find.
92 :type k: int
93 :param n_proc: Number of parallel processes to use.
94 :type n_proc: int
95 :return: Tuple of (distances, indices) arrays.
96 :rtype: tuple
97 """
98 nx = X.shape[0]
99 n_per_batch = int(np.ceil(nx / n_proc))
100 LOGGER.info(
101 f"querying BallTree with a pool for n_grid={nx} "
102 f"n_proc={n_proc} n_per_batch={n_per_batch} n_neighbors={k}"
103 )
104 X_chunks = [X[(i * n_per_batch) : (i + 1) * n_per_batch, :] for i in range(n_proc)]
105 from functools import partial
106 from multiprocessing import Pool
108 f = partial(query_batch, tree=tree, k=k, n_per_batch=100000)
109 with Pool(n_proc) as pool:
110 list_y = pool.map(f, X_chunks)
112 distances = np.concatenate([list_y[i][0] for i in range(n_proc)])
113 indices = np.concatenate([list_y[i][1] for i in range(n_proc)])
114 return distances, indices
117def query_batch(X, tree, k=100, n_per_batch=10000):
118 """
119 Query BallTree in batches to manage memory usage.
121 :param X: Query points, shape (n_points, n_dims).
122 :type X: numpy.ndarray
123 :param tree: BallTree instance to query.
124 :type tree: sklearn.neighbors.BallTree
125 :param k: Number of nearest neighbors to find. Defaults to 100.
126 :type k: int
127 :param n_per_batch: Number of points per batch. Defaults to 10000.
128 :type n_per_batch: int
129 :return: Tuple of (distances, indices) arrays.
130 :rtype: tuple
131 """
132 nx = X.shape[0]
133 n_batches = int(np.ceil(nx / n_per_batch))
134 indices = np.zeros([nx, k], dtype=np.int64)
135 distances = np.zeros([nx, k])
136 for i in range(n_batches):
137 si, ei = i * n_per_batch, (i + 1) * n_per_batch
138 Xq = X[si:ei, :]
139 if len(Xq) > 0:
140 dist, ind = tree.query(Xq, k=k)
141 indices[si:ei, :] = ind
142 distances[si:ei, :] = dist
143 LOGGER.info(f"batch={i:>6}/{n_batches:>6}")
145 return distances, indices
148def predict_with_neighbours(y, ind, dist):
149 """
150 Predict values using inverse-distance weighted average of neighbors.
152 :param y: Training values, shape (n_samples,).
153 :type y: numpy.ndarray
154 :param ind: Neighbor indices, shape (n_points, n_neighbors).
155 :type ind: numpy.ndarray
156 :param dist: Neighbor distances, shape (n_points, n_neighbors).
157 :type dist: numpy.ndarray
158 :return: Predicted values, shape (n_points,).
159 :rtype: numpy.ndarray
160 """
161 yn = y[ind]
162 wn = 1.0 / dist
163 wn[wn == 0] = 1e10
164 yi = np.average(yn, weights=wn, axis=1)
166 return yi
169def predict_knn_balltree(Xi, X, y, n_neighbors, tree):
170 """
171 Predict using k-nearest neighbors with BallTree.
173 :param Xi: Query points, shape (n_points, n_dims).
174 :type Xi: numpy.ndarray
175 :param X: Training points (unused, neighbors from tree).
176 :type X: numpy.ndarray
177 :param y: Training values, shape (n_samples,).
178 :type y: numpy.ndarray
179 :param n_neighbors: Number of neighbors to use.
180 :type n_neighbors: int
181 :param tree: BallTree instance for neighbor queries.
182 :type tree: sklearn.neighbors.BallTree
183 :return: Predicted values, shape (n_points,).
184 :rtype: numpy.ndarray
185 """
186 nx, nd = Xi.shape
187 dist, ind = tree.query(Xi, k=n_neighbors)
188 yi = predict_with_neighbours(y, ind, dist)
189 return yi
192def predict_knn_linear(Xi, X, y, n_neighbors, tree):
193 """
194 Predict using local linear interpolation of nearest neighbors.
196 For each query point, finds k-nearest neighbors and fits a local
197 linear interpolator using those neighbors.
199 :param Xi: Query points, shape (n_points, n_dims).
200 :type Xi: numpy.ndarray
201 :param X: Training points, shape (n_samples, n_dims).
202 :type X: numpy.ndarray
203 :param y: Training values, shape (n_samples,).
204 :type y: numpy.ndarray
205 :param n_neighbors: Number of neighbors to use for local interpolation.
206 :type n_neighbors: int
207 :param tree: BallTree instance for neighbor queries.
208 :type tree: sklearn.neighbors.BallTree
209 :return: Predicted values, shape (n_points,). Out-of-hull points return -inf.
210 :rtype: numpy.ndarray
211 """
212 nx, nd = Xi.shape
213 dist, ind = tree.query(Xi, k=n_neighbors)
214 X_nearest = X[ind, :]
215 y_nearest = y[ind]
216 yi = np.full(nx, -np.inf)
217 n_nan = 0
219 for i in range(nx):
220 Xn = X_nearest[i, :]
221 yn = y_nearest[i, :]
222 interp = LinearNDInterpolator(Xn, yn)
223 yi[i] = interp(Xi[i, :])
224 if ~np.isfinite(yi[i]):
225 n_nan += 1
226 sys.stdout.write(f"\r{i}/{nx} {n_nan}")
228 return yi
231class MultiInterp:
232 """
233 Multi-method N-dimensional interpolator.
235 Provides a unified interface to multiple interpolation methods with
236 automatic coordinate scaling and bounds checking.
238 :param X: Training points, shape (n_samples, n_dims).
239 :type X: numpy.ndarray
240 :param y: Training values, shape (n_samples,).
241 :type y: numpy.ndarray
242 :param method: Interpolation method. Options:
243 - 'nn': Nearest neighbor (scipy NearestNDInterpolator)
244 - 'linear': Linear interpolation (scipy LinearNDInterpolator)
245 - 'rbf': Radial basis function (scipy Rbf)
246 - 'rbft': Radial basis function with bounds (Rbft class)
247 - 'knn_regression': K-neighbors regression (sklearn)
248 - 'rnn_regression': Radius neighbors regression (sklearn)
249 - 'knn_linear': Local linear interpolation
250 - 'knn_balltree': K-neighbors with BallTree
251 - 'random_forest': Random forest regressor (sklearn)
252 :type method: str
253 :param kw: Additional keyword arguments passed to the underlying interpolator.
255 Example
256 -------
257 >>> import numpy as np
258 >>> from cosmic_toolbox.MultiInterp import MultiInterp
259 >>> X = np.random.rand(100, 2)
260 >>> y = np.sin(X[:, 0]) + np.cos(X[:, 1])
261 >>> interp = MultiInterp(X, y, method='nn')
262 >>> Xi = np.random.rand(10, 2)
263 >>> yi = interp(Xi)
264 """
266 def __init__(self, X, y, method="nn", **kw):
267 self.X = X.copy()
268 self.y = y.copy()
269 self.bounds = [np.min(X, axis=0), np.max(X, axis=0)]
270 self.scaler = MinMaxScaler()
271 self.scaler.fit(self.X)
272 self.X = self.scaler.transform(self.X)
273 self.interp = None
274 self.method = method
275 self.kw = kw
276 self.init_interp(**kw)
278 def __call__(self, Xi, **kw):
279 """
280 Evaluate the interpolator at given points.
282 :param Xi: Points at which to interpolate, shape (n_points, n_dims).
283 :type Xi: numpy.ndarray
284 :param kw: Additional keyword arguments passed to underlying interpolator.
285 :return: Interpolated values. Out-of-bounds points return -inf.
286 :rtype: numpy.ndarray
287 """
288 yi = np.zeros(len(Xi))
289 select = self._in_bounds(Xi)
290 yi[~select] = -np.inf
292 Xi = self.scaler.transform(Xi.copy())
294 if self.method.lower() == "rbf":
295 yi[select] = self.interp(*list(Xi[select, :].T), **kw)
297 elif self.method.lower() == "rbft":
298 yi[select] = self.interp(Xi[select, :], **kw)
300 elif self.method.lower() in ["knn_regression", "rnn_regression"]:
301 with warnings.catch_warnings():
302 warnings.simplefilter("ignore")
303 yi[select] = self.interp.predict(Xi[select, :], **kw)
305 elif self.method.lower() in ["nn", "linear"]:
306 yi[select] = self.interp(Xi[select, :], **kw)
308 elif self.method.lower() == "knn_linear":
309 yi[select] = predict_knn_linear(
310 Xi[select, :],
311 self.X,
312 self.y,
313 self.n_neighbors,
314 self.interp,
315 **kw,
316 )
318 elif self.method.lower() == "random_forest":
319 yi[select] = self.interp.predict(Xi[select, :], **kw)
321 elif self.method.lower() == "knn_balltree":
322 yi[select] = predict_knn_balltree(
323 Xi[select, :],
324 self.X,
325 self.y,
326 self.n_neighbors,
327 self.interp,
328 **kw,
329 )
331 else:
332 raise Exception(f"unknown interp method {self.method}")
334 return yi
336 def _in_bounds(self, X):
337 """
338 Check if points are within training data bounds.
340 :param X: Points to check, shape (n_points, n_dims).
341 :type X: numpy.ndarray
342 :return: Boolean array indicating which points are in bounds.
343 :rtype: numpy.ndarray
344 """
345 return np.all(self.bounds[0] < X, axis=1) & np.all(self.bounds[1] > X, axis=1)
347 def init_interp(self, **kw):
348 """
349 Initialize the underlying interpolator based on method.
351 :param kw: Keyword arguments passed to the underlying interpolator.
352 :raises Exception: If method is unknown.
353 """
354 if self.method.lower() == "rbf":
355 self.interp = Rbf(*list(self.X.T), self.y, **kw)
357 elif self.method.lower() == "rbft":
358 self.interp = Rbft(points=self.X, values=self.y, **kw)
360 elif self.method.lower() == "knn" or self.method.lower() == "knn_regression":
361 self.interp = KNeighborsRegressor(**kw)
362 self.interp.fit(self.X, self.y)
364 elif self.method.lower() == "rnn_regression":
365 kw.setdefault("radius", 0.1)
366 self.interp = RadiusNeighborsRegressor(**kw)
367 self.interp.fit(self.X, self.y)
369 elif self.method.lower() == "nn":
370 self.interp = NearestNDInterpolator(self.X, self.y, **kw)
372 elif self.method.lower() == "linear":
373 self.interp = LinearNDInterpolator(self.X, self.y, **kw)
375 elif self.method.lower() == "knn_linear":
376 # self.slice_linear_upsampling(n_repeat=1)
377 n_neighbors = kw.pop("n_neighbors", self.X.shape[1] * 3)
378 self.interp = BallTree(self.X, **kw)
379 self.n_neighbors = n_neighbors
381 elif self.method.lower() == "random_forest":
382 self.interp = RandomForestRegressor(**kw)
383 self.interp.fit(self.X, self.y)
385 elif self.method.lower() == "knn_balltree":
386 n_neighbors = kw.pop("n_neighbors", self.X.shape[1] * 3)
387 self.interp = BallTree(self.X, **kw)
388 self.n_neighbors = n_neighbors
390 else:
391 raise Exception(f"unknown interp method {self.method}")
393 def precompute_grid_neighbors(self, Xn, n_neighbors=100, n_proc=1):
394 """
395 Precompute neighbors for a grid of query points.
397 Useful for repeated interpolation on the same grid with different
398 training values (e.g., in MCMC sampling).
400 :param Xn: Grid points, shape (n_grid, n_dims).
401 :type Xn: numpy.ndarray
402 :param n_neighbors: Number of neighbors to precompute. Defaults to 100.
403 :type n_neighbors: int
404 :param n_proc: Number of parallel processes. Defaults to 1.
405 :type n_proc: int
406 :raises AssertionError: If method is not 'knn_balltree'.
407 """
408 assert self.method == "knn_balltree"
410 Xn = self.scaler.transform(Xn.copy())
411 dist, ind = query_split(tree=self.interp, X=Xn, k=n_neighbors, n_proc=n_proc)
413 self.neighbors_ind = ind
414 self.neighbors_dist = dist.astype(np.float32)
415 self.neighbors_Xn = Xn.astype(np.float32)
417 def interpolate_grid_neighbours(self, y, n_neighbors=None):
418 """
419 Interpolate using precomputed grid neighbors.
421 Requires prior call to precompute_grid_neighbors.
423 :param y: Training values, shape (n_samples,).
424 :type y: numpy.ndarray
425 :param n_neighbors: Number of neighbors to use. Defaults to self.n_neighbors.
426 :type n_neighbors: int or None
427 :return: Interpolated values at grid points.
428 :rtype: numpy.ndarray
429 :raises Exception: If requested neighbors exceeds precomputed neighbors.
430 """
431 assert len(y) == len(self.X)
432 if n_neighbors is None:
433 n_neighbors = self.n_neighbors
435 nn = self.neighbors_ind.shape[1]
436 if n_neighbors > nn:
437 raise Exception(f"number of available neighbors {nn}")
438 yi = predict_with_neighbours(
439 y,
440 self.neighbors_ind[:, :n_neighbors],
441 self.neighbors_dist[:, :n_neighbors],
442 )
444 return yi
446 def slice_linear_upsampling(self, n_repeat=1, n_neighbors=1):
447 """
448 Upsample training data by adding midpoints between neighbors.
450 Experimental method to densify training data for better interpolation.
452 :param n_repeat: Number of upsampling iterations. Defaults to 1.
453 :type n_repeat: int
454 :param n_neighbors: Number of neighbors to use for midpoint generation.
455 Defaults to 1.
456 :type n_neighbors: int
457 """
458 for _ in range(n_repeat):
459 bt = BallTree(self.X)
460 dist, ids = bt.query(self.X, k=n_neighbors + 1)
461 list_Xn = [self.X]
462 list_yn = [self.y]
463 for j in range(1, n_neighbors + 1):
464 Xn = (self.X + self.X[ids[:, j], :]) / 2.0
465 yn = (self.y + self.y[ids[:, j]]) / 2.0
466 list_Xn += [Xn]
467 list_yn += [yn]
469 # Xp = self.X[ids,:]
470 # yp = self.y[ids]
471 # Xn = Xp.mean(axis=1)
472 # yn = yp.mean(axis=1)
473 # Xc = np.concatenate([self.X, Xn], axis=0)
474 # yc = np.concatenate([self.y, yn])
475 Xn = np.concatenate(list_Xn, axis=0)
476 yn = np.concatenate(list_yn, axis=0)
477 self.X = Xn
478 self.y = yn
479 print(Xn.shape, yn.shape)