Skip to content

Commit

Permalink
Improve C++ SimpleMotorFeedforward unit type support
Browse files Browse the repository at this point in the history
Allow using non-base types
Allow using angles for serde
  • Loading branch information
KangarooKoala committed Nov 28, 2024
1 parent b6de7ac commit d5ca91f
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ namespace frc {
* permanent-magnet DC motor.
*/
template <class Distance>
requires std::same_as<units::meter, Distance> ||
std::same_as<units::radian, Distance>
requires units::length_unit<Distance> || units::angle_unit<Distance>
class SimpleMotorFeedforward {
public:
using Velocity =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
#include "wpimath/protobuf/controller.npb.h"

// Everything is converted into units for
// frc::SimpleMotorFeedforward<units::meters>
// frc::SimpleMotorFeedforward<units::meters> or
// frc::SimpleMotorFeedforward<units::radians>

template <class Distance>
struct wpi::Protobuf<frc::SimpleMotorFeedforward<Distance>> { // NOLINT
requires units::length_unit<Distance> || units::angle_unit<Distance>
struct wpi::Protobuf<frc::SimpleMotorFeedforward<Distance>> {
using MessageStruct = wpi_proto_ProtobufSimpleMotorFeedforward;
using InputStream =
wpi::ProtoInputStream<frc::SimpleMotorFeedforward<Distance>>;
Expand All @@ -24,33 +26,33 @@ struct wpi::Protobuf<frc::SimpleMotorFeedforward<Distance>> { // NOLINT

static std::optional<frc::SimpleMotorFeedforward<Distance>> Unpack(
InputStream& stream) {
using BaseUnit =
units::unit<std::ratio<1>, units::traits::base_unit_of<Distance>>;
using BaseFeedforward = frc::SimpleMotorFeedforward<BaseUnit>;
wpi_proto_ProtobufSimpleMotorFeedforward msg;
if (!stream.Decode(msg)) {
return {};
}

return frc::SimpleMotorFeedforward<Distance>{
units::volt_t{msg.ks},
units::unit_t<frc::SimpleMotorFeedforward<units::meters>::kv_unit>{
msg.kv},
units::unit_t<frc::SimpleMotorFeedforward<units::meters>::ka_unit>{
msg.ka},
units::unit_t<typename BaseFeedforward::kv_unit>{msg.kv},
units::unit_t<typename BaseFeedforward::ka_unit>{msg.ka},
units::second_t{msg.dt},
};
}

static bool Pack(OutputStream& stream,
const frc::SimpleMotorFeedforward<Distance>& value) {
using BaseUnit =
units::unit<std::ratio<1>, units::traits::base_unit_of<Distance>>;
using BaseFeedforward = frc::SimpleMotorFeedforward<BaseUnit>;
wpi_proto_ProtobufSimpleMotorFeedforward msg{
.ks = value.GetKs().value(),
.kv =
units::unit_t<frc::SimpleMotorFeedforward<units::meters>::kv_unit>{
value.GetKv()}
.value(),
.ka =
units::unit_t<frc::SimpleMotorFeedforward<units::meters>::ka_unit>{
value.GetKa()}
.value(),
.kv = units::unit_t<typename BaseFeedforward::kv_unit>{value.GetKv()}
.value(),
.ka = units::unit_t<typename BaseFeedforward::ka_unit>{value.GetKa()}
.value(),
.dt = units::second_t{value.GetDt()}.value(),
};
return stream.Encode(msg);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
// frc::SimpleMotorFeedforward<units::radians>

template <class Distance>
requires units::length_unit<Distance> || units::angle_unit<Distance>
struct wpi::Struct<frc::SimpleMotorFeedforward<Distance>> {
static constexpr std::string_view GetTypeName() {
return "SimpleMotorFeedforward";
Expand All @@ -25,40 +26,44 @@ struct wpi::Struct<frc::SimpleMotorFeedforward<Distance>> {

static frc::SimpleMotorFeedforward<Distance> Unpack(
std::span<const uint8_t> data) {
using BaseUnit =
units::unit<std::ratio<1>, units::traits::base_unit_of<Distance>>;
using BaseFeedforward = frc::SimpleMotorFeedforward<BaseUnit>;
constexpr size_t kKsOff = 0;
constexpr size_t kKvOff = kKsOff + 8;
constexpr size_t kKaOff = kKvOff + 8;
constexpr size_t kDtOff = kKaOff + 8;
return {units::volt_t{wpi::UnpackStruct<double, kKsOff>(data)},
units::unit_t<frc::SimpleMotorFeedforward<units::meters>::kv_unit>{
units::unit_t<typename BaseFeedforward::kv_unit>{
wpi::UnpackStruct<double, kKvOff>(data)},
units::unit_t<frc::SimpleMotorFeedforward<units::meters>::ka_unit>{
units::unit_t<typename BaseFeedforward::ka_unit>{
wpi::UnpackStruct<double, kKaOff>(data)},
units::second_t{wpi::UnpackStruct<double, kDtOff>(data)}};
}

static void Pack(std::span<uint8_t> data,
const frc::SimpleMotorFeedforward<Distance>& value) {
using BaseUnit =
units::unit<std::ratio<1>, units::traits::base_unit_of<Distance>>;
using BaseFeedforward = frc::SimpleMotorFeedforward<BaseUnit>;
constexpr size_t kKsOff = 0;
constexpr size_t kKvOff = kKsOff + 8;
constexpr size_t kKaOff = kKvOff + 8;
constexpr size_t kDtOff = kKaOff + 8;
wpi::PackStruct<kKsOff>(data, value.GetKs().value());
wpi::PackStruct<kKvOff>(
data,
units::unit_t<frc::SimpleMotorFeedforward<units::meters>::kv_unit>{
value.GetKv()}
.value());
data, units::unit_t<typename BaseFeedforward::kv_unit>{value.GetKv()}
.value());
wpi::PackStruct<kKaOff>(
data,
units::unit_t<frc::SimpleMotorFeedforward<units::meters>::ka_unit>{
value.GetKa()}
.value());
data, units::unit_t<typename BaseFeedforward::ka_unit>{value.GetKa()}
.value());
wpi::PackStruct<kDtOff>(data, units::second_t{value.GetDt()}.value());
}
};

static_assert(
wpi::StructSerializable<frc::SimpleMotorFeedforward<units::meters>>);
static_assert(
wpi::StructSerializable<frc::SimpleMotorFeedforward<units::feet>>);
static_assert(
wpi::StructSerializable<frc::SimpleMotorFeedforward<units::radians>>);
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@

using namespace frc;

template <typename T>
struct SimpleMotorFeedforwardProtoTestData {
using Type = SimpleMotorFeedforward<units::meters>;
using Type = SimpleMotorFeedforward<T>;

inline static const Type kTestData = {units::volt_t{0.4},
units::volt_t{4.0} / 1_mps,
units::volt_t{0.7} / 1_mps_sq, 25_ms};
inline static const Type kTestData = {
units::volt_t{0.4}, units::volt_t{4.0} / (units::unit_t<T>{1} / 1_s),
units::volt_t{0.7} / (units::unit_t<T>{1} / 1_s / 1_s), 25_ms};

static void CheckEq(const Type& testData, const Type& data) {
EXPECT_EQ(testData.GetKs().value(), data.GetKs().value());
Expand All @@ -27,5 +28,12 @@ struct SimpleMotorFeedforwardProtoTestData {
}
};

INSTANTIATE_TYPED_TEST_SUITE_P(SimpleMotorFeedforwardMeters, ProtoTest,
SimpleMotorFeedforwardProtoTestData);
INSTANTIATE_TYPED_TEST_SUITE_P(
SimpleMotorFeedforwardMeters, ProtoTest,
SimpleMotorFeedforwardProtoTestData<units::meters>);
INSTANTIATE_TYPED_TEST_SUITE_P(
SimpleMotorFeedforwardFeet, ProtoTest,
SimpleMotorFeedforwardProtoTestData<units::feet>);
INSTANTIATE_TYPED_TEST_SUITE_P(
SimpleMotorFeedforwardRadians, ProtoTest,
SimpleMotorFeedforwardProtoTestData<units::radians>);
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@

using namespace frc;

template <typename T>
struct SimpleMotorFeedforwardStructTestData {
using Type = SimpleMotorFeedforward<units::meters>;
using Type = SimpleMotorFeedforward<T>;

inline static const Type kTestData = {units::volt_t{0.4},
units::volt_t{4.0} / 1_mps,
units::volt_t{0.7} / 1_mps_sq, 25_ms};
inline static const Type kTestData = {
units::volt_t{0.4}, units::volt_t{4.0} / (units::unit_t<T>{1} / 1_s),
units::volt_t{0.7} / (units::unit_t<T>{1} / 1_s / 1_s), 25_ms};

static void CheckEq(const Type& testData, const Type& data) {
EXPECT_EQ(testData.GetKs().value(), data.GetKs().value());
Expand All @@ -27,5 +28,12 @@ struct SimpleMotorFeedforwardStructTestData {
}
};

INSTANTIATE_TYPED_TEST_SUITE_P(SimpleMotorFeedforwardMeters, StructTest,
SimpleMotorFeedforwardStructTestData);
INSTANTIATE_TYPED_TEST_SUITE_P(
SimpleMotorFeedforwardMeters, StructTest,
SimpleMotorFeedforwardStructTestData<units::meters>);
INSTANTIATE_TYPED_TEST_SUITE_P(
SimpleMotorFeedforwardFeet, StructTest,
SimpleMotorFeedforwardStructTestData<units::feet>);
INSTANTIATE_TYPED_TEST_SUITE_P(
SimpleMotorFeedforwardRadians, StructTest,
SimpleMotorFeedforwardStructTestData<units::radians>);

0 comments on commit d5ca91f

Please sign in to comment.