# Copyright (c) 2020, Fabio Muratore, Honda Research Institute Europe GmbH, and
# Technical University of Darmstadt.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# 3. Neither the name of Fabio Muratore, Honda Research Institute Europe GmbH,
# or Technical University of Darmstadt, nor the names of its contributors may
# be used to endorse or promote products derived from this software without
# specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL FABIO MURATORE, HONDA RESEARCH INSTITUTE EUROPE GMBH,
# OR TECHNICAL UNIVERSITY OF DARMSTADT BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
# IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
from random import randint
from typing import List, Mapping, Optional, Tuple, Union
import numpy as np
from init_args_serializer import Serializable
import pyrado
from pyrado.domain_randomization.domain_randomizer import DomainRandomizer
from pyrado.environment_wrappers.base import EnvWrapper
from pyrado.environment_wrappers.utils import all_envs, inner_env, remove_env, typed_env
from pyrado.environments.base import Env
from pyrado.environments.sim_base import SimEnv
from pyrado.utils.input_output import completion_context
[docs]class DomainRandWrapper(EnvWrapper, Serializable):
"""Base class for environment wrappers which call a `DomainRandomizer` to randomize the domain parameters"""
def __init__(self, wrapped_env: Union[SimEnv, EnvWrapper], randomizer: Optional[DomainRandomizer]):
"""
Constructor
:param wrapped_env: environment to wrap
:param randomizer: `DomainRandomizer` object holding the probability distribution of all randomizable
domain parameters, pass `None` if you want to subclass wrapping another `DomainRandWrapper`
and use its randomizer
"""
if not isinstance(inner_env(wrapped_env), SimEnv):
raise pyrado.TypeErr(given=wrapped_env, expected_type=SimEnv)
if not isinstance(randomizer, DomainRandomizer) and randomizer is not None:
raise pyrado.TypeErr(given=randomizer, expected_type=DomainRandomizer)
Serializable._init(self, locals())
# Invoke EnvWrapper's constructor
super().__init__(wrapped_env)
self._randomizer = randomizer
@property
def randomizer(self) -> DomainRandomizer:
return self._randomizer
@randomizer.setter
def randomizer(self, randomizer: DomainRandomizer):
if not isinstance(randomizer, DomainRandomizer):
raise pyrado.TypeErr(given=randomizer, expected_type=DomainRandomizer)
self._randomizer = randomizer
[docs]class MetaDomainRandWrapper(DomainRandWrapper, Serializable):
"""
Domain randomization wrapper which wraps another `DomainRandWrapper` to adapt its parameters,
called domain distribution parameters.
"""
def __init__(self, wrapped_rand_env: DomainRandWrapper, dp_mapping: Mapping[int, Tuple[str, str]]):
"""
Constructor
:param wrapped_rand_env: randomized environment to wrap
:param dp_mapping: mapping from index of the numpy array (coming from the algorithm) to domain parameter name
(e.g. mass, length) and the domain distribution parameter (e.g. mean, std)
.. code-block:: python
# For the mapping arg use the this dict constructor
```
m = {0: ('name1', 'parameter_type1'), 1: ('name2', 'parameter_type2')}
```
"""
if not typed_env(wrapped_rand_env, DomainRandWrapper):
raise pyrado.TypeErr(given=wrapped_rand_env, expected_type=DomainRandWrapper)
Serializable._init(self, locals())
# Invoke the DomainRandWrapper's constructor
super().__init__(wrapped_rand_env, None)
self.dp_mapping = dp_mapping
@property
def randomizer(self) -> DomainRandomizer:
# Forward to the wrapped DomainRandWrapper
return self._wrapped_env.randomizer
@randomizer.setter
def randomizer(self, dr: DomainRandomizer):
# Forward to the wrapped DomainRandWrapper
self._wrapped_env.randomizer = dr
[docs] def adapt_randomizer(self, domain_distr_param_values: np.ndarray):
# Check the input dimension and reshape if necessary
if domain_distr_param_values.ndim == 1:
pass
elif domain_distr_param_values.ndim == 2:
domain_distr_param_values = domain_distr_param_values.ravel()
else:
raise pyrado.ShapeErr(given=domain_distr_param_values, expected_match=(1,))
# Reconfigure the wrapped environment's DomainRandomizer
for i, value in enumerate(domain_distr_param_values):
dp_name, ddp_name = self.dp_mapping.get(i)
self._wrapped_env.randomizer.adapt_one_distr_param(dp_name, ddp_name, value)
[docs]class DomainRandWrapperLive(DomainRandWrapper, Serializable):
"""
Domain randomization wrapper which randomized the wrapped env at every reset.
Thus every rollout is done with different domain parameters.
"""
[docs] def reset(self, init_state: np.ndarray = None, domain_param: dict = None) -> np.ndarray:
if domain_param is None:
# No explicit specification of domain parameters, so randomizer is called to draw a parameter dict
self._randomizer.randomize(num_samples=1)
domain_param = self._randomizer.get_params(fmt="dict", dtype="numpy")
# Forward to EnvWrapper, which delegates to self._wrapped_env
return super().reset(init_state=init_state, domain_param=domain_param)
[docs]class DomainRandWrapperBuffer(DomainRandWrapper, Serializable):
"""
Domain randomization wrapper which randomized the wrapped env using a buffer of domain parameter sets.
At every call of the reset method this wrapper cycles through that buffer.
"""
def __init__(self, wrapped_env, randomizer: Optional[DomainRandomizer], selection: Optional[str] = "cyclic"):
"""
Constructor
:param wrapped_env: environment to wrap around
:param randomizer: `DomainRandomizer` object that manages the randomization. If `None`, the user has to set the
buffer manually, the circular reset however works the same way
:param selection: method to draw samples from the buffer, either cyclic or random
"""
if selection not in ["cyclic", "random"]:
raise pyrado.ValueErr(given=selection, eq_constraint="cyclic or random")
Serializable._init(self, locals())
# Invoke the DomainRandWrapper's constructor
super().__init__(wrapped_env, randomizer)
self._ring_idx = None
self._buffer = None
self.selection = selection
@property
def ring_idx(self) -> int:
"""Get the buffer's index."""
return self._ring_idx
@ring_idx.setter
def ring_idx(self, idx: int):
"""Set the buffer's index."""
if not (isinstance(idx, int) or not 0 <= idx < len(self._buffer)):
raise pyrado.ValueErr(given=idx, ge_constraint="0 (int)", l_constraint=len(self._buffer))
self._ring_idx = idx
@property
def selection(self) -> str:
"""Get the selection method."""
return self._selection
@selection.setter
def selection(self, selection: str):
"""Set the selection method."""
if selection not in ["cyclic", "random"]:
raise pyrado.ValueErr(given=selection, eq_constraint="cyclic or random")
self._selection = selection
[docs] def fill_buffer(self, num_domains: int):
"""
Fill the internal buffer with domains.
:param num_domains: number of randomized domain parameter sets to store in the buffer
"""
if self._randomizer is None:
raise pyrado.TypeErr(msg="The randomizer must not be None to call fill_buffer()!")
if not isinstance(num_domains, int) or num_domains < 0:
raise pyrado.ValueErr(given=num_domains, g_constraint="0 (int)")
self._randomizer.randomize(num_domains)
self._buffer = self._randomizer.get_params(-1, fmt="list", dtype="numpy")
self._ring_idx = 0
@property
def buffer(self):
"""Get the domain parameter buffer."""
return self._buffer
@buffer.setter
def buffer(self, buffer: Union[List[dict], dict]):
"""
Set the domain parameter buffer.
Depends on the way the buffer has been saved, see the `DomainRandomizer.get_params()` arguments.
:param buffer: list of dicts, each describing a domain ,or just one dict for one domain
"""
if not (isinstance(buffer, list) or isinstance(buffer, dict)):
raise pyrado.TypeErr(given=buffer, expected_type=[list, dict])
self._buffer = buffer
[docs] def reset(self, init_state: np.ndarray = None, domain_param: dict = None) -> np.ndarray:
if domain_param is None:
# No explicit specification of domain parameters, so randomizer is requested
if isinstance(self._buffer, dict):
# The buffer consists of one domain parameter set
domain_param = self._buffer
elif isinstance(self._buffer, list):
# The buffer consists of a list of domain parameter sets
domain_param = self._buffer[self._ring_idx] # first selection will be index 0
if self._selection == "cyclic":
self._ring_idx = (self._ring_idx + 1) % len(self._buffer)
elif self._selection == "random":
self._ring_idx = randint(0, len(self._buffer) - 1)
else:
raise pyrado.TypeErr(given=self._buffer, expected_type=[dict, list])
# Forward to EnvWrapper, which delegates to self._wrapped_env
return super().reset(init_state=init_state, domain_param=domain_param)
def _get_state(self, state_dict: dict):
super()._get_state(state_dict)
state_dict["buffer"] = self._buffer
state_dict["ring_idx"] = self._ring_idx
def _set_state(self, state_dict: dict, copying: bool = False):
super()._set_state(state_dict, copying)
self._buffer = state_dict["buffer"]
self._ring_idx = state_dict["ring_idx"]
[docs]def remove_all_dr_wrappers(env: Env, verbose: bool = False):
"""
Go through the environment chain and remove all wrappers of type `DomainRandWrapper` (and subclasses).
:param env: env chain with domain randomization wrappers
:param verbose: choose if status messages should be printed
:return: env chain without domain randomization wrappers
"""
while any(isinstance(subenv, DomainRandWrapper) for subenv in all_envs(env)):
if verbose:
with completion_context(
f"Found domain randomization wrapper of type {type(env).__name__}. Removing it now",
color="y",
bright=True,
):
env = remove_env(env, DomainRandWrapper)
else:
env = remove_env(env, DomainRandWrapper)
return env