Source code for openmdao.core.mpi_wrap
""" A bunch of MPI utilities."""
import os
import sys
import io
from contextlib import contextmanager
import traceback
import numpy
import six
from six import PY3
trace = os.environ.get('OPENMDAO_TRACE')
def _redirect_streams(to_fd):
"""
Redirect stdout/stderr to the given file descriptor.
Based on: http://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/.
"""
original_stdout_fd = sys.stdout.fileno()
original_stderr_fd = sys.stderr.fileno()
# Flush and close sys.stdout/err - also closes the file descriptors (fd)
sys.stdout.close()
sys.stderr.close()
# Make original_stdout_fd point to the same file as to_fd
os.dup2(to_fd, original_stdout_fd)
os.dup2(to_fd, original_stderr_fd)
# Create a new sys.stdout that points to the redirected fd
if PY3:
sys.stdout = io.TextIOWrapper(os.fdopen(original_stdout_fd, 'wb'))
sys.stderr = io.TextIOWrapper(os.fdopen(original_stdout_fd, 'wb'))
else:
sys.stdout = os.fdopen(original_stdout_fd, 'wb', 0) # 0 makes them unbuffered
sys.stderr = os.fdopen(original_stderr_fd, 'wb', 0)
[docs]def use_proc_files():
"""Calling this will cause stdout/stderr from each MPI process to be written
to a separate file in the current directory named <rank>.out.
"""
if MPI is not None:
rank = MPI.COMM_WORLD.rank
sname = "%s.out" % rank
ofile = open(sname, 'wb')
_redirect_streams(ofile.fileno())
[docs]def under_mpirun():
"""Return True if we're being executed under mpirun."""
# this is a bit of a hack, but there appears to be
# no consistent set of environment vars between MPI
# implementations.
for name in os.environ.keys():
if name == 'OMPI_COMM_WORLD_RANK' or \
name == 'MPIEXEC_HOSTNAME' or \
name.startswith('MPIR_') or \
name.startswith('MPICH_'):
return True
return False
if under_mpirun():
from mpi4py import MPI
def debug(*msg): # pragma: no cover
newmsg = ["%d: " % MPI.COMM_WORLD.rank] + list(msg)
for m in newmsg:
sys.stdout.write("%s " % m)
sys.stdout.write('\n')
sys.stdout.flush()
else:
MPI = None
[docs] def debug(*msg): # pragma: no cover
for m in msg:
sys.stdout.write("%s " % str(m))
sys.stdout.write('\n')
[docs]class FakeComm(object):
""" Who needs a real Comm when you have a fake one."""
def __init__(self):
self.rank = 0
self.size = 1
@contextmanager
[docs]def MultiProcFailCheck(comm):
""" Wrap this around code that you want to globally fail if it fails
on any MPI process in comm. If not running under MPI, don't
handle any exceptions.
"""
if MPI is None:
yield
else:
try:
yield
except:
fails = comm.allgather(traceback.format_exc())
else:
fails = comm.allgather('')
for i, f in enumerate(fails):
if f:
raise RuntimeError("a test failed in (at least) rank %d: traceback follows\n%s"
% (i, f))
[docs]def any_proc_is_true(comm, expr):
"""Returns True if expr is True in any proc in the given comm."""
any_true = numpy.array(0, dtype=int)
if trace:
debug("Allreduce for any_proc_is_true")
# some mpi versions don't support Allreduce with boolean types
# and logical operators, so just use ints and MPI.SUM instead.
comm.Allreduce(numpy.array(1 if expr else 0, dtype=int),
any_true, op=MPI.SUM)
if trace:
debug("Allreduce DONE")
return any_true > 0
if os.environ.get('USE_PROC_FILES'):
use_proc_files()