Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

import itertools 

import os 

from cosmoHammer import getLogger 

import time 

 

# If mpi4py is installed, import it. 

try: 

from mpi4py import MPI 

MPI = MPI 

except ImportError: 

MPI = None 

 

class MpiPool(object): 

""" 

Implementation of a mpi based pool. Currently it supports only the map function. 

 

:param mapFunction: the map function to apply on the mpi nodes 

 

""" 

def __init__(self, mapFunction): 

self.rank = MPI.COMM_WORLD.Get_rank() 

self.mapFunction = mapFunction 

 

def map(self, function, sequence): 

""" 

Emulates a pool map function using Mpi. 

Retrieves the number of mpi processes and splits the sequence of walker position  

in order to allow each process its block 

 

:param function: the function to apply on the items of the sequence 

:param sequence: a sequence of items 

 

:returns sequence: sequence of results 

""" 

 

(rank,size) = (MPI.COMM_WORLD.Get_rank(),MPI.COMM_WORLD.Get_size()) 

#sync 

sequence = mpiBCast(sequence) 

 

getLogger().debug("Rank: %s, pid: %s MpiPool: starts processing iteration" %(rank, os.getpid())) 

#split, process and merge the sequence 

mergedList = mergeList(MPI.COMM_WORLD.allgather( 

self.mapFunction(function, splitList(sequence,size)[rank]))) 

getLogger().debug("Rank: %s, pid: %s MpiPool: done processing iteration"%(rank, os.getpid())) 

# time.sleep(10) 

return mergedList 

 

def isMaster(self): 

""" 

Returns true if the rank is 0 

""" 

return (self.rank==0) 

 

def mpiBCast(value): 

""" 

Mpi bcasts the value and Returns the value from the master (rank = 0). 

""" 

getLogger().debug("Rank: %s, pid: %s MpiPool: bcast", MPI.COMM_WORLD.Get_rank(), os.getpid()) 

return MPI.COMM_WORLD.bcast(value) 

 

def splitList(list, n): 

""" 

Splits the list into block of equal sizes (listlength/n) 

 

:param list: a sequence of items 

:param n: the number of blocks to create 

 

:returns sequence: a list of blocks 

""" 

getLogger().debug("Rank: %s, pid: %s MpiPool: splitList", MPI.COMM_WORLD.Get_rank(), os.getpid()) 

blockLen = len(list) / float(n) 

return [list[int(round(blockLen * i)): int(round(blockLen * (i + 1)))] for i in range(n)] 

 

def mergeList(lists): 

""" 

Merges the lists into one single list 

 

:param lists: a list of lists 

 

:returns list: the merged list 

""" 

getLogger().debug("Rank: %s, pid: %s MpiPool: mergeList", MPI.COMM_WORLD.Get_rank(), os.getpid()) 

return list(itertools.chain(*lists))