from __future__ import print_function, division, absolute_import, unicode_literals
import emcee
import numpy as np
import logging
import time
import cosmoHammer
import cosmoHammer.Constants as c
from cosmoHammer import getLogger
from cosmoHammer.util import SampleFileUtil
from cosmoHammer.util import SampleBallPositionGenerator
from cosmoHammer.util.IterationStopCriteriaStrategy import IterationStopCriteriaStrategy
[docs]class CosmoHammerSampler(object):
"""
A complete sampler implementation taking care of correct setup, chain burn in and sampling.
:param params: the parameter of the priors
:param likelihoodComputationChain: the callable computation chain
:param filePrefix: the prefix for the log and output files
:param walkersRatio: the ratio of walkers and the count of sampled parameters
:param burninIterations: number of iteration for burn in
:param sampleIterations: number of iteration to sample
:param stopCriteriaStrategy: the strategy to stop the sampling.
Default is None an then IterationStopCriteriaStrategy is used
:param initPositionGenerator: the generator for the init walker position.
Default is None an then SampleBallPositionGenerator is used
:param storageUtil: util used to store the results
:param threadCount: The count of threads to be used for the computation. Default is 1
:param reuseBurnin: Flag if the burn in should be reused.
If true the values will be read from the file System. Default is False
"""
def __init__(self, params, likelihoodComputationChain, filePrefix, walkersRatio, burninIterations,
sampleIterations, stopCriteriaStrategy=None, initPositionGenerator=None,
storageUtil=None, threadCount=1, reuseBurnin=False, logLevel=logging.INFO, pool=None):
"""
CosmoHammer sampler implementation
"""
self.params = params
self.likelihoodComputationChain = likelihoodComputationChain
self.walkersRatio = walkersRatio
self.reuseBurnin = reuseBurnin
self.filePrefix = filePrefix
self.threadCount = threadCount
self.paramCount = len(self.paramValues)
self.nwalkers = self.paramCount*walkersRatio
self.burninIterations = burninIterations
self.sampleIterations = sampleIterations
assert likelihoodComputationChain is not None, "The sampler needs a chain"
assert sampleIterations > 0, "CosmoHammer needs to sample for at least one iterations"
if not hasattr(self.likelihoodComputationChain, "params"):
self.likelihoodComputationChain.params = params
# setting up the logging
self._configureLogging(filePrefix+c.LOG_FILE_SUFFIX, logLevel)
if self.isMaster(): self.log("Using CosmoHammer "+str(cosmoHammer.__version__))
# The sampler object
self._sampler = self.createEmceeSampler(likelihoodComputationChain, pool=pool)
if(storageUtil is None):
storageUtil = self.createSampleFileUtil()
self.storageUtil = storageUtil
if(stopCriteriaStrategy is None):
stopCriteriaStrategy = self.createStopCriteriaStrategy()
stopCriteriaStrategy.setup(self)
self.stopCriteriaStrategy = stopCriteriaStrategy
if(initPositionGenerator is None):
initPositionGenerator = self.createInitPositionGenerator()
initPositionGenerator.setup(self)
self.initPositionGenerator = initPositionGenerator
def _configureLogging(self, filename, logLevel):
logger = getLogger()
logger.setLevel(logLevel)
fh = logging.FileHandler(filename, "w")
fh.setLevel(logLevel)
# create console handler with a higher log level
ch = logging.StreamHandler()
ch.setLevel(logging.ERROR)
# create formatter and add it to the handlers
formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s')
fh.setFormatter(formatter)
ch.setFormatter(formatter)
# add the handlers to the logger
for handler in logger.handlers[:]:
try:
handler.close()
except AttributeError:
pass
logger.removeHandler(handler)
logger.addHandler(fh)
logger.addHandler(ch)
# logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s',
# filename=filename, filemode='w', level=logLevel)
[docs] def createStopCriteriaStrategy(self):
"""
Returns a new instance of a stop criteria stategy
"""
return IterationStopCriteriaStrategy()
[docs] def createSampleFileUtil(self):
"""
Returns a new instance of a File Util
"""
return SampleFileUtil(self.filePrefix, reuseBurnin=self.reuseBurnin)
[docs] def createInitPositionGenerator(self):
"""
Returns a new instance of a Init Position Generator
"""
return SampleBallPositionGenerator()
@property
def paramValues(self):
return self.params[:,0]
@property
def paramWidths(self):
return self.params[:,3]
[docs] def startSampling(self):
"""
Launches the sampling
"""
try:
if self.isMaster(): self.log(self.__str__())
if(self.burninIterations>0):
if(self.reuseBurnin):
pos, prob, rstate = self.loadBurnin()
datas = [None]*len(pos)
else:
pos, prob, rstate, datas = self.startSampleBurnin()
else:
pos = self.createInitPos()
prob = None
rstate = None
datas = None
# Starting from the final position in the burn-in chain, sample for 1000
# steps.
self.log("start sampling after burn in")
start = time.time()
self.sample(pos, prob, rstate, datas)
end = time.time()
self.log("sampling done! Took: " + str(round(end-start,4))+"s")
# Print out the mean acceptance fraction. In general, acceptance_fraction
# has an entry for each walker
self.log("Mean acceptance fraction:"+ str(round(np.mean(self._sampler.acceptance_fraction), 4)))
finally:
if self._sampler.pool is not None:
try:
self._sampler.pool.close()
except AttributeError:
pass
try:
self.storageUtil.close()
except AttributeError:
pass
[docs] def loadBurnin(self):
"""
loads the burn in form the file system
"""
self.log("reusing previous burn in")
pos = self.storageUtil.importFromFile(self.filePrefix+c.BURNIN_SUFFIX)[-self.nwalkers:]
prob = self.storageUtil.importFromFile(self.filePrefix+c.BURNIN_PROB_SUFFIX)[-self.nwalkers:]
rstate= self.storageUtil.importRandomState(self.filePrefix+c.BURNIN_STATE_SUFFIX)
self.log("loading done")
return pos, prob, rstate
[docs] def startSampleBurnin(self):
"""
Runs the sampler for the burn in
"""
self.log("start burn in")
start = time.time()
p0 = self.createInitPos()
pos, prob, rstate, data = self.sampleBurnin(p0)
end = time.time()
self.log("burn in sampling done! Took: " + str(round(end-start,4))+"s")
self.log("Mean acceptance fraction for burn in:" + str(round(np.mean(self._sampler.acceptance_fraction), 4)))
self.resetSampler()
return pos, prob, rstate, data
[docs] def resetSampler(self):
"""
Resets the emcee sampler in the master node
"""
if self.isMaster():
self.log("Reseting emcee sampler")
# Reset the chain to remove the burn-in samples.
self._sampler.reset()
[docs] def sampleBurnin(self, p0):
"""
Run the emcee sampler for the burnin to create walker which are independent form their starting position
"""
counter = 1
for pos, prob, rstate, datas in self._sampler.sample(p0, iterations=self.burninIterations):
if self.isMaster():
self.storageUtil.persistBurninValues(pos, prob, datas)
if(counter%10==0):
self.log("Iteration finished:" + str(counter))
counter = counter + 1
if self.isMaster():
self.log("storing random state")
self.storageUtil.storeRandomState(self.filePrefix+c.BURNIN_STATE_SUFFIX, rstate)
return pos, prob, rstate, datas
[docs] def sample(self, burninPos, burninProb=None, burninRstate=None, datas=None):
"""
Starts the sampling process
"""
counter = 1
for pos, prob, _, datas in self._sampler.sample(burninPos, lnprob0=burninProb, rstate0=burninRstate,
blobs0=datas, iterations=self.sampleIterations):
if self.isMaster():
self.log("Iteration done. Persisting", logging.DEBUG)
self.storageUtil.persistSamplingValues(pos, prob, datas)
if(self.stopCriteriaStrategy.hasFinished()):
break
if(counter%10==0):
self.log("Iteration finished:" + str(counter))
counter = counter + 1
[docs] def isMaster(self):
"""
Returns True. Can be overridden for multitasking i.e. with MPI
"""
return True
[docs] def log(self, message, level=logging.INFO):
"""
Logs a message to the logfile
"""
getLogger().log(level, message)
[docs] def createEmceeSampler(self, callable, **kwargs):
"""
Factory method to create the emcee sampler
"""
if self.isMaster(): self.log("Using emcee "+str(emcee.__version__))
return emcee.EnsembleSampler(self.nwalkers,
self.paramCount,
callable,
threads=self.threadCount,
**kwargs)
[docs] def createInitPos(self):
"""
Factory method to create initial positions
"""
return self.initPositionGenerator.generate()
[docs] def getChain(self):
"""
Returns the sample chain
"""
return self._sampler.chain
def __str__(self, *args, **kwargs):
"""
Returns the string representation of the sampler config
"""
desc = "Sampler: " + str(type(self))+"\n" \
"configuration: \n" \
" Params: " +str(self.paramValues)+"\n" \
" Burnin iterations: " +str(self.burninIterations)+"\n" \
" Samples iterations: " +str(self.sampleIterations)+"\n" \
" Walkers ratio: " +str(self.walkersRatio)+"\n" \
" Reusing burn in: " +str(self.reuseBurnin)+"\n" \
" init pos generator: " +str(self.initPositionGenerator)+"\n" \
" stop criteria: " +str(self.stopCriteriaStrategy)+"\n" \
" storage util: " +str(self.storageUtil)+"\n" \
"likelihoodComputationChain: \n" + str(self.likelihoodComputationChain) \
+"\n"
return desc