diff --git a/Cargo.toml b/Cargo.toml index ac88f53..124fda0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,23 +28,35 @@ log4rs = { version = "1.3.0", features = [ num-bigint = "0.4.5" num-traits = "0.2.19" openssl = "0.10.64" -poem = { version = "3.0.1", features = ["websocket"] } +poem = "3.0.1" utoipa = { version = "5.0.0-alpha.0", features = [] } rand = "0.8.5" regex = "1.10.4" -reqwest = "0.12.4" +reqwest = { version = "0.12.5", default-features = false, features = [ + "http2", + "macos-system-configuration", + "charset", + "rustls-tls-webpki-roots", +] } serde = { version = "1.0.203", features = ["derive"] } serde_json = { version = "1.0.117", features = ["raw_value"] } sqlx = { version = "0.7.4", features = [ "json", "chrono", "ipnetwork", - "runtime-tokio-native-tls", + "runtime-tokio-rustls", "any", ] } thiserror = "1.0.61" tokio = { version = "1.38.0", features = ["full"] } -sentry = "0.33.0" +sentry = { version = "0.34.0", default-features = false, features = [ + "backtrace", + "contexts", + "debug-images", + "panic", + "reqwest", + "rustls", +] } clap = { version = "4.5.4", features = ["derive"] } chorus = { git = "http://github.com/polyphony-chat/chorus", rev = "d591616", features = [ @@ -54,3 +66,8 @@ serde_path_to_error = "0.1.16" percent-encoding = "2.3.1" hex = "0.4.3" itertools = "0.13.0" +tokio-tungstenite = { version = "0.23.1", features = [ + "rustls-tls-webpki-roots", +] } +pubserve = { version = "1.1.0", features = ["async", "send"] } +parking_lot = { version = "0.12.3", features = ["deadlock_detection"] } diff --git a/src/api/mod.rs b/src/api/mod.rs index 138049b..ad85afa 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,12 +1,13 @@ use poem::{ - EndpointExt, - IntoResponse, listener::TcpListener, - middleware::{NormalizePath, TrailingSlash}, Route, Server, web::Json, + middleware::{NormalizePath, TrailingSlash}, + web::Json, + EndpointExt, IntoResponse, Route, Server, }; use serde_json::json; use sqlx::MySqlPool; +use crate::SharedEventPublisherMap; use crate::{ api::{ middleware::{ @@ -21,8 +22,9 @@ use crate::{ mod middleware; mod routes; -pub async fn start_api(db: MySqlPool) -> Result<(), Error> { +pub async fn start_api(db: MySqlPool, publisher_map: SharedEventPublisherMap) -> Result<(), Error> { log::info!(target: "symfonia::api::cfg", "Loading configuration"); + let config = Config::init(&db).await?; if config.sentry.enabled { @@ -69,6 +71,7 @@ pub async fn start_api(db: MySqlPool) -> Result<(), Error> { .nest("/api/v9", routes) .data(db) .data(config) + .data(publisher_map) .with(NormalizePath::new(TrailingSlash::Trim)) .catch_all_error(custom_error); diff --git a/src/api/routes/guilds/id/channels.rs b/src/api/routes/guilds/id/channels.rs index 8d46ef3..ef8f6bc 100644 --- a/src/api/routes/guilds/id/channels.rs +++ b/src/api/routes/guilds/id/channels.rs @@ -1,10 +1,10 @@ use chorus::types::{ - ChannelModifySchema, ChannelType, jwt::Claims, ModifyChannelPositionsSchema, Snowflake, + jwt::Claims, ChannelModifySchema, ChannelType, ModifyChannelPositionsSchema, Snowflake, }; use poem::{ handler, - IntoResponse, - Response, web::{Data, Json, Path}, + web::{Data, Json, Path}, + IntoResponse, Response, }; use reqwest::StatusCode; use sqlx::MySqlPool; @@ -139,6 +139,7 @@ mod tests { position: Some(0), ..Default::default() }, + ..Default::default() }, Channel { inner: chorus::types::Channel { @@ -146,6 +147,7 @@ mod tests { position: Some(1), ..Default::default() }, + ..Default::default() }, Channel { inner: chorus::types::Channel { @@ -153,6 +155,7 @@ mod tests { position: Some(2), ..Default::default() }, + ..Default::default() }, Channel { inner: chorus::types::Channel { @@ -160,6 +163,7 @@ mod tests { position: Some(3), ..Default::default() }, + ..Default::default() }, Channel { inner: chorus::types::Channel { @@ -167,6 +171,7 @@ mod tests { position: Some(4), ..Default::default() }, + ..Default::default() }, Channel { inner: chorus::types::Channel { @@ -174,6 +179,7 @@ mod tests { position: Some(5), ..Default::default() }, + ..Default::default() }, ]; diff --git a/src/database/entities/application.rs b/src/database/entities/application.rs index 0e05cc7..591e972 100644 --- a/src/database/entities/application.rs +++ b/src/database/entities/application.rs @@ -1,12 +1,15 @@ +use super::*; + use std::ops::{Deref, DerefMut}; +use std::sync::Arc; -use bitflags::Flags; use chorus::types::{ApplicationFlags, Snowflake}; +use parking_lot::RwLock; use serde::{Deserialize, Serialize}; -use sqlx::{FromRow, MySqlPool}; +use sqlx::MySqlPool; use crate::{ - database::entities::{Config, user::User}, + database::entities::{user::User, Config}, errors::Error, }; @@ -17,6 +20,9 @@ pub struct Application { pub owner_id: Snowflake, pub bot_user_id: Option, pub team_id: Option, + #[sqlx(skip)] + #[serde(skip)] + pub publisher: SharedEventPublisher, } impl Deref for Application { @@ -62,6 +68,7 @@ impl Application { owner_id: owner_id.to_owned(), bot_user_id, team_id: None, + publisher: Arc::new(RwLock::new(pubserve::Publisher::new())), }; let _res = sqlx::query("INSERT INTO applications (id, name, summary, hook, bot_public, verify_key, owner_id, flags, integration_public, discoverability_state, discovery_eligibility_flags) VALUES (?, ?, ?, true, true, ?, ?, ?, true, 1, 2240)") diff --git a/src/database/entities/channel.rs b/src/database/entities/channel.rs index b6c0c8f..f01ada1 100644 --- a/src/database/entities/channel.rs +++ b/src/database/entities/channel.rs @@ -1,25 +1,29 @@ -use std::ops::{Deref, DerefMut}; - +use super::*; use chorus::types::{ - ChannelMessagesAnchor, ChannelModifySchema, ChannelType, CreateChannelInviteSchema, InviteType, - MessageSendSchema, PermissionOverwrite, Snowflake, + ChannelDelete, ChannelMessagesAnchor, ChannelModifySchema, ChannelType, ChannelUpdate, + CreateChannelInviteSchema, InviteType, MessageSendSchema, PermissionOverwrite, Snowflake, }; use itertools::Itertools; +use pubserve::Publisher; use serde::{Deserialize, Serialize}; -use sqlx::{MySqlPool, types::Json}; +use sqlx::{types::Json, MySqlPool}; +use std::ops::{Deref, DerefMut}; use crate::{ database::entities::{ - GuildMember, invite::Invite, message::Message, read_state::ReadState, recipient::Recipient, + invite::Invite, message::Message, read_state::ReadState, recipient::Recipient, GuildMember, User, Webhook, }, errors::{ChannelError, Error, GuildError, UserError}, }; -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, sqlx::FromRow)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, sqlx::FromRow, Default)] pub struct Channel { #[sqlx(flatten)] pub(crate) inner: chorus::types::Channel, + #[sqlx(skip)] + #[serde(skip)] + pub publisher: SharedEventPublisher, } impl Deref for Channel { @@ -87,6 +91,7 @@ impl Channel { guild_id, ..Default::default() }, + ..Default::default() }; sqlx::query("INSERT INTO channels (id, type, name, nsfw, guild_id, parent_id, flags, permission_overwrites, default_thread_rate_limit_per_user, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, NOW())") diff --git a/src/database/entities/guild.rs b/src/database/entities/guild.rs index 1ebe129..978905c 100644 --- a/src/database/entities/guild.rs +++ b/src/database/entities/guild.rs @@ -1,11 +1,10 @@ -use std::{ - ops::{Deref, DerefMut}, - sync::{Arc, RwLock}, -}; +use super::*; + +use std::ops::{Deref, DerefMut}; use chorus::types::{ - ChannelType, NSFWLevel, PermissionFlags, PremiumTier, - PublicUser, Snowflake, SystemChannelFlags, types::guild_configuration::GuildFeaturesList, WelcomeScreenObject, + types::guild_configuration::GuildFeaturesList, ChannelType, NSFWLevel, PermissionFlags, + PremiumTier, PublicUser, Snowflake, SystemChannelFlags, WelcomeScreenObject, }; use serde::{Deserialize, Serialize}; use sqlx::{FromRow, MySqlPool, QueryBuilder, Row}; @@ -31,6 +30,9 @@ pub struct Guild { pub parent: Option, pub template_id: Option, pub nsfw: bool, + #[sqlx(skip)] + #[serde(skip)] + pub publisher: SharedEventPublisher, } impl Deref for Guild { diff --git a/src/database/entities/mod.rs b/src/database/entities/mod.rs index fdb7c32..896fc70 100644 --- a/src/database/entities/mod.rs +++ b/src/database/entities/mod.rs @@ -16,6 +16,8 @@ pub use user_settings::*; pub use voice_state::*; pub use webhook::*; +use crate::SharedEventPublisher; + mod application; mod attachment; mod audit_log; diff --git a/src/database/entities/role.rs b/src/database/entities/role.rs index 183defc..9de6b27 100644 --- a/src/database/entities/role.rs +++ b/src/database/entities/role.rs @@ -1,3 +1,5 @@ +use super::*; + use std::ops::{Deref, DerefMut}; use chorus::types::{PermissionFlags, Snowflake}; @@ -11,6 +13,9 @@ pub struct Role { #[sqlx(flatten)] inner: chorus::types::RoleObject, pub guild_id: Snowflake, + #[sqlx(skip)] + #[serde(skip)] + pub publisher: SharedEventPublisher, } impl Deref for Role { diff --git a/src/database/entities/user.rs b/src/database/entities/user.rs index dae2ed1..a20cf23 100644 --- a/src/database/entities/user.rs +++ b/src/database/entities/user.rs @@ -1,3 +1,5 @@ +use super::*; + use std::{ default::Default, ops::{Deref, DerefMut}, @@ -28,6 +30,9 @@ pub struct User { #[sqlx(skip)] pub settings: UserSettings, pub extended_settings: sqlx::types::Json, + #[sqlx(skip)] + #[serde(skip)] + pub publisher: SharedEventPublisher, } impl Deref for User { diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index e69de29..f7f1656 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -0,0 +1,20 @@ +mod types; + +use log::info; +use sqlx::MySqlPool; +pub use types::*; + +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +use crate::errors::Error; +use crate::SharedEventPublisherMap; + +pub async fn start_gateway( + db: MySqlPool, + publisher_map: SharedEventPublisherMap, +) -> Result<(), Error> { + info!(target: "symfonia::gateway", "Starting gateway server"); + // `publishers` will live for the lifetime of the gateway server, in the main gateway thread + Ok(()) +} diff --git a/src/gateway/types.rs b/src/gateway/types.rs new file mode 100644 index 0000000..5e647bb --- /dev/null +++ b/src/gateway/types.rs @@ -0,0 +1,128 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use ::serde::{Deserialize, Serialize}; +use chorus::types::*; + +#[derive( + Debug, + ::serde::Deserialize, + ::serde::Serialize, + Clone, + PartialEq, + PartialOrd, + Eq, + Ord, + Copy, + Hash, +)] +/// Enum representing all possible* event types that can be received from or sent to the gateway. +/// +/// TODO: This is only temporary. Replace with this enum from chorus, when it is ready. +pub enum EventType { + Hello, + Ready, + Resumed, + InvalidSession, + ChannelCreate, + ChannelUpdate, + ChannelDelete, + ChannelPinsUpdate, + ThreadCreate, + ThreadUpdate, + ThreadDelete, + ThreadListSync, + ThreadMemberUpdate, + ThreadMembersUpdate, + GuildCreate, + GuildUpdate, + GuildDelete, + GuildBanAdd, + GuildBanRemove, + GuildEmojisUpdate, + GuildIntegrationsUpdate, + GuildMemberAdd, + GuildMemberRemove, + GuildMemberUpdate, + GuildMembersChunk, + GuildRoleCreate, + GuildRoleUpdate, + GuildRoleDelete, + IntegrationCreate, + IntegrationUpdate, + IntegrationDelete, + InteractionCreate, + InviteCreate, + InviteDelete, + MessageCreate, + MessageUpdate, + MessageDelete, + MessageDeleteBulk, + MessageReactionAdd, + MessageReactionRemove, + MessageReactionRemoveAll, + MessageReactionRemoveEmoji, + PresenceUpdate, + TypingStart, + UserUpdate, + VoiceStateUpdate, + VoiceServerUpdate, + WebhooksUpdate, + StageInstanceCreate, + StageInstanceUpdate, + StageInstanceDelete, + RequestMembers, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +/// Enum representing all possible* events that can be received from or sent to the gateway. +/// +/// TODO: This is only temporary. Replace with this enum from chorus, when it is ready. +#[serde(rename_all = "PascalCase")] +pub enum Event { + Hello(GatewayHello), + Ready(GatewayReady), + Resumed(GatewayResume), + InvalidSession(GatewayInvalidSession), + ChannelCreate(ChannelCreate), + ChannelUpdate(ChannelUpdate), + ChannelDelete(ChannelDelete), + ThreadCreate(ThreadCreate), + ThreadUpdate(ThreadUpdate), + ThreadDelete(ThreadDelete), + ThreadListSync(ThreadListSync), + ThreadMemberUpdate(ThreadMemberUpdate), + ThreadMembersUpdate(ThreadMembersUpdate), + GuildCreate(GuildCreate), + GuildUpdate(GuildUpdate), + GuildDelete(GuildDelete), + GuildBanAdd(GuildBanAdd), + GuildBanRemove(GuildBanRemove), + GuildEmojisUpdate(GuildEmojisUpdate), + GuildIntegrationsUpdate(GuildIntegrationsUpdate), + GuildMemberAdd(GuildMemberAdd), + GuildMemberRemove(GuildMemberRemove), + GuildMemberUpdate(GuildMemberUpdate), + GuildMembersChunk(GuildMembersChunk), + InteractionCreate(InteractionCreate), + InviteCreate(InviteCreate), + InviteDelete(InviteDelete), + MessageCreate(MessageCreate), + MessageUpdate(MessageUpdate), + MessageDelete(MessageDelete), + MessageDeleteBulk(MessageDeleteBulk), + MessageReactionAdd(MessageReactionAdd), + MessageReactionRemove(MessageReactionRemove), + MessageReactionRemoveAll(MessageReactionRemoveAll), + MessageReactionRemoveEmoji(MessageReactionRemoveEmoji), + PresenceUpdate(PresenceUpdate), + TypingStart(TypingStartEvent), + UserUpdate(UserUpdate), + VoiceStateUpdate(VoiceStateUpdate), + VoiceServerUpdate(VoiceServerUpdate), + WebhooksUpdate(WebhooksUpdate), + StageInstanceCreate(StageInstanceCreate), + StageInstanceUpdate(StageInstanceUpdate), + StageInstanceDelete(StageInstanceDelete), +} diff --git a/src/main.rs b/src/main.rs index 3da7acd..16dce4d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,20 +1,28 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use chorus::types::Snowflake; use clap::Parser; + +use gateway::Event; +use log::LevelFilter; use log4rs::{ append::{ console::{ConsoleAppender, Target}, rolling_file::{ policy::compound::{ - CompoundPolicy, roll::delete::DeleteRoller, trigger::size::SizeTrigger, + roll::delete::DeleteRoller, trigger::size::SizeTrigger, CompoundPolicy, }, RollingFileAppender, }, }, config::{Appender, Logger, Root}, - Config, encode::pattern::PatternEncoder, filter::Filter, + Config, }; -use log::LevelFilter; +use parking_lot::RwLock; +use pubserve::Publisher; mod api; mod cdn; @@ -23,6 +31,10 @@ mod errors; mod gateway; mod util; +pub type SharedEventPublisher = Arc>>; +pub type EventPublisherMap = HashMap; +pub type SharedEventPublisherMap = Arc>; + #[derive(Debug)] struct LogFilter; @@ -180,6 +192,11 @@ async fn main() { .await .expect("Failed to seed config"); } - - api::start_api(db).await.unwrap(); + let shared_publisher_map = Arc::new(RwLock::new(HashMap::new())); + api::start_api(db.clone(), shared_publisher_map.clone()) + .await + .unwrap(); + gateway::start_gateway(db.clone(), shared_publisher_map.clone()) + .await + .unwrap(); }