Skip to content

Commit

Permalink
Implement partial responses support in MercuryManager (#171)
Browse files Browse the repository at this point in the history
* fixes for esp32

* fixes for esp32

* reuploaded sdkconfig.defaults

* added gzip compatible mercury decoder (failed if message were trucated) - updated from main-branch

* fix: clang-format

* fix: Minor style fixes

* fix: clear partials on sub res and reconnect

* fix: Restore bell version

* fix: Update bell

---------

Co-authored-by: Filip <[email protected]>
  • Loading branch information
tobiasguyer and feelfreelinux authored Jul 12, 2024
1 parent 1b07a9c commit 5c35a33
Show file tree
Hide file tree
Showing 11 changed files with 571 additions and 404 deletions.
13 changes: 8 additions & 5 deletions cspot/include/MercurySession.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,9 @@ class MercurySession : public bell::Task, public cspot::Session {

struct Response {
Header mercuryHeader;
uint8_t flags;
DataParts parts;
uint64_t sequenceId;
bool fail;
};

typedef std::function<void(Response&)> ResponseCallback;
typedef std::function<void(bool, const std::vector<uint8_t>&)>
AudioKeyCallback;
Expand All @@ -54,6 +51,11 @@ class MercurySession : public bell::Task, public cspot::Session {
COUNTRY_CODE_RESPONSE = 0x1B,
};

enum class ResponseFlag : uint8_t {
FINAL = 0x01,
PARTIAL = 0x02,
};

std::unordered_map<RequestType, std::string> RequestTypeMap = {
{RequestType::GET, "GET"},
{RequestType::SEND, "SEND"},
Expand Down Expand Up @@ -111,7 +113,8 @@ class MercurySession : public bell::Task, public cspot::Session {
void runTask() override;
void reconnect();

std::unordered_map<uint64_t, ResponseCallback> callbacks;
std::unordered_map<int64_t, ResponseCallback> callbacks;
std::unordered_map<int64_t, Response> partials;
std::unordered_map<std::string, ResponseCallback> subscriptions;
std::unordered_map<uint32_t, AudioKeyCallback> audioKeyCallbacks;

Expand All @@ -129,6 +132,6 @@ class MercurySession : public bell::Task, public cspot::Session {

void failAllPending();

Response decodeResponse(const std::vector<uint8_t>& data);
std::pair<int, int64_t> decodeResponse(const std::vector<uint8_t>& data);
};
} // namespace cspot
3 changes: 3 additions & 0 deletions cspot/protobuf/mercury.options
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
Header.uri max_size:256, fixed_length:false
Header.method max_size:64, fixed_length:false
UserField.key type:FT_POINTER
UserField.value type:FT_POINTER
Header.user_fields max_count:64, fixed_count:false
8 changes: 8 additions & 0 deletions cspot/protobuf/mercury.proto
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
message Header {
optional string uri = 0x01;
optional string content_type = 0x02;
optional string method = 0x03;
optional int32 status_code = 0x04;
repeated UserField user_fields = 0x06;
}

message UserField {
optional string key = 0x01;
optional string value = 0x02;
}
95 changes: 65 additions & 30 deletions cspot/src/MercurySession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ void MercurySession::reconnect() {
try {
this->conn = nullptr;
this->shanConn = nullptr;
this->partials.clear();

this->connectWithRandomAp();
this->authenticate(this->authBlob);
Expand Down Expand Up @@ -174,19 +175,26 @@ void MercurySession::handlePacket() {
CSPOT_LOG(debug, "Received mercury packet");

auto response = this->decodeResponse(packet.data);
if (this->callbacks.count(response.sequenceId) > 0) {
auto seqId = response.sequenceId;
this->callbacks[response.sequenceId](response);
this->callbacks.erase(this->callbacks.find(seqId));
if (response.first == static_cast<uint8_t>(ResponseFlag::FINAL)) {
auto partial = this->partials.find(response.second);
if (this->callbacks.count(response.second)) {
this->callbacks[response.second](partial->second);
this->callbacks.erase(this->callbacks.find(response.second));
}
this->partials.erase(partial);
}
break;
}
case RequestType::SUBRES: {
auto response = decodeResponse(packet.data);

auto uri = std::string(response.mercuryHeader.uri);
if (this->subscriptions.count(uri) > 0) {
this->subscriptions[uri](response);
if (response.first == static_cast<uint8_t>(ResponseFlag::FINAL)) {
auto partial = this->partials.find(response.second);
auto uri = std::string(partial->second.mercuryHeader.uri);
if (this->subscriptions.count(uri) > 0) {
this->subscriptions[uri](partial->second);
}
this->partials.erase(partial);
}
break;
}
Expand Down Expand Up @@ -214,33 +222,60 @@ void MercurySession::failAllPending() {
this->callbacks = {};
}

MercurySession::Response MercurySession::decodeResponse(
std::pair<int, int64_t> MercurySession::decodeResponse(
const std::vector<uint8_t>& data) {
Response response = {};
response.parts = {};

auto sequenceLength = ntohs(extract<uint16_t>(data, 0));
response.sequenceId = hton64(extract<uint64_t>(data, 2));

auto partsNumber = ntohs(extract<uint16_t>(data, 11));

auto headerSize = ntohs(extract<uint16_t>(data, 13));
auto headerBytes =
std::vector<uint8_t>(data.begin() + 15, data.begin() + 15 + headerSize);

auto pos = 15 + headerSize;
while (pos < data.size()) {
uint64_t sequenceId;
uint8_t flag;
if (sequenceLength == 2)
sequenceId = ntohs(extract<uint16_t>(data, 2));
else if (sequenceLength == 4)
sequenceId = ntohl(extract<uint32_t>(data, 2));
else if (sequenceLength == 8)
sequenceId = hton64(extract<uint64_t>(data, 2));
else
return std::make_pair(0, 0);

size_t pos = 2 + sequenceLength;
flag = (uint8_t)data[pos];
pos++;
auto parts = ntohs(extract<uint16_t>(data, pos));
pos += 2;
auto partial = partials.find(sequenceId);
if (partial == partials.end()) {
CSPOT_LOG(debug,
"Creating new Mercury Response, seq: %llu, flags: %i, parts: %i",
sequenceId, flag, parts);
partial = this->partials.insert({sequenceId, Response()}).first;
partial->second.parts = {};
partial->second.fail = false;
} else
CSPOT_LOG(debug,
"Adding to Mercury Response, seq: %llu, flags: %i, parts: %i",
sequenceId, flag, parts);
uint8_t index = 0;
while (parts) {
if (data.size() <= pos || partial->second.fail)
break;
auto partSize = ntohs(extract<uint16_t>(data, pos));

response.parts.push_back(std::vector<uint8_t>(
data.begin() + pos + 2, data.begin() + pos + 2 + partSize));
pos += 2 + partSize;
pos += 2;
if (!partial->second.mercuryHeader.has_uri) {
partial->second.fail = false;
auto headerBytes = std::vector<uint8_t>(data.begin() + pos,
data.begin() + pos + partSize);
pbDecode(partial->second.mercuryHeader, Header_fields, headerBytes);
} else {
if (index >= partial->second.parts.size())
partial->second.parts.push_back(std::vector<uint8_t>{});
partial->second.parts[index].insert(partial->second.parts[index].end(),
data.begin() + pos,
data.begin() + pos + partSize);
index++;
}
pos += partSize;
parts--;
}

pbDecode(response.mercuryHeader, Header_fields, headerBytes);
response.fail = false;

return response;
return std::make_pair(flag, sequenceId);
}

uint64_t MercurySession::executeSubscription(RequestType method,
Expand Down
22 changes: 12 additions & 10 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,7 @@

clang-tools = pkgs.clang-tools.override {llvmPackages = llvm;};

apps = {
};

packages = {
target-cli = llvm.stdenv.mkDerivation {
name = "cspotcli";
src = ./.;
cmakeFlags = ["-DCSPOT_TARGET_CLI=ON"];
nativeBuildInputs = with pkgs; [
cspot-pkgs = with pkgs; [
avahi
avahi-compat
cmake
Expand All @@ -50,6 +42,16 @@
portaudio
protobuf
];

apps = {
};

packages = {
target-cli = llvm.stdenv.mkDerivation {
name = "cspotcli";
src = ./.;
cmakeFlags = ["-DCSPOT_TARGET_CLI=ON"];
nativeBuildInputs = cspot-pkgs;
# Patch nanopb shebangs to refer to provided python
postPatch = ''
patchShebangs cspot/bell/external/nanopb/generator/*
Expand All @@ -60,7 +62,7 @@

devShells = {
default = pkgs.mkShell {
packages = with pkgs; [cmake unstable.mbedtls ninja python3] ++ [clang-tools llvm.clang];
packages = cspot-pkgs;
};
};
in {
Expand Down
2 changes: 0 additions & 2 deletions targets/cli/CliPlayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,6 @@ void CliPlayer::runTask() {
continue;
} else {
if (lastHash != chunk->trackHash) {
std::cout << " Last hash " << lastHash << " new hash "
<< chunk->trackHash << std::endl;
lastHash = chunk->trackHash;
this->handler->notifyAudioReachedPlayback();
}
Expand Down
134 changes: 134 additions & 0 deletions targets/esp32/main/EspPlayer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#include "EspPlayer.h"

#include <cstdint> // for uint8_t
#include <functional> // for __base
#include <iostream> // for operator<<, basic_ostream, endl, cout
#include <memory> // for shared_ptr, make_shared, make_unique
#include <mutex> // for scoped_lock
#include <string_view> // for hash, string_view
#include <type_traits> // for remove_extent_t
#include <utility> // for move
#include <variant> // for get
#include <vector> // for vector

#include "BellUtils.h" // for BELL_SLEEP_MS
#include "CircularBuffer.h"
#include "Logger.h"
#include "SpircHandler.h" // for SpircHandler, SpircHandler::EventType
#include "StreamInfo.h" // for BitWidth, BitWidth::BW_16
#include "TrackPlayer.h" // for TrackPlayer

EspPlayer::EspPlayer(std::unique_ptr<AudioSink> sink,
std::shared_ptr<cspot::SpircHandler> handler)
: bell::Task("player", 32 * 1024, 0, 1) {
this->handler = handler;
this->audioSink = std::move(sink);

this->circularBuffer = std::make_shared<bell::CircularBuffer>(1024 * 128);

auto hashFunc = std::hash<std::string_view>();

this->handler->getTrackPlayer()->setDataCallback(
[this, &hashFunc](uint8_t* data, size_t bytes, std::string_view trackId) {
auto hash = hashFunc(trackId);
this->feedData(data, bytes, hash);
return bytes;
});

this->isPaused = false;

this->handler->setEventHandler(
[this, &hashFunc](std::unique_ptr<cspot::SpircHandler::Event> event) {
switch (event->eventType) {
case cspot::SpircHandler::EventType::PLAY_PAUSE:
if (std::get<bool>(event->data)) {
this->pauseRequested = true;
} else {
this->isPaused = false;
this->pauseRequested = false;
}
break;
case cspot::SpircHandler::EventType::DISC:
this->circularBuffer->emptyBuffer();
break;
case cspot::SpircHandler::EventType::FLUSH:
this->circularBuffer->emptyBuffer();
break;
case cspot::SpircHandler::EventType::SEEK:
this->circularBuffer->emptyBuffer();
break;
case cspot::SpircHandler::EventType::PLAYBACK_START:
this->isPaused = true;
this->playlistEnd = false;
this->circularBuffer->emptyBuffer();
break;
case cspot::SpircHandler::EventType::DEPLETED:
this->playlistEnd = true;
break;
case cspot::SpircHandler::EventType::VOLUME: {
int volume = std::get<int>(event->data);
break;
}
default:
break;
}
});
startTask();
}

void EspPlayer::feedData(uint8_t* data, size_t len, size_t trackId) {
size_t toWrite = len;

while (toWrite > 0) {
this->current_hash = trackId;
size_t written =
this->circularBuffer->write(data + (len - toWrite), toWrite);
if (written == 0) {
BELL_SLEEP_MS(10);
}

toWrite -= written;
}
}

void EspPlayer::runTask() {
std::vector<uint8_t> outBuf = std::vector<uint8_t>(1024);

std::scoped_lock lock(runningMutex);

size_t lastHash = 0;

while (isRunning) {
if (!this->isPaused) {
size_t read = this->circularBuffer->read(outBuf.data(), outBuf.size());
if (this->pauseRequested) {
this->pauseRequested = false;
std::cout << "Pause requested!" << std::endl;
this->isPaused = true;
}

this->audioSink->feedPCMFrames(outBuf.data(), read);

if (read == 0) {
if (this->playlistEnd) {
this->handler->notifyAudioEnded();
this->playlistEnd = false;
}
BELL_SLEEP_MS(10);
continue;
} else {
if (lastHash != current_hash) {
lastHash = current_hash;
this->handler->notifyAudioReachedPlayback();
}
}
} else {
BELL_SLEEP_MS(100);
}
}
}

void EspPlayer::disconnect() {
isRunning = false;
std::scoped_lock lock(runningMutex);
}
Loading

0 comments on commit 5c35a33

Please sign in to comment.