Coverage for src/ufig/plugins/write_catalog_for_emu.py: 95%

233 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-12 19:08 +0000

1# Copyright (C) 2019 ETH Zurich, Institute for Particle Physics and Astrophysics 

2 

3""" 

4Created on Aug 2021 

5author: Silvan Fischbacher, Tomasz Kacprzak 

6""" 

7 

8import itertools 

9 

10import numpy as np 

11from cosmic_toolbox import arraytools as at 

12from cosmic_toolbox import file_utils, logger 

13from cosmic_toolbox.utils import is_between 

14from ivy.plugin.base_plugin import BasePlugin 

15 

16from ufig.plugins.write_catalog import catalog_to_rec 

17 

18LOGGER = logger.get_logger(__file__) 

19 

20 

21def ensure_valid_cats(cat_gals, cat_stars): 

22 if cat_gals is None and cat_stars is None: 22 ↛ 23line 22 didn't jump to line 23 because the condition on line 22 was never true

23 raise ValueError("No catalogs provided") 

24 if cat_gals is None: 24 ↛ 25line 24 didn't jump to line 25 because the condition on line 24 was never true

25 cat_gals = np.empty(0, dtype=cat_stars.dtype) 

26 if cat_stars is None: 

27 cat_stars = np.empty(0, dtype=cat_gals.dtype) 

28 return cat_gals, cat_stars 

29 

30 

31def get_elliptical_indices( 

32 x, 

33 y, 

34 r50, 

35 e1, 

36 e2, 

37 imshape=(4200, 4200), 

38 n_half_light_radius=5, 

39 pre_selected_indices=None, 

40): 

41 """ 

42 Get the indices of the pixels within an elliptical region 

43 and their distances from the center. 

44 The distance is normalized to true pixel distance in the elliptical coordinate 

45 system. If the radius is too small to include any pixel, the center pixel is 

46 returned. 

47 

48 :param x: x coordinate of the center of the ellipse 

49 :param y: y coordinate of the center of the ellipse 

50 :param r50: (half light) radius of the ellipse 

51 :param e1: ellipticity component 1 

52 :param e2: ellipticity component 2 

53 :param imshape: shape of the image 

54 :param n_half_light_radius: number of half light radii to consider 

55 :param pre_selected_indices: pre-selected indices of the image where the distance 

56 should be calculated, tuple of (x, y) indices 

57 :return: indices of the pixels within the elliptical region 

58 and their distances from the center 

59 """ 

60 

61 # choose indices withing n times the half light radius 

62 r = n_half_light_radius * r50 

63 

64 # Grid dimensions 

65 grid_width = imshape[1] 

66 grid_height = imshape[0] 

67 

68 # Get the absolute ellipticity 

69 e_abs = np.sqrt(e1**2 + e2**2) 

70 

71 # Calculate the semi-major and semi-minor axes of the ellipse 

72 a = np.sqrt(1 / (1 - e_abs)) * r 

73 b = np.sqrt(1 / (1 + e_abs)) * r 

74 

75 # Calculate the rotation angle of the ellipse 

76 theta = 0.5 * np.arctan2(e2, e1) 

77 

78 cos_theta = np.cos(theta) 

79 sin_theta = np.sin(theta) 

80 

81 if pre_selected_indices is None: 

82 # Calculate the bounding box of the ellipse 

83 rx = a * np.abs(cos_theta) + b * np.abs(sin_theta) 

84 ry = a * np.abs(sin_theta) + b * np.abs(cos_theta) 

85 

86 left = max(0, int(x - rx)) 

87 right = min(grid_width - 1, int(x + rx)) 

88 top = max(0, int(y - ry)) 

89 bottom = min(grid_height - 1, int(y + ry)) 

90 

91 # Generate a meshgrid within the bounding box 

92 # xx, yy = np.meshgrid(np.arange(left, right + 1), np.arange(top, bottom + 1)) 

93 yy, xx = np.indices((bottom - top + 1, right - left + 1)) 

94 yy += top 

95 xx += left 

96 

97 # Rotate the coordinates back to the original frame 

98 x_rot = (xx - x) * cos_theta + (yy - y) * sin_theta 

99 y_rot = -(xx - x) * sin_theta + (yy - y) * cos_theta 

100 

101 # Check if each coordinate is within the elliptical region 

102 distances = (x_rot**2 / a**2) + (y_rot**2 / b**2) 

103 # indices = np.stack((xx, yy), axis=-1) 

104 

105 good_indices = distances <= 1 

