# Copyright (C) 2023 ETH Zurich
# Institute for Particle Physics and Astrophysics
# Author: Silvan Fischbacher
import itertools
from copy import deepcopy
import matplotlib.pyplot as plt
import numpy as np
from cosmic_toolbox import logger
from trianglechain.BaseChain import BaseChain
from trianglechain.make_subplots import contour_cl, density_image, scatter_density
from trianglechain.params import ensure_rec
from trianglechain.utils_plots import (
add_colorbar,
delete_all_ticks,
get_labels,
get_lines_and_labels,
get_n_points_for_scatter,
get_old_lims,
rasterize_density_images,
set_limits,
update_current_ranges,
update_current_ticks,
)
LOGGER = logger.get_logger(__file__)
[docs]class RectangleChain(BaseChain):
"""
Class to produce reactangular plots
Parameters defined for this class are used for all plots that are added to the
figure. If you want to change the parameters for a specific plot, you can do so
by passing the parameters to the plotting function.
:param fig: matplotlib figure, default: None
:param size: size of the panels, if one number is given, the panels are rectangular
with the y axis being 70% of the x axis, if two numbers are given, the first
number is the width of the panels and the second number is the height of the
panels, default: 4
:param params_x: parameters to plot on the x axis
:param params_y: parameters to plot on the y axis
:param names: names of parameters (when data is np array), default: None
:param ranges: dictionary with the ranges for the parameters
:param labels_x: list of labels (e.g. latex style) for the parameters for the
plot on the x axis
:param labels_y: list of labels (e.g. latex style) for the parameters for the
plot on the y axis
:param fill: if the contours should be filled, default: False
:param grid: if the grid should be plotted, default: False
:param n_ticks: number of ticks on the axes, default: 3
:param ticks: dict specifying the ticks for a parameter
:param tick_length: length of the ticks, default: 3
:param n_bins: number of bins for the 1D histograms, default: 100
:param density_estimation_method: method for density estimation. Available options:
- smoothing (default):
First create a histogram of samples and then smooth it with a Gaussian kernel
corresponding to the variance of the 20% of the smallest eigenvalue of the 2D distribution
(smoothing scale can be adapted using the smoothing parameter in de_kwargs).
- gaussian_mixture:
Use Gaussian mixture to fit the 2D samples.
- median_filter:
Use median filter on the 2D histogram.
- kde:
Use TreeKDE, may be slow.
- hist:
Simple 2D histogram.
:param cmap: colormap, default: "viridis"
:param colorbar: if a colorbar should be plotted, default: False
:param colorbar_label: label for the colorbar, default: None
:param colorbar_ax: axis for the colorbar, default: [0.735, 0.5, 0.03, 0.25]
:param cmap_vmin: minimum value for the colormap, default: 0
:param cmap_vmax: maximum value for the colormap, default: None
:param show_legend: if a legend should be shown, default: False
:param alpha: alpha for the 2D histograms, default: 1
:param alpha_for_low_density: if low density areas should fade to transparent
:param alpha_threshold: threshold from where the fading to transparent should
start, default: 0
:param n_points_scatter: number of points to use for scatter plots,
default: -1 (all)
:param label_fontsize: fontsize of the labels, default: 24
:param de_kwargs: density estimation kwargs, dictionary with keys:
- n_points:
number of bins for 2d histograms used to create contours etc., default: n_bins
- levels:
density levels for contours, the contours will enclose this
level of probability, default: [0.68, 0.95]
- n_levels_check:
number of levels to check when looking for density levels
More levels is more accurate, but slower, default: 2000
- smoothing_parameter1D:
smoothing scale for the 1D histograms, default: 0.1
- smoothing_parameter2D:
smoothing scale for the 2D histograms, default: 0.2
:param grid_kwargs: kwargs for ax.grid, passed to ax.grid
:param labels_kwargs: kwargs for the x and y labels, passed to ax.set_xlabel and
:param line_kwargs: kwargs for the lines, passed to plt.contour and plt.contourf
:param ticks_kwargs: kwargs for the ticks, passed to ax.set_xticklabels
:param scatter_kwargs: kwargs for the scatter plot, passed to plt.scatter
:param subplots_kwargs: kwargs for the subplots, passed to plt.subplots
:param axlines_kwargs: kwargs for the axlines, passed to ax.axhline and ax.axvline
Basic usage::
rec = RectangleChain(params_x=["par1", "par2"], params_y=["par2", "par3"])
# plot contours at given confidence levels
rec.contour_cl(samples)
# plot PDF density image
rec.density_image(samples)
# simple scatter plot
rec.scatter(samples)
# scatter plot, with probability for each sample provided
rec.scatter_prob(samples, prob=prob)
# scatter plot, color corresponds to probability
rec.scatter_density(samples)
"""
def __init__(self, params_x, params_y, fig=None, size=4, **kwargs):
if "colorbar_ax" not in kwargs:
kwargs["colorbar_ax"] = [0.93, 0.1, 0.03, 0.8]
super().__init__(fig=fig, size=size, **kwargs)
self.params_x = params_x
self.params_y = params_y
self.add_plotting_functions(self.add_plot)
[docs] def add_plot(
self,
data,
plottype,
prob=None,
color=None,
label=None,
**kwargs,
):
"""
Plotting function for the line chain class. Parameters that are passed to
this function are overwriting the default parameters of the class.
:param data: data to plot, can be recarray, array, pandas dataframe or dict
:param prob: probability for each sample, default: None
:param color: color for the plot, default: None
:param label: label for the plot, default: None
:param names: list of names of the parameters, only used when input is
unstructured array
:param fill: if the contours should be filled, default: False
:param grid: if the grid should be plotted, default: False
:param n_bins: number of bins for the 1D histograms, default: 100
:param density_estimation_method: method for density estimation. Available
options:
- smoothing (default):
First create a histogram of samples and then smooth it with a Gaussian kernel
corresponding to the variance of the 20% of the smallest eigenvalue of the 2D distribution
(smoothing scale can be adapted using the smoothing parameter in de_kwargs).
- gaussian_mixture:
Use Gaussian mixture to fit the 2D samples.
- median_filter:
Use median filter on the 2D histogram.
- kde:
Use TreeKDE, may be slow.
- hist:
Simple 2D histogram.
:param cmap: colormap, default: "viridis"
:param colorbar: if a colorbar should be plotted, default: False
:param colorbar_label: label for the colorbar, default: None
:param colorbar_ax: axis for the colorbar, default: [0.735, 0.5, 0.03, 0.25]
:param cmap_vmin: minimum value for the colormap, default: 0
:param cmap_vmax: maximum value for the colormap, default: None
:param show_legend: if a legend should be shown, default: False
:param alpha: alpha for the 2D histograms, default: 1
:param normalize_prob2D: if the 2D histograms should be normalized for
scatter_prob, default: True
:param alpha_for_low_density: if low density areas should fade to transparent
:param alpha_threshold: threshold from where the fading to transparent should
start, default: 0
:param n_points_scatter: number of points to use for scatter plots,
default: -1 (all)
:param de_kwargs: density estimation kwargs, dictionary with keys:
- n_points:
number of bins for 2d histograms used to create contours etc., default: n_bins
- levels:
density levels for contours, the contours will enclose this
level of probability, default: [0.68, 0.95]
- n_levels_check:
number of levels to check when looking for density levels
More levels is more accurate, but slower, default: 2000
- smoothing_parameter1D:
smoothing scale for the 1D histograms, default: 0.1
- smoothing_parameter2D:
smoothing scale for the 2D histograms, default: 0.2
:param grid_kwargs: kwargs for ax.grid, passed to ax.grid
:param ticks_kwargs: kwargs for the ticks, passed to ax.set_xticklabels
:param labels_kwargs: kwargs for the x and y labels, passed to ax.set_xlabel and
:param line_kwargs: kwargs for the lines, passed to plt.contour and plt.contourf
:param scatter_kwargs: kwargs for the scatter plot, passed to plt.scatter
:param subplots_kwargs: kwargs for the subplots, passed to plt.subplots
"""
# check if all kwargs are valid trianglechain arguments
self._check_unexpected_kwargs(kwargs)
kwargs_copy = deepcopy(self.kwargs)
kwargs_copy.update(kwargs)
if (plottype == "scatter_prob") & (prob is None):
raise ValueError("prob needs to be defined for scatter_prob")
color = self.setup_color(color)
self.fig = plot_rec_marginals(
params_x=self.params_x,
params_y=self.params_y,
fig=self.fig,
size=self.size,
func=plottype,
data=data,
prob=prob,
color=color,
label=label,
**kwargs_copy,
)
return self.fig
[docs]def plot_rec_marginals(
params_x,
params_y,
data,
prob=None,
func="contour_cl",
color="#0063B9",
label=None,
fig=None,
size=4,
ranges={},
labels_x=None,
labels_y=None,
names=None,
fill=False,
grid=False,
n_ticks=3,
ticks={},
tick_length=3,
n_bins=100,
density_estimation_method="smoothing",
cmap=plt.cm.viridis,
cmap_vmin=0,
cmap_vmax=None,
colorbar=False,
colorbar_label=None,
colorbar_ax=[0.735, 0.5, 0.03, 0.25],
show_legend=False,
normalize_prob2D=True,
alpha=None,
alpha2D=1,
alpha_for_low_density=False,
alpha_threshold=0,
n_points_scatter=-1,
label_fontsize=24,
tick_fontsize=14,
legend_fontsize=24,
de_kwargs={},
grid_kwargs={},
labels_kwargs={},
line_kwargs={},
scatter_kwargs={},
subplots_kwargs={},
axlines_kwargs={},
ticks_kwargs={},
**kwargs,
):
"""
Plot line plots of chains.
:param data: rec array, array, dict or pd dataframe
data to plot
:param prob: probability for each sample
:param names: names of parameters (when data is np array), default: None
:param func: function to use for plotting
options: contour_cl, density_image, scatter_density, scatter_prob, scatter, axlines
default: contour_cl
:param color: color of the plot, default: "#0063B9"
:param cmap: colormap for 2D plots, default: plt.cm.viridis
:param cmap_vmin: minimum value for colormap, default: 0
:param cmap_vmax: maximum value for colormap, default: None
:param colorbar: show colorbar, default: False
:param colorbar_label: label for colorbar, default: None
:param colorbar_ax: position of colorbar, default: [0.735, 0.5, 0.03, 0.25]
:param ranges: dictionary with ranges for each parameter, default: {}
:param ticks: dictionary with ticks for each parameter, default: {}
:param n_ticks: number of ticks for each parameter, default: 3
:param tick_length: length of ticks, default: 3
:param n_bins: number of bins for histograms, default: 20
:param fig: figure to plot on, default: None
:param size: size of the figure, default: 4
:param fill: fill the area of the contours, default: True
:param grid: show grid, default: False
:param labels: labels for each parameter, default: None
if None, labels are taken from the parameter names
:param label: label for the plot, default: None
:param label_fontsize: fontsize of the label, default: 24
:param tick_fontsize: fontsize of the ticks, default: 14
:param legend_fontsize: fontsize of the legend, default: 24
:param show_legend: show legend, default: False
:param density_estimation_method: method to use for density estimation
options: smoothing, histo, kde, gaussian_mixture, median_filter
default: smoothing
:param normalize_prob2D: normalize probability for 2D plots, default: True
:param alpha: alpha value for the plot, default: None
:param alpha2D: alpha value for 2D plots, default: 1
:param alpha_for_low_density: use alpha for low density regions, default: False
:param alpha_threshold: threshold for alpha, default: 0
:param subplots_kwargs: kwargs for plt.subplots, default: {}
:param de_kwargs: kwargs for density estimation, default: {}
:param labels_kwargs: kwargs for labels, default: {}
:param grid_kwargs: kwargs for grid, default: {}
:param line_kwargs: kwargs for line plots, default: {}
:param scatter_kwargs: kwargs for scatter plots, default: {}
:param normalize_prob2D: normalize probability for 2D plots, default: True
:param n_points_scatter: number of points for scatter plots, default: -1 (all)
:param axlines_kwargs: kwargs for axlines, default: {}
:param ticks_kwargs: kwargs for ticks, default: {}
:param kwargs: additional kwargs for the plot function
:return: fig, axes
"""
if alpha is not None:
if alpha2D != 1:
LOGGER.warning("parameters alpha and alpha2D are both set, using alpha")
else:
if alpha2D != 1:
alpha = alpha2D
else:
alpha = 1
###############################
# prepare data and setup plot #
###############################
data = ensure_rec(data, names=names)
# needed for plotting chains with different automatic limits
current_ranges = {}
current_ticks = {}
labels_x = get_labels(labels_x, params_x)
labels_y = get_labels(labels_y, params_y)
columns = params_x + params_y
# Setup the probabilities for possible plots
prob_label = None
if prob is not None:
if np.min(prob) < 0:
prob_offset = -np.min(prob)
else:
prob_offset = 0
if normalize_prob2D:
prob2D = (prob + prob_offset) / np.sum(prob + prob_offset)
else:
# for example to plot an additional parameter in parameter space
prob_label = prob
prob2D = None
# Setup the figure orientation
n_cols = len(params_x)
n_rows = len(params_y)
# Setup the figure size
if isinstance(size, (list, tuple)):
x_size = size[0]
y_size = size[1]
else:
x_size = size
y_size = size * 0.7
# Setup the figure
if fig is None:
fig, _ = plt.subplots(
nrows=n_rows,
ncols=n_cols,
figsize=(n_cols * x_size, n_rows * y_size),
**subplots_kwargs,
)
ax = np.array(fig.get_axes()).ravel().reshape(n_rows, n_cols)
else:
ax = np.array(fig.get_axes()).ravel().reshape(n_rows, n_cols)
# get ranges for each parameter (if not specified, max/min of data is used)
update_current_ranges(current_ranges, ranges, columns, data)
def get_current_ax(ax, i, j):
axc = ax[i, j]
return axc
#################
# 2D histograms #
#################
for i, j in itertools.product(range(n_rows), range(n_cols)):
# j corresponds to x axis
j_rec = j
i_rec = i + n_cols
axc = get_current_ax(ax, i, j)
old_xlims, old_ylims = get_old_lims(axc)
if func == "contour_cl":
contour_cl(
axc,
data=data,
ranges=current_ranges,
columns=columns,
i=i_rec,
j=j_rec,
fill=fill,
color=color,
de_kwargs=de_kwargs,
line_kwargs=line_kwargs,
prob=prob,
density_estimation_method=density_estimation_method,
label=label,
alpha=alpha,
)
if func == "density_image":
density_image(
axc,
data=data,
ranges=current_ranges,
columns=columns,
i=i_rec,
j=j_rec,
cmap=cmap,
de_kwargs=de_kwargs,
vmin=cmap_vmin,
vmax=cmap_vmax,
prob=prob,
density_estimation_method=density_estimation_method,
label=label,
alpha=alpha,
alpha_for_low_density=alpha_for_low_density,
alpha_threshold=alpha_threshold,
)
elif func == "scatter":
x, y = get_n_points_for_scatter(
data[columns[j_rec]],
data[columns[i_rec]],
n_points_scatter=n_points_scatter,
)
axc.scatter(
x,
y,
c=color,
label=label,
alpha=alpha,
**scatter_kwargs,
)
elif func == "scatter_prob":
if normalize_prob2D:
_prob = prob2D
else:
_prob = prob_label
x, y, _prob = get_n_points_for_scatter(
data[columns[j_rec]],
data[columns[i_rec]],
prob=_prob,
n_points_scatter=n_points_scatter,
)
sorting = np.argsort(_prob)
axc.scatter(
x[sorting],
y[sorting],
c=_prob[sorting],
label=label,
cmap=cmap,
alpha=alpha,
**scatter_kwargs,
)
elif func == "scatter_density":
scatter_density(
axc,
points1=data[columns[j_rec]],
points2=data[columns[i_rec]],
n_bins=n_bins,
lim1=current_ranges[columns[j_rec]],
lim2=current_ranges[columns[i_rec]],
n_points_scatter=n_points_scatter,
cmap=cmap,
vmin=cmap_vmin,
vmax=cmap_vmax,
label=label,
alpha=alpha,
)
elif func == "axlines":
if len(data[columns[i_rec]]) > 1:
raise ValueError(
"axlines can only be used with one point, not with {} points".format(
len(data[columns[i_rec]])
)
)
axc.axhline(
y=data[columns[i_rec]],
color=color,
label=label,
alpha=alpha,
**axlines_kwargs,
)
axc.axvline(
x=data[columns[j_rec]],
color=color,
label=label,
alpha=alpha,
**axlines_kwargs,
)
set_limits(
axc,
ranges,
current_ranges,
columns[i_rec],
columns[j_rec],
old_xlims,
old_ylims,
)
# grid
if grid:
for axc in ax.flatten():
axc.grid(zorder=0, **grid_kwargs)
#########
# ticks #
#########
def get_ticks(i):
try:
return ticks[columns[i]]
except Exception:
return current_ticks[columns[i]]
def plot_yticks(axc, i, length=10, direction="in"):
axc.yaxis.set_ticks_position("both")
axc.set_yticks(get_ticks(i))
axc.tick_params(direction=direction, length=length)
def plot_xticks(axc, i, length=10, direction="in"):
axc.xaxis.set_ticks_position("both")
axc.set_xticks(get_ticks(i))
axc.tick_params(direction=direction, length=length)
delete_all_ticks(ax)
update_current_ticks(current_ticks, columns, ranges, current_ranges, n_ticks)
for i, j in itertools.product(range(n_rows), range(n_cols)):
axc = get_current_ax(ax, i, j)
plot_xticks(axc, j, tick_length)
plot_yticks(axc, i + n_cols, tick_length)
def plot_tick_labels(axc, xy, i, ticks_kwargs):
ticklabels = [t for t in get_ticks(i)]
if xy == "y":
axc.set_yticklabels(
ticklabels,
rotation=0,
fontsize=tick_fontsize,
**ticks_kwargs,
)
else:
# xy == "x":
axc.set_xticklabels(
ticklabels,
rotation=90,
fontsize=tick_fontsize,
**ticks_kwargs,
)
for i, j in itertools.product(range(n_rows), range(n_cols)):
axc = get_current_ax(ax, i, j)
if i == n_rows - 1:
plot_tick_labels(axc, "x", j, ticks_kwargs)
if j == 0:
plot_tick_labels(axc, "y", i + n_cols, ticks_kwargs)
# legends
legend_lines, legend_labels = get_lines_and_labels(ax)
for i, j in itertools.product(range(n_rows), range(n_cols)):
labelpad = 10
axc = get_current_ax(ax, i, j)
if j == 0:
axc.set_ylabel(
labels_y[i],
**labels_kwargs,
rotation=90,
labelpad=labelpad,
fontsize=label_fontsize,
)
axc.yaxis.set_label_position("left")
if i == n_rows - 1:
axc.set_xlabel(
labels_x[j],
**labels_kwargs,
rotation=0,
labelpad=labelpad,
fontsize=label_fontsize,
)
axc.xaxis.set_label_position("bottom")
if legend_lines and show_legend:
# only print legend when there are labels for it
fig.legend(
legend_lines,
legend_labels,
bbox_to_anchor=(1, 1),
bbox_transform=ax[0, -1].transAxes,
fontsize=legend_fontsize,
)
if colorbar:
add_colorbar(
fig,
cmap_vmin,
cmap_vmax,
cmap,
colorbar_ax,
colorbar_label,
legend_fontsize,
tick_fontsize,
prob_label,
)
rasterize_density_images(ax)
plt.subplots_adjust(hspace=0, wspace=0)
fig.align_ylabels()
fig.align_xlabels()
return fig