Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the automatic differentiation multibody solver based on JAX #305

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from

Conversation

ben-l-p
Copy link
Collaborator

@ben-l-p ben-l-p commented Nov 6, 2024

I have added a new solver NonlinearDynamicMultibodyJAX with the similar functionality as NonlinearDynamicMultibody, except it does not include any DynamicTrim routine (this can be added in due course when required). This new solver makes use of AD to calculate the Jacobians which leads to a much nicer and more compact definition of constraints, particularly for use of the penalty method. A lot of the key constraint types from the original solver are included here, such as:

  • Free hinge between beams with arbitrary axis line
  • Free hinge between beam and a fixed point in the inertial FoR
  • Spherical joint between beams
  • Spherical joints between beam and a fixed point in the inertial FoR
  • Fully constrained joint between beams
  • Fully constrained beam to a fixed point in the inertial FoR (this is effectively a less efficient way of created a “prescribed” beam, but was useful for testing)

The constraints numerics and the derivatives are defined in the sharpy/structure/utils/lagrangeconstraintsjax file. I have included new versions of all the other parts of the framework to prevent changes to one breaking the other solver etc., hence why sharpy/utils/multibodyjax and sharpy/solvers/timeintegratorsjax exist.

Also included is the ability for controlled actuation between beams, which I primarily implemented for testing variable sweep wings, however the formulation/implementation is general for any 3D rotation. These can be controlled with the new sharpy/controllers/multibodycontroller controller type, which takes a Cartesian rotation vector time series as input. To allow this to sweep the wing correctly, the ability to warp the aero grid has been added with a new aerogrid_warp_factor parameter in the multibody file. This allows for a gradual sweep around a kink, and should not effect for existing cases, as it has no effect is the parameter is not included in the H5 file.

A test case for this solver is included in the form of a flexible double pendulum comparison. A free double pendulum case is run, the angles from the two hinges extracted and applied onto a prescribed model, where both should yield the “same” result for structural deflections.

Also included is some documentation on the multibody case files (should be general for both multibody solvers). In recent testing I have found a bug in the StaticTrim routine which was not present in v2.0 and I am currently investigating, which will be another PR in due course, but it is not connected to the multibody implementation.

I have also done some code cleanups I found along the way, due to PyCharm largely doing it for me, but for files not related to the new solver the functionality should not have been changed. Lastly, I found one of the unit test cases was creating files which didn't get deleted at the end, which has also been fixed.

to the angle of attack (in radians) and then the ``C_L``, ``C_D`` and ``C_M``.
to the angle of attack (in radians) and then the ``C_L``, ``C_D`` and ``C_M``.

Multibody file
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New documentation - should be general for both multibody implementations.

@@ -145,7 +145,8 @@ def run(self):
"openpyxl>=3.0.10",
"lxml>=4.4.1",
"PySocks",
"PyYAML"
"PyYAML",
"jax",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added JAX as a new dependency during install