106 if not np.any(good_indices): 106 ↛ 107line 106 didn't jump to line 107 because the condition on line 106 was never true

107 good_indices = distances <= np.min(distances) 

108 return xx[good_indices], yy[good_indices], np.sqrt(distances[good_indices]) * r 

109 

110 # Subtract x and y from the pre-selected indices 

111 diff_x = pre_selected_indices[0] - x 

112 diff_y = pre_selected_indices[1] - y 

113 

114 # Calculate the rotated coordinates 

115 x_rot = diff_x * cos_theta + diff_y * sin_theta 

116 y_rot = -diff_x * sin_theta + diff_y * cos_theta 

117 

118 # Check if each coordinate is within the elliptical region 

119 distances = (x_rot**2 / a**2) + (y_rot**2 / b**2) 

120 within_region = distances <= 1 

121 

122 return within_region, np.sqrt(distances[within_region]) * r 

123 

124 

125def sersic_brightness(magnitude, r50, n, r): 

126 """ 

127 Calculate the surface brightness of a Sersic profile. 

128 

129 :param magnitude: magnitude of the object 

130 :param r50: half light radius 

131 :param n: Sersic index 

132 :param r: radius 

133 :return: surface brightness at radius r 

134 """ 

135 b = sersic_b(n) 

136 I_r50 = np.exp(-0.4 * magnitude) 

137 ratio = (r / r50) ** (1 / n) 

138 surface_brightness = I_r50 * np.exp(-b * (ratio - 1)) 

139 return surface_brightness 

140 

141 

142def sersic_b(n): 

143 """ 

144 Calculate the b parameter of a Sersic profile. 

145 

146 :param n: Sersic index 

147 :return: b parameter 

148 """ 

149 return 2 * n - 0.324 

150 

151 

152def estimate_flux_of_points( 

153 cat, imshape=(4200, 4200), max_mag=26, n_half_light_radius=5 

154): 

155 """ 

156 Estimates the flux of the points of the image where the galaxies are located. 

157 

158 :param cat: catalog 

159 :return: estimated flux at the points of the galaxies 

160 """ 

161 

162 # Magnitude cut 

163 objects = cat[cat["mag"] < max_mag] 

164 

165 # Setup the indices 

166 index_x = cat["x"].astype(int) 

167 index_y = cat["y"].astype(int) 

168 

169 # Setup the flux 

170 flux = np.zeros(len(cat)) 

171 

172 for obj in objects: 

173 r = obj["r50"] 

174 x = int(obj["x"]) 

175 y = int(obj["y"]) 

176 

177 # select = ((cat["x"]-x)**2 + (cat["y"]-y)**2) < (n_half_light_radius*r)**2 

178 

179 good_ind, distance = get_elliptical_indices( 

180 x, 

181 y, 

182 r, 

183 obj["e1"], 

184 obj["e2"], 

185 imshape=imshape, 

186 n_half_light_radius=n_half_light_radius, 

187 pre_selected_indices=(index_x, index_y), 

188 ) 

189 

190 # subtract by distance of main pixel to avoid negative values 

191 # when subtracting the object's flux 

192 flux[good_ind] += sersic_brightness(obj["mag"], r, obj["sersic_n"], distance) 

193 return flux 

194 

195 

196def estimate_flux_full_image( 

197 cat, imshape=(4200, 4200), max_mag=26, n_half_light_radius=5 

198): 

199 """ 

200 Estimates the flux of the image from catalog (with magnitude cut) 

201 

202 :param cat: catalog 

203 :param imshape: shape of the image 

204 :param max_mag: maximum magnitude that is considered for blending 

205 :param n_half_light_radius: number of half light radii to consider for each galaxy 

206 :return: estimated flux of the image 

207 """ 

208 

209 # Magnitude cut 

210 objects = cat[cat["mag"] < max_mag] 

211 

212 # Setup the image grid 

213 image = np.zeros(imshape) 

214 

215 # Accumulate the object indices and mask values 

216 for obj in objects: 

217 r = obj["r50"] 

218 x = int(obj["x"]) 

219 y = int(obj["y"]) 

220 x_ind, y_ind, distance = get_elliptical_indices( 

221 x, 

222 y, 

223 r, 

224 obj["e1"], 

225 obj["e2"], 

226 imshape=imshape, 

227 n_half_light_radius=n_half_light_radius, 

228 ) 

229 image[y_ind, x_ind] += sersic_brightness( 

230 obj["mag"], r, obj["sersic_n"], distance 

231 ) 

232 return image 

233 

234 

