Skip to content

Commit

Permalink
fixed refactoring bug
Browse files Browse the repository at this point in the history
  • Loading branch information
hill-a committed Jul 4, 2018
1 parent 950c2df commit 0b9d89c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ install:

script:
- flake8 --select=F baselines/common
- docker run baselines-test sh -c 'pytest --cov-report term --cov-report xml --cov=. && python-codacy-coverage -r coverage.xml --token=$CODACY_PROJECT_TOKEN'
- docker run baselines-test --env CODACY_PROJECT_TOKEN=$CODACY_PROJECT_TOKEN sh -c 'pytest --cov-report term --cov-report xml --cov=. && python-codacy-coverage -r coverage.xml --token=$CODACY_PROJECT_TOKEN'
12 changes: 6 additions & 6 deletions baselines/a2c/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,12 @@ def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, nlstm=256, reuse=Fa
self.initial_state = None
self.vf = vf

def step(self, obs, state, mask):
def step(self, obs, *args, **kwargs):
a, v, neglogp = self.sess.run([self.a0, self.vf, self.neglogp0], {self.obs_ph: obs})
return a, v, self.initial_state, neglogp

def value(self, obs, state, mask):
return self.sess.run(self.vf, {self.obs_ph: obs, self.states_ph: state, self.masks_ph: mask})
def value(self, obs, *args, **kwargs):
return self.sess.run(self.vf, {self.obs_ph: obs})


class MlpPolicy(A2CPolicy):
Expand All @@ -188,10 +188,10 @@ def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, nlstm=256, reuse=Fa
self.initial_state = None
self.vf = vf

def step(self, obs, state, mask):
def step(self, obs, *args, **kwargs):
a, v, neglogp = self.sess.run([self.a0, self.vf, self.neglogp0], {self.obs_ph: obs})
return a, v, self.initial_state, neglogp

def value(self, obs, state, mask):
return self.sess.run(self.vf, {self.obs_ph: obs, self.states_ph: state, self.masks_ph: mask})
def value(self, obs, *args, **kwargs):
return self.sess.run(self.vf, {self.obs_ph: obs})

0 comments on commit 0b9d89c

Please sign in to comment.