From 8e8ae6977619dc4e017ba82b4ea25f9a33eefc7d Mon Sep 17 00:00:00 2001 From: coavins Date: Sun, 16 Jul 2023 12:42:08 -0400 Subject: [PATCH] Add LuaConnector for autotracking N64 games This commit adds a LuaConnector class which is capable of connecting to emulator Lua scripts for autotracking. Proof of concept was done using BizHawk 2.8 and the WarpWorld ConnectorLib lua script to power the autotracking in my fork of Hamsda's ZOOTR pack. This commit also adds an IAutotrackProvider interface which the autotracker uses to abstract away the implementation of the LuaConnector. Maybe the other trackers (snes, uat, ap) can be migrated over to this interface to help clean up the autotracker code. --- Makefile | 2 + schema/packs/manifest.json | 4 + src/core/autotracker.h | 151 ++++++++++++++++- src/core/autotrackprovider.h | 44 +++++ src/core/pack.cpp | 15 +- src/core/scripthost.cpp | 2 +- src/core/scripthost.h | 4 +- src/core/tsbuffer.h | 73 ++++++++ src/luaconnector/common.h | 29 ++++ src/luaconnector/connection.cpp | 251 ++++++++++++++++++++++++++++ src/luaconnector/connection.h | 77 +++++++++ src/luaconnector/luaconnector.cpp | 262 +++++++++++++++++++++++++++++ src/luaconnector/luaconnector.h | 70 ++++++++ src/luaconnector/message.h | 79 +++++++++ src/luaconnector/server.cpp | 268 ++++++++++++++++++++++++++++++ src/luaconnector/server.h | 133 +++++++++++++++ src/luaconnector/tsqueue.h | 96 +++++++++++ 17 files changed, 1553 insertions(+), 7 deletions(-) create mode 100644 src/core/autotrackprovider.h create mode 100644 src/core/tsbuffer.h create mode 100644 src/luaconnector/common.h create mode 100644 src/luaconnector/connection.cpp create mode 100644 src/luaconnector/connection.h create mode 100644 src/luaconnector/luaconnector.cpp create mode 100644 src/luaconnector/luaconnector.h create mode 100644 src/luaconnector/message.h create mode 100644 src/luaconnector/server.cpp create mode 100644 src/luaconnector/server.h create mode 100644 src/luaconnector/tsqueue.h diff --git a/Makefile b/Makefile index 58fca022..c35dfdb3 100644 --- a/Makefile +++ b/Makefile @@ -12,6 +12,7 @@ SRC = $(wildcard $(SRC_DIR)/*.cpp) \ $(wildcard $(SRC_DIR)/usb2snes/*.cpp) \ $(wildcard $(SRC_DIR)/uat/*.cpp) \ $(wildcard $(SRC_DIR)/ap/*.cpp) \ + $(wildcard $(SRC_DIR)/luaconnector/*.cpp) \ $(wildcard $(SRC_DIR)/http/*.cpp) \ $(wildcard $(SRC_DIR)/packmanager/*.cpp) #lib/gifdec/gifdec.c @@ -23,6 +24,7 @@ HDR = $(wildcard $(SRC_DIR)/*.h) \ $(wildcard $(SRC_DIR)/usb2snes/*.h) \ $(wildcard $(SRC_DIR)/uat/*.h) \ $(wildcard $(SRC_DIR)/ap/*.h) \ + $(wildcard $(SRC_DIR)/luaconnector/*.h) \ $(wildcard $(SRC_DIR)/http/*.h) \ $(wildcard $(SRC_DIR)/packmanager/*.h) INCLUDE_DIRS = -Ilib -Ilib/lua -Ilib/asio/include -DASIO_STANDALONE -Ilib/miniz -Ilib/json/include -Ilib/valijson/include -Ilib/tinyfiledialogs -Ilib/wswrap/include #-Ilib/gifdec diff --git a/schema/packs/manifest.json b/schema/packs/manifest.json index bb743334..23929a53 100644 --- a/schema/packs/manifest.json +++ b/schema/packs/manifest.json @@ -20,6 +20,10 @@ "description": "Platform the pack is used for. If platform is \"snes\", the snes autotracker is to be enabled.", "type": "string" }, + "platform_poptracker": { + "description": "Platform the pack is used for. If specified, PopTracker will use this instead of the value indicated in the 'platform' field.", + "type": "string" + }, "versions_url": { "description": "URL to versions.json of the pack. Can be used for automatic updates. Information from global packs.json takes precedence. See https://github.com/black-sliver/PopTracker/tree/packlist for more information.", "type": "string", diff --git a/src/core/autotracker.h b/src/core/autotracker.h index edee480b..0964cae6 100644 --- a/src/core/autotracker.h +++ b/src/core/autotracker.h @@ -4,6 +4,8 @@ #include "../usb2snes/usb2snes.h" #include "../uat/uatclient.h" #include "../ap/aptracker.h" +#include "../luaconnector/luaconnector.h" +#include "autotrackprovider.h" #include "signal.h" #include #include @@ -68,6 +70,13 @@ class AutoTracker final : public LuaInterface{ _snes->setMapping(USB2SNES::Mapping::EXHIROM); } } + if( strcasecmp(platform.c_str(), "n64") == 0 ) { + _provider = new LuaConnector::LuaConnector(_name); + _lastBackendIndex++; + _backendIndex[_provider] = _lastBackendIndex; + _state.push_back(State::Disabled); + _provider->setMapping(flags); + } if (flags.find("uat") != flags.end()) { _uat = new UATClient(); _lastBackendIndex++; @@ -123,6 +132,12 @@ class AutoTracker final : public LuaInterface{ if (_ap) delete _ap; _ap = nullptr; + if( _provider ) + { + delete _provider; + _provider = nullptr; + } + if (spawnedWorkers) { // wait a bit if we started a thread to increase readability of logs std::this_thread::sleep_for(std::chrono::milliseconds(21)); @@ -146,6 +161,7 @@ class AutoTracker final : public LuaInterface{ if (name == BACKEND_AP_NAME) return _ap ? getState(_backendIndex[_ap]) : State::Unavailable; if (name == BACKEND_UAT_NAME) return _uat ? getState(_backendIndex[_uat]) : State::Unavailable; if (name == BACKEND_SNES_NAME) return _snes ? getState(_backendIndex[_snes]) : State::Unavailable; + if (_provider && name == _provider->getName()) return _provider ? getState(_backendIndex[_provider]) : State::Unavailable; return State::Unavailable; } @@ -154,6 +170,7 @@ class AutoTracker final : public LuaInterface{ if (_ap && _backendIndex[_ap] == index) return BACKEND_AP_NAME; if (_uat && _backendIndex[_uat] == index) return BACKEND_UAT_NAME; if (_snes && _backendIndex[_snes] == index) return BACKEND_SNES_NAME; + if( _provider && _backendIndex[_provider] == index ) return _provider->getName(); return BACKEND_NONE_NAME; } @@ -243,6 +260,36 @@ class AutoTracker final : public LuaInterface{ res = true; } + if( _provider && backendEnabled(_provider) ) + { + int index = _backendIndex[_provider]; + State oldState = _state[index]; + bool isReady = _provider->isReady(); + bool gameConnected = isReady ? _provider->isConnected() : false; + + if( gameConnected ) { + _state[index] = State::ConsoleConnected; + } + else if( isReady ) { + _state[index] = State::BridgeConnected; + } + else { + _state[index] = State::Disconnected; + } + + if( _state[index] != oldState ) { + onStateChange.emit(this, index, _state[index]); + } + + if( _provider->update() ) { + if( _state[index] == State::ConsoleConnected ) { + onDataChange.emit(this); + } + + res = true; + } + } + return res; } @@ -253,6 +300,11 @@ class AutoTracker final : public LuaInterface{ _snes->addWatch((uint32_t)addr, len); return true; } + else if( _provider ) + { + _provider->addWatch((uint32_t)addr, len); + return true; + } return false; } @@ -263,12 +315,18 @@ class AutoTracker final : public LuaInterface{ _snes->removeWatch((uint32_t)addr, len); return true; } + else if( _provider ) { + _provider->removeWatch((uint32_t)addr, len); + return true; + } return false; } void setInterval(unsigned ms) { if (_snes) _snes->setUpdateInterval(ms); + if( _provider ) + _provider->setWatchUpdateInterval(ms); } void clearCache() { @@ -276,6 +334,8 @@ class AutoTracker final : public LuaInterface{ _snes->clearCache(); if (_uat) _uat->sync(_slot); + if( _provider ) + _provider->clearCache(); } // TODO: canRead(addr,len) to detect incomplete segment @@ -288,18 +348,65 @@ class AutoTracker final : public LuaInterface{ for (size_t i=0; ireadFromCache((uint32_t)addr, len, buf); + for( size_t i = 0; i < len; i++ ) + res.push_back(buf[i]); + } return res; } int ReadU8(int segment, int offset=0) { + if( _provider ) + { + // this is a live blocking call to read memory from the game + uint32_t address = segment; + uint32_t o = offset; + return _provider->readU8Live(address, o); + } + else // NOTE: this is AutoTracker:Read8. we only have 1 segment, that is AutoTracker return ReadUInt8(segment+offset); } - int ReadU16(int segment, int offset=0) { return ReadUInt16(segment+offset); } - int ReadU24(int segment, int offset=0) { return ReadUInt24(segment+offset); } - int ReadU32(int segment, int offset=0) { return ReadUInt32(segment+offset); } + int ReadU16(int segment, int offset=0) + { + if( _provider ) + { + // this is a live blocking call to read memory from the game + uint32_t address = segment; + uint32_t o = offset; + return _provider->readU16Live(address, o); + } + else + return ReadUInt16(segment+offset); + } + int ReadU24(int segment, int offset=0) + { + if( _provider ) + { + // this is a live blocking call to read memory from the game + uint32_t address = segment; + uint32_t o = offset; + return _provider->readU32Live(address, o) & 0xffffff; + } + else + return ReadUInt24(segment+offset); + } + int ReadU32(int segment, int offset=0) + { + if( _provider ) + { + // this is a live blocking call to read memory from the game + uint32_t address = segment; + uint32_t o = offset; + return _provider->readU32Live(address, o); + } + else + return ReadUInt32(segment+offset); + } int ReadUInt8(int addr) { @@ -311,6 +418,11 @@ class AutoTracker final : public LuaInterface{ //printf("$%06x = %02x\n", a, res); return res; } + else if( _provider ) { + auto res = _provider->readUInt8FromCache(addr); + //if( res == 0 ) _provider->addWatch(addr, 1); + return res; + } return 0; } @@ -321,6 +433,11 @@ class AutoTracker final : public LuaInterface{ if (res == 0) _snes->addWatch(addr,2); return res; } + else if( _provider ) { + auto res = _provider->readUInt16FromCache(addr); + //if( res == 0 ) _provider->addWatch(addr, 2); + return res; + } return 0; } @@ -331,6 +448,11 @@ class AutoTracker final : public LuaInterface{ if (res == 0) _snes->addWatch(addr,3); return res; } + else if( _provider ) { + auto res = _provider->readUInt32FromCache(addr) & 0xffffff; + //if( res == 0 ) _provider->addWatch(addr, 3); + return res; + } return 0; } @@ -341,6 +463,11 @@ class AutoTracker final : public LuaInterface{ if (res == 0) _snes->addWatch(addr,4); return res; } + else if( _provider ) { + auto res = _provider->readUInt32FromCache(addr); + //if( res == 0 ) _provider->addWatch(addr, 4); + return res; + } return 0; } @@ -377,6 +504,12 @@ class AutoTracker final : public LuaInterface{ return true; } } + else if ( _provider && _backendIndex[_provider] == index) { + _state[index] = State::Disconnected; + onStateChange.emit(this, index, _state[index]); + _provider->start(); + return true; + } return false; } @@ -411,6 +544,12 @@ class AutoTracker final : public LuaInterface{ } if (_ap && _backendIndex[_ap] == index) _ap->disconnect(); + if( _provider && _backendIndex[_provider] == index ) + { + _provider->stop(); + _provider->clearCache(); + } + onStateChange.emit(this, index, _state[index]); } } @@ -420,6 +559,11 @@ class AutoTracker final : public LuaInterface{ return _ap; } + IAutotrackProvider* getAutotrackProvider() const + { + return _provider; + } + void setSnesAddresses(const std::vector& addresses) { _snesAddresses = addresses; @@ -457,6 +601,7 @@ class AutoTracker final : public LuaInterface{ USB2SNES *_snes = nullptr; UATClient *_uat = nullptr; APTracker *_ap = nullptr; + IAutotrackProvider* _provider = nullptr; std::string _slot; // selected slot for UAT std::map _vars; // variable store for UAT std::string _name; diff --git a/src/core/autotrackprovider.h b/src/core/autotrackprovider.h new file mode 100644 index 00000000..fa6b7429 --- /dev/null +++ b/src/core/autotrackprovider.h @@ -0,0 +1,44 @@ +#ifndef _CORE_AUTOTRACK_PROVIDER_H +#define _CORE_AUTOTRACK_PROVIDER_H + +#include +#include +#include +#include + +class IAutotrackProvider { +public: + virtual ~IAutotrackProvider() = default; + + virtual const std::string& getName() = 0; + + virtual bool start() = 0; + virtual bool stop() = 0; + + // Returns true if cache was changed + virtual bool update() = 0; + + virtual bool isReady() = 0; + virtual bool isConnected() = 0; + + virtual void clearCache() = 0; + + virtual void addWatch(uint32_t address, unsigned int length) = 0; + virtual void removeWatch(uint32_t address, unsigned int length) = 0; + virtual void setWatchUpdateInterval(size_t interval) = 0; + + virtual void setMapping(const std::set& flags) = 0; + virtual uint32_t mapAddress(uint32_t address) = 0; + + virtual bool readFromCache(uint32_t address, unsigned int length, void* out) = 0; + virtual uint8_t readUInt8FromCache(uint32_t address, uint32_t offset = 0) { return 0; } + virtual uint16_t readUInt16FromCache(uint32_t address, uint32_t offset = 0) { return 0; } + virtual uint32_t readUInt32FromCache(uint32_t address, uint32_t offset = 0) { return 0; } + + virtual uint8_t readU8Live(uint32_t address, uint32_t offset = 0) { return 0; } + virtual uint16_t readU16Live(uint32_t address, uint32_t offset = 0) { return 0; } + virtual uint32_t readU32Live(uint32_t address, uint32_t offset = 0) { return 0; } +}; + +#endif /* _CORE_AUTOTRACK_PROVIDER_H */ + diff --git a/src/core/pack.cpp b/src/core/pack.cpp index 13e1fbc7..0cf3ba55 100644 --- a/src/core/pack.cpp +++ b/src/core/pack.cpp @@ -75,6 +75,12 @@ Pack::Info Pack::getInfo() const _minPopTrackerVersion, variants }; + + // Use the PopTracker platform override field + std::string platform_override = to_string(_manifest, "platform_poptracker", ""); + if( platform_override != "" ) + info.platform = platform_override; + return info; } @@ -186,7 +192,14 @@ void Pack::setVariant(const std::string& variant) std::string Pack::getPlatform() const { - return to_string(_manifest,"platform",""); + std::string platform = to_string(_manifest, "platform", ""); + + // Use the PopTracker platform override field + std::string platform_override = to_string(_manifest, "platform_poptracker", ""); + if( platform_override != "" ) + platform = platform_override; + + return platform; } std::string Pack::getVersion() const diff --git a/src/core/scripthost.cpp b/src/core/scripthost.cpp index 31124334..ea157add 100644 --- a/src/core/scripthost.cpp +++ b/src/core/scripthost.cpp @@ -198,7 +198,7 @@ LuaItem* ScriptHost::CreateLuaItem() return _tracker->CreateLuaItem(); } -std::string ScriptHost::AddMemoryWatch(const std::string& name, int addr, int len, LuaRef callback, int interval) +std::string ScriptHost::AddMemoryWatch(const std::string& name, unsigned int addr, int len, LuaRef callback, int interval) { if (interval==0) interval=500; /*orig:1000*/ // default diff --git a/src/core/scripthost.h b/src/core/scripthost.h index 49f4c039..0199645f 100644 --- a/src/core/scripthost.h +++ b/src/core/scripthost.h @@ -21,7 +21,7 @@ class ScriptHost : public LuaInterface { bool LoadScript(const std::string& file); LuaItem *CreateLuaItem(); - std::string AddMemoryWatch(const std::string& name, int addr, int len, LuaRef callback, int interval); + std::string AddMemoryWatch(const std::string& name, unsigned int addr, int len, LuaRef callback, int interval); bool RemoveMemoryWatch(const std::string& name); std::string AddWatchForCode(const std::string& name, const std::string& code, LuaRef callback); bool RemoveWatchForCode(const std::string& name); @@ -35,7 +35,7 @@ class ScriptHost : public LuaInterface { struct MemoryWatch { int callback; - int addr; + unsigned int addr; int len; int interval; std::string name; diff --git a/src/core/tsbuffer.h b/src/core/tsbuffer.h new file mode 100644 index 00000000..d4908c6b --- /dev/null +++ b/src/core/tsbuffer.h @@ -0,0 +1,73 @@ +#ifndef _CORE_TSBUFFER_H +#define _CORE_TSBUFFER_H + +#include +#include +#include + +/// +/// Threadsafe buffer class +/// +template +class tsbuffer { +public: + tsbuffer() = default; + tsbuffer(const tsbuffer&) = delete; + + T& operator[] (const uint32_t& k) + { + //std::scoped_lock lock(_mutex); + std::scoped_lock lock(_mutex); + return _data[k]; + } + + void read(uint32_t address, unsigned int length, void* out) + { + std::scoped_lock lock(_mutex); + uint8_t* dst = (uint8_t*)out; + for (size_t i = 0; i < length; i++) dst[i] = _data[address + i]; + } + + template + R readInt(uint32_t addr) + { + std::scoped_lock lock(_mutex); + R res = 0; + { + for (size_t n = 0; n < sizeof(R); n++) { + res <<= 8; + res += _data[addr + sizeof(R) - n - 1]; + } + } + return res; + } + + // Returns true if any data was changed. + bool write(uint32_t address, unsigned int length, const char* in) + { + std::scoped_lock lock(_mutex); + uint8_t* src = (uint8_t*)in; + bool bChanged = false; + + for (size_t i = 0; i < length; i++) { + if (!bChanged && _data[address + i] != src[i]) + bChanged = true; + + _data[address + i] = src[i]; + } + + return bChanged; + } + + void clear() + { + std::scoped_lock lock(_mutex); + _data.clear(); + } + +protected: + std::mutex _mutex; + std::map _data; +}; + +#endif /* _CORE_TSBUFFER_H */ diff --git a/src/luaconnector/common.h b/src/luaconnector/common.h new file mode 100644 index 00000000..9d78b797 --- /dev/null +++ b/src/luaconnector/common.h @@ -0,0 +1,29 @@ +#ifndef _LUACOMMON_H +#define _LUACOMMON_H + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#ifndef _WIN32_WINNT +#define _WIN32_WINNT 0x0A00 +#endif +#endif + +#ifndef ASIO_STANDALONE +#define ASIO_STANDALONE +#endif + +#include +#include + +#include + +#undef LUACONNECTOR_ASYNC +#endif diff --git a/src/luaconnector/connection.cpp b/src/luaconnector/connection.cpp new file mode 100644 index 00000000..3f2e2f65 --- /dev/null +++ b/src/luaconnector/connection.cpp @@ -0,0 +1,251 @@ +#include "connection.h" + +namespace LuaConnector { + +namespace Net { + +#ifndef LUACONNECTOR_ASYNC +Connection::Connection(asio::io_context& context, asio::ip::tcp::socket socket) + : _context(context), _socket(std::move(socket)) +{ +} +#else +Connection::Connection(asio::io_context& context, asio::ip::tcp::socket socket, tsqueue& qIn) + : _context(context), _socket(std::move(socket)), _qMessagesIn(qIn) +{ +} +#endif + +Connection::~Connection() +{ + printf("LuaConnector: [%u] Disconnected\n", _uid); +} + +void Connection::ConnectToClient(uint32_t uid) +{ + try { + if (_socket.is_open()) { + _uid = uid; + + printf("LuaConnector: [%u] Connection established\n", _uid); + +#ifdef LUACONNECTOR_ASYNC + // kick off async task + ReadHeaderAsync(); +#endif + } + } + catch (const std::exception& e) { + printf("LuaConnector: [%u] Exception: %s\n", _uid, e.what()); + } +} + +bool Connection::Disconnect() +{ + if (IsConnected()) { + asio::post(_context, [this]() { _socket.close(); }); + } + + return false; +} + +bool Connection::IsConnected() const +{ + return _socket.is_open(); +} + +uint32_t Connection::GetID() const +{ + return _uid; +} + +#ifndef LUACONNECTOR_ASYNC +Message Connection::Send(const Message& msg) +{ + if (_socket.is_open()) { + //printf("LuaConnector: [%u] Send on main thread\n", _uid); + + try { + Message reply; + + WriteHeader(msg); + WriteBody(msg); + ReadHeader(reply); + ReadBody(reply); + + return reply; + } + catch (const std::exception&) { + printf("LuaConnector: [%u] Communication failure\n", _uid); + _socket.close(); + return Message(); + } + } + + return Message(); +} + +void Connection::ReadHeader(Message& msg) +{ + asio::read(_socket, asio::buffer(&msg.header, sizeof(MessageHeader))); + + msg.header.size = ntohl(msg.header.size); + + if (msg.header.size > 0) { + msg.body.resize(msg.header.size); + } +} + +void Connection::ReadBody(Message& msg) +{ + asio::read(_socket, asio::buffer(msg.body.data(), msg.header.size)); +} + +void Connection::WriteHeader(const Message& msg) +{ + asio::write(_socket, asio::buffer(&msg.header, sizeof(MessageHeader))); +} + +void Connection::WriteBody(const Message& msg) +{ + asio::write(_socket, asio::buffer(msg.body.data(), msg.body.size())); +} + +#else +void Connection::SendAsync(const Message& msg) +{ + //printf("LuaConnector: [%u] Sending Header: %u Body size: %u\n", _uid, ntohl(msg.header.size), msg.body.size()); + + asio::post(_context, + [this, msg]() + { + bool notWritingMessage = _qMessagesOut.push_back(msg); + + if (notWritingMessage) { + WriteHeaderAsync(); + } + }); +} + +void Connection::ReadHeaderAsync() +{ + try { + asio::async_read(_socket, asio::buffer(&_msgTemporaryIn.header, sizeof(MessageHeader)), + [this](std::error_code ec, std::size_t length) + { + if (!ec) { + // use network byte order + _msgTemporaryIn.header.size = ntohl(_msgTemporaryIn.header.size); + + if (_msgTemporaryIn.header.size > 0) { + // resize message body to make room + _msgTemporaryIn.body.resize(_msgTemporaryIn.header.size); + + // register next async task + ReadBodyAsync(); + } + else { + AddToIncomingMessageQueue(); + } + } + else { + //printf("LuaConnector: [%u] Read header failed: %s\n", _uid, ec.message().c_str()); + _socket.close(); + } + }); + } + catch (const std::exception& e) { + printf("LuaConnector: [%u] Register async read failed: %s\n", _uid, e.what()); + _socket.close(); + } +} + +void Connection::ReadBodyAsync() +{ + try { + asio::async_read(_socket, asio::buffer(_msgTemporaryIn.body.data(), _msgTemporaryIn.header.size), + [this](std::error_code ec, std::size_t length) + { + if (!ec) { + AddToIncomingMessageQueue(); + } + else { + //printf("LuaConnector: [%u] Read body failed: %s\n", _uid, ec.message().c_str()); + _socket.close(); + } + }); + } + catch (const std::exception& e) { + printf("LuaConnector: [%u] Register async read failed: %s\n", _uid, e.what()); + _socket.close(); + } +} + +void Connection::WriteHeaderAsync() +{ + try { + asio::async_write(_socket, asio::buffer(&_qMessagesOut.front().header, sizeof(MessageHeader)), + [this](std::error_code ec, std::size_t length) + { + if (!ec) { + if (_qMessagesOut.front().body.size() > 0) { + WriteBodyAsync(); + } + else { + bool isEmpty = _qMessagesOut.remove_front(); + + if (!isEmpty) { + WriteHeaderAsync(); + } + } + } + else { + //printf("LuaConnector: [%u] Write header failed: %s\n", _uid, ec.message().c_str()); + _socket.close(); + } + }); + } + catch (const asio::system_error& e) { + printf("LuaConnector: [%u] Register async write failed: %s\n", _uid, e.what()); + _socket.close(); + } +} + +void Connection::WriteBodyAsync() +{ + try { + asio::async_write(_socket, asio::buffer(_qMessagesOut.front().body.data(), _qMessagesOut.front().body.size()), + [this](std::error_code ec, std::size_t length) + { + if (!ec) { + bool isNowEmpty = _qMessagesOut.remove_front(); + + if (!isNowEmpty) { + WriteHeaderAsync(); + } + } + else { + //printf("LuaConnector: [%u] Write body failed: %s\n", _uid, ec.message().c_str()); + _socket.close(); + } + }); + } + catch (const std::exception& e) { + printf("LuaConnector: [%u] Register async write failed: %s\n", _uid, e.what()); + _socket.close(); + } +} + +void Connection::AddToIncomingMessageQueue() +{ + // Add complete message to server's incoming message queue + _qMessagesIn.push_back(_msgTemporaryIn); + + // register next async task + ReadHeaderAsync(); +} + +#endif +} + +} diff --git a/src/luaconnector/connection.h b/src/luaconnector/connection.h new file mode 100644 index 00000000..93331d30 --- /dev/null +++ b/src/luaconnector/connection.h @@ -0,0 +1,77 @@ +#ifndef _CONNECTION_H +#define _CONNECTION_H + +#include "common.h" +#include "tsqueue.h" +#include "message.h" + +using json = nlohmann::json; + +namespace LuaConnector { + +namespace Net { + +class Connection { +public: +#ifndef LUACONNECTOR_ASYNC + Connection(asio::io_context&, asio::ip::tcp::socket); +#else + Connection(asio::io_context&, asio::ip::tcp::socket, tsqueue& qIn); +#endif + ~Connection(); + + // Anything we should prepare when newly connected + void ConnectToClient(uint32_t uid = 0); + // Anything we should clean up when disconnected + bool Disconnect(); + + bool IsConnected() const; + + uint32_t GetID() const; + +#ifndef LUACONNECTOR_ASYNC + // Send a message to the client, and return its response. + // This function will block until the full response is received. + // Do not use with async alternative in the same application. + Message Send(const Message&); + +protected: + void ReadHeader(Message&); + void ReadBody(Message&); + void WriteHeader(const Message&); + void WriteBody(const Message&); + +#else + // Queue a message to be sent to the client. + // This function is non-blocking. + // Do not use with blocking alternative in the same application. + void SendAsync(const Message&); + +protected: + void ReadHeaderAsync(); + void ReadBodyAsync(); + void WriteHeaderAsync(); + void WriteBodyAsync(); + + void AddToIncomingMessageQueue(); + + tsqueue _qMessagesOut; + tsqueue& _qMessagesIn; + +#endif + +protected: + asio::io_context& _context; + asio::ip::tcp::socket _socket; + + Message _msgTemporaryIn; + + uint32_t _uid = 0; + uint32_t _bid = 0; +}; + +} + +} + +#endif diff --git a/src/luaconnector/luaconnector.cpp b/src/luaconnector/luaconnector.cpp new file mode 100644 index 00000000..b6a7518d --- /dev/null +++ b/src/luaconnector/luaconnector.cpp @@ -0,0 +1,262 @@ +#include "luaconnector.h" +#include "common.h" +#include "../core/util.h" +#include "server.h" +#include + +namespace LuaConnector { +LuaConnector::LuaConnector(const std::string& name) +{ + printf("LuaConnector(%s)\n", sanitize_print(name).c_str()); + _appname = name; +} + +LuaConnector::~LuaConnector() +{ + printf("~LuaConnector()\n"); + + stop(); +} + +const std::string& LuaConnector::getName() +{ + return _name; +} + +bool LuaConnector::start() +{ + try { + if (!_server) { + _server = new Net::Server(_data); + _server->Start(); + } + } + catch (const std::exception& e) { + printf("LuaConnector Exception: %s\n", e.what()); + return false; + } + + return true; +} + +bool LuaConnector::stop() +{ + try { + if (_server) { + delete _server; + _server = nullptr; + } + } + catch (const std::exception& e) { + printf("LuaConnector Exception: %s\n", e.what()); + return false; + } + + return false; +} + +bool LuaConnector::update() +{ + bool data_changed = false; + + if (_server) { + if (_recalculateWatches) { + recalculate_watches(); + _recalculateWatches = false; + } + + // check watches + if (_server->ClientIsConnected() + && _lastWatchCheck + std::chrono::milliseconds(_watchRefreshMilliseconds) < std::chrono::system_clock::now()) { + //printf("Checking watches\n"); + _lastWatchCheck = std::chrono::system_clock::now(); + for (auto& w : _combinedWatches) // use _combinedWatches in prod + { + data_changed |= _server->ReadBlockBuffered(w.first, w.second); + + if (!_server->ClientIsConnected()) + break; + } + } + + data_changed |= _server->Update(); + + //if( data_changed ) + //{ + // std::bitset<8> t = _data[mapaddr(0x8011AB07)]; + // printf("Mido's house: %s\n", t.to_string().c_str()); + //} + } + + return data_changed; +} + +bool LuaConnector::isReady() +{ + return _server && _server->IsListening(); +} + +bool LuaConnector::isConnected() +{ + return _server && _server->ClientIsConnected(); +} + +void LuaConnector::clearCache() +{ + _data.clear(); +} + +void LuaConnector::addWatch(uint32_t address, unsigned int length) +{ + //printf("addWatch\n"); + address = mapAddress(address); + + for (auto& w : _watches) { + if (w.first == address && w.second == length) + return; + } + + _watches.push_back(std::make_pair(address, length)); + _recalculateWatches = true; +} + +void LuaConnector::removeWatch(uint32_t address, unsigned int length) +{ + address = mapAddress(address); + _watches.erase(std::remove(_watches.begin(), _watches.end(), std::make_pair(address, length))); + _recalculateWatches = true; +} + +void LuaConnector::setWatchUpdateInterval(size_t interval) +{ + _watchRefreshMilliseconds = interval; +} + +void LuaConnector::setMapping(const std::set& flags) +{ +} + +uint32_t LuaConnector::mapAddress(uint32_t address) +{ + if (address >= 0x80000000) + return address - 0x80000000; + else + return address; +} + +bool LuaConnector::readFromCache(uint32_t address, unsigned int length, void* out) +{ + address = mapAddress(address); + _data.read(address, length, out); + return true; +} + +uint8_t LuaConnector::readUInt8FromCache(uint32_t address, uint32_t offset) +{ + address = mapAddress(address); + return _data[address]; +} + +uint16_t LuaConnector::readUInt16FromCache(uint32_t address, uint32_t offset) +{ + address = mapAddress(address); + return _data.readInt(address); +} + +uint32_t LuaConnector::readUInt32FromCache(uint32_t address, uint32_t offset) +{ + address = mapAddress(address); + return _data.readInt(address); +} + +// returns watches combined +std::vector combine_watches(std::vector src, bool& bChanged) +{ + bChanged = false; + uint32_t match_distance = 0xFF; + + Watchlist out; + + // for each watch + for (auto& w : src) { + bool matched = false; + + uint32_t w_begin = w.first; + uint32_t w_end = w.first + w.second; + + // find a combined watch that's close to it + for (auto& c : out) { + uint32_t c_begin = c.first; + uint32_t c_end = c.first + c.second; + + // we are near the end of combined watch + uint32_t gap = w_begin - c_end; + if (gap < match_distance) { + // extend end of combined watch to encompass us + c.second += gap + w.second; + + matched = true; + bChanged = true; + } + + // we are near the start of combined watch + gap = c_begin - w_end; + if (gap < match_distance) { + // extend beginning of combined watch to encompass us + c.first -= gap - w.second; + c.second += gap + w.second; + + matched = true; + bChanged = true; + } + } + + // did not match + if (!matched) { + out.push_back(w); + } + } + + return out; +} + +uint8_t LuaConnector::readU8Live(uint32_t address, uint32_t offset) +{ + address = mapAddress(address); + if (_server && _server->ClientIsConnected()) { + return _server->ReadU8Live(address); + } + return 0; +} + +uint16_t LuaConnector::readU16Live(uint32_t address, uint32_t offset) +{ + address = mapAddress(address); + if (_server && _server->ClientIsConnected()) { + uint16_t result = _server->ReadU16Live(address); + + // for backwards compatibility with EmoTracker, + // we reverse the byte order + uint16_t swapped = ((result & 0xff) << 8) | ((result & 0xff00) >> 8); + + return swapped; + } + return 0; +} + +void LuaConnector::recalculate_watches() +{ + printf("Recalculating watches\n"); + + _combinedWatches.clear(); + Watchlist w = _watches; + + bool changed = true; + while (changed) + w = combine_watches(w, changed); + + _combinedWatches = w; +} + +} + diff --git a/src/luaconnector/luaconnector.h b/src/luaconnector/luaconnector.h new file mode 100644 index 00000000..557b53fd --- /dev/null +++ b/src/luaconnector/luaconnector.h @@ -0,0 +1,70 @@ +#ifndef _LUACONNECTOR_H +#define _LUACONNECTOR_H + +#include "../core/autotrackprovider.h" +#include +#include "../core/tsbuffer.h" + +namespace LuaConnector { + +namespace Net { +class Server; +} + +typedef std::pair Watch; +typedef std::vector Watchlist; + +class LuaConnector : public IAutotrackProvider { +public: + LuaConnector(const std::string& appname); + ~LuaConnector(); + + const std::string& getName() override; + + bool start() override; + bool stop() override; + + bool update() override; + + bool isReady() override; + bool isConnected() override; + + void clearCache() override; + + void addWatch(uint32_t address, unsigned int length) override; + void removeWatch(uint32_t address, unsigned int length) override; + void setWatchUpdateInterval(size_t interval) override; + + void setMapping(const std::set& flags) override; + uint32_t mapAddress(uint32_t address) override; + + bool readFromCache(uint32_t address, unsigned int length, void* out) override; + uint8_t readUInt8FromCache(uint32_t address, uint32_t offset = 0) override; + uint16_t readUInt16FromCache(uint32_t address, uint32_t offset = 0) override; + uint32_t readUInt32FromCache(uint32_t address, uint32_t offset = 0) override; + + uint8_t readU8Live(uint32_t address, uint32_t offset = 0) override; + uint16_t readU16Live(uint32_t address, uint32_t offset = 0) override; + +private: + + void recalculate_watches(); + + Net::Server* _server = nullptr; + + std::string _appname; + const std::string _name = "Lua"; + + tsbuffer _data; + + std::chrono::time_point _lastWatchCheck; + Watchlist _watches; + Watchlist _combinedWatches; + bool _recalculateWatches = true; + + uint64_t _watchRefreshMilliseconds = 1500; +}; + +} + +#endif // _LUACONNECTOR_H diff --git a/src/luaconnector/message.h b/src/luaconnector/message.h new file mode 100644 index 00000000..eb74ce9d --- /dev/null +++ b/src/luaconnector/message.h @@ -0,0 +1,79 @@ +#ifndef _MESSAGE_H +#define _MESSAGE_H + +#include "common.h" + +using json = nlohmann::json; + +namespace LuaConnector +{ + namespace Net + { + struct MessageHeader + { + uint32_t size = 0; + }; + + struct Message + { + MessageHeader header{}; + std::vector body; + + Message() {} + + Message(const json& j) + { + // dump json to string + std::string s = j.dump(); + + //printf("New message contents: %s\n", s.c_str()); + + // get size of string + std::size_t size = s.size(); + + // resize body + body.resize(size); + + // copy string into message body + std::memcpy(body.data(), s.c_str(), size); + + // recalculate header + header.size = body.size(); + + // use network byte order + header.size = htonl(header.size); + } + + json GetJson() const + { + // initialize string + std::string s; + + // resize string to fit body + s.resize(body.size() + 1, '\0'); + + // copy message body into string + std::memcpy(&s[0], body.data(), body.size()); + + json j; + + try + { + j = json::parse(s); + } + catch( const json::parse_error& e ) + { + printf("Failed to parse json: %s\n", e.what()); + } + + return j; + } + + std::size_t size() const + { + return sizeof(MessageHeader) + body.size(); + } + }; + } +} +#endif diff --git a/src/luaconnector/server.cpp b/src/luaconnector/server.cpp new file mode 100644 index 00000000..cd9176a1 --- /dev/null +++ b/src/luaconnector/server.cpp @@ -0,0 +1,268 @@ +#include "server.h" +#include "../core/util.h" +#include "connection.h" +#include +#include + +using asio::ip::tcp; + +namespace LuaConnector { + +namespace Net { + +Server::Server(tsbuffer& data) : _data(data), _acceptor(_context, tcp::endpoint(tcp::v4(), _CONNECTORLIB_PORT)) +{ +} + +Server::~Server() +{ + Stop(); +} + +bool Server::Start() +{ + try { + // give the context work to do when it starts + waitForClientConnectionAsync(); + + // start the context in its thread + _threadContext = std::thread([this]() + { + _context.run(); + }); + + _isListening = true; + } + catch (const std::exception& e) { + printf("LuaConnector: Exception: %s\n", e.what()); + return false; + } + + printf("LuaConnector: Started\n"); + return true; +} + +void Server::Stop() +{ + // request the context to close + _context.stop(); + + // wait for thread to stop + if (_threadContext.joinable()) _threadContext.join(); + + _isListening = false; + + printf("LuaConnector: Stopped\n"); +} + +bool Server::IsListening() +{ + return _isListening; +} + +bool Server::ClientIsConnected() +{ + return _connection && _connection->IsConnected(); +} + +bool Server::Update() +{ + bool data_changed = false; + +#ifdef LUACONNECTOR_ASYNC + // loop inbound messages + while (!_qMessagesIn.empty()) { + auto msg = _qMessagesIn.pop_front(); + + json body = msg.GetJson(); + + data_changed |= updateBuffer(body); + } + +#endif + // send keepalive + if (_connection && _lastKeepalive + std::chrono::milliseconds(2000) < std::chrono::system_clock::now()) { + _lastKeepalive = std::chrono::system_clock::now(); + do_nothing(); + } + + return data_changed; +} + +bool Server::ReadByteBuffered(uint32_t address) +{ + json body; + + body["type"] = READ_BYTE; + body["address"] = address; + body["domain"] = "RDRAM"; + + json reply = sendJsonMessage(body); + return updateBuffer(reply); +} + +bool Server::ReadBlockBuffered(uint32_t address, uint32_t length) +{ + json body; + + body["type"] = READ_BLOCK; + body["address"] = address; + body["value"] = length; + body["domain"] = "RDRAM"; + + json reply = sendJsonMessage(body); + return updateBuffer(reply); +} + +void Server::do_nothing() +{ + json body; + + body["type"] = DO_NOTHING; + + sendJsonMessage(body); +} + +void Server::print_message(const std::string& messageText) +{ + json body; + + body["type"] = PRINT_MESSAGE; + body["message"] = messageText; + + sendJsonMessage(body); +} + +uint8_t Server::ReadU8Live(uint32_t address) +{ + json block; + + block["type"] = READ_BYTE; + block["address"] = address; + block["domain"] = "RDRAM"; + + json reply = sendJsonMessage(block); + + uint8_t result = 0; + + if (!reply.is_null()) + result = reply["value"]; + + return result; +} + +uint16_t Server::ReadU16Live(uint32_t address) +{ + json block; + + block["type"] = READ_USHORT; + block["address"] = address; + block["domain"] = "RDRAM"; + + json reply = sendJsonMessage(block); + + uint16_t result = 0; + + if (!reply.is_null()) + result = reply["value"]; + + return result; +} + +// async +void Server::waitForClientConnectionAsync() +{ + _acceptor.async_accept( + [this](std::error_code ec, tcp::socket socket) + { + // Incoming connection request + if (!ec) { + printf("LuaConnector: New connection: %s:%d\n", socket.remote_endpoint().address().to_string().c_str(), socket.remote_endpoint().port()); + + if (_connection == nullptr || !_connection->IsConnected()) { + try { + if (_connection) { + // disconnect the client + onClientDisconnect(); + _connection.reset(); + } + + _connection = std::make_unique(_context, std::move(socket)); + + _connection->ConnectToClient(_idCounter++); + + onClientConnect(); + } + catch (const std::exception& e) { + printf("LuaConnector: Exception: %s\n", e.what()); + } + } + else { + printf("LuaConnector: Rejected, connection in use\n"); + } + } + else { + printf("LuaConnector: New connection error: %s\n", ec.message().c_str()); + } + + // give the context more work - wait for another connection + waitForClientConnectionAsync(); + } + ); +} + +json Server::sendJsonMessage(const json& j) +{ + Message m(j); + + Message reply = messageClient(m); + + if (reply.body.size() > 0) + return reply.GetJson(); + else + return json(); +} + +Message Server::messageClient(const Message& msg) +{ + if (_connection && _connection->IsConnected()) { + return _connection->Send(msg); + } + else { + // disconnect the client + onClientDisconnect(); + _connection.reset(); + return Message(); + } +} + +void Server::onClientConnect() +{ + //print_message("Connected to PopTracker"); // say hello +} + +void Server::onClientDisconnect() +{ + +} + +bool Server::updateBuffer(const json& response) +{ + if (response.is_null()) + return false; + + if (response["type"] == READ_BLOCK) { + uint32_t address = response["address"]; + size_t length = response["value"]; + std::string encoded = response["block"]; + std::string buffer = websocketpp::base64_decode(encoded); + + return _data.write(address, length, buffer.c_str()); + } + + return false; +} + +} + +} diff --git a/src/luaconnector/server.h b/src/luaconnector/server.h new file mode 100644 index 00000000..ea22650b --- /dev/null +++ b/src/luaconnector/server.h @@ -0,0 +1,133 @@ +#ifndef _SERVER_H +#define _SERVER_H + +#include "common.h" +#include "message.h" +#include "connection.h" +#include "../core/tsbuffer.h" + +using json = nlohmann::json; + +namespace LuaConnector { + +namespace Net { + +enum MESSAGE_TYPES { + READ_BYTE = 0x00, + READ_USHORT = 0x01, + READ_UINT = 0x02, + READ_BLOCK = 0x0F, + WRITE_BYTE = 0x10, + WRITE_USHORT = 0x11, + WRITE_UINT = 0x12, + WRITE_BLOCK = 0x1F, + ATOMIC_BIT_FLIP = 0x20, + ATOMIC_BIT_UNFLIP = 0x21, + MEMORY_FREEZE_UNSIGNED = 0x30, + MEMORY_UNFREEZE = 0x3F, + LOAD_ROM = 0xE0, + UNLOAD_ROM = 0xE1, + GET_ROM_PATH = 0xE2, + GET_EMULATOR_CORE_ID = 0xE3, + CORESTATE_LOAD = 0xE4, + CORESTATE_SAVE = 0xE5, + CORESTATE_DELETE = 0xE6, + CORESTATE_RELOAD = 0xE7, + PRINT_MESSAGE = 0xF0, + DO_NOTHING = 0xFF +}; + +/// +/// asio server based on olc c++ networking tutorial +/// +class Server { +public: + Server(tsbuffer&); + ~Server(); + + bool Start(); + void Stop(); + + bool IsListening(); + bool ClientIsConnected(); + + // Main update loop + // Returns true if we changed the buffer. + bool Update(); + + // Client commands + // + + // Read data into the buffer. + // Returns true if buffer was changed. + bool ReadByteBuffered(uint32_t address); + + // Read data into the buffer. + // Returns true if buffer was changed. + bool ReadBlockBuffered(uint32_t address, uint32_t length); + + // Read data without updating the buffer. + // Returns the requested bytes. + uint8_t ReadU8Live(uint32_t address); + + // Read data without updating the buffer. + // Returns the requested bytes. + uint16_t ReadU16Live(uint32_t address); + + // Do nothing, useful for keepalive + void do_nothing(); + + // Print a message to the screen. + void print_message(const std::string&); + +private: + + // Starts an async accept task. + // It's safe to use blocking calls with this, I think. + void waitForClientConnectionAsync(); + + // Send a json payload to the client. + // This method blocks until the operation completes. + // Returns the json response. + json sendJsonMessage(const json&); + + // Send a message to the client. + // This method blocks until the operation completes. + // Returns the client's response. + Message messageClient(const Message&); + + // Anything server should prepare on client connect + void onClientConnect(); + // Anything server should clean up on client disconnect + void onClientDisconnect(); + + // Returns true if data was changed. + bool updateBuffer(const json&); + + bool _isListening = false; + + uint16_t _CONNECTORLIB_PORT = 43884; + + tsbuffer& _data; + + // Async message queue + //tsqueue _qMessagesIn; + + // order of declaration is important - it is also the order of initialization + asio::io_context _context; + std::thread _threadContext; + + asio::ip::tcp::acceptor _acceptor; + + uint32_t _idCounter = 10000; + + std::unique_ptr _connection; + + std::chrono::time_point _lastKeepalive; +}; + +} + +} + +#endif diff --git a/src/luaconnector/tsqueue.h b/src/luaconnector/tsqueue.h new file mode 100644 index 00000000..c44facd2 --- /dev/null +++ b/src/luaconnector/tsqueue.h @@ -0,0 +1,96 @@ +#ifndef _TSQUEUE_H +#define _TSQUEUE_H + +#include "common.h" + +namespace LuaConnector { + +namespace Net { + +template +class tsqueue { +public: + tsqueue() = default; + tsqueue(const tsqueue&) = delete; + virtual ~tsqueue() { clear(); } + + const T& front() + { + std::scoped_lock lock(muxQueue); + return deqQueue.front(); + } + + const T& back() + { + std::scoped_lock lock(muxQueue); + return deqQueue.back(); + } + + // Returns true if the queue was empty before you pushed this item + bool push_back(const T& item) + { + std::scoped_lock lock(muxQueue); + bool wasEmpty = deqQueue.empty(); + deqQueue.emplace_back(std::move(item)); + return wasEmpty; + } + + void push_front(const T& item) + { + std::scoped_lock lock(muxQueue); + return deqQueue.emplace_front(std::move(item)); + } + + bool empty() + { + std::scoped_lock lock(muxQueue); + return deqQueue.empty(); + } + + size_t count() + { + std::scoped_lock lock(muxQueue); + return deqQueue.count(); + } + + void clear() + { + std::scoped_lock lock(muxQueue); + deqQueue.clear(); + } + + // Returns true if the queue is now empty + bool remove_front() + { + std::scoped_lock lock(muxQueue); + deqQueue.pop_front(); + bool isNowEmpty = deqQueue.empty(); + return isNowEmpty; + } + + T pop_front() + { + std::scoped_lock lock(muxQueue); + auto t = std::move(deqQueue.front()); + deqQueue.pop_front(); + return t; + } + + T pop_back() + { + std::scoped_lock lock(muxQueue); + auto t = std::move(deqQueue.back()); + deqQueue.pop_back(); + return t; + } + +protected: + std::mutex muxQueue; + std::deque deqQueue; +}; + +} + +} + +#endif