Coverage for src/galsbi/ucat/galaxy_population_models/galaxy_sed.py: 98%
55 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-13 03:24 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-13 03:24 +0000
1# Copyright (C) 2018 ETH Zurich, Institute for Particle Physics and Astrophysics
3"""
4Created 2021
5author: Tomasz Kacprzak
6"""
8from collections import OrderedDict
10import numpy as np
12##########################################################
13#
14# Sampling templates
15#
16##########################################################
19def dirichlet_alpha_ev(z, alpha0, alpha1, z1):
20 f = np.expand_dims(z, axis=-1) / z1
21 alpha = alpha0 ** (1 - f) * alpha1**f
22 return alpha
25def draw_dirichlet_add_weight(alpha, weight):
26 samples = alpha # Modifies alpha in-place
28 for i in range(alpha.shape[1]):
29 samples[:, i] = np.random.gamma(alpha[:, i], scale=1)
31 samples /= np.sum(samples, axis=1, keepdims=True)
33 samples *= weight
35 return samples
38def sample_template_coeff_dirichlet__alpha_mode(
39 z, amode0, amode1, z1, weight, alpha0_std, alpha1_std
40):
41 """
42 Samples template coefficients from redshift-dependent Dirichlet distributions.
43 See also docs/jupyter_notebooks/coeff_distribution_dirichlet.ipynb.
44 Then the alpha0 and alpha1 will be scaled such that std(alpha)=alpha_std,
45 for a equal alphas. alpha_std is also interpolated between redshifts.
46 """
48 def dir_mode_to_alpha(amode, sigma):
49 assert np.allclose(np.sum(amode, axis=1), 1), "dirichlet amode must sum to 1"
50 K = amode.shape[1]
51 ms = 1 / sigma**2 * (1 / K - 1 / K**2) - K - 1
52 alpha = 1 + amode * ms
53 return alpha
55 def dir_mode(alpha):
56 return (alpha - 1) / (np.sum(alpha) - len(alpha))
58 def get_max_sigma(k):
59 return np.sqrt((1 / k - 1 / k**2) / (k + 1))
61 # check alpha mode and interpolate in redshift
62 for amode_ in (amode0, amode1):
63 assert np.allclose(
64 np.sum(amode_), 1
65 ), "dirichlet amode coefficients must sum to 1"
66 amode = dirichlet_alpha_ev(z, amode0, amode1, z1)
67 amode /= np.sum(amode, axis=1, keepdims=True)
69 # get maximum allowed sigma for given number of dimensions
70 K = amode.shape[1]
71 max_sig = get_max_sigma(K)
73 # check standard deviation of alpha and interpolate
74 for alpha_std_ in (alpha0_std, alpha1_std):
75 assert alpha_std_ > 0, "dirichlet alpha_std must be >0"
76 assert (
77 alpha_std_ < max_sig
78 ), f"dirichlet alpha0_std must be < max_sig for K={K} {max_sig:2.3e}"
79 alpha_std = dirichlet_alpha_ev(z, alpha0_std, alpha1_std, z1)
80 alpha_std = np.clip(alpha_std, a_min=1e-4, a_max=max_sig - 1e-4)
82 # convert to Dirichlet alpha
83 alpha = dir_mode_to_alpha(amode, alpha_std)
85 # finally, draw the samples
86 samples = draw_dirichlet_add_weight(alpha, weight)
88 return samples
91def sample_template_coeff_dirichlet(z, alpha0, alpha1, z1, weight):
92 """
93 Samples template coefficients from redshift-dependent Dirichlet distributions.
94 See also docs/jupyter_notebooks/coeff_distribution_dirichlet.ipynb.
95 """
97 alpha = dirichlet_alpha_ev(z, alpha0, alpha1, z1)
99 samples = draw_dirichlet_add_weight(alpha, weight)
101 return samples
104def sample_template_coeff_lumfuncs(par, redshift_z, n_templates):
105 template_coeffs = OrderedDict()
106 for g in redshift_z:
107 z = redshift_z[g]
108 alpha0 = np.array(
109 [getattr(par, f"template_coeff_alpha0_{g}_{i}") for i in range(n_templates)]
110 )
111 alpha1 = np.array(
112 [getattr(par, f"template_coeff_alpha1_{g}_{i}") for i in range(n_templates)]
113 )
114 z1 = getattr(par, f"template_coeff_z1_{g}")
115 weight = getattr(par, f"template_coeff_weight_{g}")
117 if par.template_coeff_sampler == "dirichlet":
118 template_coeffs[g] = sample_template_coeff_dirichlet(
119 z, alpha0, alpha1, z1, weight
120 )
122 elif par.template_coeff_sampler == "dirichlet_alpha_mode":
123 template_coeffs[g] = sample_template_coeff_dirichlet__alpha_mode(
124 z,
125 alpha0,
126 alpha1,
127 z1,
128 weight,
129 alpha0_std=getattr(par, f"template_coeff_alpha0_{g}_std"),
130 alpha1_std=getattr(par, f"template_coeff_alpha1_{g}_std"),
131 )
133 return template_coeffs