Source code for openmdao.core.mpi_wrap

""" A bunch of MPI utilities."""

import os
import sys
import io
from contextlib import contextmanager
import traceback

import numpy
import six
from six import PY3

trace = os.environ.get('OPENMDAO_TRACE')

def _redirect_streams(to_fd):
    """
    Redirect stdout/stderr to the given file descriptor.
    Based on: http://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/.
    """

    original_stdout_fd = sys.stdout.fileno()
    original_stderr_fd = sys.stderr.fileno()

    # Flush and close sys.stdout/err - also closes the file descriptors (fd)
    sys.stdout.close()
    sys.stderr.close()

    # Make original_stdout_fd point to the same file as to_fd
    os.dup2(to_fd, original_stdout_fd)
    os.dup2(to_fd, original_stderr_fd)

    # Create a new sys.stdout that points to the redirected fd
    if PY3:
        sys.stdout = io.TextIOWrapper(os.fdopen(original_stdout_fd, 'wb'))
        sys.stderr = io.TextIOWrapper(os.fdopen(original_stdout_fd, 'wb'))
    else:
        sys.stdout = os.fdopen(original_stdout_fd, 'wb', 0) # 0 makes them unbuffered
        sys.stderr = os.fdopen(original_stderr_fd, 'wb', 0)

[docs]def use_proc_files(): """Calling this will cause stdout/stderr from each MPI process to be written to a separate file in the current directory named <rank>.out. """ if MPI is not None: rank = MPI.COMM_WORLD.rank sname = "%s.out" % rank ofile = open(sname, 'wb') _redirect_streams(ofile.fileno())
[docs]def under_mpirun(): """Return True if we're being executed under mpirun.""" # this is a bit of a hack, but there appears to be # no consistent set of environment vars between MPI # implementations. for name in os.environ.keys(): if name == 'OMPI_COMM_WORLD_RANK' or \ name == 'MPIEXEC_HOSTNAME' or \ name.startswith('MPIR_') or \ name.startswith('MPICH_'): return True return False
if under_mpirun(): from mpi4py import MPI def debug(*msg): # pragma: no cover newmsg = ["%d: " % MPI.COMM_WORLD.rank] + list(msg) for m in newmsg: sys.stdout.write("%s " % m) sys.stdout.write('\n') sys.stdout.flush() else: MPI = None
[docs] def debug(*msg): # pragma: no cover for m in msg: sys.stdout.write("%s " % str(m)) sys.stdout.write('\n')
[docs]class FakeComm(object): """ Who needs a real Comm when you have a fake one.""" def __init__(self): self.rank = 0 self.size = 1
@contextmanager
[docs]def MultiProcFailCheck(comm): """ Wrap this around code that you want to globally fail if it fails on any MPI process in comm. If not running under MPI, don't handle any exceptions. """ if MPI is None: yield else: try: yield except: fails = comm.allgather(traceback.format_exc()) else: fails = comm.allgather('') for i, f in enumerate(fails): if f: raise RuntimeError("a test failed in (at least) rank %d: traceback follows\n%s" % (i, f))
[docs]def any_proc_is_true(comm, expr): """Returns True if expr is True in any proc in the given comm.""" any_true = numpy.array(0, dtype=int) if trace: debug("Allreduce for any_proc_is_true") # some mpi versions don't support Allreduce with boolean types # and logical operators, so just use ints and MPI.SUM instead. comm.Allreduce(numpy.array(1 if expr else 0, dtype=int), any_true, op=MPI.SUM) if trace: debug("Allreduce DONE") return any_true > 0
if os.environ.get('USE_PROC_FILES'): use_proc_files()