Skip to content

Commit

Permalink
allow user-defined functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Prevter committed Oct 26, 2024
1 parent b8b31d6 commit ec777a9
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 73 deletions.
61 changes: 61 additions & 0 deletions include/rift/config.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#pragma once
#include <functional>
#include <optional>
#include <unordered_map>
#include <span>

#include "value.hpp"

/// @brief The namespace for configuring the Rift interpreter
namespace rift::config {

using RuntimeFunc = std::function<Value(std::span<Value>)>; // Function signature for runtime functions
using RuntimeFuncMap = std::unordered_map<std::string, RuntimeFunc>; // Map of runtime functions

/// @brief Get the map of runtime functions
const RuntimeFuncMap& getRuntimeFunctions();

/// @brief Add a runtime function to the map
void addRuntimeFunction(const std::string& name, const RuntimeFunc& func);

}

namespace rift {
/// @brief Get an argument from a span of values
template <typename T>
std::optional<T> getArgument(std::span<Value> args, size_t index) {
if (args.size() <= index) {
return std::nullopt;
}

if constexpr (std::is_same_v<T, std::string>) {
if (!args[index].isString()) {
return std::nullopt;
}
return args[index].toString();
} else if constexpr (std::is_same_v<T, int>) {
if (args[index].isInteger()) {
return args[index].getInteger();
}
if (args[index].isFloat()) {
return static_cast<int>(args[index].getFloat());
}
return std::nullopt;
} else if constexpr (std::is_same_v<T, float>) {
if (args[index].isFloat()) {
return args[index].getFloat();
}
if (args[index].isInteger()) {
return static_cast<float>(args[index].getInteger());
}
return std::nullopt;
} else if constexpr (std::is_same_v<T, bool>) {
if (!args[index].isBoolean()) {
return std::nullopt;
}
return args[index].getBoolean();
}

return std::nullopt;
}
}
16 changes: 16 additions & 0 deletions src/config.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include <rift/config.hpp>

namespace rift::config {

static RuntimeFuncMap runtimeFunctions;

const RuntimeFuncMap& getRuntimeFunctions() {
return runtimeFunctions;
}

void addRuntimeFunction(const std::string& name, const RuntimeFunc& func) {
runtimeFunctions[name] = func;
}
}


114 changes: 41 additions & 73 deletions src/nodes/functioncall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,47 +9,11 @@
#include <cctype>
#include <optional>
#include <span>
#include <rift/config.hpp>

