From 16867a0131970a7f7d4cd476b1a432bc4c601f1e Mon Sep 17 00:00:00 2001 From: david-istvan Date: Tue, 18 Jun 2024 23:59:58 +0200 Subject: [PATCH] added assertion to runner --- input/opinions-test.txt | 2 +- src/runner.py | 4 ++++ tests/runner_tests.py | 11 +++++++---- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/input/opinions-test.txt b/input/opinions-test.txt index e67bb1d..146f855 100644 --- a/input/opinions-test.txt +++ b/input/opinions-test.txt @@ -1,2 +1,2 @@ 3 -[1,1], -2 \ No newline at end of file +[1,1], 2 \ No newline at end of file diff --git a/src/runner.py b/src/runner.py index b5da9cd..5ee787e 100644 --- a/src/runner.py +++ b/src/runner.py @@ -171,8 +171,12 @@ def discrete_policy_grad(self, max_episodes, advice=None, is_random=False): logging.debug('Generating default policy') policy = self.get_default_policy(environment) if advice: + original_policy = policy logging.info(f'\t\t\t Shaping policy with human input at u={advice.u}') policy = self.shape_policy(policy, advice) + if advice.u==1.0: + assert np.array_equal(original_policy, policy) + print(np.array_equal(original_policy, policy)) #logging.debug('Initial policy:') #logging.debug(policy) diff --git a/tests/runner_tests.py b/tests/runner_tests.py index 1884a6d..e31c685 100644 --- a/tests/runner_tests.py +++ b/tests/runner_tests.py @@ -10,8 +10,10 @@ class RunnerTests(unittest.TestCase): def testShaping(self): #3x3 grid - default_policy = default_policy = np.full((9, 4), 0.25) - print(default_policy) + policy = np.full((9, 4), 0.25) + p1 = policy + + print(p1) #advice about the middle cell file = os.path.abspath(f'input/opinions-test.txt') @@ -19,15 +21,16 @@ def testShaping(self): human_input = opinion_parser.parse(file) #advice = Advice(human_input, 0) - advice = Advice(human_input, 1) + advice = Advice(human_input, 0.99) print(f'{advice} @u={advice.u}.') r = Runner(12, 63, 2, 10) - policy = r.shape_policy(default_policy, advice) + policy = r.shape_policy(policy, advice) print(policy) + print(np.array_equal(p1, policy)) if __name__ == "__main__": unittest.main()