# encoding: utf-8
from __future__ import absolute_import, division, print_function
import functools
import glob
import itertools
import os
import random
import sys
import warnings
import dill
import numpy as np
from ._parallel import (
check_parallel,
get_pool_for,
get_rank,
runs_on_euler_node,
runs_with_mpi,
)
from ._utils import (
capture,
mute_stderr,
os_write_works,
with_progressbar,
with_simple_progress_report,
)
from .parameters import Parameters
from .persisting import _SamplerPersister, dont_persist
with mute_stderr():
# might cause MPI related error message if run on one single node, even
# without mpirun:
import emcee
emcee_version = int(emcee.__version__.split(".")[0])
RUNS_WITH_MPI = runs_with_mpi()
class _ParameterAdapter(object):
"""
Addapts parameters argument of emcee lnprob function to parameters of uhammer
lnprob function.
"""
def __init__(self, names, values):
self.names = names
self.values = values
# this sets every parameter as an attribute:
self.__dict__.update(dict(zip(names, values)))
def __str__(self):
assignments = [
"{}={:e}".format(name, value)
for name, value in zip(self.names, self.values)
]
return "Parameters({})".format(", ".join(assignments))
class _LnProbAdapter(object):
"""
Adapts API of uhammer lnprob function to emcee hammer API.
"""
def __init__(self, lnprob_uhammer, p, show_output, output_path_pattern):
self.lnprob_uhammer = lnprob_uhammer
self.p = p
self.show_output = show_output
self.output_path_pattern = output_path_pattern
def __call__(self, parameters, *args):
if self.output_path_pattern is not None:
rank = get_rank()
if rank is not None:
output_path = self.output_path_pattern.format(worker_id=rank)
else:
output_path = self.output_path_pattern
else:
output_path = None
with capture(output_path, self.show_output):
return self.lnprob(parameters, args)
def lnprob(self, parameters, args):
p = self.p
names = p.names
lnprob_prior = 0
for name, value in zip(names, parameters):
if not p._is_in_bounds(name, value):
return -np.inf
lnprob_prior -= np.log(p._range_length(name))
parameters_for_lnprob = _ParameterAdapter(names, parameters)
computed_lnprob = self.lnprob_uhammer(parameters_for_lnprob, *args)
if computed_lnprob > 0:
client_function_name = self.lnprob_uhammer.__name__
raise AssertionError(
"{} computed positive value {} for parameters {}".format(
client_function_name, computed_lnprob, parameters_for_lnprob
)
)
return computed_lnprob + lnprob_prior
def __getstate__(self):
# distributed computation requires pickling of the posterior
# function using pickle module from the python standard library.
# this can fail e.g. if nested functions or classes with
# classmethods are infolved.
# to fix that we implement here how to pickle this class using
# dill which is more versatile than pickle from the standard
# library:
return dill.dumps(self.__dict__)
def __setstate__(self, state):
# see __getstate__
# this method implements the unpickling part.
self.__dict__.update(dill.loads(state))
[docs]class PickableEmceeSampler(emcee.EnsembleSampler):
def __getstate__(self):
dd = self.__dict__.copy()
if "pool" in dd:
del dd["pool"]
return dd
def __setstate__(self, dd):
self.pool = None
self.__dict__.update(dd)
if emcee_version > 2:
self.backend.chain = self.backend.chain[: self.iteration + 1, :, :]
self.backend.log_prob = self.backend.log_prob[: self.iteration + 1, :]
def _setup_random_state(seed, sampler):
rstate0 = sampler._random
rstate0.seed(seed)
random.seed(seed)
np.random.seed(seed)
[docs]def sample(
lnprob,
p,
args=None,
n_walkers_per_param=10,
seed=None,
n_samples=1000,
parallel=False,
show_progress=False,
show_output=True,
output_prefix=None,
persist=dont_persist(),
verbose=True,
):
"""
Runs emcee sampler to generate samples from the declared
distribution.
:param lnprob: Function for computing the natural logarithm of the
distribution to sample from. Fist argument is a
object with the current values of the declared
parameters as attributes. The function may accept
additioanal arguments for auxiliary data.
:param p: Instance of the ``Parameters`` class.
:param args: List with auxiliary values. Can also be ``None`` or
``[]``.
:param n_walkers_per_param: Number of walkers per parameter.
Results in ``len(p) *
n_walkers_per_param`` walkers in total.
:param seed: Used random seed in case you want reproducible
samples. Skip this argument or use ``None`` else.
:param n_samples: Number of samples to generate. So the return value
is a ``numpy`` array of shape ``(n_samples, len(p))``.
:param parallel: Run in sequention or parallel mode? Automatically
detects number of cores, also detects if code is
run using ``mpirun``.
:param show_progress: Show progress of sampling.
:param show_output: set to ``True`` if you want to see the output
from the workers on the terminal, else set it to
``False``.
:param output_prefix: prefix for file path to record output from the
workers. Default is ``None`` which means no
recording of output.
:param persist: persisting strategy, on of the functions in
:py:mod:`~uhammer.persisting`.
:param verbose: boolean value. If set to ``True``
:py:func:`~uhammer.sampler.sample` prints extra information.
This function returns two values:
- computed samples as a ``numpy`` array of shape ``(n_samples, len(p)``
- computed log probabilites as ``numpy`` array of shape ``(n_samples,)``.
Samples are arranged as follows, log probbilites in similar order:
.. table::
+----------------------------------------+
| samples |
+========================================+
| p-dimensional sample from walker 0 |
+----------------------------------------+
| p-dimensional sample from walker 1 |
+----------------------------------------+
| ... |
+----------------------------------------+
| p-dimensional sample from walker n - 1 |
+----------------------------------------+
| p-dimensional sample from walker 0 |
+----------------------------------------+
| p-dimensional sample from walker 1 |
+----------------------------------------+
| ... |
+----------------------------------------+
"""
if args is None:
args = []
assert isinstance(p, Parameters), p
assert len(p) > 0, "need parameters"
assert isinstance(args, (list, tuple)), args
assert isinstance(n_walkers_per_param, int), n_walkers_per_param
assert n_walkers_per_param > 1, "need at least 2 walkers per parameter"
if seed is not None:
assert isinstance(seed, int)
assert isinstance(n_samples, int), n_samples
assert isinstance(parallel, bool)
assert isinstance(show_progress, bool), show_progress
assert isinstance(verbose, bool), verbose
assert isinstance(show_output, bool), show_output
if output_prefix is not None:
assert isinstance(output_prefix, str), output_prefix
check_parallel(parallel)
if output_prefix is not None:
output_path_pattern = capture_output_path(output_prefix, parallel)
else:
output_path_pattern = None
show_progress = _check_and_fix_progress_settings(show_output, show_progress)
# holds paramters passed to lnprob:
_lnprob = _LnProbAdapter(lnprob, p, show_output, output_path_pattern)
n_walkers = n_walkers_per_param * len(p)
# might be global including startup of workers, workers hang there:
# workers are stopped at exit:
pool = get_pool_for(n_walkers) if parallel else None
if pool is not None and verbose:
print("uhammer: started {} with {} workers".format(pool, pool.size))
sampler = PickableEmceeSampler(n_walkers, len(p), _lnprob, args=args, pool=pool)
sampler.n_walkers = n_walkers
_setup_random_state(seed, sampler) if seed is not None else np.random.RandomState()
p0 = [[p.start_value(name) for name in p.names] for _ in range(n_walkers)]
result = _sample(p0, sampler, n_samples, show_progress, persist, verbose=verbose)
return result
[docs]def continue_sampling(
persistence_file_path,
n_samples,
parallel=False,
show_progress=False,
show_output=True,
output_prefix=None,
persist=dont_persist(),
verbose=True,
):
"""
Continues sampling from a persisted sampler.
Original arguments of :py:func:`~uhammer.sampler.sample` like
``lnprob``, `n_walkers_per_param``, ``seed``, etc. are reused from
the persisted sampler and can not be modified anymore.
:param persistence_file_path: Path to a perviously persisted
sampler. This includes `lnprob`,
parameters and number of walkers.
:param n_samples: Number of samples to generate. So the return value
is a ``numpy`` array of shape ``(n_samples, len(p))``.
:param parallel: Run in sequention or parallel mode ? Automatic
dection of number of cores, also auto detects if
code is run using ``mpirun``.
:param show_progress: Show progress of sampling.
:param show_output: set to ``True`` if you want to see the output
from the workers on the terminal, else set it to
``False``.
:param output_prefix: prefix for file path to record output from the
workers. Default is ``None`` which means no
recording of output.
:param persist: persisting strategy, on of the functions in
:py:mod:`~uhammer.persisting`.
:param verbose: boolean value, if ``True``
:py:func:`~uhammer.sampler.continue_sampling`
prints some information.
This function returns two values:
- computed samples as a ``numpy`` array of shape ``(n_samples, len(p)``
- computed log probabilites as ``numpy`` array of shape ``(n_samples,)``.
Samples are arranged as follows, log probbilites in similar order:
.. table::
+----------------------------------------+
| samples |
+========================================+
| p-dimensional sample from walker 0 |
+----------------------------------------+
| p-dimensional sample from walker 1 |
+----------------------------------------+
| ... |
+----------------------------------------+
| p-dimensional sample from walker n - 1 |
+----------------------------------------+
| p-dimensional sample from walker 0 |
+----------------------------------------+
| p-dimensional sample from walker 1 |
+----------------------------------------+
| ... |
+----------------------------------------+
"""
assert isinstance(persistence_file_path, str)
assert os.path.exists(persistence_file_path), "file {} does not exist".format(
persistence_file_path
)
assert isinstance(n_samples, int), n_samples
assert isinstance(parallel, bool)
assert isinstance(show_progress, bool), show_progress
assert isinstance(show_output, bool), show_output
if output_prefix is not None:
assert isinstance(output_prefix, str), output_prefix
check_parallel(parallel)
show_progress = _check_and_fix_progress_settings(show_output, show_progress)
if output_prefix is not None:
output_path_pattern = capture_output_path(output_prefix, parallel)
else:
output_path_pattern = None
sampler = _SamplerPersister.restore_sampler(persistence_file_path)
sampler.pool = get_pool_for(sampler.n_walkers) if parallel else None
if emcee_version < 3:
sampler.lnprobfn.f.show_output = show_output
sampler.lnprobfn.f.output_path_pattern = output_path_pattern
else:
sampler.log_prob_fn.f.show_output = show_output
sampler.log_prob_fn.f.output_path_pattern = output_path_pattern
p0, lnprob0, rstate0 = sampler._last_run_mcmc_result
return _sample(
p0, sampler, n_samples, show_progress, persist, lnprob0, rstate0, verbose
)
[docs]def capture_output_path(prefix, parallel):
existing = glob.glob(prefix + "_run_*_*.txt")
run_ids = [int(os.path.basename(p).split("_")[2]) for p in existing]
if run_ids:
next_run_id = max(run_ids) + 1
else:
next_run_id = 0
if not parallel:
return "{prefix}_run_{run_id:04d}.txt".format(prefix=prefix, run_id=next_run_id)
return "{prefix}_run_{run_id:04d}_worker_{{worker_id}}.txt".format(
prefix=prefix, run_id=next_run_id
)
def _check_and_fix_progress_settings(show_output, show_progress):
if show_output and show_progress:
warnings.warn(
"not supressing output might corrupt the progressbar. "
"Better set show_output to False and set an appropriabe "
"value for output_prefix."
)
if show_progress and not os_write_works() and not runs_on_euler_node():
show_progress = False
warnings.warn("disabled progressbar, can not write to fid 1")
return show_progress
def _sample(
p0,
sampler,
n_samples,
show_progress,
persist,
lnprob0=None,
rstate0=None,
verbose=True,
):
print_ = functools.partial(print, file=sys.__stdout__)
n_steps = int(np.ceil(n_samples / sampler.n_walkers))
if verbose:
print_("uhammer: perform {} steps of emcee sampler".format(n_steps))
# eemcee sampler.sample
sample_iter = sampler.sample(p0, iterations=n_steps)
if show_progress:
if runs_on_euler_node():
sample_iter = with_simple_progress_report(n_steps, sample_iter)
else:
sample_iter = with_progressbar(n_steps, sample_iter)
for step in itertools.count(1):
try:
state = next(sample_iter)
if emcee_version < 3:
state = state[:3]
sampler._last_run_mcmc_result = state
except StopIteration:
break
except Exception as e:
persisted = persist(sampler, step, n_steps, e)
if verbose and persisted:
print_("uhammer: persist sampler due to exception")
raise e
persisted = persist(sampler, step, n_steps, None)
if verbose and persisted:
print_("uhammer: persisted sampler after iteration {}".format(step))
n_param = sampler.chain.shape[2]
samples = sampler.chain.swapaxes(0, 1).reshape(-1, n_param)[:n_samples, :]
if emcee_version < 3:
lnprobs = sampler.lnprobability.swapaxes(0, 1).flatten()[:n_samples]
else:
# this change is a preliminary workaround for a bug in emcee.
# this is already fixed in emcee git repository, but not in the package
lnprobs = sampler.lnprobability.flatten()[-n_samples:]
return samples, lnprobs