235def add_blending_points(cat, par): 

236 """ 

237 Add blending information to catalog estimating only the flux at the position 

238 of the objects. 

239 

240 :param cat: catalog 

241 :param par: context parameters 

242 :return: catalog with blending information 

243 """ 

244 imshape = (par.size_y, par.size_x) 

245 max_mag = par.mag_for_scaling 

246 n_half_light_radius = par.n_r50_for_flux_estimation 

247 

248 new_names = ["estimated_flux"] 

249 cat = at.add_cols(cat, new_names, dtype=par.catalog_precision) 

250 flux = estimate_flux_of_points( 

251 cat, 

252 imshape=imshape, 

253 max_mag=max_mag, 

254 n_half_light_radius=n_half_light_radius, 

255 ) 

256 mag = cat["mag"] 

257 sersic_n = cat["sersic_n"] 

258 r50 = cat["r50"] 

259 stars = cat["r50"] == 0 

260 r50[stars] = cat["psf_fwhm"][stars] / 2 

261 sersic_n[stars] = 1 

262 

263 select = mag <= max_mag 

264 # subtract flux of object itself 

265 object_flux = sersic_brightness(mag[select], r50[select], sersic_n[select], 0) 

266 flux[select] = flux[select] - object_flux 

267 cat["estimated_flux"] = flux 

268 return cat 

269 

270 

271def add_blending_full_image(cat, par): 

272 """ 

273 Add blending information to catalog estimating the flux at all positions. 

274 

275 :param cat: catalog 

276 :param par: context parameters 

277 :return: catalog with blending information 

278 """ 

279 imshape = (par.size_y, par.size_x) 

280 max_mag = par.mag_for_scaling 

281 n_half_light_radius = par.n_r50_for_flux_estimation 

282 

283 new_names = ["estimated_flux"] 

284 

285 cat = at.add_cols(cat, new_names, dtype=par.catalog_precision) 

286 estimated_flux = estimate_flux_full_image( 

287 cat, 

288 imshape=imshape, 

289 max_mag=max_mag, 

290 n_half_light_radius=n_half_light_radius, 

291 ) 

292 x = cat["x"] 

293 y = cat["y"] 

294 mag = cat["mag"] 

295 sersic_n = cat["sersic_n"] 

296 r50 = cat["r50"] 

297 stars = cat["r50"] == 0 

298 r50[stars] = cat["psf_fwhm"][stars] / 2 

299 sersic_n[stars] = 1 

300 

301 flux = estimated_flux[y.astype(int), x.astype(int)] 

302 select = mag <= max_mag 

303 object_flux = sersic_brightness(mag[select], r50[select], sersic_n[select], 0) 

304 flux[select] = flux[select] - object_flux 

305 cat["estimated_flux"] = flux 

306 return cat 

307 

308 

309def add_blending_integrated(cat, par): 

310 """ 

311 Computes the average galaxy density weighted by magnitude and sizes to estimate the 

312 blending risk. The value is the same for all objects in the image. 

313 

314 :param cat: catalog 

315 :param par: context parameters 

316 :return: catalog with additional column 

317 "density_mag_weighted" and "density_size_weighted" 

318 """ 

319 mag_for_scaling = par.mag_for_scaling 

320 r50_for_scaling = par.n_r50_for_flux_estimation 

321 

322 new_names = ["density_mag_weighted", "density_size_weighted"] 

323 cat = at.add_cols(cat, new_names, dtype=par.catalog_precision) 

324 

325 mag = cat["mag"] 

326 r50 = cat["r50"] 

327 

328 # Calculate the density 

329 cat["density_mag_weighted"] = np.sum( 

330 np.exp(-0.4 * mag) / np.exp(-0.4 * mag_for_scaling) 

331 ) 

332 cat["density_size_weighted"] = np.sum(r50 / r50_for_scaling) 

333 

334 return cat 

335 

336 

337def add_blending_binned_integrated(cat, par): 

338 """ 

339 Computes the average galaxy density weighted by magnitude and sizes to estimate the 

340 blending risk. The image is divided into bins and the value is computed for each 

341 bin. 

342 

343 :param cat: catalog 

344 :param par: context parameters 

345 :return: catalog with additional column 

346 "density_mag_weighted" and "density_size_weighted" 

347 """ 

348 mag_for_scaling = par.mag_for_scaling 

349 r50_for_scaling = par.n_r50_for_flux_estimation 

350 n_bins = par.n_bins_for_flux_estimation 

351 

352 new_names = ["density_mag_weighted", "density_size_weighted"] 

