Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226

227

228

229

230

231

232

233

234

235

236

237

238

239

240

241

242

243

244

245

246

247

248

249

250

251

252

253

254

255

256

257

258

259

260

261

262

263

264

265

266

267

268

269

270

271

272

273

274

275

276

277

278

279

280

281

282

283

284

285

286

287

288

289

290

291

292

293

294

295

296

297

298

299

300

301

302

303

304

305

306

307

308

309

310

311

312

313

314

315

316

317

318

319

320

321

322

from __future__ import print_function, division, absolute_import, unicode_literals 

 

import emcee 

import numpy as np 

import logging 

import time 

 

import cosmoHammer 

import cosmoHammer.Constants as c 

 

from cosmoHammer import getLogger 

from cosmoHammer.util import SampleFileUtil 

from cosmoHammer.util import SampleBallPositionGenerator 

from cosmoHammer.util.IterationStopCriteriaStrategy import IterationStopCriteriaStrategy 

 

 

 

class CosmoHammerSampler(object): 

""" 

A complete sampler implementation taking care of correct setup, chain burn in and sampling. 

 

:param params: the parameter of the priors 

:param likelihoodComputationChain: the callable computation chain 

:param filePrefix: the prefix for the log and output files 

:param walkersRatio: the ratio of walkers and the count of sampled parameters 

:param burninIterations: number of iteration for burn in 

:param sampleIterations: number of iteration to sample 

:param stopCriteriaStrategy: the strategy to stop the sampling.  

Default is None an then IterationStopCriteriaStrategy is used 

:param initPositionGenerator: the generator for the init walker position.  

Default is None an then SampleBallPositionGenerator is used 

:param storageUtil: util used to store the results 

:param threadCount: The count of threads to be used for the computation. Default is 1 

:param reuseBurnin: Flag if the burn in should be reused.  

If true the values will be read from the file System. Default is False 

 

""" 

 

def __init__(self, params, likelihoodComputationChain, filePrefix, walkersRatio, burninIterations, 

sampleIterations, stopCriteriaStrategy=None, initPositionGenerator=None, 

storageUtil=None, threadCount=1, reuseBurnin=False, logLevel=logging.INFO, pool=None): 

""" 

CosmoHammer sampler implementation 

 

""" 

self.params = params 

self.likelihoodComputationChain = likelihoodComputationChain 

self.walkersRatio = walkersRatio 

self.reuseBurnin = reuseBurnin 

self.filePrefix = filePrefix 

self.threadCount = threadCount 

self.paramCount = len(self.paramValues) 

self.nwalkers = self.paramCount*walkersRatio 

self.burninIterations = burninIterations 

self.sampleIterations = sampleIterations 

 

assert likelihoodComputationChain is not None, "The sampler needs a chain" 

assert sampleIterations > 0, "CosmoHammer needs to sample for at least one iterations" 

 

if not hasattr(self.likelihoodComputationChain, "params"): 

self.likelihoodComputationChain.params = params 

 

# setting up the logging 

self._configureLogging(filePrefix+c.LOG_FILE_SUFFIX, logLevel) 

 

if self.isMaster(): self.log("Using CosmoHammer "+str(cosmoHammer.__version__)) 

 

# The sampler object 

self._sampler = self.createEmceeSampler(likelihoodComputationChain, pool=pool) 

 

if(storageUtil is None): 

storageUtil = self.createSampleFileUtil() 

 

self.storageUtil = storageUtil 

 

if(stopCriteriaStrategy is None): 

stopCriteriaStrategy = self.createStopCriteriaStrategy() 

 

stopCriteriaStrategy.setup(self) 

self.stopCriteriaStrategy = stopCriteriaStrategy 

 

if(initPositionGenerator is None): 

initPositionGenerator = self.createInitPositionGenerator() 

 

initPositionGenerator.setup(self) 

self.initPositionGenerator = initPositionGenerator 

 

 

def _configureLogging(self, filename, logLevel): 

logger = getLogger() 

logger.setLevel(logLevel) 

fh = logging.FileHandler(filename, "w") 

fh.setLevel(logLevel) 

# create console handler with a higher log level 

ch = logging.StreamHandler() 

ch.setLevel(logging.ERROR) 

# create formatter and add it to the handlers 

formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s') 

fh.setFormatter(formatter) 

ch.setFormatter(formatter) 

# add the handlers to the logger 

for handler in logger.handlers[:]: 

try: 

handler.close() 

except AttributeError: 

pass 

logger.removeHandler(handler) 

logger.addHandler(fh) 

logger.addHandler(ch) 

# logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s',  

# filename=filename, filemode='w', level=logLevel) 

 

 

def createStopCriteriaStrategy(self): 

""" 

Returns a new instance of a stop criteria stategy 

""" 

return IterationStopCriteriaStrategy() 

 

def createSampleFileUtil(self): 

""" 

Returns a new instance of a File Util 

""" 

return SampleFileUtil(self.filePrefix, reuseBurnin=self.reuseBurnin) 

 

def createInitPositionGenerator(self): 

""" 

Returns a new instance of a Init Position Generator 

""" 

return SampleBallPositionGenerator() 

 

@property 

def paramValues(self): 

return self.params[:,0] 

 

@property 

def paramWidths(self): 

return self.params[:,3] 

 

def startSampling(self): 

""" 

Launches the sampling 

""" 

try: 

if self.isMaster(): self.log(self.__str__()) 

if(self.burninIterations>0): 

 

if(self.reuseBurnin): 

