Source code for cosmoHammer.util.MpiUtil
import itertools
import os
from cosmoHammer import getLogger
import time
# If mpi4py is installed, import it.
try:
from mpi4py import MPI
MPI = MPI
except ImportError:
MPI = None
[docs]class MpiPool(object):
"""
Implementation of a mpi based pool. Currently it supports only the map function.
:param mapFunction: the map function to apply on the mpi nodes
"""
def __init__(self, mapFunction):
self.rank = MPI.COMM_WORLD.Get_rank()
self.mapFunction = mapFunction
[docs] def map(self, function, sequence):
"""
Emulates a pool map function using Mpi.
Retrieves the number of mpi processes and splits the sequence of walker position
in order to allow each process its block
:param function: the function to apply on the items of the sequence
:param sequence: a sequence of items
:returns sequence: sequence of results
"""
(rank,size) = (MPI.COMM_WORLD.Get_rank(),MPI.COMM_WORLD.Get_size())
#sync
sequence = mpiBCast(sequence)
getLogger().debug("Rank: %s, pid: %s MpiPool: starts processing iteration" %(rank, os.getpid()))
#split, process and merge the sequence
mergedList = mergeList(MPI.COMM_WORLD.allgather(
self.mapFunction(function, splitList(sequence,size)[rank])))
getLogger().debug("Rank: %s, pid: %s MpiPool: done processing iteration"%(rank, os.getpid()))
# time.sleep(10)
return mergedList
[docs] def isMaster(self):
"""
Returns true if the rank is 0
"""
return (self.rank==0)
[docs]def mpiBCast(value):
"""
Mpi bcasts the value and Returns the value from the master (rank = 0).
"""
getLogger().debug("Rank: %s, pid: %s MpiPool: bcast", MPI.COMM_WORLD.Get_rank(), os.getpid())
return MPI.COMM_WORLD.bcast(value)
[docs]def splitList(list, n):
"""
Splits the list into block of equal sizes (listlength/n)
:param list: a sequence of items
:param n: the number of blocks to create
:returns sequence: a list of blocks
"""
getLogger().debug("Rank: %s, pid: %s MpiPool: splitList", MPI.COMM_WORLD.Get_rank(), os.getpid())
blockLen = len(list) / float(n)
return [list[int(round(blockLen * i)): int(round(blockLen * (i + 1)))] for i in range(n)]
[docs]def mergeList(lists):
"""
Merges the lists into one single list
:param lists: a list of lists
:returns list: the merged list
"""
getLogger().debug("Rank: %s, pid: %s MpiPool: mergeList", MPI.COMM_WORLD.Get_rank(), os.getpid())
return list(itertools.chain(*lists))