353 cat = at.add_cols(cat, new_names, dtype=par.catalog_precision) 

354 

355 # Calculate the flux and size of the objects and scale them 

356 mag = cat["mag"] 

357 r50 = cat["r50"] 

358 flux = np.exp(-0.4 * mag) / np.exp(-0.4 * mag_for_scaling) 

359 r50 = r50 / r50_for_scaling 

360 

361 pixels_per_bin_x = int(np.ceil(par.size_x / n_bins)) 

362 pixels_per_bin_y = int(np.ceil(par.size_y / n_bins)) 

363 

364 for x, y in itertools.product(range(n_bins), range(n_bins)): 

365 # Select the objects in the bin 

366 select = is_between(cat["x"], x * pixels_per_bin_x, (x + 1) * pixels_per_bin_x) 

367 select &= is_between(cat["y"], y * pixels_per_bin_y, (y + 1) * pixels_per_bin_y) 

368 # Calculate the density in the bin 

369 cat["density_mag_weighted"][select] = np.sum(flux[select]) 

370 cat["density_size_weighted"][select] = np.sum(r50[select]) 

371 return cat 

372 

373 

374def add_blending_ngal(cat, par): 

375 """ 

376 Computes the number of galaxies for different magnitude cuts. This can later be used 

377 to estimate the blending risk. 

378 

379 :param cat: catalog 

380 :param par: context parameters 

381 :return: catalog with additional column "ngal_{}".format(mag_cuts) 

382 """ 

383 mag_cuts = par.mag_for_scaling 

384 if not isinstance(mag_cuts, list): 384 ↛ 385line 384 didn't jump to line 385 because the condition on line 384 was never true

385 mag_cuts = [mag_cuts] 

386 

387 new_names = [] 

388 for cut in mag_cuts: 

389 new_names.append(f"ngal_{cut}") 

390 

391 cat = at.add_cols(cat, new_names, dtype=par.catalog_precision) 

392 

393 for cut in mag_cuts: 

394 cat[f"ngal_{cut}"] = np.sum(cat["mag"] < cut) 

395 

396 return cat 

397 

398 

399def add_no_blending(cat, par): 

400 """ 

401 Add no blending information to the catalog. 

402 

403 :param cat: catalog 

404 :param par: context parameters 

405 :return: catalog with no blending information 

406 """ 

407 return cat 

408 

409 

410def enrich_star_catalog(cat, par): 

411 """ 

412 Add additional columns to the star catalog such that it can be used the same way as 

413 the galaxy catalog 

414 

415 :param cat: catalog of stars 

416 :param par: ctx parameters 

417 :param catalog_precision: precision of the catalog 

418 :return: catalog of stars with additional columns and a list of the new column names 

419 """ 

420 new_names = ["r50", "sersic_n", "e1", "e2", "z", "galaxy_type", "excess_b_v"] 

421 if cat is None: 

422 # to just get the names of the columns later 

423 return cat, new_names 

424 cat = at.add_cols(cat, new_names, dtype=par.catalog_precision) 

425 

426 # set values that make sense for stars 

427 cat["r50"] = 0 

428 cat["sersic_n"] = 0 

429 cat["e1"] = 0 

430 cat["e2"] = 0 

431 cat["z"] = 0 

432 cat["galaxy_type"] = -1 

433 cat["excess_b_v"] = 0 

434 

435 cat, _ = enrich_catalog(cat, par) 

436 return cat, new_names 

437 

438 

439def enrich_catalog(cat, par): 

440 """ 

441 Enrich the catalog with computed columns: absolute ellipticity, noise levels 

442 

443 :param cat: catalog 

444 :param par: ctx parameters 

445 :param catalog_precision: precision of the catalog 

446 :return: catalog with additional columns 

447 """ 

448 cat = at.add_cols( 

449 cat, 

450 ["e_abs"], 

451 data=np.sqrt(cat["e1"] ** 2 + cat["e2"] ** 2), 

452 dtype=par.catalog_precision, 

453 ) 

454 # add noise levels if used in the emulator 

455 if not par.emu_mini: 455 ↛ 474line 455 didn't jump to line 474 because the condition on line 455 was always true

456 cat = at.add_cols( 

457 cat, 

458 ["bkg_noise_amp"], 

459 data=(np.ones(len(cat)) * par.bkg_noise_amp), 

460 dtype=par.catalog_precision, 

461 ) 

462 try: 

463 y = cat["y"].astype(int) 

464 x = cat["x"].astype(int) 