pos, prob, rstate = self.loadBurnin() 

datas = [None]*len(pos) 

 

else: 

pos, prob, rstate, datas = self.startSampleBurnin() 

else: 

pos = self.createInitPos() 

prob = None 

rstate = None 

datas = None 

# Starting from the final position in the burn-in chain, sample for 1000 

# steps. 

self.log("start sampling after burn in") 

start = time.time() 

self.sample(pos, prob, rstate, datas) 

end = time.time() 

self.log("sampling done! Took: " + str(round(end-start,4))+"s") 

 

# Print out the mean acceptance fraction. In general, acceptance_fraction 

# has an entry for each walker 

self.log("Mean acceptance fraction:"+ str(round(np.mean(self._sampler.acceptance_fraction), 4))) 

finally: 

if self._sampler.pool is not None: 

try: 

self._sampler.pool.close() 

except AttributeError: 

pass 

try: 

self.storageUtil.close() 

except AttributeError: 

pass 

 

 

 

def loadBurnin(self): 

""" 

loads the burn in form the file system 

""" 

self.log("reusing previous burn in") 

 

pos = self.storageUtil.importFromFile(self.filePrefix+c.BURNIN_SUFFIX)[-self.nwalkers:] 

 

prob = self.storageUtil.importFromFile(self.filePrefix+c.BURNIN_PROB_SUFFIX)[-self.nwalkers:] 

 

rstate= self.storageUtil.importRandomState(self.filePrefix+c.BURNIN_STATE_SUFFIX) 

 

self.log("loading done") 

return pos, prob, rstate 

 

 

def startSampleBurnin(self): 

""" 

Runs the sampler for the burn in 

""" 

self.log("start burn in") 

start = time.time() 

p0 = self.createInitPos() 

pos, prob, rstate, data = self.sampleBurnin(p0) 

end = time.time() 

self.log("burn in sampling done! Took: " + str(round(end-start,4))+"s") 

self.log("Mean acceptance fraction for burn in:" + str(round(np.mean(self._sampler.acceptance_fraction), 4))) 

 

self.resetSampler() 

 

return pos, prob, rstate, data 

 

 

def resetSampler(self): 

""" 

Resets the emcee sampler in the master node 

""" 

if self.isMaster(): 

self.log("Reseting emcee sampler") 

# Reset the chain to remove the burn-in samples. 

self._sampler.reset() 

 

 

 

def sampleBurnin(self, p0): 

""" 

Run the emcee sampler for the burnin to create walker which are independent form their starting position 

""" 

 

counter = 1 

for pos, prob, rstate, datas in self._sampler.sample(p0, iterations=self.burninIterations): 

if self.isMaster(): 

self.storageUtil.persistBurninValues(pos, prob, datas) 

if(counter%10==0): 

self.log("Iteration finished:" + str(counter)) 

 

counter = counter + 1 

 

if self.isMaster(): 

self.log("storing random state") 

self.storageUtil.storeRandomState(self.filePrefix+c.BURNIN_STATE_SUFFIX, rstate) 

 

return pos, prob, rstate, datas 

 

 

def sample(self, burninPos, burninProb=None, burninRstate=None, datas=None): 

""" 

Starts the sampling process 

""" 

counter = 1 

for pos, prob, _, datas in self._sampler.sample(burninPos, lnprob0=burninProb, rstate0=burninRstate, 

blobs0=datas, iterations=self.sampleIterations): 

if self.isMaster(): 

self.log("Iteration done. Persisting", logging.DEBUG) 

self.storageUtil.persistSamplingValues(pos, prob, datas) 

 

if(self.stopCriteriaStrategy.hasFinished()): 

break 

 

if(counter%10==0): 

self.log("Iteration finished:" + str(counter)) 

 

counter = counter + 1 

 

 

def isMaster(self): 

""" 

Returns True. Can be overridden for multitasking i.e. with MPI 

""" 

return True 

 

def log(self, message, level=logging.INFO): 

""" 

Logs a message to the logfile 

""" 

getLogger().log(level, message) 

 

 

def createEmceeSampler(self, callable, **kwargs): 

""" 

Factory method to create the emcee sampler 

""" 

if self.isMaster(): self.log("Using emcee "+str(emcee.__version__)) 

return emcee.EnsembleSampler(self.nwalkers, 

self.paramCount, 

callable, 

threads=self.threadCount, 

**kwargs) 

 

def createInitPos(self): 

""" 

Factory method to create initial positions 

""" 

return self.initPositionGenerator.generate() 

 

 

def getChain(self): 

""" 

Returns the sample chain 

""" 

return self._sampler.chain 

 

def __str__(self, *args, **kwargs): 

""" 

Returns the string representation of the sampler config 

""" 

desc = "Sampler: " + str(type(self))+"\n" \ 

"configuration: \n" \ 

" Params: " +str(self.paramValues)+"\n" \ 

" Burnin iterations: " +str(self.burninIterations)+"\n" \ 

" Samples iterations: " +str(self.sampleIterations)+"\n" \ 

" Walkers ratio: " +str(self.walkersRatio)+"\n" \ 

" Reusing burn in: " +str(self.reuseBurnin)+"\n" \ 

" init pos generator: " +str(self.initPositionGenerator)+"\n" \ 

" stop criteria: " +str(self.stopCriteriaStrategy)+"\n" \ 

" storage util: " +str(self.storageUtil)+"\n" \ 

"likelihoodComputationChain: \n" + str(self.likelihoodComputationChain) \ 

+"\n" 

 

return desc