# Copyright (c) 2020, Fabio Muratore, Honda Research Institute Europe GmbH, and
# Technical University of Darmstadt.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# 3. Neither the name of Fabio Muratore, Honda Research Institute Europe GmbH,
# or Technical University of Darmstadt, nor the names of its contributors may
# be used to endorse or promote products derived from this software without
# specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL FABIO MURATORE, HONDA RESEARCH INSTITUTE EUROPE GMBH,
# OR TECHNICAL UNIVERSITY OF DARMSTADT BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
# IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
from typing import Any, Callable, NoReturn, Optional
from pyrado.algorithms.stopping_criteria.stopping_criterion import StoppingCriterion
[docs]class AlwaysStopStoppingCriterion(StoppingCriterion):
"""Stopping criterion that is always met."""
def __repr__(self) -> str:
return "AlwaysStopStoppingCriterion"
def __str__(self) -> str:
return "True"
[docs] def is_met(self, algo) -> bool:
return True
[docs]class NeverStopStoppingCriterion(StoppingCriterion):
"""Stopping criterion that is never met."""
def __repr__(self) -> str:
return "NeverStopStoppingCriterion"
def __str__(self) -> str:
return "False"
[docs] def is_met(self, algo) -> bool:
return False
[docs]class ToggleableStoppingCriterion(StoppingCriterion):
"""Stopping criterion that can be turned on/off from the outside."""
def __init__(self, met: bool = False):
"""
Constructor.
:param met: initialization of the return value of `is_met`
"""
super().__init__()
self._met = met
def __repr__(self) -> str:
return f"ToggleableStoppingCriterion[met={self._met}]"
def __str__(self) -> str:
return str(self._met)
[docs] def is_met(self, algo=None) -> bool:
return self._met
[docs] def on(self) -> NoReturn:
self._met = True
[docs] def off(self) -> NoReturn:
self._met = False
[docs] def toggle(self) -> bool:
self._met = not self._met
return self._met
[docs]class CustomStoppingCriterion(StoppingCriterion):
"""Custom stopping criterion that takes an arbitrary callable to evaluate."""
def __init__(self, criterion_fn: Callable[[Any], bool], name: Optional[str] = None):
"""
Constructor.
:param criterion_fn: signature `[Algorithm] -> bool`; gets evaluated when `is_met` is called; allows for custom
functionality, e.g. if an algorithm requires special treatment; the given algorithm is the
same that was passed to the `is_met` method
:param name: name of the stopping criterion, used for `str(..)` and ´repr(..)`
"""
super().__init__()
self._criterion_fn = criterion_fn
self._name = name
def __repr__(self) -> str:
return f"CustomStoppingCriterion[_criterion_fn={repr(self._criterion_fn)}; name={self._name}]"
def __str__(self) -> str:
return "Custom" if self._name is None else self._name
[docs] def is_met(self, algo) -> bool:
return self._criterion_fn(algo)
[docs]class IterCountStoppingCriterion(StoppingCriterion):
"""Uses the iteration number as a stopping criterion, i.e. sets a maximum number of iterations."""
def __init__(self, max_iter: int):
"""
Constructor.
:param max_iter: maximum number of iterations
"""
super().__init__()
self._max_iter = max_iter
[docs] def is_met(self, algo) -> bool:
return algo.curr_iter >= self._max_iter
[docs]class SampleCountStoppingCriterion(StoppingCriterion):
"""Uses the sampler count as a stopping criterion, i.e. sets a maximum number samples."""
def __init__(self, max_sample_count: int):
"""
Constructor.
:param max_sample_count: maximum sample count
"""
super().__init__()
self._max_sample_count = max_sample_count
[docs] def is_met(self, algo) -> bool:
return algo.sample_count >= self._max_sample_count