Skip to content

Commit

Permalink
accommodating slack variables
Browse files Browse the repository at this point in the history
  • Loading branch information
aarontrowbridge committed Nov 6, 2023
1 parent 981945b commit a543ea7
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions src/methods_named_trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ function add_component!(
# check data against existing data
@assert size(data, 2) == traj.T
@assert name keys(traj.components)
@assert type (:state, :control)
@assert type (:state, :control, :slack)


# update components
Expand All @@ -126,8 +126,14 @@ function add_component!(

if type == :state
comp_dict[:states] = vcat(comp_dict[:states], comp_dict[name])
else
elseif type == :control
comp_dict[:controls] = vcat(comp_dict[:controls], comp_dict[name])
else
if :slacks keys(comp_dict)
comp_dict[:slacks] = comp_dict[name]
else
comp_dict[:slacks] = vcat(comp_dict[:slacks], comp_dict[name])
end
end

traj.components = NamedTuple(comp_dict)
Expand All @@ -143,9 +149,15 @@ function add_component!(

if type == :state
dim_dict[:states] += dim
else
elseif type == :control
traj.control_names = (traj.control_names..., name)
dim_dict[:controls] += dim
else
if :slacks keys(dim_dict)
dim_dict[:slacks] = dim
else
dim_dict[:slacks] += dim
end
end

traj.dims = NamedTuple(dim_dict)
Expand All @@ -172,7 +184,9 @@ Remove a component from the trajectory.
"""
function remove_component(traj::NamedTrajectory, name::Symbol)
@assert name traj.names
comps = NamedTuple([(key => data) for (key, data) pairs(components(traj)) if key != name])
comps = NamedTuple([
(key => data) for (key, data) pairs(components(traj)) if key != name
])
return NamedTrajectory(comps, traj)
end

Expand All @@ -183,7 +197,9 @@ Remove a set of components from the trajectory.
"""
function remove_components(traj::NamedTrajectory, names::Vector{Symbol})
@assert all([name traj.names for name names])
comps = NamedTuple([(key => data) for (key, data) pairs(components(traj)) if !(key names)])
comps = NamedTuple([
(key => data) for (key, data) pairs(components(traj)) if !(key names)
])
return NamedTrajectory(comps, traj)
end

Expand Down

0 comments on commit a543ea7

Please sign in to comment.