Skip to content

Commit

Permalink
added assertion to runner
Browse files Browse the repository at this point in the history
  • Loading branch information
david-istvan committed Jun 18, 2024
1 parent d9da52f commit 16867a0
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
2 changes: 1 addition & 1 deletion input/opinions-test.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
3
[1,1], -2
[1,1], 2
4 changes: 4 additions & 0 deletions src/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,12 @@ def discrete_policy_grad(self, max_episodes, advice=None, is_random=False):
logging.debug('Generating default policy')
policy = self.get_default_policy(environment)
if advice:
original_policy = policy
logging.info(f'\t\t\t Shaping policy with human input at u={advice.u}')
policy = self.shape_policy(policy, advice)
if advice.u==1.0:
assert np.array_equal(original_policy, policy)
print(np.array_equal(original_policy, policy))

#logging.debug('Initial policy:')
#logging.debug(policy)
Expand Down
11 changes: 7 additions & 4 deletions tests/runner_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,27 @@ class RunnerTests(unittest.TestCase):
def testShaping(self):
#3x3 grid

default_policy = default_policy = np.full((9, 4), 0.25)
print(default_policy)
policy = np.full((9, 4), 0.25)
p1 = policy

print(p1)

#advice about the middle cell
file = os.path.abspath(f'input/opinions-test.txt')
opinion_parser = OpinionParser()

human_input = opinion_parser.parse(file)
#advice = Advice(human_input, 0)
advice = Advice(human_input, 1)
advice = Advice(human_input, 0.99)

print(f'{advice} @u={advice.u}.')

r = Runner(12, 63, 2, 10)
policy = r.shape_policy(default_policy, advice)
policy = r.shape_policy(policy, advice)

print(policy)

print(np.array_equal(p1, policy))

if __name__ == "__main__":
unittest.main()

0 comments on commit 16867a0

Please sign in to comment.