@@ -270,15 +265,28 @@ def generate_zeta_timestep_info(self, structure_tstep, aero_tstep, beam, setting
raise NotImplementedError(str(self.data_dict['control_surface_type'][i_control_surface]) +
' control surfaces are not yet implemented')


# add sweep for aerogrid warping in constraint defintition
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New code for dynamically sweeping the aero grid. This shouldn't impact any existing cases as it will ignore it if the warp factor parameter does not exist.

@@ -62,7 +61,7 @@ def get_coefs(self, aoa_deg):
cd = self.cd_interp(aoa_deg)
cm = self.cm_interp(aoa_deg)

return cl, cd, cm
return cl[0], cd[0], cm[0]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed one of the pesky errors converting an array to a scalar during the unit test - this seems to work fine.



@controller_interface.controller
class MultibodyController(controller_interface.BaseController):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New multibody controller for setting angle between beams.

@@ -0,0 +1,368 @@
import numpy as np
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New solver, very similar to the existing multibody solver in its inputs.

@@ -323,11 +312,10 @@ def run(self, **kwargs):
"""

aero_tstep = settings_utils.set_value_or_default(kwargs, 'aero_step', self.data.aero.timestep_info[-1])
structure_tstep = settings_utils.set_value_or_default(kwargs, 'structural_step', self.data.structure.timestep_info[-1])
convect_wake = settings_utils.set_value_or_default(kwargs, 'convect_wake', False)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cut some parameters here as they aren't actually referenced again

@@ -71,6 +71,8 @@ class NewmarkBeta(_BaseTimeIntegrator):

def __init__(self):

self.sys_size = None
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some variables used later were not declared in init

@@ -0,0 +1,209 @@
import numpy as np
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New time integrator for the JAX solver, should have the same numerics as the existing implementation.

@@ -544,7 +476,8 @@ def get_body(self, ibody):
int_list_nodes = np.arange(0, ibody_beam.num_node, 1)
for ielem in range(ibody_beam.num_elem):
for inode_in_elem in range(ibody_beam.num_node_elem):
ibody_beam.connectivities[ielem, inode_in_elem] = int_list_nodes[ibody_nodes == ibody_beam.connectivities[ielem, inode_in_elem]]
ibody_beam.connectivities[ielem, inode_in_elem] = int_list_nodes[
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another fix for converting arrays to scalars which creates a warning during unit test.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good spot and fix!

@@ -0,0 +1,498 @@
from typing import Callable, Any, Optional, Type, cast
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where the multibody constraint numerics take place

@@ -1989,13 +1991,15 @@ def check(self):
raise RuntimeError(("'behaviour' parameter is required in '%s' lagrange constraint" % self.behaviour))


def generate_multibody_file(list_LagrangeConstraints, list_Bodies, route, case_name):
def generate_multibody_file(list_LagrangeConstraints, list_Bodies, route, case_name, use_jax=False):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to add this additional setting as the generate case class checks for constraints in the existing solver; as I have new constraints which aren't implemented there, it fails when this is False.

@@ -0,0 +1,325 @@
"""
Multibody library for the NonlinearDynamicMultibodyJAX solver
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy of the existing multibody utils library, with some things cut.

@@ -322,7 +301,7 @@ def test_doublependulum_hinge_slanted_lateralrot(self):
def tearDown(self):
solver_path = os.path.abspath(os.path.dirname(os.path.realpath(__file__)))
solver_path += '/'
for name in [name_hinge_slanted, name_hinge_slanted_pen, name_hinge_slanted_lateralrot]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed an issue with the double angled pendulum test case where the names weren't correct

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh sorry! my bad😂😂

@@ -0,0 +1,368 @@
import numpy as np
import typing
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to add typing as a dependency?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typing is the built-in Python type hinting, so this doesn't require any new packages

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahhhhhh ok! thanks

@@ -317,17 +317,12 @@ def trim_algorithm(self):

def evaluate(self, alpha, deflection_gamma, thrust):
if not np.isfinite(alpha):
import pdb; pdb.set_trace()
raise ValueError
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I think the reason behind breaking and not terminating the run is because staticcoupled has no post-processing - so rather than killing it (and leaving no trace to what happened) they preferred to leave a break point here. If we haven't got plans to introduce post-processing for staticcoupled iterations, is there any neater alternative to just throwing an error here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was having the trace cause me issues when running on the HPC a while ago, I'm not sure if it's good practice to have traces like this in “production” code (although I may be wrong). This error is thrown when one of the trim gradients is zero, which is currently occurring for some of my cases and I believe is a bug, that I'm currently looking into. If static coupled fails, the code won't get this far, as the FORTRAN will instead throw a singular matrix error.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that makes sense - I meant more in the off chance that staticcoupled did return eg. same total forces for a parametrised run with different geometry (perhaps the staticcoupled get total forces function working on the timestep[-1] time information, and somehow got contaminated? just a wild guess) - that would kill the trim routine like what you've seen. but yeah I agree it is probably best to clean up production code for an HPC environment assuming no user input is possible further

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a message to the ValueError so that it is possible to trace where it is coming from?

@@ -208,7 +196,7 @@ def generate_aero_file():

working_elem = 0
working_node = 0
# right wing (surface 0, beam 0)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my take is to leave these in - it helps with understanding the model generation procedure which if now there's no simplification coming up the pipeline seems to have a big learning curve.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good spot, must have gotten carried away with deleting commented code at some point! Will add these comments back in

@@ -207,7 +193,7 @@ def generate_aero_file():

working_elem = 0
working_node = 0
# right wing (surface 0, beam 0)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likewise, I suggest leaving them in, happy to have a discussion

Copy link
Collaborator

@kccwing kccwing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work, @ben-l-p! Thanks for building a parallel for most files used by the new jax routines - that should hopefully make the work of migrating over in the future much easier. Happy to merge given this passes all tests and that you've been using it day-to-day already.

@kccwing kccwing requested a review from wong-hl November 11, 2024 21:37
Comment on lines 274 to 281
try:
cst_name = f"constraint_{i_constraint:02d}"
ctrl_id = structure_tstep.mb_dict[cst_name]['controller_id'].decode('UTF-8')
f_warp = structure_tstep.mb_dict[cst_name]['aerogrid_warp_factor'][i_elem, i_local_node]
ang_z = structure_tstep.mb_prescribed_dict[ctrl_id]['delta_psi'][2]
ang_warp += f_warp * ang_z
except KeyError:
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, why is a KeyError ok?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The constraints here are input as dictionaries and have different key-value pairs depending on their functionality (hinge axis, node number etc.). Warping the aerodynamic grid will occur here if a constraint has both controller_id and aerogrid_warp_factor entries, with the intended behavior to skip this code if a constraint is missing one. A key error will occur if it's a constraint that is missing either of these entries (and therefore is not a constraint which requires the warping). Of course, it could technically fail if I had coded it wrong and delta_psi is not defined, and this would incorrectly ignore this code, but I'm pretty sure that can't happen.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Is the logic of it along this line?

if 'controller_id' in structure_tstep.mb_dict[cst_name] and 'aerogrid_warp_factor' in structure_tstep.mb_dict[cst_name]:
    f_warp = structure_tstep.mb_dict[cst_name]['aerogrid_warp_factor'][i_elem, i_local_node]
    ctrl_id = structure_tstep.mb_dict[cst_name]['controller_id'].decode('UTF-8')
    ang_z = structure_tstep.mb_prescribed_dict[ctrl_id]['delta_psi'][2]
    ang_warp += f_warp * ang_z

)