namespace rift {

namespace builtins {
template <typename T>
std::optional<T> getArgument(std::span<Value> args, size_t index) {
if (args.size() <= index) {
return std::nullopt;
}

if constexpr (std::is_same_v<T, std::string>) {
if (!args[index].isString()) {
return std::nullopt;
}
return args[index].toString();
} else if constexpr (std::is_same_v<T, int>) {
if (args[index].isInteger()) {
return args[index].getInteger();
}
if (args[index].isFloat()) {
return static_cast<int>(args[index].getFloat());
}
return std::nullopt;
} else if constexpr (std::is_same_v<T, float>) {
if (args[index].isFloat()) {
return args[index].getFloat();
}
if (args[index].isInteger()) {
return static_cast<float>(args[index].getInteger());
}
return std::nullopt;
} else if constexpr (std::is_same_v<T, bool>) {
if (!args[index].isBoolean()) {
return std::nullopt;
}
return args[index].getBoolean();
}

return std::nullopt;
}

Value len(std::span<Value> args) {
if (args.size() != 1) {
return Value::string("<error: len requires 1 argument>");
Expand Down Expand Up @@ -520,43 +484,47 @@ namespace rift {
}

std::function<Value(std::span<Value>)> findFunction(std::string_view name) {
static const std::unordered_map<std::string_view, std::function<Value(std::span<Value>)>> functions = {
{"len", builtins::len},
{"substr", builtins::substr},
{"toUpper", builtins::toUpper},
{"toLower", builtins::toLower},
{"trim", builtins::trim},
{"replace", builtins::replace},
{"random", builtins::random},
{"round", builtins::round},
{"floor", builtins::floor},
{"ceil", builtins::ceil},
{"abs", builtins::abs},
{"min", builtins::min},
{"max", builtins::max},
{"sum", builtins::sum},
{"avg", builtins::avg},
{"sqrt", builtins::sqrt},
{"pow", builtins::pow},
{"sin", builtins::sin},
{"cos", builtins::cos},
{"tan", builtins::tan},
{"precision", builtins::precision},
{"leftPad", builtins::leftPad},
{"rightPad", builtins::rightPad},
{"middlePad", builtins::middlePad},
{"ordinal", builtins::ordinal},
{"duration", builtins::duration},
static bool initialized = false;
if (!initialized) {
// Initialize the function map
initialized = true;
config::addRuntimeFunction("len", builtins::len);
config::addRuntimeFunction("substr", builtins::substr);
config::addRuntimeFunction("toUpper", builtins::toUpper);
config::addRuntimeFunction("toLower", builtins::toLower);
config::addRuntimeFunction("trim", builtins::trim);
config::addRuntimeFunction("replace", builtins::replace);
config::addRuntimeFunction("random", builtins::random);
config::addRuntimeFunction("round", builtins::round);
config::addRuntimeFunction("floor", builtins::floor);
config::addRuntimeFunction("ceil", builtins::ceil);
config::addRuntimeFunction("abs", builtins::abs);
config::addRuntimeFunction("min", builtins::min);
config::addRuntimeFunction("max", builtins::max);
config::addRuntimeFunction("sum", builtins::sum);
config::addRuntimeFunction("avg", builtins::avg);
config::addRuntimeFunction("sqrt", builtins::sqrt);
config::addRuntimeFunction("pow", builtins::pow);
config::addRuntimeFunction("sin", builtins::sin);
config::addRuntimeFunction("cos", builtins::cos);
config::addRuntimeFunction("tan", builtins::tan);
config::addRuntimeFunction("precision", builtins::precision);
config::addRuntimeFunction("leftPad", builtins::leftPad);
config::addRuntimeFunction("rightPad", builtins::rightPad);
config::addRuntimeFunction("middlePad", builtins::middlePad);
config::addRuntimeFunction("ordinal", builtins::ordinal);
config::addRuntimeFunction("duration", builtins::duration);
// Aliases
{"lpad", builtins::leftPad},
{"rpad", builtins::rightPad},
{"mpad", builtins::middlePad},
{"ord", builtins::ordinal},
{"prec", builtins::precision},
{"rand", builtins::random}
};

auto it = functions.find(name);
config::addRuntimeFunction("lpad", builtins::leftPad);
config::addRuntimeFunction("rpad", builtins::rightPad);
config::addRuntimeFunction("mpad", builtins::middlePad);
config::addRuntimeFunction("ord", builtins::ordinal);
config::addRuntimeFunction("prec", builtins::precision);
config::addRuntimeFunction("rand", builtins::random);
}

auto& functions = config::getRuntimeFunctions();
auto it = functions.find(std::string(name));
if (it == functions.end()) {
return nullptr;
}
Expand Down
13 changes: 13 additions & 0 deletions test/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,18 @@ static void RIFT_TEST_CASE(const std::string& input, const std::string& expected
}
}

#include <rift/config.hpp>
rift::Value myCustomFunc(std::span<rift::Value> args) {
auto arg = rift::getArgument<std::string>(args, 0);
if (!arg) {
return rift::Value::from("Invalid argument");
}
return rift::Value::from("Hello, " + *arg + "!");
}

int main() {
rift::config::addRuntimeFunction("myCustomFunc", myCustomFunc);

std::cout << "Running tests..." << std::endl;
{
RIFT_TEST_CASE("Hello, {name}!", "Hello, World!", { VALUE("name", "World") });
Expand Down Expand Up @@ -78,6 +89,8 @@ int main() {
RIFT_TEST_CASE("{(true || null) == true}", "true");
RIFT_TEST_CASE("{true ?? 'cool'}{false ?? 'not cool'}", "cool");
RIFT_TEST_CASE("{duration(123.456)}", "2:03.456");

RIFT_TEST_CASE("{myCustomFunc('World')}", "Hello, World!");
}

// Show the results
Expand Down

0 comments on commit ec777a9

Please sign in to comment.