Source code for openmdao.recorders.recording_manager

import itertools
import time
from openmdao.core.mpi_wrap import MPI

[docs]class RecordingManager(object): def __init__(self, *args, **kargs): super(RecordingManager, self).__init__(*args, **kargs) self._vars_to_record = { 'pnames' : set(), 'unames' : set(), 'rnames' : set(), } self._recorders = [] self.__has_serial_recorders = False
[docs] def append(self, recorder): self._recorders.append(recorder)
def __getitem__(self, index): return self._recorders[index] def __iter__(self): return iter(self._recorders) def _local_vars(self, root, vec, varnames): local_vars = [] for name in varnames: if root.comm.rank == root._owning_ranks[name]: local_vars.append((name, vec[name])) return local_vars def _gather_vars(self, root, local_vars): ''' Gathers and returns only variables listed in `varnames` from the vector `vec` ''' all_vars = root.comm.gather(local_vars, root=0) if root.comm.rank == 0: return dict(itertools.chain(*all_vars))
[docs] def startup(self, root): for recorder in self._recorders: recorder.startup(root) if not recorder._parallel: self.__has_serial_recorders = True pnames, unames, rnames = recorder._filtered[root.pathname] self._vars_to_record['pnames'].update(pnames) self._vars_to_record['unames'].update(unames) self._vars_to_record['rnames'].update(rnames)
[docs] def record(self, root, metadata): ''' Gathers variables for non-parallel case recorders and calls record for all recorders Args ---- metadata: `dict` Metadata for iteration coordinate ''' metadata['timestamp'] = time.time() params = root.params unknowns = root.unknowns resids = root.resids if MPI and self.__has_serial_recorders: pnames = self._vars_to_record['pnames'] unames = self._vars_to_record['unames'] rnames = self._vars_to_record['rnames'] params = self._gather_vars(root, self._local_vars(root, params, pnames)) unknowns = self._gather_vars(root, self._local_vars(root, unknowns, unames)) resids = self._gather_vars(root, self._local_vars(root, resids, rnames)) # If the recorder does not support parallel recording # we need to make sure we only record on rank 0. for recorder in self._recorders: if recorder._parallel or root.comm.rank == 0: recorder.record(params, unknowns, resids, metadata)