Usage: Nflow
from cosmic_toolbox import arraytools as at, file_utils, colors
import numpy as np
from sklearn.model_selection import train_test_split
from trianglechain import TriangleChain
from edelweiss.nflow import (
Nflow, load_nflow
)
colors.set_cycle()
np.random.seed(1996)
# generate distribution
def get_data(n_samples=50000):
conditional_params = np.random.randn(n_samples, 2)
output_params = np.zeros((n_samples, 2))
output_params[:, 0] = (
np.sin(conditional_params[:, 1]) + np.cos(conditional_params[:, 1]) + np.random.randn(n_samples) * 0.2
)
output_params[:, 1] = conditional_params[:, 0] ** 3 + np.random.randn(n_samples) * 0.2
dataset = np.hstack((conditional_params, output_params))
names = [f"p_{i}" for i in range(4)]
input = names[:5]
output = names[5:]
data = at.arr2rec(dataset, names)
return data
X = get_data()
X_train, X_test = train_test_split(
X, test_size=0.3, random_state=42
)
names = X.dtype.names
Conditional normalizing flow
# The normalizing flow can be used in conditional mode with input and output parameters
nflow_cond = Nflow(input=names[:2], output=names[2:], scaler="quantile")
nflow_cond.train(X, verbose=False, progress_bar=True)
24-07-25 16:05:00 nflow.py INF ==============================
24-07-25 16:05:00 nflow.py INF Training normalizing flow with
24-07-25 16:05:00 nflow.py INF 50000 samples and
24-07-25 16:05:00 nflow.py INF conditional parameters: ['p_0' 'p_1']
24-07-25 16:05:00 nflow.py INF other parameters: ['p_2' 'p_3']
24-07-25 16:05:00 nflow.py INF ==============================
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:41<00:00, 2.42it/s]
24-07-25 16:05:42 nflow.py INF Training completed with best loss at epoch 97/100 with loss -3.10
# the sampled distribution is therefore identical for the input parameters
X_sampled = nflow_cond(X_test[nflow_cond.input])
# group the triangle by conditional and non conditional parameters
# non-conditional parameters show slight disagreement, could be improved with more epochs
ranges = {"p_3": [-10, 10]}
tri = TriangleChain(size=3, grouping_kwargs = dict(n_per_group=(2, 2)), ranges=ranges)
tri.contour_cl(X_test, label="truth");
tri.contour_cl(X_sampled, label="nflow", show_legend=True);
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 49.71it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 48.53it/s]
Normalizing flow without conditions
# without conditions, all parameters are considered output parameters
nflow_nocond = Nflow(scaler="quantile")
nflow_nocond.train(X, verbose=False, progress_bar=True, epochs=100)
24-07-25 16:05:44 nflow.py INF ==============================
24-07-25 16:05:44 nflow.py INF Training normalizing flow with
24-07-25 16:05:44 nflow.py INF 50000 samples and
24-07-25 16:05:44 nflow.py INF conditional parameters: None
24-07-25 16:05:44 nflow.py INF other parameters: None
24-07-25 16:05:44 nflow.py INF ==============================
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [01:45<00:00, 1.06s/it]
24-07-25 16:07:31 nflow.py INF Training completed with best loss at epoch 89/100 with loss -3.15
# the sampled distribution is different for all parameters
X_sampled = nflow_nocond(n_samples=len(X_test))
tri = TriangleChain(size=3, ranges=ranges)
tri.contour_cl(X_test, label="truth");
tri.contour_cl(X_sampled, label="nflow", show_legend=True);
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 31.92it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 28.73it/s]