diff --git a/scripts/explore_parameters.py b/scripts/explore_parameters.py index 6e34ab7..34afcf2 100755 --- a/scripts/explore_parameters.py +++ b/scripts/explore_parameters.py @@ -126,7 +126,9 @@ def step(): global gridworld, parameter_values, env, agent, running, stepping, terminated, t, state, total, aleph, aleph0, delta, initialMu0, initialMu20, visited_state_alephs, visited_action_alephs print() env._fps = values['speed_slider'] - action, aleph4action = agent.localPolicy(state, aleph).sample()[0] + # action, aleph4action = agent.localPolicy(state, aleph).sample()[0] + action = agent.act() + aleph4action = agent.last_aleph4action visited_state_alephs.add((state, aleph)) visited_action_alephs.add((state, action, aleph4action)) if values['lossCoeff4WassersteinTerminalState'] != 0: @@ -144,8 +146,10 @@ def step(): if parameter_values['verbose'] or parameter_values['debug']: print("t:", t, ", last delta:" ,delta, ", total:", total, ", s:", state, ", aleph4s:", aleph, ", a:", action, ", aleph4a:", aleph4action) nextState, delta, terminated, _, info = env.step(action) - total += delta - aleph = agent.propagateAspiration(state, action, aleph4action, delta, nextState) + agent.observe(nextState, delta, terminated) + total = agent.total # total += delta + aleph = agent.last_aleph4state # agent.propagateAspiration(state, action, aleph4action, delta, nextState) + state = nextState if terminated: print("t:",t, ", last delta:",delta, ", final total:", total, ", final s:", state, ", aleph4s:", aleph) @@ -183,6 +187,7 @@ def step(): t += 1 if stepping: stepping = False + def reset_env(start=False): # TODO: only regenerate env if different from before! global gridworld, parameter_values, env, agent, running, stepping, terminated, t, state, total, aleph, aleph0, delta, initialMu0, initialMu20, visited_state_alephs, visited_action_alephs @@ -191,6 +196,7 @@ def reset_env(start=False): if gridworld != old_gridworld: env, aleph0 = make_simple_gridworld(gw=gridworld, render_mode="human", fps=values['speed_slider']) # env = env.get_prolonged_version(5) + agent = AgentMDPPlanning(world=env) if values['override_aleph_checkbox']: aleph = (values['aleph0_low'], values['aleph0_high']) else: @@ -201,6 +207,7 @@ def reset_env(start=False): if parameter_values['lossTemperature'] == 0: parameter_values['lossTemperature'] = 1e-6 parameter_values.update({ + 'initialAspiration': aleph, 'verbose': values['verbose_checkbox'], 'debug': values['debug_checkbox'], 'allowNegativeCoeffs': True, @@ -209,9 +216,10 @@ def reset_env(start=False): 'wassersteinFromInitial': values['wasserstein_checkbox'], }) print("\n\nRESTART gridworld", gridworld, parameter_values) + agent.reset(parameter_values) state, info = env.reset() + agent.observe(state) print("Initial state:", env.state_embedding(state), ", initial aleph:", aleph) - agent = AgentMDPPlanning(parameter_values, world=env) # agent.localPolicy(state, aleph) # call it once to precompute tables and save time for later initialMu0 = list(agent.ETerminalState_state(state, aleph, "default")) initialMu20 = list(agent.ETerminalState2_state(state, aleph, "default")) diff --git a/src/satisfia/agents/makeMDPAgentSatisfia.py b/src/satisfia/agents/makeMDPAgentSatisfia.py index 1e2a90e..f1a0588 100755 --- a/src/satisfia/agents/makeMDPAgentSatisfia.py +++ b/src/satisfia/agents/makeMDPAgentSatisfia.py @@ -23,7 +23,13 @@ class AspirationAgent(ABC): reachable_states = None default_transition = None + + ### Methods for initialization, resetting, clearing caches: + def __init__(self, params): + if params: self.reset(params) + + def reset(self, params): """ If world is provided, maxAdmissibleQ, minAdmissibleQ, Q, Q2, ..., Q6 are not needed because they are computed from the world. Otherwise, these functions must be provided, e.g. as learned using some reinforcement learning algorithm. Their signature is - maxAdmissibleQ|minAdmissibleQ: (state, action) -> float @@ -41,7 +47,9 @@ def __init__(self, params): if lossCoeff4StateDistance > 0, referenceState must be provided """ + defaults = { + "initialAspiration": None, # admissibility parameters: "maxLambda": 1, # upper bound on local relative aspiration in each step (must be minLambda...1) # TODO: rename to lambdaHi "minLambda": 0, # lower bound on local relative aspiration in each step (must be 0...maxLambda) # TODO: rename to lambdaLo @@ -103,6 +111,15 @@ def __init__(self, params): self.params.update(params) # TODO do I need to add params_.options + self.clear_caches() + self.last_state = None + self.last_aleph4state = params["initialAspiration"] + self.last_action = None + self.last_aleph4action = None + self.last_delta = None + self.terminated = False + self.total = None + self.stateActionPairsSet = set() assert self.params["lossTemperature"] > 0, "lossTemperature must be > 0" @@ -216,9 +233,23 @@ def deltaVar(s, a, al4s, al4a, p): → aspiration4state → simulate (RECURSION)""" + def clear_caches(self): + """Clear all function caches (called by reset())""" + # loop through all parent classes: + for cls in self.__class__.__mro__: + # loop through all attributes of the class: + for key, value in cls.__dict__.items(): + # check if the attribute is a cached function: + if callable(value): + if hasattr(value, "cache_clear"): + value.cache_clear() + def __getitem__(self, name): return self.params[name] + + ### Methods for computing feasibility sets / reference simplices: + @cache def maxAdmissibleV(self, state): # recursive if self.verbose or self.debug: @@ -259,6 +290,9 @@ def admissibility4state(self, state): def admissibility4action(self, state, action): return self.minAdmissibleQ(state, action), self.maxAdmissibleQ(state, action) + + # Methods for computing aspirations: + # When in state, we can get any expected total in the interval # [minAdmissibleV(state), maxAdmissibleV(state)]. # So when having aspiration aleph, we can still fulfill it in expectation if it lies in the interval. @@ -324,6 +358,9 @@ def aspiration4action(self, state, action, aleph4state): print(pad(state),"| | ╰ aspiration4action, state",prettyState(state),"action",action,"aleph4state",aleph4state,":",res,"(steadfast)") return res + + ### Methods for computing loss components independent of actual policy: + @cache def disorderingPotential_state(self, state): # recursive if self.debug or self.verbose: @@ -375,6 +412,9 @@ def X(other_state): print(pad(state),"| | | | ╰ agency_state", prettyState(state), ":", res) return res + + # Methods for computing the policy, propagating aspirations, acting, and observing: + # Based on the admissibility information computed above, we can now construct the policy, # which is a mapping taking a state and an aspiration interval as input and returning # a categorical distribution over (action, aleph4action) pairs. @@ -538,6 +578,39 @@ def propagateAspiration(self, state, action, aleph4action, Edel, nextState): including Edel in the formula. """ + def observe(self, state, delta=None, terminated=False): + """Called after env.reset() or env.step()""" + self.last_delta = delta + if delta is not None: + if self.total is None: + self.total = delta + else: + self.total += delta + self.terminated = terminated + if not terminated: + if self.last_state is not None: + # propagate the aspiration: + self.last_aleph4state = self.propagateAspiration(self.last_state, self.last_action, self.last_aleph4action, delta, state) + # otherwise it was set in reset() + self.last_state = state + if self.verbose or self.debug: + print("observed state", prettyState(state), ", delta", delta, " (terminated", terminated, "); resulting total", self.total, ", aleph4state", self.last_aleph4state) + + def act(self): + """Choose an action based on current state and aspiration""" + assert not self.terminated, "cannot act after termination" + state, aleph4state = self.last_state, self.last_aleph4state + assert state is not None, "cannot act without having observed a state" + action, aleph4action = self.localPolicy(state, aleph4state).sample()[0] + # TODO later: potentially consult with the principal and change the aleph4state action and/or and/or aleph4action + self.last_action, self.last_aleph4action = action, aleph4action + if self.verbose or self.debug: + print("acting in state", prettyState(state), ", choosing action", action, ", aleph4action", aleph4action) + return action + + + ### Methods for computing loss components dependent on actual policy: + @cache def V(self, state, aleph4state): # recursive if self.debug: @@ -858,6 +931,9 @@ def X(actionAndAleph): def randomTieBreaker(self, state, action): return random.random() + + ### Methods for computing overall safety loss: + # now we can combine all of the above quantities to a combined (safety) loss function: # state, action, aleph4state, aleph4action, estActionProbability @@ -882,6 +958,9 @@ def getData(self): # FIXME: still needed? "locs": [state.loc for state in states], } + + ### Abstract methods that need to be implemented by subclasses: + @abstractmethod def maxAdmissibleQ(self, state, action): pass @abstractmethod @@ -979,7 +1058,7 @@ def __init__(self, params, maxAdmissibleQ=None, minAdmissibleQ=None, self.possible_actions = possible_actions class AgentMDPPlanning(AspirationAgent): - def __init__(self, params, world=None): + def __init__(self, params=None, world=None): self.world = world super().__init__(params)