Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Tobi Dance Branch #48

Draft
wants to merge 116 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
116 commits
Select commit Hold shift + click to select a range
cc12775
Starting template for integrating vector of target positions
marcinpaluch1994 Nov 24, 2022
c383c1c
added parameter descriptions
tobidelbruck Nov 25, 2022
50d71a5
REF TO Control_Toolkit
marcinpaluch1994 Nov 25, 2022
8fe670f
added parameter descriptions
tobidelbruck Nov 25, 2022
824ae35
Merge remote-tracking branch 'origin/Tobi_Dance' into Tobi_Dance
tobidelbruck Nov 25, 2022
89e2528
set hanging mode True to see whole state of pole.
tobidelbruck Nov 25, 2022
2dad5ae
fixed README header and add link to global compile file
tobidelbruck Nov 25, 2022
63590e1
more notes added
tobidelbruck Nov 26, 2022
d441109
document what JIT means
tobidelbruck Nov 26, 2022
a326093
added some docstrings
tobidelbruck Nov 26, 2022
cc33e60
added in utility to reload config if it has been modified on disk sin…
tobidelbruck Nov 26, 2022
3cb32d2
fixed logging function to be colored cyan for info and to generate py…
tobidelbruck Nov 27, 2022
4954af2
to fix tensorflow JIT compile case insentivity problem, renamed cart …
tobidelbruck Nov 27, 2022
7a754f3
generate pycharm link in log output.
tobidelbruck Nov 27, 2022
4d36e70
improved load_or_reload_config_if_modified to return dict of changed …
tobidelbruck Nov 28, 2022
747628c
added dict_differ library
tobidelbruck Nov 28, 2022
ff46345
resize window smaller so we can see IDE.
tobidelbruck Nov 28, 2022
b3326f3
fix typing
tobidelbruck Nov 28, 2022
38c45b0
added section for cartpole_trajectory_cost.py
tobidelbruck Nov 28, 2022
c08c2fa
rename config_cost_function.yml to config_cost_functions.yml for cons…
tobidelbruck Nov 28, 2022
1767461
finally the dynamically modifiable control cost parameters are workin…
tobidelbruck Dec 11, 2022
a94082c
now spin and balance both work! and so does changing the policy and …
tobidelbruck Dec 12, 2022
39fcecd
got basic shimmy movement to work now. added helper vars to access co…
tobidelbruck Dec 12, 2022
61537d7
added cartonly trajectory and fixed bug that erased the target positi…
tobidelbruck Dec 12, 2022
bf83557
added check for incorrect float as int in config file if existing att…
tobidelbruck Dec 12, 2022
2ffa4ed
passing current state to cartpole_trajectory_generator.py so it can e…
tobidelbruck Dec 12, 2022
8336079
added MPPI papers to docstring
tobidelbruck Dec 13, 2022
f5c733e
added MPPI papers to docstring
tobidelbruck Dec 13, 2022
1fdbe41
Rename tensorflow compilation flags
frehe Dec 15, 2022
9958310
Update control toolkit
frehe Dec 15, 2022
6352890
Finish renaming of num_rollouts -> batch_size
frehe Dec 15, 2022
9ff1277
fix setting of correct type of variable, fix comments and docstrings
tobidelbruck Dec 18, 2022
4e6e17d
add pypref to store cartpole GUI preferences
tobidelbruck Dec 18, 2022
544e4c3
add preferences for CSV file
tobidelbruck Dec 18, 2022
9965f9b
local changes, all minor except for trajectory cost that is in flux
tobidelbruck Dec 18, 2022
2c59279
Merge remote-tracking branch 'origin/Tobi_Dance' into Tobi_Dance
tobidelbruck Dec 18, 2022
0972952
renamed s to state for clariy in many of the classes.
Dec 24, 2022
5ef8c49
added dancer that reads CSV file to specify sequence of 'steps' (beha…
Dec 26, 2022
f7d016f
added the fist demo dance sequence file
Dec 26, 2022
284cfa3
got shimmy linear ramp from one freq/amp to another one working
Dec 28, 2022
4f8c691
added beep to signal change of step.
Dec 30, 2022
3cb2509
got cart bounce working for MPPI rollouts and added modeling of absor…
Jan 5, 2023
95ecbf6
add docstring for set_status_text method
Jan 5, 2023
00277fd
added comments about edge_bounce
Jan 5, 2023
bed7eee
docstrings
Jan 5, 2023
e02eb42
doctrings
Jan 5, 2023
06ef6ce
added policy number to dance step policy so that we can branch on it …
Jan 16, 2023
75bee6b
finally got spinning to work robustly by crafting a spin cost functio…
tobidelbruck Jan 27, 2023
ae2fbea
added vlc music player, start on start of dance, stop not implemented…
tobidelbruck Jan 27, 2023
3dd3fe4
added signal to CartPoleMainWindow that emits signal when simulation …
tobidelbruck Jan 28, 2023
acf3231
fixed logic of starting dance. changed tensorflow to fix version 2.10…
Jan 29, 2023
12a6e66
added config values for track edge barrier and track barrier length
Jan 29, 2023
4e64f1e
changed dance to absolute song time to make it easier to synchronize.
tobidelbruck Jan 29, 2023
4d2b903
moved the cartpole_dancer.py song and csv to config_cost_functions.yml.
tobidelbruck Jan 30, 2023
86fd933
moved the cartpole_dancer.py song and csv to config_cost_functions.yml.
tobidelbruck Jan 31, 2023
784ef6e
Merge remote-tracking branch 'origin/master' into Tobi_Dance
tobidelbruck Jan 31, 2023
f7974ff
merged from master, commented out dynamic song rate changes since it …
tobidelbruck Jan 31, 2023
72b244b
Merge remote-tracking branch 'origin/main' into Tobi_Dance
tobidelbruck Jan 31, 2023
fa170c1
only add signal if GUI exists
tobidelbruck Feb 1, 2023
34d778b
fixed logging at start of GPU. Fixed print_help().
tobidelbruck Feb 2, 2023
d3e38b6
typos, comments, and Merge remote-tracking branch 'origin/main' into …
tobidelbruck Feb 2, 2023
6b6d66e
move preferences to own file in others prefs.py
tobidelbruck Feb 2, 2023
427e618
Merge branch 'master' into Tobi_Dance
tobidelbruck Feb 2, 2023
f972fb3
merged from Tobi_Dance and added some loggers
tobidelbruck Feb 3, 2023
e66ac87
moved get_logger to own file in SI_Toolkit
tobidelbruck Feb 6, 2023
3bfe82a
moved get_logger to own file in SI_Toolkit
tobidelbruck Feb 6, 2023
fee814c
added search path for running from physical-cartpole.
tobidelbruck Feb 7, 2023
fb49139
update path to config_cost_functions.yml
tobidelbruck Feb 7, 2023
13ea6dc
move get_logger.py to Control_Toolkit so that it can be used by physi…
tobidelbruck Feb 8, 2023
9616633
cartpole_dancer.py starts to work. Music starts and stops, some steps…
tobidelbruck Feb 10, 2023
06f1947
improved control slightly by adding back more cost terms to provide s…
tobidelbruck Feb 11, 2023
872ec3c
added some docstrings, but they are not very informative
tobidelbruck Feb 12, 2023
0770b98
added primitive ability to record the predictor_ODE_tf.py predictions…
tobidelbruck Feb 13, 2023
6b1f3f0
added prediction and target trajectory to logging to allow model mism…
tobidelbruck Feb 13, 2023
ead9c79
updated config_cost_functions.yml to match physical-cartpole better, …
tobidelbruck Feb 14, 2023
42aebb8
added NaturalPeriod to cartpole constants, it is .73 seconds using th…
tobidelbruck Feb 14, 2023
ae7035b
add computation of pole natural frequency to p_globals.py.
tobidelbruck Feb 14, 2023
4e7d284
fixed spin cost to count only terminal state of pole at end of horizo…
tobidelbruck Feb 15, 2023
48d665b
cleanup comments
tobidelbruck Feb 15, 2023
10809ba
updated dance, dance to satisfaction finally starting to work.New spi…
tobidelbruck Feb 15, 2023
bebba78
Merge remote-tracking branch 'origin/Tobi_Dance' into Tobi_Dance
tobidelbruck Feb 15, 2023
d517d7d
changed shimmy target trajectory to keep pole pointed at correct angl…
tobidelbruck Feb 15, 2023
d42307f
corrected formula for natural frequency based on error pointed out by…
tobidelbruck Feb 15, 2023
5abab34
fixed cartpole_trajectory_generator.py for shimmy to make pole angle …
tobidelbruck Feb 15, 2023
41ad200
fixed angle target for shimmy to be correct, shimmy works well in sim…
tobidelbruck Feb 15, 2023
b9bbb40
added 'cartwheel' step to cartpole_trajectory_generator.py.
tobidelbruck Feb 16, 2023
195f669
modified cartwheel step to be a state machine with transitions betwee…
tobidelbruck Feb 17, 2023
23e03b7
fixed logic for multiple cartwheels, seems to work fine now in simula…
tobidelbruck Feb 18, 2023
1562d7a
small changes to weights to make balancing work better
tobidelbruck Feb 18, 2023
bab5a18
Merge remote-tracking branch 'origin/Tobi_Dance' into Tobi_Dance
tobidelbruck Feb 18, 2023
4e8eac0
added comments to headers of some config files, not done yet.
tobidelbruck Feb 19, 2023
6744441
Merge remote-tracking branch 'origin/Tobi_Dance' into Tobi_Dance
tobidelbruck Feb 19, 2023
0fe4481
working demo of dancer, fixed inits of vars, updated costs to work be…
tobidelbruck Feb 19, 2023
921968c
fixed some logic and reduced some loggers to debug level
tobidelbruck Feb 19, 2023
02aded4
Merge remote-tracking branch 'origin/Tobi_Dance' into Tobi_Dance
tobidelbruck Feb 19, 2023
60fd44e
added explanatory comments to config files
Feb 20, 2023
dab0b32
added more explanatory comments to config files
Feb 20, 2023
b6b071f
Merge remote-tracking branch 'origin/Tobi_Dance' into Tobi_Dance
tobidelbruck Feb 20, 2023
5118bf2
fixed logic for shimmy start time, will test in simulation
tobidelbruck Feb 20, 2023
92417f2
fixed shimmy math.
tobidelbruck Feb 20, 2023
1b5be4f
working dancer
tobidelbruck Feb 20, 2023
6b33317
improved console reporting of current objective and logging output so…
tobidelbruck Feb 21, 2023
0e8a9f5
improved logging output to make debug logger light gray, include file…
tobidelbruck Feb 21, 2023
ee2fe88
improved cartonly and integrated to start of satification dance
tobidelbruck Feb 21, 2023
318f641
reverted sookie sookie with corrupted binary
tobidelbruck Jan 28, 2023
3666578
fixed get_logger.py that now uses a single logger name to only add th…
tobidelbruck Feb 22, 2023
5424c72
fixed logic to test for changed cartpole cycles and casting to python…
tobidelbruck Feb 22, 2023
c2cf1d9
added to and fro cartwheel step
tobidelbruck Feb 22, 2023
b85dfdf
improved console output.
tobidelbruck Feb 22, 2023
9da6d58
slightly changed pole length based on careful measurement
tobidelbruck Feb 22, 2023
fa8c24f
reduced chatter in logging
tobidelbruck Feb 23, 2023
455f52b
updated config_cost_functions.yml from physical-cartpole
tobidelbruck Feb 23, 2023
dfef006
major changes to cartpole_dancer_cost and cartpole_trajectory_generat…
tobidelbruck Feb 28, 2023
2a37817
latest physical-cartpole dancer, works with RPGD
tobidelbruck Feb 28, 2023
ae92487
added jtag connection photo.
tobidelbruck May 11, 2023
36ad4c5
initial commit of Shreyan's code for energy-based controller for cart…
tobidelbruck May 13, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 45 additions & 11 deletions CartPole/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
import csv
# Import module to interact with OS
import os
import sys
import traceback
# Import module to get a current time and date used to name the files containing the history of simulations
from datetime import datetime
from typing import Optional

# To detect the latest csv file

import numpy as np
Expand All @@ -22,9 +25,9 @@
from Control_Toolkit.others.globals_and_utils import (
get_available_controller_names, get_available_optimizer_names, get_controller_name, get_optimizer_name, import_controller_by_name)
from others.globals_and_utils import MockSpace, create_rng, load_config
from others.p_globals import (P_GLOBALS, J_fric, L, m_cart, M_fric, TrackHalfLength,
controlBias, controlDisturbance, export_globals,
g, k, m_pole, u_max, v_max)
from others.p_globals import (J_fric, L, m_cart, M_fric, TrackHalfLength,
controlBias, controlDisturbance, g, k, m_pole, u_max, v_max, cart_bounce_factor, NaturalPeriod,
export_globals)
# Interpolate function to create smooth random track
from scipy.interpolate import BPoly, interp1d
# Run range() automatically adding progress bar in terminal
Expand All @@ -40,6 +43,9 @@
from CartPole.state_utilities import (ANGLE_COS_IDX, ANGLE_IDX, ANGLE_SIN_IDX,
ANGLED_IDX, POSITION_IDX, POSITIOND_IDX)

from Control_Toolkit.others.get_logger import get_logger
log = get_logger(__name__)

# region Imported modules

try:
Expand Down Expand Up @@ -143,7 +149,7 @@ def __init__(self, initial_state=s0, path_to_experiment_recordings=None):
# region Variables controlling operation of the program - should not be modified directly
self.save_flag = False # Signalizes that the current time step should be saved
self.csv_filepath = None # Where to save the experiment history.
self.controller = None # Placeholder for the currently used controller function
self.controller:template_controller = Optional[None] # Placeholder for the currently used controller function
Copy link
Member Author

@frehe frehe Dec 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.controller:template_controller = Optional[None] # Placeholder for the currently used controller function
self.controller: "Optional[template_controller]" = None # Placeholder for the currently used controller function

self.controller_name = '' # Placeholder for the currently used controller name
self.optimizer_name = '' # Placeholder for the currently used optimizer name
self.controller_idx = None # Placeholder for the currently used controller index
Expand Down Expand Up @@ -199,7 +205,7 @@ def __init__(self, initial_state=s0, path_to_experiment_recordings=None):
self.slider_max = 1.0
self.slider_value = 0.0

self.show_hanging_pole = False
self.show_hanging_pole = True

self.physical_to_graphics = None
self.graphics_to_physical = None
Expand Down Expand Up @@ -446,6 +452,7 @@ def cartpole_integration(self):

def edge_bounce(self):
# Elastic collision at edges
# TODO should be semielastic
self.s[ANGLE_IDX], self.s[ANGLED_IDX], self.s[POSITION_IDX], self.s[POSITIOND_IDX] = edge_bounce_numba(
self.s[ANGLE_IDX],
np.cos(self.s[ANGLE_IDX]),
Expand All @@ -460,6 +467,10 @@ def edge_bounce(self):
# This function should be called for the first time to calculate 0th time step
# Otherwise it goes out of sync with saving
def Update_Q(self):
""" Determine the dimensionless [-1,1] value of the motor power Q
This function should be called for the first time to calculate 0th time step
Otherwise it goes out of sync with saving,
"""
# Calculate time steps from last update
# The counter should be initialized at max-1 to start with a control input update
self.dt_controller_steps_counter += 1
Expand All @@ -471,7 +482,7 @@ def Update_Q(self):
# in this case slider corresponds already to the power of the motor
self.Q = self.slider_value
else: # in this case slider gives a target position, lqr regulator
self.Q = self.controller.step(self.s_with_noise_and_latency, self.time, {"target_position": self.target_position, "target_equilibrium": self.target_equilibrium})
self.Q = self.controller.step(self.s_with_noise_and_latency, self.time, updated_attributes= {"target_position": self.target_position, "target_equilibrium": self.target_equilibrium})
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.Q = self.controller.step(self.s_with_noise_and_latency, self.time, updated_attributes= {"target_position": self.target_position, "target_equilibrium": self.target_equilibrium})
self.Q = self.controller.step(self.s_with_noise_and_latency, time=self.time, updated_attributes= {"target_position": self.target_position, "target_equilibrium": self.target_equilibrium})

Good catch. Technically, we should also pass the time as keyword-arg, because it is also optional.


self.dt_controller_steps_counter = 0

Expand All @@ -488,6 +499,11 @@ def update_parameters(self):

# This method saves the dictionary keeping the history of simulation to a .csv file
def save_history_csv(self, csv_name=None, mode='init', length_of_experiment='unknown'):
""" Saves history of cartpole state and control
:param csv_name: the filename base, .csv is appended if it is not there. self.path_to_experiment_recordings is prepended for path
:param mode: the mode for saving, default is 'init' which makes the timestamped folder etc TODO what are these modes???????
:param length_of_experiment: the duration of this experiment in seconds, written to header of CSV
"""

if mode == 'init':

Expand Down Expand Up @@ -550,8 +566,9 @@ def save_history_csv(self, csv_name=None, mode='init', length_of_experiment='unk

writer.writerow(['#'])
writer.writerow(['# Parameters:'])
for k in P_GLOBALS.__dict__:
writer.writerow(['# ' + k + ': ' + str(getattr(P_GLOBALS, k))])
c = load_config("config.yml")
for k,v in c.items():
writer.writerow(['# ' + k + ': ' + str(v)])
writer.writerow(['#'])

writer.writerow(['# Data:'])
Expand Down Expand Up @@ -900,8 +917,9 @@ def set_controller(self, controller_name=None, controller_idx=None):
)

else:
log.debug(f'configuring controller "{self.controller}"')
self.controller.configure()


# Set the maximal allowed value of the slider - relevant only for GUI
if self.controller_name == 'manual-stabilization':
Expand Down Expand Up @@ -932,8 +950,10 @@ def set_cartpole_state_at_t0(self, reset_mode=1, s=None, target_position=None, r
pass

# reset global variables
global k, m_cart, m_pole, g, J_fric, M_fric, L, v_max, u_max, controlDisturbance, controlBias, TrackHalfLength
k[...], m_cart[...], m_pole[...], g[...], J_fric[...], M_fric[...], L[...], v_max[...], u_max[...], controlDisturbance[...], controlBias[...], TrackHalfLength[...] = export_globals()
global k, m_cart, m_pole, g, J_fric, M_fric, L, v_max, u_max, controlDisturbance, controlBias, TrackHalfLength, cart_bounce_factor, NaturalPeriod
# TODO why is ellipis object used here? https://stackoverflow.com/questions/772124/what-does-the-ellipsis-object-do
# these outputs of export_globals are numpy scalar arrays, i.e. each constant is a np.array with a single element
k[...], m_cart[...], m_pole[...], g[...], J_fric[...], M_fric[...], L[...], v_max[...], u_max[...], controlDisturbance[...], controlBias[...], TrackHalfLength[...], cart_bounce_factor[...], NaturalPeriod[...] = export_globals()

self.time = 0.0
if reset_mode == 0: # Don't change it
Expand Down Expand Up @@ -1297,3 +1317,17 @@ def animationManage(i):
return anim

# endregion


def is_physical_cartpole_running_and_control_enabled():
""" super hack to determine if we are running physical cartpole and control is turned on"""
if 'DriverFunctions' in sys.modules: # if this module exists in sys.modules, we can deduce that physical-cartpole is running
try:
physical_cartpole_instance = sys.modules[
'DriverFunctions'].PhysicalCartPoleDriver.PhysicalCartPoleDriver.PhysicalCartPoleDriverInstance
if getattr(physical_cartpole_instance, 'controlEnabled') == True:
log.debug(f'physical cartpole present and control enabled')
return True
except Exception as e:
log.warning(f'Could not determine if control is enabled: {e}')
return False
28 changes: 20 additions & 8 deletions CartPole/cartpole_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from others.globals_and_utils import create_rng, load_config
from others.p_globals import (J_fric, L, m_cart, M_fric, TrackHalfLength,
controlBias, controlDisturbance, g, k, m_pole, u_max, v_max)
controlBias, controlDisturbance, g, k, m_pole, u_max, v_max,cart_bounce_factor)

from CartPole.state_utilities import (ANGLE_COS_IDX, ANGLE_IDX, ANGLE_SIN_IDX,
ANGLED_IDX, POSITION_IDX, POSITIOND_IDX,
Expand All @@ -30,8 +30,8 @@
Should be the same up to the angle-direction-convention and notation changes.

The convention:
Pole upright position defines 0 angle
Cart movement to the right is positive
Pole upright position defines 0 angle, units of angle is in radians
Cart movement to the right is positive, units are meters
Clockwise angle rotation is defined as negative

Required angle convention for CartPole GUI: CLOCK-NEG
Expand Down Expand Up @@ -146,12 +146,24 @@ def cartpole_ode(s: np.ndarray, u: float,

return angleDD, positionDD

def edge_bounce(angle, angle_cos, angleD, position, positionD, t_step, L=L):
def edge_bounce(angle, angle_cos, angleD, position, positionD, t_step, L=L, cart_bounce_factor=cart_bounce_factor):
""" Models bounce at edge of cart track. Very simple complete elastic bounce currently.

:param angle:
:param angle_cos:
:param angleD:
:param position:
:param positionD:
:param t_step: the timestep in seconds
:param L: the pole length

:returns: angle, angleD, position, positionD
"""
if position >= TrackHalfLength or -position >= TrackHalfLength: # Without abs to compile with tensorflow
angleD -= 2 * (positionD * angle_cos) / L
angle += angleD * t_step
positionD = -positionD
position += positionD * t_step
angleD -= 2 * (positionD * angle_cos) / L # TODO why this formula???
# angle += angleD * t_step # update angle according to new derivative of angle
positionD = -cart_bounce_factor*positionD # perfect bounce
# position += positionD * t_step # step back the amount of bounce
return angle, angleD, position, positionD


Expand Down
49 changes: 40 additions & 9 deletions CartPole/cartpole_model_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import tensorflow as tf
from others.globals_and_utils import create_rng, load_config
from others.p_globals import (J_fric, L, m_cart, M_fric, TrackHalfLength,
from others.p_globals import (J_fric, L, m_cart, M_fric, TrackHalfLength, cart_bounce_factor,
controlBias, controlDisturbance, g, k, m_pole, u_max,
v_max)
from SI_Toolkit.Functions.TF.Compile import CompileTF
Expand All @@ -26,6 +26,7 @@
controlDisturbance = tf.convert_to_tensor(controlDisturbance)
controlBias = tf.convert_to_tensor(controlBias)
TrackHalfLength = tf.convert_to_tensor(TrackHalfLength)
cart_bounce_factor = tf.convert_to_tensor(cart_bounce_factor)


rng = create_rng(__name__, config["cartpole"]["seed"])
Expand Down Expand Up @@ -71,7 +72,10 @@ def _cartpole_ode(ca, sa, angleD, positionD, u,
Calculates current values of second derivative of angle and position
from current value of angle and position, and their first derivatives

:param angle, angleD, position, positionD: Essential state information of cart
:param angle, angleD, position, positionD:
Pole angle in radians. 0 means pole is upright. Clockwise angle rotation is defined as negative.
Cart position is in meters, 0 at middle of track, positive to rightwards.
Essential state information of cart
:param u: Force applied on cart in unnormalized range

:returns: angular acceleration, horizontal acceleration
Expand Down Expand Up @@ -135,19 +139,46 @@ def cartpole_ode(s: np.ndarray, u: float,
)
return angleDD, positionDD

def edge_bounce(angle, angle_cos, angleD, position, positionD, t_step, L=L):
if position >= TrackHalfLength or -position >= TrackHalfLength: # Without abs to compile with tensorflow
angleD -= 2 * (positionD * angle_cos) / L
angle += angleD * t_step
positionD = -positionD
position += positionD * t_step
# @tf.function
def edge_bounce(angle, angle_cos, angleD, position, positionD, t_step, L=L, cart_bounce_factor=cart_bounce_factor):
""" Models bounce at edge of cart track. Very simple elastic bounce currently.

:param angle: Pole angle in radians. 0 means pole is upright. Clockwise angle rotation is defined as negative.
:param angle_cos:
:param angleD: rad/s. Positive means CCW rotation
:param position: meters, 0 at middle of table, positive rightwards
:param positionD: m/w, positive rightwards
:param t_step: the timestep in seconds
:param L: the pole length in meters
:param cart_bounce_factor: fraction of cart speed after bounce from edge

:returns: angle, angleD, position, positionD
"""

i=tf.greater_equal(tf.abs(position),TrackHalfLength) # find those rollouts that go past edge of track

# for those that do, update the swing according to this dynamics
angleD = tf.where(i,angleD-2 * (positionD * angle_cos) / L, angleD) # TODO why this formula???
# don't update angle since the euler step will already do it
# angle = angle+angleD * t_step # update angle according to new derivative of angle
# and the cart velocity is reversed with some absorption
positionD = tf.where(i,-cart_bounce_factor*positionD,positionD) # imperfect bounce
# don't update position since Euler step will do it
# position = position+positionD * t_step # step back the amount of bounce

# following is old serial code
# if position >= TrackHalfLength or -position >= TrackHalfLength: # Without abs to compile with tensorflow
# angleD -= 2 * (positionD * angle_cos) / L # TODO why this formula???
# angle += angleD * t_step # update angle according to new derivative of angle
# positionD = -cart_bounce_factor*positionD # perfect bounce
# position += positionD * t_step # step back the amount of bounce
return angle, angleD, position, positionD


def edge_bounce_wrapper(angle, angle_cos, angleD, position, positionD, t_step, L=L):
for i in range(position.size):
angle[i], angleD[i], position[i], positionD[i] = edge_bounce(angle[i], angle_cos[i], angleD[i], position[i], positionD[i],
t_step, L)
t_step, L) # see cartpole_tf.py; note we no longer need the edge_bounce wrapper because edge bounce compiles as tf code
return angle, angleD, position, positionD


Expand Down
19 changes: 12 additions & 7 deletions CartPole/cartpole_tf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import tensorflow as tf
from others.globals_and_utils import create_rng, load_config
from others.p_globals import (J_fric, L, m_cart, M_fric, TrackHalfLength,
from others.p_globals import (J_fric, L, m_cart, M_fric, TrackHalfLength,cart_bounce_factor,
controlBias, controlDisturbance, g, k, m_pole, u_max,
v_max)
from SI_Toolkit.Functions.TF.Compile import CompileTF
Expand All @@ -25,11 +25,12 @@
controlDisturbance = tf.convert_to_tensor(controlDisturbance)
controlBias = tf.convert_to_tensor(controlBias)
TrackHalfLength = tf.convert_to_tensor(TrackHalfLength)
cart_bounce_factor = tf.convert_to_tensor(cart_bounce_factor)

rng = create_rng(__name__, config["cartpole"]["seed"])

###
# FIXME: Currently tf predictor is not modeling edge bounce!
# TODO: Currently tf predictor is not modeling edge bounce!
###


Expand All @@ -56,7 +57,7 @@ def wrap_angle_rad(sin, cos):


@CompileTF
def edge_bounce_wrapper(angle, angle_cos, angleD, position, positionD, t_step, L=L):
def edge_bounce_wrapper(angle, angle_cos, angleD, position, positionD, t_step, L=L, cart_bounce_factor=cart_bounce_factor):
angle_bounced = tf.TensorArray(tf.float32, size=tf.size(angle), dynamic_size=False)
angleD_bounced = tf.TensorArray(tf.float32, size=tf.size(angleD), dynamic_size=False)
position_bounced = tf.TensorArray(tf.float32, size=tf.size(position), dynamic_size=False)
Expand All @@ -65,7 +66,7 @@ def edge_bounce_wrapper(angle, angle_cos, angleD, position, positionD, t_step, L
for i in tf.range(tf.size(position)):
angle_i, angleD_i, position_i, positionD_i = edge_bounce_tf(angle[i], angle_cos[i], angleD[i], position[i],
positionD[i],
t_step, L)
t_step, L, cart_bounce_factor=cart_bounce_factor)
angle_bounced = angle_bounced.write(i, angle_i)
angleD_bounced = angleD_bounced.write(i, angleD_i)
position_bounced = position_bounced.write(i, position_i)
Expand Down Expand Up @@ -127,19 +128,23 @@ def _cartpole_fine_integration_tf(angle, angleD,
positionDD, t_step, )

# The edge bounce calculation seems to be too much for a GPU to tackle
# angle_cos = tf.cos(angle)
# angle, angleD, position, positionD = edge_bounce_wrapper(angle, angle_cos, angleD, position, positionD, t_step, L)

# TODO it is currently commented out in master branch
angle_cos = tf.cos(angle)
angle, angleD, position, positionD = edge_bounce(angle, angle_cos, angleD, position, positionD, t_step, L, cart_bounce_factor)
# # note we no longer need the edge_bounce wrapper because edge_bounce compiles as tf code

# angle_cos = tf.cos(angle)
angle_sin = tf.sin(angle)

angle = wrap_angle_rad(angle_sin, angle_cos)
#print('test 7')
return angle, angleD, position, positionD, angle_cos, angle_sin


@CompileTF
def cartpole_fine_integration_tf(s, u, t_step, intermediate_steps,
k=k, m_cart=m_cart, m_pole=m_pole, g=g, J_fric=J_fric, M_fric=M_fric, L=L):
#print('test 5')
"""
Calculates current values of second derivative of angle and position
from current value of angle and position, and their first derivatives
Expand Down
9 changes: 8 additions & 1 deletion CartPole/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import os
import glob
from typing import Optional

import pandas as pd

def get_full_paths_to_csvs(default_locations='', csv_names=None):
Expand Down Expand Up @@ -87,7 +89,12 @@ def get_full_paths_to_csvs(default_locations='', csv_names=None):


# load csv file with experiment recording (e.g. for replay)
def load_csv_recording(file_path):
def load_csv_recording(file_path:str)->pd.DataFrame:
""" Loads the recording CSV file
:param file_path: path to CSV including full filename with suffix

:returns: False if file not found or pd.DataFrame if found
"""
if isinstance(file_path, list):
file_path = file_path[0]

Expand Down
Loading