def __init__(self):
self.in_dict = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's in_dict? Is it an input dictionary or a boolean about whether something is in a dictionary

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have included this to be consistent with the other controllers available, it's the settings dictionary, and I can add a comment to state this. I do not like how the controller code (for all of them) is structured, but I would rather they're all the same at least.

Comment on lines 158 to 167
if controlled_state["structural"].mb_prescribed_dict is None:
controlled_state["structural"].mb_prescribed_dict = dict()
controlled_state["structural"].mb_prescribed_dict[self.controller_id] = {
"psi": control_command,
"psi_dot": psi_dot,
}
controlled_state["structural"].mb_prescribed_dict[self.controller_id].update(
{"delta_psi": control_command - self.prescribed_ang_time_history[0, :]}
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there isn't a need to declare the dict() and use update() by simplifying it to,

        controlled_state["structural"].mb_prescribed_dict[self.controller_id] = {
            "psi": control_command,
            "psi_dot": psi_dot,
            "delta_psi": control_command - self.prescribed_ang_time_history[0, :]
        }

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good spot!

@@ -317,17 +317,12 @@ def trim_algorithm(self):

def evaluate(self, alpha, deflection_gamma, thrust):
if not np.isfinite(alpha):
import pdb; pdb.set_trace()
raise ValueError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a message to the ValueError so that it is possible to trace where it is coming from?

Comment on lines 2049 to 2053
try:
constraint_id.create_dataset("rot_axisA2",
data=getattr(constraint, "rot_axisA2"))
except:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is existing code. But, why are there try, except blocks that don't catch any error? Should they be like the newly introduced code where it catches an AttributeError?

lc_settings: list[dict] = []
self.num_lm_tot = 0
for i in range(self.data.structure.ini_mb_dict['num_constraints']):
lc_settings.append(self.data.structure.ini_mb_dict[f'constraint_{i:02d}'])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small style comment - If i is always going to be an int as it comes from range(), then there shouldn't be a need to specify the format as f"{i:0d}", f"{i}" should be enough

(there are a couple more of this below)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The number needs to be formatted to have two digits, i.e. 00, 01, 02. Not a fan of this method as it means this has to happen in a few places in the code, but unless if I overhaul the multibody then it's a necessity for backwards compatibility.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see! That makes sense. tbf I didn't realise that it had to have two digits so that's my bad

@wong-hl
Copy link
Contributor

wong-hl commented Nov 13, 2024

Btw I really like your PR message, it's super clear about what's in the PR etc

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants