Source code for tests.algorithms.test_stopping_criteria

# Copyright (c) 2020, Fabio Muratore, Honda Research Institute Europe GmbH, and
# Technical University of Darmstadt.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
# 3. Neither the name of Fabio Muratore, Honda Research Institute Europe GmbH,
#    or Technical University of Darmstadt, nor the names of its contributors may
#    be used to endorse or promote products derived from this software without
#    specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL FABIO MURATORE, HONDA RESEARCH INSTITUTE EUROPE GMBH,
# OR TECHNICAL UNIVERSITY OF DARMSTADT BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
# IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
from copy import deepcopy
from types import SimpleNamespace
from typing import List, Optional

import numpy as np
import pytest

import pyrado
from pyrado.algorithms.stopping_criteria.predefined_criteria import (
    AlwaysStopStoppingCriterion,
    CustomStoppingCriterion,
    IterCountStoppingCriterion,
    NeverStopStoppingCriterion,
    SampleCountStoppingCriterion,
    ToggleableStoppingCriterion,
)
from pyrado.algorithms.stopping_criteria.rollout_based_criteria import (
    ConvergenceStoppingCriterion,
    MinReturnStoppingCriterion,
    ReturnStatistic,
    ReturnStatisticBasedStoppingCriterion,
)
from pyrado.algorithms.stopping_criteria.stopping_criterion import _AndStoppingCriterion, _OrStoppingCriterion
from pyrado.algorithms.utils import RolloutSavingWrapper
from pyrado.environments.base import Env
from pyrado.policies.base import Policy
from pyrado.sampling.sampler import SamplerBase
from pyrado.sampling.step_sequence import StepSequence


