Source code for pyrado.sampling.sampler_pool

# Copyright (c) 2020, Fabio Muratore, Honda Research Institute Europe GmbH, and
# Technical University of Darmstadt.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
# 3. Neither the name of Fabio Muratore, Honda Research Institute Europe GmbH,
#    or Technical University of Darmstadt, nor the names of its contributors may
#    be used to endorse or promote products derived from this software without
#    specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL FABIO MURATORE, HONDA RESEARCH INSTITUTE EUROPE GMBH,
# OR TECHNICAL UNIVERSITY OF DARMSTADT BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
# IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
import os
import traceback
import warnings
from copy import deepcopy
from enum import Enum, auto
from queue import Empty
from typing import List, Tuple

import torch.multiprocessing as mp
from tqdm import tqdm

import pyrado
from pyrado.sampling.step_sequence import StepSequence


[docs]class GlobalNamespace: """Type of the worker's global namespace""" pass
_CMD_STOP = "stop" _RES_SUCCESS = "success" _RES_ERROR = "error" _RES_FATAL = "fatal" ENABLE_SINGLE_WORKER_OPTIMIZATION = os.getenv("ENABLE_SINGLE_WORKER_OPTIMIZATION", None) is not None def _pool_worker(from_master, to_master): """ Two queues: from master and to master :param from_master: tuple (func, args) or the special _CMD_STOP :param to_master: tuple (success, obj), where obj is the result on success=True and an error message on success=False """ # Use a custom global namespace. This 'trick' G = GlobalNamespace() while True: # Get command from master, should be a tuple cmd = from_master.get() if cmd == _CMD_STOP: # Terminate, no need to send a message break # CMD is a tuple of (func, args) func, args, kwargs = cmd # Invoke func try: res = func(G, *args, **kwargs) except Exception: msg = traceback.format_exc() to_master.put((_RES_ERROR, msg)) except: # Generic exception: still report, but also terminate the worker msg = traceback.format_exc() to_master.put((_RES_FATAL, msg)) raise else: to_master.put((_RES_SUCCESS, res)) class _OpState(Enum): PENDING = auto() ERRORED = auto() DONE = auto() class _WorkerInfo: """ Internal class managing a single worker process in the sampler pool. """ def __init__(self, num): self._to_slave = mp.Queue() self._from_slave = mp.Queue() # Create the process self._process = mp.Process( target=_pool_worker, args=(self._to_slave, self._from_slave), name=f"Sampler-Worker-{num}", ) self._process.daemon = True # Start it self._process.start() # Track pending invocations self._pending = False # Stores early retrieved result self._result = None def invoke_start(self, func, *args, **kwargs): if not self._process.is_alive(): raise RuntimeError("Worker has terminated!") if self._pending: raise RuntimeError("There is still a pending call waiting for completion.") # Send to slave self._to_slave.put((func, args, kwargs)) self._pending = True def operation_state(self): """ Try to retrieve the result :return: False if there was an error result and the loop should be cancelled """ if not self._process.is_alive(): raise RuntimeError("Worker has terminated") if not self._pending: return _OpState.PENDING if self._result is None: # Try to pull try: self._result = self._from_slave.get(block=False) except Empty: # Nothing there yet return _OpState.PENDING # Check result if self._result[0] == _RES_SUCCESS: return _OpState.DONE else: return _OpState.ERRORED def invoke_wait(self): if not self._process.is_alive(): raise RuntimeError("Worker has terminated") if not self._pending: raise RuntimeError("There is no pending call.") if self._result is None: # Await result if not done yet res = self._from_slave.get() else: # Clear stored result res = self._result self._result = None self._pending = False # Interpret result stat, value = res if stat == _RES_SUCCESS: return value elif stat == _RES_ERROR: raise RuntimeError(f"Caught error in {self._process.name}:\n{value}") elif stat == _RES_FATAL: # Mark as failed. For now, there is no way to recover, we just get an error on the next start method. raise RuntimeError(f"Fatal error in {self._process.name}:\n{value}") raise pyrado.ValueErr(given=stat, eq_constraint="_RES_SUCCESS, _RES_ERROR, or _RES_FATAL") def stop(self): if self._pending: warnings.warn("There is still a pending call waiting for completion.", UserWarning) # Send stop signal self._to_slave.put(_CMD_STOP) # Wait a bit for the process to die self._process.join(1) # Check if stopped if self._process.is_alive(): # Send SIGTERM in case there was a problem with the graceful stop self._process.terminate() # Wait for the process to die from the SIGTERM self._process.join(1) # Check if stopped if self._process.is_alive(): # Sigterm didn't work, so break out the big guns with SIGKILL self._process.kill() # Wait for the process to die from the SIGKILL self._process.join(1) def _run_set_seed(G, seed): """Ignore global space, and forward to `pyrado.set_seed()`""" pyrado.set_seed(seed) def _run_collect( G, counter, run_counter, lock, n, min_runs, func, args, kwargs ) -> List[Tuple[int, List[StepSequence], int]]: """Worker function for `SamplerPool.run_collect()`""" result = [] while True: with lock: # Increment the run counter here to get the id for the next seed run_counter.value += 1 run_num = run_counter.value # Invoke once res, n_done = func(G, run_num, *args, **kwargs) # Add to result and record increment result.append((run_num, res, n_done)) with lock: # Increment done counter counter.value += n_done # Check if done if counter.value >= n and (min_runs is None or run_counter.value >= min_runs): break return result def _run_map(G, func, argqueue): """Worker function for `SamplerPool.run_map()`""" result = [] while True: try: index, arg = argqueue.get(block=False) except Empty: break result.append((index, func(G, arg))) return result
[docs]class SamplerPool: """ A process pool capable of executing operations in parallel. This differs from the multiprocessing.Pool class in that it explicitly incorporates process-local state. Every parallel function gets a GlobalNamespace object as first argument, which can hold arbitrary worker-local state. This allows for certain optimizations. For example, when the parallel operation requires an object that is expensive to transmit, we can create this object once in each process, store it in the namespace, and then use it in every map function call. This class also contains additional methods to call a function exactly once in each worker, to setup worker-local state. """ def __init__(self, num_threads: int): if not isinstance(num_threads, int): raise pyrado.TypeErr(given=num_threads, expected_type=int) if num_threads < 1: raise pyrado.ValueErr(given=num_threads, ge_constraint="1") self._num_threads = num_threads if not ENABLE_SINGLE_WORKER_OPTIMIZATION or num_threads > 1: # Create workers self._workers = [_WorkerInfo(i + 1) for i in range(self._num_threads)] self._manager = mp.Manager() self._G = GlobalNamespace()
[docs] def stop(self): """Terminate all workers.""" if not ENABLE_SINGLE_WORKER_OPTIMIZATION or self._num_threads > 1: for w in self._workers: w.stop()
def _start(self, func, *args, **kwargs): # Start invocation for w in self._workers: w.invoke_start(func, *args, **kwargs) def _await_result(self): return [w.invoke_wait() for w in self._workers] def _operation_in_progress(self): done = True for w in self._workers: s = w.operation_state() if s == _OpState.ERRORED: return False done = done and s == _OpState.DONE return not done
[docs] def invoke_all(self, func, *args, **kwargs): """ Invoke func on all workers using the same argument values. The return values are collected into a list. :param func: the first argument of func will be a worker-local namespace """ if ENABLE_SINGLE_WORKER_OPTIMIZATION and self._num_threads == 1: return [func(self._G, *args, **kwargs)] else: # Start invocation for w in self._workers: w.invoke_start(func, *args, **kwargs) # Await results return self._await_result()
[docs] def invoke_all_map(self, func, arglist): """ Invoke func(arg) on all workers using one argument from the list for each ordered worker. The length of the argument list must match the number of workers. The first argument of func will be a worker-local namespace. The return values are collected into a list. """ assert self._num_threads == len(arglist) if ENABLE_SINGLE_WORKER_OPTIMIZATION and self._num_threads == 1: return [func(self._G, arglist[0])] # Start invocation for w, arg in zip(self._workers, arglist): w.invoke_start(func, arg) # Await results return self._await_result()
[docs] def run_map(self, func, arglist: list, progressbar: tqdm = None): """ A parallel version of `[func(G, arg) for arg in arglist]`. There is no deterministic assignment of workers to arglist elements. Optionally runs with progress bar. :param func: mapper function, must be pickleable :param arglist: list of function args :param progressbar: optional progress bar from the `tqdm` library :return: list of results """ # Set max on progress bar if progressbar is not None: progressbar.total = len(arglist) # Single thread optimization if ENABLE_SINGLE_WORKER_OPTIMIZATION and self._num_threads == 1: res = [] for arg in arglist: res.append(func(self._G, deepcopy(arg))) # numpy arrays and others are passed by reference if progressbar: progressbar.update(1) return res # Put args into a parallel queue argqueue = self._manager.Queue(maxsize=len(arglist)) # Fill the queue, must be done fist to avoid race conditions # Add the original argument index to be able to restore it later for indexedArg in enumerate(arglist): argqueue.put(indexedArg) # Start workers self._start(_run_map, func, argqueue) # show progress bar if any if progressbar is not None: while self._operation_in_progress(): # Retrieve number of remaining jobs remaining = argqueue.qsize() if remaining == 0: break done = len(arglist) - remaining # Update progress (need to subtract since it's incremental) progressbar.update(done - progressbar.n) # Collect results in one list allres = self._await_result() result = [item for res in allres for item in res] # Sort results by index to ensure consistent order with args result.sort(key=lambda t: t[0]) return [item for _, item in result]
[docs] def run_collect(self, n, func, *args, collect_progressbar: tqdm = None, min_runs=None, **kwargs) -> tuple: """ Collect at least n samples from func, where the number of samples per run can vary. This is done by calling res, ns = func(G, *args, **kwargs) until the sum of ns exceeds n. This is intended for situations like reinforcement learning runs. If the environment ends up in an error state, you get less samples per run. To ensure a stable learning behaviour, you can specify the minimum amount of samples to collect before returning. Since the workers can only check the amount of samples before starting a run, you will likely get more samples than the minimum. No generated samples that are part of a rollout are dropped. However, if some rollouts where sampled that are "too much", those will be dropped to get seed- determinism across different number of workers. :param n: minimum number of samples to collect :param func: sampler function, must be pickleable :param args: remaining positional args are passed to the function :param collect_progressbar: `tdqm` progress bar to use; default None :param min_runs: optionally specify a minimum amount of runs to be executed before returning :param kwargs: remaining keyword args are passed to the function :return: list of results :return: total number of samples """ # Set total on progress bar if collect_progressbar is not None: collect_progressbar.total = n if ENABLE_SINGLE_WORKER_OPTIMIZATION and self._num_threads == 1: # Do locally result = [] counter = 0 while counter < n or (min_runs is not None and len(result) < min_runs): # Invoke once res, n_done = func(self._G, num=len(result) + 1, *args, **kwargs) # Add to result and record increment result.append(res) counter += n_done if collect_progressbar is not None: collect_progressbar.update(n_done) # return result and total return result, counter # Create counter + counter-lock as shared vars (counter has no own lock since there is an explicit one). counter = self._manager.Value("i", 0) run_counter = self._manager.Value("i", 0) lock = self._manager.RLock() # Start async computation self._start(_run_collect, counter, run_counter, lock, n, min_runs, func, args, kwargs) # Show progress bar if collect_progressbar is not None: while self._operation_in_progress(): # Retrieve current counter value with lock: cnt = counter.value if cnt >= n: break # Update progress (need to subtract since it's incremental) collect_progressbar.update(cnt - collect_progressbar.n) # Collect results in one list allres = self._await_result() result = [item for res in allres for item in res] # Sort results by index to ensure consistent order result.sort(key=lambda t: t[0]) result_filtered = [] n_total = 0 for i, (_, item, n_of_item) in enumerate(result): n_total += n_of_item result_filtered.append(item) if n_total >= n and (min_runs is None or i + 1 >= min_runs): break return result_filtered, counter.value
[docs] def set_seed(self, seed): """ Set a deterministic seed on all workers. .. note:: This is intended to only be used in **legacy** evaluation scripts! For new code and everything that should really be reproducible, pass the seed to the `sample()` method of a `ParallelRolloutSampler`. :param seed: seed value for the random number generators """ self.invoke_all_map(_run_set_seed, [seed + i for i in range(self._num_threads)])
def __reduce__(self): # We cannot really pickle this object since it has a lot of hidden state in the worker processes raise RuntimeError("The sampler pool is not serializable!") def __del__(self): # Stop the workers as soon as the pool is not referenced anymore self.stop()