465 cat = at.add_cols( 

466 cat, 

467 ["bkg_noise_std"], 

468 data=par.bkg_noise_std[y, x], 

469 dtype=par.catalog_precision, 

470 ) 

471 return cat, ["e_abs", "bkg_noise_amp", "bkg_noise_std"] 

472 except ValueError: 

473 pass 

474 return cat, ["e_abs"] 

475 

476 

477FLUX_ESTIMATOR = { 

478 "full_image": add_blending_full_image, 

479 "points": add_blending_points, 

480 "integrated": add_blending_integrated, 

481 "binned_integrated": add_blending_binned_integrated, 

482 "ngal": add_blending_ngal, 

483 "none": add_no_blending, 

484} 

485 

486 

487class Plugin(BasePlugin): 

488 def __call__(self): 

489 par = self.ctx.parameters 

490 

491 f = self.ctx.current_filter 

492 

493 LOGGER.info(f"Writing catalog for filter {f}") 

494 

495 # write the classic catalogs 

496 cat_gals = catalog_to_rec(self.ctx.galaxies) if "galaxies" in self.ctx else None 

497 cat_stars = catalog_to_rec(self.ctx.stars) if "stars" in self.ctx else None 

498 # Create the catalogs for the emulators 

499 filepath_det = par.det_clf_catalog_name_dict[f] 

500 conf = par.emu_conf 

501 cat_stars, _ = enrich_star_catalog(cat_stars, par) 

502 cat_gals, _ = enrich_catalog(cat_gals, par) 

503 

504 cat_gals, cat_stars = ensure_valid_cats(cat_gals, cat_stars) 

505 

506 cat = {} 

507 for p in conf["input_band_dep"]: 

508 try: 

509 cat[p] = np.concatenate([cat_gals[p], cat_stars[p]], axis=0).astype( 

510 par.catalog_precision 

511 ) 

512 except Exception: 

513 acceptable_params = [ 

514 "ngal", 

515 "estimated_flux", 

516 "density_mag_weighted", 

517 "density_size_weighted", 

518 ] 

519 # if p starts with any of the acceptable params, it is fine 

520 if any([p.startswith(ap) for ap in acceptable_params]): 520 ↛ 526line 520 didn't jump to line 526 because the condition on line 520 was always true

521 LOGGER.debug( 

522 f"Could not concatenate param {p}" 

523 " (if this is a flux estimator, this is expected)" 

524 ) 

525 continue 

526 LOGGER.warning(f"Could not concatenate param {p}") 

527 for p in conf["input_band_indep"]: 

528 cat[p] = np.concatenate([cat_gals[p], cat_stars[p]], axis=0).astype( 

529 par.catalog_precision 

530 ) 

531 if "x" in cat_gals.dtype.names: 

532 params = ["x", "y", "id"] 

533 else: 

534 params = ["ra", "dec", "id"] 

535 for p in params: 

536 cat[p] = np.concatenate([cat_gals[p], cat_stars[p]], axis=0) 

537 

538 for p in ["psf_fwhm"]: 

539 # add parameters that are not part of the emulator but are needed 

540 # for the sample selection and matching, mainly used for training data 

541 if p not in cat and not par.emu_mini: 541 ↛ 538line 541 didn't jump to line 538 because the condition on line 541 was always true

542 cat[p] = np.concatenate([cat_gals[p], cat_stars[p]], axis=0) 

543 

544 if (par.flux_estimation_type == "full_image") or ( 

545 par.flux_estimation_type == "points" 

546 ): 

547 # add ellipticities and sersic param 

548 for p in ["e1", "e2", "sersic_n"]: 

549 if p not in cat: 549 ↛ 548line 549 didn't jump to line 548 because the condition on line 549 was always true

550 cat[p] = np.concatenate([cat_gals[p], cat_stars[p]], axis=0) 

551 

552 cat = at.dict2rec(cat) 

553 

554 # add blending risk parameters 

555 add_blending = FLUX_ESTIMATOR[par.flux_estimation_type] 

556 cat = add_blending(cat, par) 

557 file_utils.write_to_hdf(filepath_det, cat) 

558 

559 # save galaxy and star catalog 

560 gals = cat["galaxy_type"] != -1 

561 stars = cat["galaxy_type"] == -1 

562 file_utils.write_to_hdf(par.galaxy_catalog_name_dict[f], cat[gals]) 

563 file_utils.write_to_hdf(par.star_catalog_name_dict[f], cat[stars]) 

564 

565 def __str__(self): 

566 return "write emu-opt ucat catalog"