[docs]class MockSampler(SamplerBase): def __init__(self, step_sequences: Optional[List[StepSequence]] = None): super().__init__(min_rollouts=0, min_steps=0) self.step_sequences = [] if step_sequences is None else step_sequences
[docs] def reinit(self, env: Optional[Env] = None, policy: Optional[Policy] = None): pass
[docs] def sample(self) -> List[StepSequence]: return deepcopy(self.step_sequences)
[docs]class ExposingReturnStatisticBasedStoppingCriterion(ReturnStatisticBasedStoppingCriterion): def __init__(self, return_statistic: ReturnStatistic = ReturnStatistic.median, num_lookbacks: int = 1): super().__init__(return_statistic, num_lookbacks) self.return_statistic_value = np.nan def _is_met_with_return_statistic(self, algo, sampler: RolloutSavingWrapper, return_statistic: float) -> bool: self.return_statistic_value = return_statistic return False
# noinspection PyTypeChecker
[docs]def test_magic_function_implementation_and(): a = CustomStoppingCriterion(None, "A") b = CustomStoppingCriterion(None, "B") for criterion, expected_str in [ (a & a, "(A and A)"), (b & b, "(B and B)"), (a & b, "(A and B)"), (b & a, "(B and A)"), ]: assert isinstance(criterion, _AndStoppingCriterion) assert str(criterion) == expected_str
# noinspection PyTypeChecker
[docs]def test_magic_function_implementation_or(): a = CustomStoppingCriterion(None, "A") b = CustomStoppingCriterion(None, "B") for criterion, expected_str in [(a | a, "(A or A)"), (b | b, "(B or B)"), (a | b, "(A or B)"), (b | a, "(B or A)")]: assert isinstance(criterion, _OrStoppingCriterion) assert str(criterion) == expected_str
[docs]def test_criterion_combination_and(): a = AlwaysStopStoppingCriterion() b = NeverStopStoppingCriterion() a_and_a = _AndStoppingCriterion(a, a) b_and_b = _AndStoppingCriterion(b, b) a_and_b = _AndStoppingCriterion(a, b) b_and_a = _AndStoppingCriterion(b, a) assert str(a_and_a) == "(True and True)" assert str(b_and_b) == "(False and False)" assert str(a_and_b) == "(True and False)" assert str(b_and_a) == "(False and True)" assert a_and_a.is_met(None) assert not b_and_b.is_met(None) assert not a_and_b.is_met(None) assert not b_and_a.is_met(None)
[docs]def test_criterion_combination_or(): a = AlwaysStopStoppingCriterion() b = NeverStopStoppingCriterion() a_or_a = _OrStoppingCriterion(a, a) b_or_b = _OrStoppingCriterion(b, b) a_or_b = _OrStoppingCriterion(a, b) b_or_a = _OrStoppingCriterion(b, a) assert str(a_or_a) == "(True or True)" assert str(b_or_b) == "(False or False)" assert str(a_or_b) == "(True or False)" assert str(b_or_a) == "(False or True)" assert a_or_a.is_met(None) assert not b_or_b.is_met(None) assert a_or_b.is_met(None) assert b_or_a.is_met(None)
[docs]def test_criterion_always(): a = AlwaysStopStoppingCriterion() assert a.is_met(None)
[docs]def test_criterion_never(): a = NeverStopStoppingCriterion() assert not a.is_met(None)
[docs]def test_criterion_toggleable_init_default(): a = ToggleableStoppingCriterion() assert not a.is_met(None)
[docs]def test_criterion_toggleable_init_met(): a = ToggleableStoppingCriterion(met=True) assert a.is_met(None)
[docs]def test_criterion_toggleable_init_not_met(): a = ToggleableStoppingCriterion(met=False) assert not a.is_met(None)
[docs]def test_criterion_toggleable_set_on_off_init_met(): a = ToggleableStoppingCriterion(met=True) assert a.is_met(None) a.on() assert a.is_met(None) a.on() assert a.is_met(None) a.off() assert not a.is_met(None)
[docs]def test_criterion_toggleable_set_on_off_init_not_met(): a = ToggleableStoppingCriterion(met=False) assert not a.is_met(None) a.on() assert a.is_met(None) a.on() assert a.is_met(None) a.off() assert not a.is_met(None)
[docs]@pytest.mark.parametrize("is_met_expected", [True, False]) def test_criterion_custom(is_met_expected): # Assigning to a variable in a closure would redefine the scope, so rather use a list as a holding. was_called = [False] algo_expected = "ABC" def criterion_fn(algo): was_called[0] = True assert algo == algo_expected return is_met_expected criterion = CustomStoppingCriterion(criterion_fn, "Name") assert str(criterion) == "Name" assert criterion.is_met(algo_expected) == is_met_expected assert was_called[0]
[docs]def test_criterion_iter_count_lower(): algo = SimpleNamespace(curr_iter=1) criterion = IterCountStoppingCriterion(max_iter=2) assert not criterion.is_met(algo)
[docs]def test_criterion_iter_count_higher(): algo = SimpleNamespace(curr_iter=3) criterion = IterCountStoppingCriterion(max_iter=2) assert criterion.is_met(algo)
[docs]def test_criterion_iter_count_equal(): algo = SimpleNamespace(curr_iter=2) criterion = IterCountStoppingCriterion(max_iter=2) assert criterion.is_met(algo)
[docs]def test_criterion_sample_count_lower(): algo = SimpleNamespace(sample_count=1) criterion = SampleCountStoppingCriterion(max_sample_count=2) assert not criterion.is_met(algo)
[docs]def test_criterion_sample_count_higher(): algo = SimpleNamespace(sample_count=3) criterion = SampleCountStoppingCriterion(max_sample_count=2) assert criterion.is_met(algo)
[docs]def test_criterion_sample_count_equal(): algo = SimpleNamespace(sample_count=2) criterion = SampleCountStoppingCriterion(max_sample_count=2) assert criterion.is_met(algo)
# noinspection PyTypeChecker
[docs]def test_criterion_rollout_based_no_sampler(): algo = SimpleNamespace() criterion = MinReturnStoppingCriterion(return_threshold=None) with pytest.raises(pyrado.ValueErr): criterion.is_met(algo)
# noinspection PyTypeChecker
[docs]def test_criterion_rollout_based_wrong_sampler(): sampler = SimpleNamespace() algo = SimpleNamespace(sampler=sampler) criterion = MinReturnStoppingCriterion(return_threshold=None) with pytest.raises(pyrado.TypeErr): criterion.is_met(algo)
# noinspection PyTypeChecker
[docs]@pytest.mark.parametrize( ["statistic", "expected"], [ (ReturnStatistic.min, 1), (ReturnStatistic.max, 6), (ReturnStatistic.median, 2), (ReturnStatistic.mean, 3), (ReturnStatistic.variance, 14 / 3), ], ) def test_criterion_return_statistic_based_check_min(statistic, expected): rollout_a = SimpleNamespace(undiscounted_return=lambda: 6) rollout_b = SimpleNamespace(undiscounted_return=lambda: 2) rollout_c = SimpleNamespace(undiscounted_return=lambda: 1) sampler = RolloutSavingWrapper(MockSampler([rollout_a, rollout_b, rollout_c])) sampler.sample() algo = SimpleNamespace(sampler=sampler) criterion = ExposingReturnStatisticBasedStoppingCriterion(return_statistic=statistic) criterion.is_met(algo) assert np.isclose(criterion.return_statistic_value, expected)
# noinspection PyTypeChecker
[docs]def test_criterion_rollout_based_min_min_return_lower(): rollout_a = SimpleNamespace(undiscounted_return=lambda: 1) sampler = RolloutSavingWrapper(MockSampler([rollout_a])) sampler.sample() algo = SimpleNamespace(sampler=sampler) criterion = MinReturnStoppingCriterion(return_threshold=2) assert not criterion.is_met(algo)
# noinspection PyTypeChecker
[docs]def test_criterion_rollout_based_min_min_return_higher(): rollout_a = SimpleNamespace(undiscounted_return=lambda: 3) sampler = RolloutSavingWrapper(MockSampler([rollout_a])) sampler.sample() algo = SimpleNamespace(sampler=sampler) criterion = MinReturnStoppingCriterion(return_threshold=2) assert criterion.is_met(algo)
# noinspection PyTypeChecker
[docs]def test_criterion_rollout_based_min_min_return_equal(): rollout_a = SimpleNamespace(undiscounted_return=lambda: 2) sampler = RolloutSavingWrapper(MockSampler([rollout_a])) sampler.sample() algo = SimpleNamespace(sampler=sampler) criterion = MinReturnStoppingCriterion(return_threshold=2) assert criterion.is_met(algo)
[docs]@pytest.mark.parametrize( ["num_iter", "expected"], [(None, [1, 2, 3]), (1, [3]), (2, [2, 3]), (3, [1, 2, 3]), (4, None)] ) def test_criterion_rollout_based_convergence_subset(num_iter, expected): criterion = ConvergenceStoppingCriterion(num_iter=num_iter) criterion._return_statistic_history = [1, 2, 3] assert criterion._get_relevant_return_statistic_subset() == expected
# noinspection PyTypeChecker
[docs]def test_criterion_rollout_based_convergence_history_filling(): rollouts = [SimpleNamespace(undiscounted_return=(lambda k: lambda: k)(n)) for n in range(10)] mock_sampler = MockSampler() sampler = RolloutSavingWrapper(mock_sampler) algo = SimpleNamespace(sampler=sampler) criterion = ConvergenceStoppingCriterion() for i, rollout in enumerate(rollouts): mock_sampler.step_sequences = [rollout] sampler.sample() criterion.is_met(algo) assert criterion._return_statistic_history == np.arange(i + 1).tolist()
[docs]def test_criterion_rollout_based_convergence_regress_constant_zero(): criterion = ConvergenceStoppingCriterion() criterion._return_statistic_history = np.zeros(10).tolist() assert np.isclose(criterion._compute_convergence_probability(), 1.0)
[docs]def test_criterion_rollout_based_convergence_regress_constant_one(): criterion = ConvergenceStoppingCriterion() criterion._return_statistic_history = np.ones(10).tolist() assert np.isclose(criterion._compute_convergence_probability(), 1.0)
[docs]def test_criterion_rollout_based_convergence_regress_not_constant(): criterion = ConvergenceStoppingCriterion() criterion._return_statistic_history = np.arange(10).tolist() assert np.isclose(criterion._compute_convergence_probability(), 0.0)
[docs]def test_criterion_rollout_based_convergence_regress_random(): criterion = ConvergenceStoppingCriterion() criterion._return_statistic_history = np.random.default_rng(seed=5).normal(loc=0.0, scale=0.001, size=10000) assert criterion._compute_convergence_probability() > 0.9
# noinspection PyTypeChecker
[docs]def test_criterion_rollout_based_convergence_lower(): rollout = SimpleNamespace(undiscounted_return=lambda: 0) sampler = RolloutSavingWrapper(MockSampler([rollout])) sampler.sample() algo = SimpleNamespace(sampler=sampler) criterion = ConvergenceStoppingCriterion(convergence_probability_threshold=0.5) criterion._compute_convergence_probability = lambda: 0.0 assert not criterion.is_met(algo)
# noinspection PyTypeChecker
[docs]def test_criterion_rollout_based_convergence_higher(): rollout = SimpleNamespace(undiscounted_return=lambda: 0) sampler = RolloutSavingWrapper(MockSampler([rollout])) sampler.sample() algo = SimpleNamespace(sampler=sampler) criterion = ConvergenceStoppingCriterion(convergence_probability_threshold=0.5) criterion._compute_convergence_probability = lambda: 1.0 assert criterion.is_met(algo)
# noinspection PyTypeChecker
[docs]def test_criterion_rollout_based_convergence_equal(): rollout = SimpleNamespace(undiscounted_return=lambda: 0) sampler = RolloutSavingWrapper(MockSampler([rollout])) sampler.sample() algo = SimpleNamespace(sampler=sampler) criterion = ConvergenceStoppingCriterion(convergence_probability_threshold=0.5) criterion._compute_convergence_probability = lambda: 0.5 assert criterion.is_met(algo)
# noinspection PyTypeChecker
[docs]def test_criterion_rollout_based_convergence_none(): rollout = SimpleNamespace(undiscounted_return=lambda: 0) sampler = RolloutSavingWrapper(MockSampler([rollout])) sampler.sample() algo = SimpleNamespace(sampler=sampler) criterion = ConvergenceStoppingCriterion(convergence_probability_threshold=0.5) criterion._compute_convergence_probability = lambda: None assert not criterion.is_met(algo)