Source code for openmdao.util.concurrent

import traceback

def concurrent_eval_lb(func, cases, comm, broadcast=False):
[docs] """ Runs a load balanced version of the given function, with the master rank (0) sending a new case to each worker rank as soon as it has finished its last case. Args ---- func : function The function to execute in workers. cases : collection of function args Entries are assumed to be of the form (args, kwargs) where kwargs are allowed to be None and args should be a list or tuple. com : MPI communicator or None The MPI communicator that is shared between the master and workers. If None, the function will be executed serially. broadcast : bool, optional If True, the results will be broadcast out to the worker procs so that the return value of concurrent_eval_lb will be the full result list in every process. """ if comm is not None: if comm.rank == 0: # master rank results = _concurrent_eval_lb_master(cases, comm) else: results = _concurrent_eval_lb_worker(func, comm) if broadcast: results = comm.bcast(results, root=0) else: # serial execution results = [] for args, kwargs in cases: try: if kwargs: retval = func(*args, **kwargs) else: retval = func(*args) except: err = traceback.format_exc() retval = None else: err = None results.append((retval, err)) return results def _concurrent_eval_lb_master(cases, comm):
""" This runs only on rank 0. It sends cases to all of the workers and collects their results. """ received = 0 sent = 0 results = [] case_iter = iter(cases) # seed the workers for i in range(1, comm.size): try: case = next(case_iter) except StopIteration: break comm.send(case, i, tag=1) sent += 1 # send the rest of the cases if sent > 0: while True: # wait for any worker to finish worker, retval, err = comm.recv(tag=2) received += 1 # store results results.append((retval, err)) # don't stop until we hear back from every worker process # we sent a case to if received == sent: break try: case = next(case_iter) except StopIteration: pass else: # send new case to the last worker that finished comm.send(case, worker, tag=1) sent += 1 # tell all workers to stop for rank in range(1, comm.size): comm.send((None, None), rank, tag=1) return results def _concurrent_eval_lb_worker(func, comm): while True: # wait on a case from the master args, kwargs = comm.recv(source=0, tag=1) if args is None: # we're done break try: if kwargs: retval = func(*args, **kwargs) else: retval = func(*args) except: err = traceback.format_exc() retval = None else: err = None # tell the master we're done with that case comm.send((comm.rank, retval, err), 0, tag=2)