Source code for cosmoHammer.CosmoHammerSampler

from __future__ import print_function, division, absolute_import, unicode_literals

import emcee
import numpy as np
import logging
import time

import cosmoHammer
import cosmoHammer.Constants as c

from cosmoHammer import getLogger
from cosmoHammer.util import SampleFileUtil
from cosmoHammer.util import SampleBallPositionGenerator
from cosmoHammer.util.IterationStopCriteriaStrategy import IterationStopCriteriaStrategy



[docs]class CosmoHammerSampler(object): """ A complete sampler implementation taking care of correct setup, chain burn in and sampling. :param params: the parameter of the priors :param likelihoodComputationChain: the callable computation chain :param filePrefix: the prefix for the log and output files :param walkersRatio: the ratio of walkers and the count of sampled parameters :param burninIterations: number of iteration for burn in :param sampleIterations: number of iteration to sample :param stopCriteriaStrategy: the strategy to stop the sampling. Default is None an then IterationStopCriteriaStrategy is used :param initPositionGenerator: the generator for the init walker position. Default is None an then SampleBallPositionGenerator is used :param storageUtil: util used to store the results :param threadCount: The count of threads to be used for the computation. Default is 1 :param reuseBurnin: Flag if the burn in should be reused. If true the values will be read from the file System. Default is False """ def __init__(self, params, likelihoodComputationChain, filePrefix, walkersRatio, burninIterations, sampleIterations, stopCriteriaStrategy=None, initPositionGenerator=None, storageUtil=None, threadCount=1, reuseBurnin=False, logLevel=logging.INFO, pool=None): """ CosmoHammer sampler implementation """ self.params = params self.likelihoodComputationChain = likelihoodComputationChain self.walkersRatio = walkersRatio self.reuseBurnin = reuseBurnin self.filePrefix = filePrefix self.threadCount = threadCount self.paramCount = len(self.paramValues) self.nwalkers = self.paramCount*walkersRatio self.burninIterations = burninIterations self.sampleIterations = sampleIterations assert likelihoodComputationChain is not None, "The sampler needs a chain" assert sampleIterations > 0, "CosmoHammer needs to sample for at least one iterations" if not hasattr(self.likelihoodComputationChain, "params"): self.likelihoodComputationChain.params = params # setting up the logging self._configureLogging(filePrefix+c.LOG_FILE_SUFFIX, logLevel) if self.isMaster(): self.log("Using CosmoHammer "+str(cosmoHammer.__version__)) # The sampler object self._sampler = self.createEmceeSampler(likelihoodComputationChain, pool=pool) if(storageUtil is None): storageUtil = self.createSampleFileUtil() self.storageUtil = storageUtil if(stopCriteriaStrategy is None): stopCriteriaStrategy = self.createStopCriteriaStrategy() stopCriteriaStrategy.setup(self) self.stopCriteriaStrategy = stopCriteriaStrategy if(initPositionGenerator is None): initPositionGenerator = self.createInitPositionGenerator() initPositionGenerator.setup(self) self.initPositionGenerator = initPositionGenerator def _configureLogging(self, filename, logLevel): logger = getLogger() logger.setLevel(logLevel) fh = logging.FileHandler(filename, "w") fh.setLevel(logLevel) # create console handler with a higher log level ch = logging.StreamHandler() ch.setLevel(logging.ERROR) # create formatter and add it to the handlers formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s') fh.setFormatter(formatter) ch.setFormatter(formatter) # add the handlers to the logger for handler in logger.handlers[:]: try: handler.close() except AttributeError: pass logger.removeHandler(handler) logger.addHandler(fh) logger.addHandler(ch) # logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', # filename=filename, filemode='w', level=logLevel)
[docs] def createStopCriteriaStrategy(self): """ Returns a new instance of a stop criteria stategy """ return IterationStopCriteriaStrategy()
[docs] def createSampleFileUtil(self): """ Returns a new instance of a File Util """ return SampleFileUtil(self.filePrefix, reuseBurnin=self.reuseBurnin)
[docs] def createInitPositionGenerator(self): """ Returns a new instance of a Init Position Generator """ return SampleBallPositionGenerator()
@property def paramValues(self): return self.params[:,0] @property def paramWidths(self): return self.params[:,3]
[docs] def startSampling(self): """ Launches the sampling """ try: if self.isMaster(): self.log(self.__str__()) if(self.burninIterations>0): if(self.reuseBurnin): pos, prob, rstate = self.loadBurnin() datas = [None]*len(pos) else: pos, prob, rstate, datas = self.startSampleBurnin() else: pos = self.createInitPos() prob = None rstate = None datas = None # Starting from the final position in the burn-in chain, sample for 1000 # steps. self.log("start sampling after burn in") start = time.time() self.sample(pos, prob, rstate, datas) end = time.time() self.log("sampling done! Took: " + str(round(end-start,4))+"s") # Print out the mean acceptance fraction. In general, acceptance_fraction # has an entry for each walker self.log("Mean acceptance fraction:"+ str(round(np.mean(self._sampler.acceptance_fraction), 4))) finally: if self._sampler.pool is not None: try: self._sampler.pool.close() except AttributeError: pass try: self.storageUtil.close() except AttributeError: pass
[docs] def loadBurnin(self): """ loads the burn in form the file system """ self.log("reusing previous burn in") pos = self.storageUtil.importFromFile(self.filePrefix+c.BURNIN_SUFFIX)[-self.nwalkers:] prob = self.storageUtil.importFromFile(self.filePrefix+c.BURNIN_PROB_SUFFIX)[-self.nwalkers:] rstate= self.storageUtil.importRandomState(self.filePrefix+c.BURNIN_STATE_SUFFIX) self.log("loading done") return pos, prob, rstate
[docs] def startSampleBurnin(self): """ Runs the sampler for the burn in """ self.log("start burn in") start = time.time() p0 = self.createInitPos() pos, prob, rstate, data = self.sampleBurnin(p0) end = time.time() self.log("burn in sampling done! Took: " + str(round(end-start,4))+"s") self.log("Mean acceptance fraction for burn in:" + str(round(np.mean(self._sampler.acceptance_fraction), 4))) self.resetSampler() return pos, prob, rstate, data
[docs] def resetSampler(self): """ Resets the emcee sampler in the master node """ if self.isMaster(): self.log("Reseting emcee sampler") # Reset the chain to remove the burn-in samples. self._sampler.reset()
[docs] def sampleBurnin(self, p0): """ Run the emcee sampler for the burnin to create walker which are independent form their starting position """ counter = 1 for pos, prob, rstate, datas in self._sampler.sample(p0, iterations=self.burninIterations): if self.isMaster(): self.storageUtil.persistBurninValues(pos, prob, datas) if(counter%10==0): self.log("Iteration finished:" + str(counter)) counter = counter + 1 if self.isMaster(): self.log("storing random state") self.storageUtil.storeRandomState(self.filePrefix+c.BURNIN_STATE_SUFFIX, rstate) return pos, prob, rstate, datas
[docs] def sample(self, burninPos, burninProb=None, burninRstate=None, datas=None): """ Starts the sampling process """ counter = 1 for pos, prob, _, datas in self._sampler.sample(burninPos, lnprob0=burninProb, rstate0=burninRstate, blobs0=datas, iterations=self.sampleIterations): if self.isMaster(): self.log("Iteration done. Persisting", logging.DEBUG) self.storageUtil.persistSamplingValues(pos, prob, datas) if(self.stopCriteriaStrategy.hasFinished()): break if(counter%10==0): self.log("Iteration finished:" + str(counter)) counter = counter + 1
[docs] def isMaster(self): """ Returns True. Can be overridden for multitasking i.e. with MPI """ return True
[docs] def log(self, message, level=logging.INFO): """ Logs a message to the logfile """ getLogger().log(level, message)
[docs] def createEmceeSampler(self, callable, **kwargs): """ Factory method to create the emcee sampler """ if self.isMaster(): self.log("Using emcee "+str(emcee.__version__)) return emcee.EnsembleSampler(self.nwalkers, self.paramCount, callable, threads=self.threadCount, **kwargs)
[docs] def createInitPos(self): """ Factory method to create initial positions """ return self.initPositionGenerator.generate()
[docs] def getChain(self): """ Returns the sample chain """ return self._sampler.chain
def __str__(self, *args, **kwargs): """ Returns the string representation of the sampler config """ desc = "Sampler: " + str(type(self))+"\n" \ "configuration: \n" \ " Params: " +str(self.paramValues)+"\n" \ " Burnin iterations: " +str(self.burninIterations)+"\n" \ " Samples iterations: " +str(self.sampleIterations)+"\n" \ " Walkers ratio: " +str(self.walkersRatio)+"\n" \ " Reusing burn in: " +str(self.reuseBurnin)+"\n" \ " init pos generator: " +str(self.initPositionGenerator)+"\n" \ " stop criteria: " +str(self.stopCriteriaStrategy)+"\n" \ " storage util: " +str(self.storageUtil)+"\n" \ "likelihoodComputationChain: \n" + str(self.likelihoodComputationChain) \ +"\n" return desc