Skip to content

Commit

Permalink
add shard pub/sub support
Browse files Browse the repository at this point in the history
  • Loading branch information
sewenew committed Dec 31, 2023
1 parent 9e650b1 commit ad6baa1
Show file tree
Hide file tree
Showing 15 changed files with 359 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/sw/redis++/async_redis.h
Original file line number Diff line number Diff line change
Expand Up @@ -1616,6 +1616,15 @@ class AsyncRedis {
_callback_fmt_command<long long>(std::forward<Callback>(cb), fmt::publish, channel, message);
}

Future<long long> spublish(const StringView &channel, const StringView &message) {
return _command<long long>(fmt::spublish, channel, message);
}

template <typename Callback>
void spublish(const StringView &channel, const StringView &message, Callback &&cb) {
_callback_fmt_command<long long>(std::forward<Callback>(cb), fmt::spublish, channel, message);
}

// co_command* are used internally. DO NOT use them.

template <typename Result, typename Callback>
Expand Down
13 changes: 13 additions & 0 deletions src/sw/redis++/async_redis_cluster.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,19 @@ AsyncSubscriber AsyncRedisCluster::subscriber() {
return AsyncSubscriber(_loop, std::move(connection));
}

AsyncSubscriber AsyncRedisCluster::subscriber(const StringView &hash_tag) {
assert(_pool);

_pool->update();

auto opts = _pool->connection_options(hash_tag);

auto connection = std::make_shared<AsyncConnection>(opts, _loop.get());
connection->set_subscriber_mode();

return AsyncSubscriber(_loop, std::move(connection));
}

}

}
11 changes: 11 additions & 0 deletions src/sw/redis++/async_redis_cluster.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class AsyncRedisCluster {

AsyncSubscriber subscriber();

AsyncSubscriber subscriber(const StringView &hash_tag);

template <typename Result, typename ...Args>
auto command(const StringView &cmd_name, const StringView &key, Args &&...args)
-> typename std::enable_if<!IsInvocable<typename LastType<Args...>::type,
Expand Down Expand Up @@ -1070,6 +1072,15 @@ class AsyncRedisCluster {
_callback_fmt_command<long long>(std::forward<Callback>(cb), fmt::publish, channel, message);
}

Future<long long> spublish(const StringView &channel, const StringView &message) {
return _command<long long>(fmt::spublish, channel, message);
}

template <typename Callback>
void spublish(const StringView &channel, const StringView &message, Callback &&cb) {
_callback_fmt_command<long long>(std::forward<Callback>(cb), fmt::spublish, channel, message);
}

// co_command* are used internally. DO NOT use them.

template <typename Result, typename Callback>
Expand Down
18 changes: 18 additions & 0 deletions src/sw/redis++/async_subscriber.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,24 @@ Future<void> AsyncSubscriber::punsubscribe(const StringView &channel) {
return _send(SubscribeEventUPtr(new SubscribeEvent(fmt::punsubscribe(channel))));
}

Future<void> AsyncSubscriber::ssubscribe(const StringView &channel) {
_check_connection();

return _send(SubscribeEventUPtr(new SubscribeEvent(fmt::ssubscribe(channel))));
}

Future<void> AsyncSubscriber::sunsubscribe() {
_check_connection();

return _send(SubscribeEventUPtr(new SubscribeEvent(fmt::sunsubscribe())));
}

Future<void> AsyncSubscriber::sunsubscribe(const StringView &channel) {
_check_connection();

return _send(SubscribeEventUPtr(new SubscribeEvent(fmt::sunsubscribe(channel))));
}

void AsyncSubscriber::_check_connection() {
if (!_connection || _connection->broken()) {
throw Error("Connection is broken");
Expand Down
50 changes: 50 additions & 0 deletions src/sw/redis++/async_subscriber.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class AsyncSubscriber {
template <typename PMsgCb>
void on_pmessage(PMsgCb &&pmsg_callback);

template <typename SMsgCb>
void on_smessage(SMsgCb &&smsg_callback);

template <typename MetaCb>
void on_meta(MetaCb &&meta_callback);

Expand Down Expand Up @@ -104,6 +107,28 @@ class AsyncSubscriber {
return punsubscribe(channels.begin(), channels.end());
}

Future<void> ssubscribe(const StringView &channel);

template <typename Input>
Future<void> ssubscribe(Input first, Input last);

template <typename T>
Future<void> ssubscribe(std::initializer_list<T> channels) {
return ssubscribe(channels.begin(), channels.end());
}

Future<void> sunsubscribe();

Future<void> sunsubscribe(const StringView &channel);

template <typename Input>
Future<void> sunsubscribe(Input first, Input last);

template <typename T>
Future<void> sunsubscribe(std::initializer_list<T> channels) {
return sunsubscribe(channels.begin(), channels.end());
}

private:
friend class AsyncRedis;

Expand Down Expand Up @@ -134,6 +159,13 @@ void AsyncSubscriber::on_pmessage(PMsgCb &&pmsg_callback) {
_connection->subscriber().on_pmessage(std::forward<PMsgCb>(pmsg_callback));
}

template <typename SMsgCb>
void AsyncSubscriber::on_smessage(SMsgCb &&smsg_callback) {
_check_connection();

_connection->subscriber().on_smessage(std::forward<SMsgCb>(smsg_callback));
}

template <typename MetaCb>
void AsyncSubscriber::on_meta(MetaCb &&meta_callback) {
_check_connection();
Expand Down Expand Up @@ -184,6 +216,24 @@ Future<void> AsyncSubscriber::punsubscribe(Input first, Input last) {
return _send(SubscribeEventUPtr(new SubscribeEvent(fmt::punsubscribe_range(first, last))));
}

template <typename Input>
Future<void> AsyncSubscriber::ssubscribe(Input first, Input last) {
range_check("ssubscribe", first, last);

_check_connection();

return _send(SubscribeEventUPtr(new SubscribeEvent(fmt::ssubscribe_range(first, last))));
}

template <typename Input>
Future<void> AsyncSubscriber::sunsubscribe(Input first, Input last) {
range_check("sunsubscribe", first, last);

_check_connection();

return _send(SubscribeEventUPtr(new SubscribeEvent(fmt::sunsubscribe_range(first, last))));
}

}

}
Expand Down
38 changes: 38 additions & 0 deletions src/sw/redis++/async_subscriber_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,16 @@ void AsyncSubscriberImpl::_run_callback(redisReply &reply) {
_handle_pmessage(reply);
break;

case Subscriber::MsgType::SMESSAGE:
_handle_smessage(reply);
break;

case Subscriber::MsgType::SUBSCRIBE:
case Subscriber::MsgType::UNSUBSCRIBE:
case Subscriber::MsgType::PSUBSCRIBE:
case Subscriber::MsgType::PUNSUBSCRIBE:
case Subscriber::MsgType::SSUBSCRIBE:
case Subscriber::MsgType::SUNSUBSCRIBE:
_handle_meta(type, reply);
break;

Expand All @@ -93,6 +99,8 @@ Subscriber::MsgType AsyncSubscriberImpl::_msg_type(const std::string &type) cons
return Subscriber::MsgType::MESSAGE;
} else if ("pmessage" == type) {
return Subscriber::MsgType::PMESSAGE;
} else if ("smessage" == type) {
return Subscriber::MsgType::SMESSAGE;
} else if ("subscribe" == type) {
return Subscriber::MsgType::SUBSCRIBE;
} else if ("unsubscribe" == type) {
Expand All @@ -101,6 +109,10 @@ Subscriber::MsgType AsyncSubscriberImpl::_msg_type(const std::string &type) cons
return Subscriber::MsgType::PSUBSCRIBE;
} else if ("punsubscribe" == type) {
return Subscriber::MsgType::PUNSUBSCRIBE;
} else if ("ssubscribe" == type) {
return Subscriber::MsgType::SSUBSCRIBE;
} else if ("sunsubscribe" == type) {
return Subscriber::MsgType::SUNSUBSCRIBE;
} else {
return Subscriber::MsgType::UNKNOWN;
}
Expand Down Expand Up @@ -164,6 +176,32 @@ void AsyncSubscriberImpl::_handle_pmessage(redisReply &reply) {
_pmsg_callback(std::move(pattern), std::move(channel), std::move(msg));
}

void AsyncSubscriberImpl::_handle_smessage(redisReply &reply) {
if (_smsg_callback == nullptr) {
return;
}

if (reply.elements != 3) {
throw ProtoError("Expect 3 sub replies");
}

assert(reply.element != nullptr);

auto *channel_reply = reply.element[1];
if (channel_reply == nullptr) {
throw ProtoError("Null channel reply");
}
auto channel = reply::parse<std::string>(*channel_reply);

auto *msg_reply = reply.element[2];
if (msg_reply == nullptr) {
throw ProtoError("Null message reply");
}
auto msg = reply::parse<std::string>(*msg_reply);

_smsg_callback(std::move(channel), std::move(msg));
}

void AsyncSubscriberImpl::_handle_meta(Subscriber::MsgType type, redisReply &reply) {
if (_meta_callback == nullptr) {
return;
Expand Down
9 changes: 9 additions & 0 deletions src/sw/redis++/async_subscriber_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ class AsyncSubscriberImpl {
_pmsg_callback = std::forward<PMsgCb>(pmsg_callback);
}

template <typename SMsgCb>
void on_smessage(SMsgCb &&smsg_callback) {
_smsg_callback = std::forward<SMsgCb>(smsg_callback);
}

template <typename MetaCb>
void on_meta(MetaCb &&meta_callback) {
_meta_callback = std::forward<MetaCb>(meta_callback);
Expand All @@ -63,13 +68,17 @@ class AsyncSubscriberImpl {

void _handle_pmessage(redisReply &reply);

void _handle_smessage(redisReply &reply);

void _handle_meta(Subscriber::MsgType type, redisReply &reply);

std::function<void (std::string channel, std::string msg)> _msg_callback;

std::function<void (std::string pattern, std::string channel,
std::string msg)> _pmsg_callback;

std::function<void (std::string channel, std::string msg)> _smsg_callback;

std::function<void (Subscriber::MsgType type, OptionalString channel,
long long num)> _meta_callback;

Expand Down
38 changes: 38 additions & 0 deletions src/sw/redis++/cmd_formatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,12 @@ inline FormattedCommand publish(const StringView &channel, const StringView &mes
message.data(), message.size());
}

inline FormattedCommand spublish(const StringView &channel, const StringView &message) {
return format_cmd("SPUBLISH %b %b",
channel.data(), channel.size(),
message.data(), message.size());
}

inline FormattedCommand punsubscribe() {
return format_cmd("PUNSUBSCRIBE");
}
Expand Down Expand Up @@ -863,6 +869,38 @@ inline FormattedCommand unsubscribe_range(Input first, Input last) {
return format_cmd(args);
}

inline FormattedCommand ssubscribe(const StringView &channel) {
return format_cmd("SSUBSCRIBE %b", channel.data(), channel.size());
}

template <typename Input>
inline FormattedCommand ssubscribe_range(Input first, Input last) {
assert(first != last);

CmdArgs args;
args << "SSUBSCRIBE" << std::make_pair(first, last);

return format_cmd(args);
}

inline FormattedCommand sunsubscribe() {
return format_cmd("SUNSUBSCRIBE");
}

inline FormattedCommand sunsubscribe(const StringView &channel) {
return format_cmd("SUNSUBSCRIBE %b", channel.data(), channel.size());
}

template <typename Input>
inline FormattedCommand sunsubscribe_range(Input first, Input last) {
assert(first != last);

CmdArgs args;
args << "SUNSUBSCRIBE" << std::make_pair(first, last);

return format_cmd(args);
}

}

}
Expand Down
44 changes: 44 additions & 0 deletions src/sw/redis++/command.h
Original file line number Diff line number Diff line change
Expand Up @@ -1594,6 +1594,14 @@ inline void publish(Connection &connection,
message.data(), message.size());
}

inline void spublish(Connection &connection,
const StringView &channel,
const StringView &message) {
connection.send("SPUBLISH %b %b",
channel.data(), channel.size(),
message.data(), message.size());
}

inline void punsubscribe(Connection &connection) {
connection.send("PUNSUBSCRIBE");
}
Expand Down Expand Up @@ -1650,6 +1658,42 @@ inline void unsubscribe_range(Connection &connection, Input first, Input last) {
connection.send(args);
}

inline void ssubscribe(Connection &connection, const StringView &channel) {
connection.send("SSUBSCRIBE %b", channel.data(), channel.size());
}

template <typename Input>
inline void ssubscribe_range(Connection &connection, Input first, Input last) {
if (first == last) {
throw Error("SSUBSCRIBE: no key specified");
}

CmdArgs args;
args << "SSUBSCRIBE" << std::make_pair(first, last);

connection.send(args);
}

inline void sunsubscribe(Connection &connection) {
connection.send("SUNSUBSCRIBE");
}

inline void sunsubscribe(Connection &connection, const StringView &channel) {
connection.send("SUNSUBSCRIBE %b", channel.data(), channel.size());
}

template <typename Input>
inline void sunsubscribe_range(Connection &connection, Input first, Input last) {
if (first == last) {
throw Error("SUNSUBSCRIBE: no key specified");
}

CmdArgs args;
args << "SUNSUBSCRIBE" << std::make_pair(first, last);

connection.send(args);
}

// Transaction commands.

inline void discard(Connection &connection) {
Expand Down
6 changes: 6 additions & 0 deletions src/sw/redis++/redis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,12 @@ long long Redis::publish(const StringView &channel, const StringView &message) {
return reply::parse<long long>(*reply);
}

long long Redis::spublish(const StringView &channel, const StringView &message) {
auto reply = command(cmd::spublish, channel, message);

return reply::parse<long long>(*reply);
}

// Transaction commands.

void Redis::watch(const StringView &key) {
Expand Down
2 changes: 2 additions & 0 deletions src/sw/redis++/redis.h
Original file line number Diff line number Diff line change
Expand Up @@ -3252,6 +3252,8 @@ class Redis {

long long publish(const StringView &channel, const StringView &message);

long long spublish(const StringView &channel, const StringView &message);

// Transaction commands.
void watch(const StringView &key);

Expand Down
Loading

0 comments on commit ad6baa1

Please sign in to comment.