From aa3d86a9e4cfbf3fe9e64009900f725a5084b2c7 Mon Sep 17 00:00:00 2001 From: Eric Huss Date: Mon, 20 Feb 2023 13:10:26 -0800 Subject: [PATCH] Add support for using SQLite. --- Cargo.lock | 190 ++++++-- Cargo.toml | 5 +- README.md | 130 ++++-- src/db.rs | 382 +++++++-------- src/db/issue_data.rs | 56 +-- src/db/jobs.rs | 152 ++---- src/db/notifications.rs | 345 +------------- src/db/postgres.rs | 841 ++++++++++++++++++++++++++++++++++ src/db/rustc_commits.rs | 79 ---- src/db/sqlite.rs | 742 ++++++++++++++++++++++++++++++ src/handlers.rs | 2 +- src/handlers/mentions.rs | 6 +- src/handlers/no_merges.rs | 8 +- src/handlers/notification.rs | 19 +- src/handlers/rustc_commits.rs | 22 +- src/main.rs | 27 +- src/notification_listing.rs | 6 +- src/zulip.rs | 28 +- tests/db/issue_data.rs | 30 ++ tests/db/jobs.rs | 112 +++++ tests/db/mod.rs | 208 +++++++++ tests/db/notification.rs | 260 +++++++++++ tests/db/rustc_commits.rs | 86 ++++ tests/server_test/mod.rs | 176 ++----- tests/testsuite.rs | 35 +- 25 files changed, 2889 insertions(+), 1058 deletions(-) create mode 100644 src/db/postgres.rs delete mode 100644 src/db/rustc_commits.rs create mode 100644 src/db/sqlite.rs create mode 100644 tests/db/issue_data.rs create mode 100644 tests/db/jobs.rs create mode 100644 tests/db/mod.rs create mode 100644 tests/db/notification.rs create mode 100644 tests/db/rustc_commits.rs diff --git a/Cargo.lock b/Cargo.lock index 6cf4662d..883566ca 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,17 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "ahash" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" +dependencies = [ + "getrandom", + "once_cell", + "version_check", +] + [[package]] name = "aho-corasick" version = "0.7.18" @@ -467,6 +478,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fastrand" version = "1.7.0" @@ -704,7 +721,7 @@ dependencies = [ "indexmap", "slab", "tokio", - "tokio-util 0.7.1", + "tokio-util", "tracing", ] @@ -714,6 +731,24 @@ version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +dependencies = [ + "ahash", +] + +[[package]] +name = "hashlink" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69fe1fcf8b4278d860ad0548329f892a3631fb63f82574df68275f34cdbe0ffa" +dependencies = [ + "hashbrown 0.12.3", +] + [[package]] name = "hermit-abi" version = "0.1.19" @@ -867,7 +902,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0f647032dfaa1f8b6dc29bd3edb7bbef4861b8b8007ebb118d6db284fd59f6ee" dependencies = [ "autocfg", - "hashbrown", + "hashbrown 0.11.2", "serde", ] @@ -928,6 +963,17 @@ version = "0.2.123" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb691a747a7ab48abc15c5b42066eaafde10dc427e3b6ee2a1cf43db04c763bd" +[[package]] +name = "libsqlite3-sys" +version = "0.25.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29f835d03d717946d28b1d1ed632eb6f0e24a299388ee623d0c23118d3e8a7fa" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "lock_api" version = "0.4.7" @@ -1171,27 +1217,25 @@ dependencies = [ [[package]] name = "parking_lot" -version = "0.11.2" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" dependencies = [ - "instant", "lock_api", "parking_lot_core", ] [[package]] name = "parking_lot_core" -version = "0.8.5" +version = "0.9.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d76e8e1493bcac0d2766c42737f34458f1c8c50c0d23bcb24ea953affb273216" +checksum = "9069cbb9f99e3a5083476ccb29ceb1de18b9118cafa53e90c9551235de2b9521" dependencies = [ "cfg-if", - "instant", "libc", "redox_syscall", "smallvec", - "winapi", + "windows-sys", ] [[package]] @@ -1254,18 +1298,18 @@ dependencies = [ [[package]] name = "phf" -version = "0.10.1" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fabbf1ead8a5bcbc20f5f8b939ee3f5b0f6f281b6ad3468b84656b658b455259" +checksum = "928c6535de93548188ef63bb7c4036bd415cd8f36ad25af44b9789b2ee72a48c" dependencies = [ "phf_shared", ] [[package]] name = "phf_shared" -version = "0.10.0" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6796ad771acdc0123d2a88dc428b5e38ef24456743ddb1744ed628f9815c096" +checksum = "e1fb5f6f826b772a8d4c0394209441e7d37cbbb967ae9c7e0e8134365c9ee676" dependencies = [ "siphasher", ] @@ -1363,7 +1407,8 @@ dependencies = [ "postgres-protocol", "serde", "serde_json", - "uuid", + "uuid 0.8.2", + "uuid 1.3.0", ] [[package]] @@ -1529,6 +1574,23 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "afab94fb28594581f62d981211a9a4d53cc8130bbcbbb89a0440d9b8e81a7746" +[[package]] +name = "rusqlite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01e213bc3ecb39ac32e81e51ebe31fd888a940515173e3a18a35f8c6e896422a" +dependencies = [ + "bitflags", + "chrono", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink", + "libsqlite3-sys", + "serde_json", + "smallvec", + "uuid 1.3.0", +] + [[package]] name = "rust_team_data" version = "1.0.0" @@ -1917,15 +1979,16 @@ dependencies = [ [[package]] name = "tokio-postgres" -version = "0.7.5" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b6c8b33df661b548dcd8f9bf87debb8c56c05657ed291122e1188698c2ece95" +checksum = "29a12c1b3e0704ae7dfc25562629798b29c72e6b1d0a681b6f29ab4ae5e7f7bf" dependencies = [ "async-trait", "byteorder", "bytes", "fallible-iterator", - "futures", + "futures-channel", + "futures-util", "log", "parking_lot", "percent-encoding", @@ -1935,21 +1998,7 @@ dependencies = [ "postgres-types", "socket2", "tokio", - "tokio-util 0.6.9", -] - -[[package]] -name = "tokio-util" -version = "0.6.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e99e1983e5d376cd8eb4b66604d2e99e79f5bd988c3055891dcd8c9e2604cc0" -dependencies = [ - "bytes", - "futures-core", - "futures-sink", - "log", - "pin-project-lite", - "tokio", + "tokio-util", ] [[package]] @@ -1986,7 +2035,7 @@ dependencies = [ "pin-project", "pin-project-lite", "tokio", - "tokio-util 0.7.1", + "tokio-util", "tower-layer", "tower-service", "tracing", @@ -2098,6 +2147,7 @@ dependencies = [ "regex", "reqwest", "route-recognizer", + "rusqlite", "rust_team_data", "serde", "serde_json", @@ -2110,7 +2160,7 @@ dependencies = [ "tracing", "tracing-subscriber", "url", - "uuid", + "uuid 1.3.0", ] [[package]] @@ -2272,6 +2322,12 @@ name = "uuid" version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7" + +[[package]] +name = "uuid" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1674845326ee10d37ca60470760d4288a6f80f304007d92e5c53bab78c9cfd79" dependencies = [ "getrandom", "serde", @@ -2447,6 +2503,72 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.42.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e2522491fbfcd58cc84d47aeb2958948c4b8982e9a2d8a2a35bbaed431390e7" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c9864e83243fdec7fc9c5444389dcbbfd258f745e7853198f365e3c4968a608" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c8b1b673ffc16c47a9ff48570a9d85e25d265735c503681332589af6253c6c7" + +[[package]] +name = "windows_i686_gnu" +version = "0.42.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de3887528ad530ba7bdbb1faa8275ec7a1155a45ffa57c37993960277145d640" + +[[package]] +name = "windows_i686_msvc" +version = "0.42.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf4d1122317eddd6ff351aa852118a2418ad4214e6613a50e0191f7004372605" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1040f221285e17ebccbc2591ffdc2d44ee1f9186324dd3e84e99ac68d699c45" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "628bfdf232daa22b0d64fdb62b09fcc36bb01f05a3939e20ab73aaf9470d0463" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "447660ad36a13288b1db4d4248e857b510e8c3a225c822ba4fb748c0aafecffd" + [[package]] name = "winreg" version = "0.10.1" diff --git a/Cargo.toml b/Cargo.toml index 30b92a8e..15decd9a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ hyper = { version = "0.14.4", features = ["server", "stream"]} tokio = { version = "1.7.1", features = ["macros", "time", "rt"] } futures = { version = "0.3", default-features = false, features = ["std"] } async-trait = "0.1.31" -uuid = { version = "0.8", features = ["v4", "serde"] } +uuid = { version = "1.3", features = ["v4", "serde"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } url = "2.1.0" @@ -42,9 +42,10 @@ tower = { version = "0.4.13", features = ["util", "limit", "buffer", "load-shed" github-graphql = { path = "github-graphql" } rand = "0.8.5" ignore = "0.4.18" -postgres-types = { version = "0.2.4", features = ["derive"] } +postgres-types = { version = "0.2.4", features = ["derive", "with-uuid-1"] } cron = { version = "0.12.0" } bytes = "1.1.0" +rusqlite = { version = "0.28.0", features = ["bundled", "chrono", "serde_json", "uuid"] } [dependencies.serde] version = "1" diff --git a/README.md b/README.md index eb4e6f3c..8d12aefb 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ The Triagebot webserver also includes several other endpoints intended for users Triagebot uses a Postgres database to retain some state. In production, it uses [RDS](https://aws.amazon.com/rds/). +For local testing you can use SQLite (see below). The server at https://triage.rust-lang.org/ runs on ECS and is configured via [Terraform](https://github.com/rust-lang/simpleinfra/blob/master/terraform/shared/services/triagebot/main.tf#L8). Updates are automatically deployed when merged to master. @@ -34,62 +35,113 @@ Some developers may settle with testing in production as the risks tend to be lo The general overview of what you will need to do: -1. Install Postgres. Look online for any help with installing and setting up Postgres (particularly if you need to create a user and set up permissions). -2. Create a database: `createdb triagebot` -3. Provide a way for GitHub to access the Triagebot webserver. - There are various ways to do this (such as placing it behind a proxy, or poking holes in your firewall). - Or, you can use a service such as https://ngrok.com/ to access on your local dev machine via localhost. - Installation is fairly simple, though requires setting up a (free) account. - Run the command `ngrok http 8000` to forward to port 8000 on localhost. - - > Note: GitHub has a webhook forwarding service available in beta. - > See [cli/gh-webhook](https://docs.github.com/en/developers/webhooks-and-events/webhooks/receiving-webhooks-with-the-github-cli) for more information. - > This is super easy to use, and doesn't require manually configuring webhook settings. - > The command to run looks something like: - > - > ```sh - > gh webhook forward --repo=ehuss/triagebot-test --events=* \ - > --url=http://127.0.0.1:8000/github-hook --secret somelongsekrit - > ``` - > - > Where the value in `--secret` is the secret value you place in `GITHUB_WEBHOOK_SECRET` described below, and `--repo` is the repo you want to test against. - -4. Create a GitHub repo to run some tests on. -5. Configure the webhook in your GitHub repo. - I recommend at least skimming the [GitHub webhook documentation](https://docs.github.com/en/developers/webhooks-and-events/webhooks/about-webhooks) if you are not familiar with webhooks. In short: - - 1. Go to the settings page. - 2. Go to the webhook section. - 3. Click "Add webhook" - 4. Include the settings: - - - Payload URL: This is the URL to your Triagebot server, for example http://7e9ea9dc.ngrok.io/github-hook. This URL is displayed when you ran the `ngrok` command above. - - Content type: application/json - - Secret: Enter a shared secret (some longish random text) - - Events: "Send me everything" -6. Configure the `.env` file: +1. Create a repo on GitHub to run tests on. +2. [Configure a database](#configure-a-database) +3. [Configure webhook forwarding](#configure-webhook-forwarding) +4. Configure the `.env` file: 1. Copy `.env.sample` to `.env` 2. `GITHUB_API_TOKEN`: This is a token needed for Triagebot to send requests to GitHub. Go to GitHub Settings > Developer Settings > Personal Access Token, and create a new token. The `repo` permission should be sufficient. If this is not set, Triagebot will also look in `~/.gitconfig` in the `github.oauth-token` setting. - 3. `DATABASE_URL`: This is the URL to the Postgres database. Something like `postgres://eric@localhost/triagebot` should work, replacing `eric` with your username. + 3. `DATABASE_URL`: This is the URL to the database. See [Configuring a database](#configuring-a-database). 4. `GITHUB_WEBHOOK_SECRET`: Enter the secret you entered in the webhook above. 5. `RUST_LOG`: Set this to `debug`. -7. Run `cargo run --bin triagebot`. This starts the http server listening on port 8000. -8. Add a `triagebot.toml` file to the main branch of your GitHub repo with whichever services you want to try out. -9. Try interacting with your repo, such as issuing `@rustbot` commands or interacting with PRs and issues (depending on which services you enabled in `triagebot.toml`). Watch the logs from the server to see what's going on. +5. Run `cargo run --bin triagebot`. This starts the http server listening for webhooks on port 8000. +6. Add a `triagebot.toml` file to the main branch of your GitHub repo with whichever services you want to try out. +7. Try interacting with your repo, such as issuing `@rustbot` commands or interacting with PRs and issues (depending on which services you enabled in `triagebot.toml`). Watch the logs from the server to see what's going on. + +### Configure a database + +For testing, it is probably easiest to use SQLite. +If you want something closer to production, then you might want to set up Postgres. + +#### SQLite + +To use SQLite, all you need to do is in the `.env` file set `DATABASE_URL` to a file: + +```bash +DATABASE_URL=db/triagebot.sqlite +``` + +If you have the [`sqlite3` CLI program](https://sqlite.org/cli.html) installed, you can use that to interactively run queries against the database with `sqlite3 db/triagebot.sqlite`. + +#### Postgres + +To use Postgres, you will need to install it and configure it: + +1. Install Postgres. Look online for any help with installing and setting up Postgres (particularly if you need to create a user and set up permissions). +2. Create a database: `createdb triagebot` +3. In the `.env` file, set the `DATABASE_URL`: + + ```sh + DATABASE_URL=postgres://eric@localhost/triagebot + ``` + + replacing `eric` with the username on your local system. + +### Configure webhook forwarding + +I recommend at least skimming the [GitHub webhook documentation](https://docs.github.com/en/developers/webhooks-and-events/webhooks/about-webhooks) if you are not familiar with webhooks. +In order for GitHub's webhooks to reach your triagebot server, you'll need to figure out some way to route them to your machine. +There are various options on how to do this. +You can poke holes into your firewall or use a proxy, but you shouldn't expose your machine to the the internet. +There are various services which help with this problem. +These generally involve running a program on your machine that connects to an external server which relays the hooks into your machine. +There are several to choose from: + +* [gh webhook](#gh-webhook) — This is a GitHub-native service, but it is currently in beta (getting access is easy, though). This is the easiest to use. +* [ngrok](#ngrok) — This is pretty easy to use, but requires setting up a free account. +* — This is another service recommended by GitHub. +* — This is another service recommended by GitHub. + +#### gh webhook + +The [`gh` CLI](https://github.com/cli/cli) is the official CLI tool which I highly recommend getting familiar with. +There is an official extension which provides webhook forwarding and also takes care of all the configuration. +See [cli/gh-webhook](https://docs.github.com/en/developers/webhooks-and-events/webhooks/receiving-webhooks-with-the-github-cli) for more information on installing it. + +This is super easy to use, and doesn't require manually configuring webhook settings. +The command to run looks something like: + +```sh +gh webhook forward --repo=ehuss/triagebot-test --events=* \ + --url=http://127.0.0.1:8000/github-hook --secret somelongsekrit +``` + +Where the value in `--secret` is the secret value you place in `GITHUB_WEBHOOK_SECRET` in the `.env` file, and `--repo` is the repo you want to test against. + +#### ngrok + +The following is an example of using to provide webhook forwarding. +You need to sign up for a free account, and also deal with configuring the GitHub webhook settings. + +1. Install ngrok. +2. Run `ngrok http 8000`. This will forward webhook events to localhost on port 8000. +3. Configure GitHub webhooks in the test repo you created. + In short: + + 1. Go to the settings page for your GitHub repo. + 2. Go to the webhook section. + 3. Click "Add webhook" + 4. Include the settings: + + * Payload URL: This is the URL to your Triagebot server, for example http://7e9ea9dc.ngrok.io/github-hook. This URL is displayed when you ran the `ngrok` command above. + * Content type: application/json + * Secret: Enter a shared secret (some longish random text) + * Events: "Send me everything" ## Tests When possible, writing unittests is very helpful and one of the easiest ways to test. For more advanced testing, there is an integration test called `testsuite` which provides an end-to-end service for testing triagebot. -There are two parts to it: +There are several parts to it: * [`github_client`](tests/github_client/mod.rs) — Tests specifically targeting `GithubClient`. This sets up an HTTP server that mimics api.github.com and verifies the client's behavior. * [`server_test`](tests/server_test/mod.rs) — This tests the `triagebot` server itself and its behavior when it receives a webhook. This launches the `triagebot` server, sets up HTTP servers to intercept api.github.com requests, launches PostgreSQL in a sandbox, and then injects webhook events into the `triagebot` server and validates its response. +* [`db`](tests/db/mod.rs) — These are tests for the database API. The real GitHub API responses are recorded in JSON files that the tests can later replay to verify the behavior of triagebot. These recordings are enabled with the `TRIAGEBOT_TEST_RECORD` environment variable. diff --git a/src/db.rs b/src/db.rs index a696be99..f6a9afd7 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,276 +1,204 @@ -use crate::handlers::jobs::handle_job; -use crate::{db::jobs::*, handlers::Context}; -use anyhow::Context as _; -use chrono::Utc; -use native_tls::{Certificate, TlsConnector}; -use postgres_native_tls::MakeTlsConnector; +use self::jobs::Job; +use self::notifications::{Identifier, Notification, NotificationData}; +use anyhow::Result; +use chrono::{DateTime, FixedOffset, Utc}; +use serde::Serialize; use std::sync::{Arc, Mutex}; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; -use tokio_postgres::Client as DbClient; +use uuid::Uuid; pub mod issue_data; pub mod jobs; pub mod notifications; -pub mod rustc_commits; +pub mod postgres; +pub mod sqlite; + +/// A bors merge commit. +#[derive(Debug, Serialize)] +pub struct Commit { + pub sha: String, + pub parent_sha: String, + pub time: DateTime, + pub pr: Option, +} + +#[async_trait::async_trait] +pub trait Connection: Send + Sync { + async fn transaction(&mut self) -> Box; + + // Pings + async fn record_username(&mut self, user_id: i64, username: String) -> Result<()>; + async fn record_ping(&mut self, notification: &Notification) -> Result<()>; + + // Rustc commits + async fn get_missing_commits(&mut self) -> Result>; + async fn record_commit(&mut self, commit: &Commit) -> Result<()>; + async fn has_commit(&mut self, sha: &str) -> Result; + async fn get_commits_with_artifacts(&mut self) -> Result>; + + // Notifications + async fn get_notifications(&mut self, username: &str) -> Result>; + async fn delete_ping( + &mut self, + user_id: i64, + identifier: Identifier<'_>, + ) -> Result>; + async fn add_metadata( + &mut self, + user_id: i64, + idx: usize, + metadata: Option<&str>, + ) -> Result<()>; + async fn move_indices(&mut self, user_id: i64, from: usize, to: usize) -> Result<()>; + + // Jobs + async fn insert_job( + &mut self, + name: &str, + scheduled_at: &DateTime, + metadata: &serde_json::Value, + ) -> Result<()>; + async fn delete_job(&mut self, id: &Uuid) -> Result<()>; + async fn update_job_error_message(&mut self, id: &Uuid, message: &str) -> Result<()>; + async fn update_job_executed_at(&mut self, id: &Uuid) -> Result<()>; + async fn get_job_by_name_and_scheduled_at( + &mut self, + name: &str, + scheduled_at: &DateTime, + ) -> Result; + async fn get_jobs_to_execute(&mut self) -> Result>; + + // Issue data + async fn lock_and_load_issue_data( + &mut self, + repo: &str, + issue_number: i32, + key: &str, + ) -> Result<(Box, Option)>; + async fn save_issue_data( + &mut self, + repo: &str, + issue_number: i32, + key: &str, + data: &serde_json::Value, + ) -> Result<()>; +} -const CERT_URL: &str = "https://s3.amazonaws.com/rds-downloads/rds-ca-2019-root.pem"; +#[async_trait::async_trait] +pub trait Transaction: Send + Sync { + fn conn(&mut self) -> &mut dyn Connection; + fn conn_ref(&self) -> &dyn Connection; + + async fn commit(self: Box) -> Result<(), anyhow::Error>; + async fn finish(self: Box) -> Result<(), anyhow::Error>; +} -lazy_static::lazy_static! { - static ref CERTIFICATE_PEM: Vec = { - let client = reqwest::blocking::Client::new(); - let resp = client - .get(CERT_URL) - .send() - .expect("failed to get RDS cert"); - resp.bytes().expect("failed to get RDS cert body").to_vec() - }; +#[async_trait::async_trait] +pub trait ConnectionManager { + type Connection; + async fn open(&self) -> Self::Connection; + async fn is_valid(&self, c: &mut Self::Connection) -> bool; } -pub struct ClientPool { - connections: Arc>>, +pub struct ConnectionPool { + connections: Arc>>, permits: Arc, + manager: M, } -pub struct PooledClient { - client: Option, - #[allow(unused)] // only used for drop impl +pub struct ManagedConnection { + conn: Option, + connections: Arc>>, + #[allow(unused)] permit: OwnedSemaphorePermit, - pool: Arc>>, } -impl Drop for PooledClient { - fn drop(&mut self) { - let mut clients = self.pool.lock().unwrap_or_else(|e| e.into_inner()); - clients.push(self.client.take().unwrap()); +impl std::ops::Deref for ManagedConnection { + type Target = T; + fn deref(&self) -> &Self::Target { + self.conn.as_ref().unwrap() } } - -impl std::ops::Deref for PooledClient { - type Target = tokio_postgres::Client; - - fn deref(&self) -> &Self::Target { - self.client.as_ref().unwrap() +impl std::ops::DerefMut for ManagedConnection { + fn deref_mut(&mut self) -> &mut Self::Target { + self.conn.as_mut().unwrap() } } -impl std::ops::DerefMut for PooledClient { - fn deref_mut(&mut self) -> &mut Self::Target { - self.client.as_mut().unwrap() +impl Drop for ManagedConnection { + fn drop(&mut self) { + let conn = self.conn.take().unwrap(); + self.connections + .lock() + .unwrap_or_else(|e| e.into_inner()) + .push(conn); } } -impl ClientPool { - pub fn new() -> ClientPool { - ClientPool { +impl ConnectionPool +where + T: Send, + M: ConnectionManager, +{ + fn new(manager: M) -> Self { + ConnectionPool { connections: Arc::new(Mutex::new(Vec::with_capacity(16))), permits: Arc::new(Semaphore::new(16)), + manager, } } - pub async fn get(&self) -> PooledClient { + pub fn raw(&mut self) -> &mut M { + &mut self.manager + } + + async fn get(&self) -> ManagedConnection { let permit = self.permits.clone().acquire_owned().await.unwrap(); - { + let conn = { let mut slots = self.connections.lock().unwrap_or_else(|e| e.into_inner()); - // Pop connections until we hit a non-closed connection (or there are no - // "possibly open" connections left). - while let Some(c) = slots.pop() { - if !c.is_closed() { - return PooledClient { - client: Some(c), - permit, - pool: self.connections.clone(), - }; - } + slots.pop() + }; + if let Some(mut c) = conn { + if self.manager.is_valid(&mut c).await { + return ManagedConnection { + conn: Some(c), + permit, + connections: self.connections.clone(), + }; } } - PooledClient { - client: Some(make_client().await.unwrap()), + let conn = self.manager.open().await; + ManagedConnection { + conn: Some(conn), + connections: self.connections.clone(), permit, - pool: self.connections.clone(), } } } -async fn make_client() -> anyhow::Result { - let db_url = std::env::var("DATABASE_URL").expect("needs DATABASE_URL"); - if db_url.contains("rds.amazonaws.com") { - let cert = &CERTIFICATE_PEM[..]; - let cert = Certificate::from_pem(&cert).context("made certificate")?; - let connector = TlsConnector::builder() - .add_root_certificate(cert) - .build() - .context("built TlsConnector")?; - let connector = MakeTlsConnector::new(connector); - - let (db_client, connection) = match tokio_postgres::connect(&db_url, connector).await { - Ok(v) => v, - Err(e) => { - anyhow::bail!("failed to connect to DB: {}", e); - } - }; - tokio::task::spawn(async move { - if let Err(e) = connection.await { - eprintln!("database connection error: {}", e); - } - }); - - Ok(db_client) - } else { - eprintln!("Warning: Non-TLS connection to non-RDS DB"); - let (db_client, connection) = - match tokio_postgres::connect(&db_url, tokio_postgres::NoTls).await { - Ok(v) => v, - Err(e) => { - anyhow::bail!("failed to connect to DB: {}", e); - } - }; - tokio::spawn(async move { - if let Err(e) = connection.await { - eprintln!("database connection error: {}", e); - } - }); - - Ok(db_client) - } +pub enum Pool { + Sqlite(ConnectionPool), + Postgres(ConnectionPool), } -pub async fn run_migrations(client: &DbClient) -> anyhow::Result<()> { - client - .execute( - "CREATE TABLE IF NOT EXISTS database_versions ( - zero INTEGER PRIMARY KEY, - migration_counter INTEGER - );", - &[], - ) - .await - .context("creating database versioning table")?; - - client - .execute( - "INSERT INTO database_versions (zero, migration_counter) - VALUES (0, 0) - ON CONFLICT DO NOTHING", - &[], - ) - .await - .context("inserting initial database_versions")?; - - let migration_idx: i32 = client - .query_one("SELECT migration_counter FROM database_versions", &[]) - .await - .context("getting migration counter")? - .get(0); - let migration_idx = migration_idx as usize; - - for (idx, migration) in MIGRATIONS.iter().enumerate() { - if idx >= migration_idx { - client - .execute(*migration, &[]) - .await - .with_context(|| format!("executing {}th migration", idx))?; - client - .execute( - "UPDATE database_versions SET migration_counter = $1", - &[&(idx as i32 + 1)], - ) - .await - .with_context(|| format!("updating migration counter to {}", idx))?; +impl Pool { + pub async fn connection(&self) -> Box { + match self { + Pool::Sqlite(p) => Box::new(sqlite::SqliteConnection::new(p.get().await)), + Pool::Postgres(p) => Box::new(p.get().await), } } - Ok(()) -} - -pub async fn schedule_jobs(db: &DbClient, jobs: Vec) -> anyhow::Result<()> { - for job in jobs { - let mut upcoming = job.schedule.upcoming(Utc).take(1); - - if let Some(scheduled_at) = upcoming.next() { - if let Err(_) = get_job_by_name_and_scheduled_at(&db, &job.name, &scheduled_at).await { - // mean there's no job already in the db with that name and scheduled_at - insert_job(&db, &job.name, &scheduled_at, &job.metadata).await?; - } + pub fn open(uri: &str) -> Pool { + if uri.starts_with("postgres") { + Pool::Postgres(ConnectionPool::new(postgres::Postgres::new(uri.into()))) + } else { + Pool::Sqlite(ConnectionPool::new(sqlite::Sqlite::new(uri.into()))) } } - Ok(()) -} - -pub async fn run_scheduled_jobs(ctx: &Context, db: &DbClient) -> anyhow::Result<()> { - let jobs = get_jobs_to_execute(&db).await.unwrap(); - tracing::trace!("jobs to execute: {:#?}", jobs); - - for job in jobs.iter() { - update_job_executed_at(&db, &job.id).await?; - - match handle_job(&ctx, &job.name, &job.metadata).await { - Ok(_) => { - tracing::trace!("job successfully executed (id={})", job.id); - delete_job(&db, &job.id).await?; - } - Err(e) => { - tracing::error!("job failed on execution (id={:?}, error={:?})", job.id, e); - update_job_error_message(&db, &job.id, &e.to_string()).await?; - } - } + pub fn new_from_env() -> Pool { + Self::open(&std::env::var("DATABASE_URL").expect("needs DATABASE_URL")) } - - Ok(()) } - -static MIGRATIONS: &[&str] = &[ - " -CREATE TABLE notifications ( - notification_id BIGSERIAL PRIMARY KEY, - user_id BIGINT, - origin_url TEXT NOT NULL, - origin_html TEXT, - time TIMESTAMP WITH TIME ZONE -); -", - " -CREATE TABLE users ( - user_id BIGINT PRIMARY KEY, - username TEXT NOT NULL -); -", - "ALTER TABLE notifications ADD COLUMN short_description TEXT;", - "ALTER TABLE notifications ADD COLUMN team_name TEXT;", - "ALTER TABLE notifications ADD COLUMN idx INTEGER;", - "ALTER TABLE notifications ADD COLUMN metadata TEXT;", - " -CREATE TABLE rustc_commits ( - sha TEXT PRIMARY KEY, - parent_sha TEXT NOT NULL, - time TIMESTAMP WITH TIME ZONE -); -", - "ALTER TABLE rustc_commits ADD COLUMN pr INTEGER;", - " -CREATE TABLE issue_data ( - repo TEXT, - issue_number INTEGER, - key TEXT, - data JSONB, - PRIMARY KEY (repo, issue_number, key) -); -", - " -CREATE TABLE jobs ( - id UUID DEFAULT gen_random_uuid() PRIMARY KEY, - name TEXT NOT NULL, - scheduled_at TIMESTAMP WITH TIME ZONE NOT NULL, - metadata JSONB, - executed_at TIMESTAMP WITH TIME ZONE, - error_message TEXT -); -", - " -CREATE UNIQUE INDEX jobs_name_scheduled_at_unique_index - ON jobs ( - name, scheduled_at - ); -", -]; diff --git a/src/db/issue_data.rs b/src/db/issue_data.rs index 4f2d43a0..bad50181 100644 --- a/src/db/issue_data.rs +++ b/src/db/issue_data.rs @@ -8,17 +8,15 @@ //! Note that this uses crude locking, so try to keep the duration between //! loading and saving to a minimum. -use crate::github::Issue; -use anyhow::{Context, Result}; +use crate::db; +use anyhow::Result; use serde::{Deserialize, Serialize}; -use tokio_postgres::types::Json; -use tokio_postgres::{Client as DbClient, Transaction}; pub struct IssueData<'db, T> where T: for<'a> Deserialize<'a> + Serialize + Default + std::fmt::Debug + Sync, { - transaction: Transaction<'db>, + transaction: Box, repo: String, issue_number: i32, key: String, @@ -30,27 +28,18 @@ where T: for<'a> Deserialize<'a> + Serialize + Default + std::fmt::Debug + Sync, { pub async fn load( - db: &'db mut DbClient, - issue: &Issue, + connection: &'db mut dyn db::Connection, + repo: String, + issue_number: i32, key: &str, ) -> Result> { - let repo = issue.repository().to_string(); - let issue_number = issue.number as i32; - let transaction = db.transaction().await?; - transaction - .execute("LOCK TABLE issue_data", &[]) - .await - .context("locking issue data")?; - let data = transaction - .query_opt( - "SELECT data FROM issue_data WHERE \ - repo = $1 AND issue_number = $2 AND key = $3", - &[&repo, &issue_number, &key], - ) - .await - .context("selecting issue data")? - .map(|row| row.get::>(0).0) - .unwrap_or_default(); + let (transaction, raw) = connection + .lock_and_load_issue_data(&repo, issue_number, key) + .await?; + let data = match raw { + Some(raw) => T::deserialize(raw)?, + None => T::default(), + }; Ok(IssueData { transaction, repo, @@ -60,20 +49,13 @@ where }) } - pub async fn save(self) -> Result<()> { + pub async fn save(mut self) -> Result<()> { + let raw_data = serde_json::to_value(self.data)?; self.transaction - .execute( - "INSERT INTO issue_data (repo, issue_number, key, data) \ - VALUES ($1, $2, $3, $4) \ - ON CONFLICT (repo, issue_number, key) DO UPDATE SET data=EXCLUDED.data", - &[&self.repo, &self.issue_number, &self.key, &Json(&self.data)], - ) - .await - .context("inserting issue data")?; - self.transaction - .commit() - .await - .context("committing issue data")?; + .conn() + .save_issue_data(&self.repo, self.issue_number, &self.key, &raw_data) + .await?; + self.transaction.commit().await?; Ok(()) } } diff --git a/src/db/jobs.rs b/src/db/jobs.rs index 5db66b0f..51621d3e 100644 --- a/src/db/jobs.rs +++ b/src/db/jobs.rs @@ -1,9 +1,13 @@ //! The `jobs` table provides a way to have scheduled jobs -use anyhow::{Context as _, Result}; -use chrono::{DateTime, Utc}; + +use crate::db::Connection; +use crate::handlers::jobs::handle_job; +use crate::handlers::Context; +use anyhow::Result; +use chrono::DateTime; +use chrono::Utc; use cron::Schedule; use serde::{Deserialize, Serialize}; -use tokio_postgres::Client as DbClient; use uuid::Uuid; pub struct JobSchedule { @@ -22,116 +26,46 @@ pub struct Job { pub error_message: Option, } -pub async fn insert_job( - db: &DbClient, - name: &String, - scheduled_at: &DateTime, - metadata: &serde_json::Value, -) -> Result<()> { - tracing::trace!("insert_job(name={})", name); - - db.execute( - "INSERT INTO jobs (name, scheduled_at, metadata) VALUES ($1, $2, $3) - ON CONFLICT (name, scheduled_at) DO UPDATE SET metadata = EXCLUDED.metadata", - &[&name, &scheduled_at, &metadata], - ) - .await - .context("Inserting job")?; - - Ok(()) -} - -pub async fn delete_job(db: &DbClient, id: &Uuid) -> Result<()> { - tracing::trace!("delete_job(id={})", id); - - db.execute("DELETE FROM jobs WHERE id = $1", &[&id]) - .await - .context("Deleting job")?; - - Ok(()) -} - -pub async fn update_job_error_message(db: &DbClient, id: &Uuid, message: &String) -> Result<()> { - tracing::trace!("update_job_error_message(id={})", id); - - db.execute( - "UPDATE jobs SET error_message = $2 WHERE id = $1", - &[&id, &message], - ) - .await - .context("Updating job error message")?; - - Ok(()) -} - -pub async fn update_job_executed_at(db: &DbClient, id: &Uuid) -> Result<()> { - tracing::trace!("update_job_executed_at(id={})", id); - - db.execute("UPDATE jobs SET executed_at = now() WHERE id = $1", &[&id]) - .await - .context("Updating job executed at")?; - - Ok(()) -} - -pub async fn get_job_by_name_and_scheduled_at( - db: &DbClient, - name: &String, - scheduled_at: &DateTime, -) -> Result { - tracing::trace!( - "get_job_by_name_and_scheduled_at(name={}, scheduled_at={})", - name, - scheduled_at - ); - - let job = db - .query_one( - "SELECT * FROM jobs WHERE name = $1 AND scheduled_at = $2", - &[&name, &scheduled_at], - ) - .await - .context("Select job by name and scheduled at")?; - - deserialize_job(&job) -} - -// Selects all jobs with: -// - scheduled_at in the past -// - error_message is null or executed_at is at least 60 minutes ago (intended to make repeat executions rare enough) -pub async fn get_jobs_to_execute(db: &DbClient) -> Result> { - let jobs = db - .query( - " - SELECT * FROM jobs WHERE scheduled_at <= now() AND (error_message IS NULL OR executed_at <= now() - INTERVAL '60 minutes')", - &[], - ) - .await - .context("Getting jobs data")?; - - let mut data = Vec::with_capacity(jobs.len()); +pub async fn schedule_jobs(connection: &mut dyn Connection, jobs: Vec) -> Result<()> { for job in jobs { - let serialized_job = deserialize_job(&job); - data.push(serialized_job.unwrap()); + let mut upcoming = job.schedule.upcoming(Utc).take(1); + + if let Some(scheduled_at) = upcoming.next() { + if let Err(_) = connection + .get_job_by_name_and_scheduled_at(&job.name, &scheduled_at) + .await + { + // mean there's no job already in the db with that name and scheduled_at + connection + .insert_job(&job.name, &scheduled_at, &job.metadata) + .await?; + } + } } - Ok(data) + Ok(()) } -fn deserialize_job(row: &tokio_postgres::row::Row) -> Result { - let id: Uuid = row.try_get(0)?; - let name: String = row.try_get(1)?; - let scheduled_at: DateTime = row.try_get(2)?; - let metadata: serde_json::Value = row.try_get(3)?; - let executed_at: Option> = row.try_get(4)?; - let error_message: Option = row.try_get(5)?; +pub async fn run_scheduled_jobs(ctx: &Context, connection: &mut dyn Connection) -> Result<()> { + let jobs = connection.get_jobs_to_execute().await.unwrap(); + tracing::trace!("jobs to execute: {:#?}", jobs); + + for job in jobs.iter() { + connection.update_job_executed_at(&job.id).await?; + + match handle_job(&ctx, &job.name, &job.metadata).await { + Ok(_) => { + tracing::trace!("job successfully executed (id={})", job.id); + connection.delete_job(&job.id).await?; + } + Err(e) => { + tracing::error!("job failed on execution (id={:?}, error={:?})", job.id, e); + connection + .update_job_error_message(&job.id, &e.to_string()) + .await?; + } + } + } - Ok(Job { - id, - name, - scheduled_at, - metadata, - executed_at, - error_message, - }) + Ok(()) } diff --git a/src/db/notifications.rs b/src/db/notifications.rs index 5b185793..a675f089 100644 --- a/src/db/notifications.rs +++ b/src/db/notifications.rs @@ -1,8 +1,10 @@ -use anyhow::Context as _; +//! Database support for the notifications feature for tracking `@` mentions. +//! +//! See + use chrono::{DateTime, FixedOffset}; -use tokio_postgres::Client as DbClient; -use tracing as log; +/// Tracking `@` mentions for users in issues/PRs. pub struct Notification { pub user_id: i64, pub origin_url: String, @@ -15,160 +17,7 @@ pub struct Notification { pub team_name: Option, } -pub async fn record_username(db: &DbClient, user_id: i64, username: String) -> anyhow::Result<()> { - db.execute( - "INSERT INTO users (user_id, username) VALUES ($1, $2) ON CONFLICT DO NOTHING", - &[&user_id, &username], - ) - .await - .context("inserting user id / username")?; - Ok(()) -} - -pub async fn record_ping(db: &DbClient, notification: &Notification) -> anyhow::Result<()> { - db.execute("INSERT INTO notifications (user_id, origin_url, origin_html, time, short_description, team_name, idx) - VALUES ( - $1, $2, $3, $4, $5, $6, - (SELECT max(notifications.idx) + 1 from notifications where notifications.user_id = $1) - )", - &[¬ification.user_id, ¬ification.origin_url, ¬ification.origin_html, ¬ification.time, ¬ification.short_description, ¬ification.team_name], - ).await.context("inserting notification")?; - - Ok(()) -} - -#[derive(Copy, Clone)] -pub enum Identifier<'a> { - Url(&'a str), - Index(std::num::NonZeroUsize), - /// Glob identifier (`all` or `*`). - All, -} - -pub async fn delete_ping( - db: &mut DbClient, - user_id: i64, - identifier: Identifier<'_>, -) -> anyhow::Result> { - match identifier { - Identifier::Url(origin_url) => { - let rows = db - .query( - "DELETE FROM notifications WHERE user_id = $1 and origin_url = $2 - RETURNING origin_html, time, short_description, metadata", - &[&user_id, &origin_url], - ) - .await - .context("delete notification query")?; - Ok(rows - .into_iter() - .map(|row| { - let origin_text: String = row.get(0); - let time: DateTime = row.get(1); - let short_description: Option = row.get(2); - let metadata: Option = row.get(3); - NotificationData { - origin_url: origin_url.to_owned(), - origin_text, - time, - short_description, - metadata, - } - }) - .collect()) - } - Identifier::Index(idx) => loop { - let t = db - .build_transaction() - .isolation_level(tokio_postgres::IsolationLevel::Serializable) - .start() - .await - .context("begin transaction")?; - - let notifications = t - .query( - "select notification_id, idx, user_id - from notifications - where user_id = $1 - order by idx asc nulls last;", - &[&user_id], - ) - .await - .context("failed to get ordering")?; - - let notification_id: i64 = notifications - .get(idx.get() - 1) - .ok_or_else(|| anyhow::anyhow!("No such notification with index {}", idx.get()))? - .get(0); - - let row = t - .query_one( - "DELETE FROM notifications WHERE notification_id = $1 - RETURNING origin_url, origin_html, time, short_description, metadata", - &[¬ification_id], - ) - .await - .context(format!( - "Failed to delete notification with id {}", - notification_id - ))?; - - let origin_url: String = row.get(0); - let origin_text: String = row.get(1); - let time: DateTime = row.get(2); - let short_description: Option = row.get(3); - let metadata: Option = row.get(4); - let deleted_notification = NotificationData { - origin_url, - origin_text, - time, - short_description, - metadata, - }; - - if let Err(e) = t.commit().await { - if e.code().map_or(false, |c| { - *c == tokio_postgres::error::SqlState::T_R_SERIALIZATION_FAILURE - }) { - log::trace!("serialization failure, restarting deletion"); - continue; - } else { - return Err(e).context("transaction commit failure"); - } - } else { - return Ok(vec![deleted_notification]); - } - }, - Identifier::All => { - let rows = db - .query( - "DELETE FROM notifications WHERE user_id = $1 - RETURNING origin_url, origin_html, time, short_description, metadata", - &[&user_id], - ) - .await - .context("delete all notifications query")?; - Ok(rows - .into_iter() - .map(|row| { - let origin_url: String = row.get(0); - let origin_text: String = row.get(1); - let time: DateTime = row.get(2); - let short_description: Option = row.get(3); - let metadata: Option = row.get(4); - NotificationData { - origin_url, - origin_text, - time, - short_description, - metadata, - } - }) - .collect()) - } - } -} - +/// Metadata associated with an `@` notification that the user can set via Zulip. #[derive(Debug)] pub struct NotificationData { pub origin_url: String, @@ -178,179 +27,11 @@ pub struct NotificationData { pub metadata: Option, } -pub async fn move_indices( - db: &mut DbClient, - user_id: i64, - from: usize, - to: usize, -) -> anyhow::Result<()> { - loop { - let t = db - .build_transaction() - .isolation_level(tokio_postgres::IsolationLevel::Serializable) - .start() - .await - .context("begin transaction")?; - - let notifications = t - .query( - "select notification_id, idx, user_id - from notifications - where user_id = $1 - order by idx asc nulls last;", - &[&user_id], - ) - .await - .context("failed to get initial ordering")?; - - let mut notifications = notifications - .into_iter() - .map(|n| n.get(0)) - .collect::>(); - - if notifications.get(from).is_none() { - anyhow::bail!( - "`from` index not present, must be less than {}", - notifications.len() - ); - } - - if notifications.get(to).is_none() { - anyhow::bail!( - "`to` index not present, must be less than {}", - notifications.len() - ); - } - - if from < to { - notifications[from..=to].rotate_left(1); - } else if to < from { - notifications[to..=from].rotate_right(1); - } - - for (idx, id) in notifications.into_iter().enumerate() { - t.execute( - "update notifications SET idx = $2 - where notification_id = $1", - &[&id, &(idx as i32)], - ) - .await - .context("update notification id")?; - } - - if let Err(e) = t.commit().await { - if e.code().map_or(false, |c| { - *c == tokio_postgres::error::SqlState::T_R_SERIALIZATION_FAILURE - }) { - log::trace!("serialization failure, restarting index movement"); - continue; - } else { - return Err(e).context("transaction commit failure"); - } - } else { - break; - } - } - - Ok(()) -} - -pub async fn add_metadata( - db: &mut DbClient, - user_id: i64, - idx: usize, - metadata: Option<&str>, -) -> anyhow::Result<()> { - loop { - let t = db - .build_transaction() - .isolation_level(tokio_postgres::IsolationLevel::Serializable) - .start() - .await - .context("begin transaction")?; - - let notifications = t - .query( - "select notification_id, idx, user_id - from notifications - where user_id = $1 - order by idx asc nulls last;", - &[&user_id], - ) - .await - .context("failed to get initial ordering")?; - - let notifications = notifications - .into_iter() - .map(|n| n.get(0)) - .collect::>(); - - match notifications.get(idx) { - None => anyhow::bail!( - "index not present, must be less than {}", - notifications.len() - ), - Some(id) => { - t.execute( - "update notifications SET metadata = $2 - where notification_id = $1", - &[&id, &metadata], - ) - .await - .context("update notification id")?; - } - } - - if let Err(e) = t.commit().await { - if e.code().map_or(false, |c| { - *c == tokio_postgres::error::SqlState::T_R_SERIALIZATION_FAILURE - }) { - log::trace!("serialization failure, restarting index movement"); - continue; - } else { - return Err(e).context("transaction commit failure"); - } - } else { - break; - } - } - - Ok(()) -} - -pub async fn get_notifications( - db: &DbClient, - username: &str, -) -> anyhow::Result> { - let notifications = db - .query( - " - select username, origin_url, origin_html, time, short_description, idx, metadata - from notifications - join users on notifications.user_id = users.user_id - where username = $1 - order by notifications.idx asc nulls last;", - &[&username], - ) - .await - .context("Getting notification data")?; - - let mut data = Vec::new(); - for notification in notifications { - let origin_url: String = notification.get(1); - let origin_text: String = notification.get(2); - let time: DateTime = notification.get(3); - let short_description: Option = notification.get(4); - let metadata: Option = notification.get(6); - - data.push(NotificationData { - origin_url, - origin_text, - short_description, - time, - metadata, - }); - } - - Ok(data) +/// Selector for deleting `@` notifications. +#[derive(Copy, Clone)] +pub enum Identifier<'a> { + Url(&'a str), + Index(std::num::NonZeroUsize), + /// Glob identifier (`all` or `*`). + All, } diff --git a/src/db/postgres.rs b/src/db/postgres.rs new file mode 100644 index 00000000..b9b96d8c --- /dev/null +++ b/src/db/postgres.rs @@ -0,0 +1,841 @@ +use super::{Commit, Identifier, Job, Notification, NotificationData}; +use crate::db::{Connection, ConnectionManager, ManagedConnection, Transaction}; +use anyhow::Context as _; +use anyhow::Result; +use chrono::Utc; +use chrono::{DateTime, FixedOffset}; +use native_tls::{Certificate, TlsConnector}; +use postgres_native_tls::MakeTlsConnector; +use tokio_postgres::types::Json; +use tokio_postgres::{GenericClient, TransactionBuilder}; +use tracing::trace; +use uuid::Uuid; + +pub struct Postgres(String, std::sync::Once); + +impl Postgres { + pub fn new(url: String) -> Self { + Postgres(url, std::sync::Once::new()) + } +} + +const CERT_URL: &str = "https://s3.amazonaws.com/rds-downloads/rds-ca-2019-root.pem"; + +lazy_static::lazy_static! { + static ref CERTIFICATE_PEM: Vec = { + let client = reqwest::blocking::Client::new(); + let resp = client + .get(CERT_URL) + .send() + .expect("failed to get RDS cert"); + resp.bytes().expect("failed to get RDS cert body").to_vec() + }; +} + +async fn make_client(db_url: &str) -> Result { + if db_url.contains("rds.amazonaws.com") { + let cert = &CERTIFICATE_PEM[..]; + let cert = Certificate::from_pem(&cert).context("made certificate")?; + let connector = TlsConnector::builder() + .add_root_certificate(cert) + .build() + .context("built TlsConnector")?; + let connector = MakeTlsConnector::new(connector); + + let (db_client, connection) = match tokio_postgres::connect(&db_url, connector).await { + Ok(v) => v, + Err(e) => { + anyhow::bail!("failed to connect to DB: {}", e); + } + }; + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("database connection error: {}", e); + } + }); + + Ok(db_client) + } else { + eprintln!("Warning: Non-TLS connection to non-RDS DB"); + let (db_client, connection) = + match tokio_postgres::connect(&db_url, tokio_postgres::NoTls).await { + Ok(v) => v, + Err(e) => { + anyhow::bail!("failed to connect to DB: {}", e); + } + }; + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("database connection error: {}", e); + } + }); + + Ok(db_client) + } +} + +static MIGRATIONS: &[&str] = &[ + " +CREATE TABLE notifications ( + notification_id BIGSERIAL PRIMARY KEY, + user_id BIGINT, + origin_url TEXT NOT NULL, + origin_html TEXT, + time TIMESTAMP WITH TIME ZONE +); +", + " +CREATE TABLE users ( + user_id BIGINT PRIMARY KEY, + username TEXT NOT NULL +); +", + "ALTER TABLE notifications ADD COLUMN short_description TEXT;", + "ALTER TABLE notifications ADD COLUMN team_name TEXT;", + "ALTER TABLE notifications ADD COLUMN idx INTEGER;", + "ALTER TABLE notifications ADD COLUMN metadata TEXT;", + " +CREATE TABLE rustc_commits ( + sha TEXT PRIMARY KEY, + parent_sha TEXT NOT NULL, + time TIMESTAMP WITH TIME ZONE +); +", + "ALTER TABLE rustc_commits ADD COLUMN pr INTEGER;", + " +CREATE TABLE issue_data ( + repo TEXT, + issue_number INTEGER, + key TEXT, + data JSONB, + PRIMARY KEY (repo, issue_number, key) +); +", + " +CREATE TABLE jobs ( + id UUID DEFAULT gen_random_uuid() PRIMARY KEY, + name TEXT NOT NULL, + scheduled_at TIMESTAMP WITH TIME ZONE NOT NULL, + metadata JSONB, + executed_at TIMESTAMP WITH TIME ZONE, + error_message TEXT +); +", + " +CREATE UNIQUE INDEX jobs_name_scheduled_at_unique_index + ON jobs ( + name, scheduled_at + ); +", +]; + +#[async_trait::async_trait] +impl ConnectionManager for Postgres { + type Connection = PostgresConnection; + async fn open(&self) -> Self::Connection { + let client = make_client(&self.0).await.unwrap(); + let mut should_init = false; + self.1.call_once(|| { + should_init = true; + }); + if should_init { + run_migrations(&client).await.unwrap(); + } + PostgresConnection::new(client).await + } + async fn is_valid(&self, conn: &mut Self::Connection) -> bool { + !conn.conn.is_closed() + } +} + +pub async fn run_migrations(client: &tokio_postgres::Client) -> Result<()> { + client + .execute( + "CREATE TABLE IF NOT EXISTS database_versions ( + zero INTEGER PRIMARY KEY, + migration_counter INTEGER + );", + &[], + ) + .await + .context("creating database versioning table")?; + + client + .execute( + "INSERT INTO database_versions (zero, migration_counter) + VALUES (0, 0) + ON CONFLICT DO NOTHING", + &[], + ) + .await + .context("inserting initial database_versions")?; + + let migration_idx: i32 = client + .query_one("SELECT migration_counter FROM database_versions", &[]) + .await + .context("getting migration counter")? + .get(0); + let migration_idx = migration_idx as usize; + + for (idx, migration) in MIGRATIONS.iter().enumerate() { + if idx >= migration_idx { + client + .execute(*migration, &[]) + .await + .with_context(|| format!("executing {}th migration", idx))?; + client + .execute( + "UPDATE database_versions SET migration_counter = $1", + &[&(idx as i32 + 1)], + ) + .await + .with_context(|| format!("updating migration counter to {}", idx))?; + } + } + + Ok(()) +} + +#[async_trait::async_trait] +impl<'a> Transaction for PostgresTransaction<'a> { + async fn commit(self: Box) -> Result<(), anyhow::Error> { + Ok(self.conn.commit().await?) + } + async fn finish(self: Box) -> Result<(), anyhow::Error> { + Ok(self.conn.rollback().await?) + } + fn conn(&mut self) -> &mut dyn Connection { + self + } + fn conn_ref(&self) -> &dyn Connection { + self + } +} + +pub struct PostgresTransaction<'a> { + conn: tokio_postgres::Transaction<'a>, +} + +pub struct PostgresConnection { + conn: tokio_postgres::Client, +} + +impl Into for PostgresConnection { + fn into(self) -> tokio_postgres::Client { + self.conn + } +} + +pub trait PClient { + type Client: Send + Sync + tokio_postgres::GenericClient; + fn conn(&self) -> &Self::Client; + fn conn_mut(&mut self) -> &mut Self::Client; + fn build_transaction(&mut self) -> TransactionBuilder; +} + +impl<'a> PClient for PostgresTransaction<'a> { + type Client = tokio_postgres::Transaction<'a>; + fn conn(&self) -> &Self::Client { + &self.conn + } + fn conn_mut(&mut self) -> &mut Self::Client { + &mut self.conn + } + fn build_transaction(&mut self) -> TransactionBuilder { + panic!("nested transactions not supported"); + } +} + +impl PClient for ManagedConnection { + type Client = tokio_postgres::Client; + fn conn(&self) -> &Self::Client { + &(&**self).conn + } + fn conn_mut(&mut self) -> &mut Self::Client { + &mut (&mut **self).conn + } + fn build_transaction(&mut self) -> TransactionBuilder { + self.conn_mut().build_transaction() + } +} + +impl PostgresConnection { + pub async fn new(conn: tokio_postgres::Client) -> Self { + PostgresConnection { conn } + } +} + +#[async_trait::async_trait] +impl

Connection for P +where + P: Send + Sync + PClient, +{ + async fn transaction(&mut self) -> Box { + let tx = self.conn_mut().transaction().await.unwrap(); + Box::new(PostgresTransaction { conn: tx }) + } + + async fn record_username(&mut self, user_id: i64, username: String) -> Result<()> { + self.conn() + .execute( + "INSERT INTO users (user_id, username) VALUES ($1, $2) ON CONFLICT DO NOTHING", + &[&user_id, &username], + ) + .await + .context("inserting user id / username")?; + Ok(()) + } + + async fn record_ping(&mut self, notification: &Notification) -> Result<()> { + self.conn() + .execute( + "INSERT INTO notifications + (user_id, origin_url, origin_html, time, short_description, team_name, idx) + VALUES ( + $1, $2, $3, $4, $5, $6, + (SELECT max(notifications.idx) + 1 from notifications + where notifications.user_id = $1) + )", + &[ + ¬ification.user_id, + ¬ification.origin_url, + ¬ification.origin_html, + ¬ification.time, + ¬ification.short_description, + ¬ification.team_name, + ], + ) + .await + .context("inserting notification")?; + + Ok(()) + } + + async fn get_missing_commits(&mut self) -> Result> { + let missing = self + .conn() + .query( + " + SELECT parent_sha + FROM rustc_commits + WHERE parent_sha NOT IN ( + SELECT sha + FROM rustc_commits + )", + &[], + ) + .await + .context("fetching missing commits")?; + Ok(missing.into_iter().map(|row| row.get(0)).collect()) + } + + async fn record_commit(&mut self, commit: &Commit) -> Result<()> { + trace!("record_commit(sha={})", commit.sha); + let pr = commit.pr.expect("commit has pr"); + self.conn() + .execute( + "INSERT INTO rustc_commits (sha, parent_sha, time, pr) + VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING", + &[&commit.sha, &commit.parent_sha, &commit.time, &(pr as i32)], + ) + .await + .context("inserting commit")?; + Ok(()) + } + + async fn has_commit(&mut self, sha: &str) -> Result { + self.conn() + .query("SELECT 1 FROM rustc_commits WHERE sha = $1", &[&sha]) + .await + .context("selecting from rustc_commits") + .map(|commits| !commits.is_empty()) + } + + async fn get_commits_with_artifacts(&mut self) -> Result> { + let commits = self + .conn() + .query( + " + select sha, parent_sha, time, pr + from rustc_commits + where time >= current_date - interval '168 days' + order by time desc;", + &[], + ) + .await + .context("Getting commit data")?; + + let mut data = Vec::with_capacity(commits.len()); + for commit in commits { + let sha: String = commit.get(0); + let parent_sha: String = commit.get(1); + let time: DateTime = commit.get(2); + let pr: Option = commit.get(3); + + data.push(Commit { + sha, + parent_sha, + time, + pr: pr.map(|n| n as u32), + }); + } + + Ok(data) + } + + async fn get_notifications(&mut self, username: &str) -> Result> { + let notifications = self + .conn() + .query( + "SELECT username, origin_url, origin_html, time, short_description, idx, metadata + FROM notifications + JOIN users ON notifications.user_id = users.user_id + WHERE username = $1 + ORDER BY notifications.idx ASC NULLS LAST;", + &[&username], + ) + .await + .context("Getting notification data")?; + + let mut data = Vec::new(); + for notification in notifications { + let origin_url: String = notification.get(1); + let origin_text: String = notification.get(2); + let time: DateTime = notification.get(3); + let short_description: Option = notification.get(4); + let metadata: Option = notification.get(6); + + data.push(NotificationData { + origin_url, + origin_text, + short_description, + time, + metadata, + }); + } + + Ok(data) + } + + async fn delete_ping( + &mut self, + user_id: i64, + identifier: Identifier<'_>, + ) -> Result> { + match identifier { + Identifier::Url(origin_url) => { + let rows = self + .conn() + .query( + "DELETE FROM notifications WHERE user_id = $1 and origin_url = $2 + RETURNING origin_html, time, short_description, metadata", + &[&user_id, &origin_url], + ) + .await + .context("delete notification query")?; + Ok(rows + .into_iter() + .map(|row| { + let origin_text: String = row.get(0); + let time: DateTime = row.get(1); + let short_description: Option = row.get(2); + let metadata: Option = row.get(3); + NotificationData { + origin_url: origin_url.to_owned(), + origin_text, + time, + short_description, + metadata, + } + }) + .collect()) + } + Identifier::Index(idx) => loop { + let t = self + .build_transaction() + .isolation_level(tokio_postgres::IsolationLevel::Serializable) + .start() + .await + .context("begin transaction")?; + + let notifications = t + .query( + "SELECT notification_id, idx, user_id + FROM notifications + WHERE user_id = $1 + ORDER BY idx ASC NULLS LAST;", + &[&user_id], + ) + .await + .context("failed to get ordering")?; + + let notification_id: i64 = notifications + .get(idx.get() - 1) + .ok_or_else(|| { + anyhow::anyhow!("No such notification with index {}", idx.get()) + })? + .get(0); + + let row = t + .query_one( + "DELETE FROM notifications WHERE notification_id = $1 + RETURNING origin_url, origin_html, time, short_description, metadata", + &[¬ification_id], + ) + .await + .context(format!( + "Failed to delete notification with id {}", + notification_id + ))?; + + let origin_url: String = row.get(0); + let origin_text: String = row.get(1); + let time: DateTime = row.get(2); + let short_description: Option = row.get(3); + let metadata: Option = row.get(4); + let deleted_notification = NotificationData { + origin_url, + origin_text, + time, + short_description, + metadata, + }; + + if let Err(e) = t.commit().await { + if e.code().map_or(false, |c| { + *c == tokio_postgres::error::SqlState::T_R_SERIALIZATION_FAILURE + }) { + trace!("serialization failure, restarting deletion"); + continue; + } else { + return Err(e).context("transaction commit failure"); + } + } else { + return Ok(vec![deleted_notification]); + } + }, + Identifier::All => { + let rows = self + .conn() + .query( + "DELETE FROM notifications WHERE user_id = $1 + RETURNING origin_url, origin_html, time, short_description, metadata", + &[&user_id], + ) + .await + .context("delete all notifications query")?; + Ok(rows + .into_iter() + .map(|row| { + let origin_url: String = row.get(0); + let origin_text: String = row.get(1); + let time: DateTime = row.get(2); + let short_description: Option = row.get(3); + let metadata: Option = row.get(4); + NotificationData { + origin_url, + origin_text, + time, + short_description, + metadata, + } + }) + .collect()) + } + } + } + + async fn add_metadata( + &mut self, + user_id: i64, + idx: usize, + metadata: Option<&str>, + ) -> Result<()> { + loop { + let t = self + .build_transaction() + .isolation_level(tokio_postgres::IsolationLevel::Serializable) + .start() + .await + .context("begin transaction")?; + + let notifications = t + .query( + "SELECT notification_id, idx, user_id + FROM notifications + WHERE user_id = $1 + ORDER BY idx ASC NULLS LAST;", + &[&user_id], + ) + .await + .context("failed to get initial ordering")?; + + let notifications = notifications + .into_iter() + .map(|n| n.get(0)) + .collect::>(); + + match notifications.get(idx) { + None => anyhow::bail!( + "index not present, must be less than {}", + notifications.len() + ), + Some(id) => { + t.execute( + "UPDATE notifications SET metadata = $2 + WHERE notification_id = $1", + &[&id, &metadata], + ) + .await + .context("update notification id")?; + } + } + + if let Err(e) = t.commit().await { + if e.code().map_or(false, |c| { + *c == tokio_postgres::error::SqlState::T_R_SERIALIZATION_FAILURE + }) { + trace!("serialization failure, restarting index movement"); + continue; + } else { + return Err(e).context("transaction commit failure"); + } + } else { + break; + } + } + + Ok(()) + } + + async fn move_indices(&mut self, user_id: i64, from: usize, to: usize) -> Result<()> { + loop { + let t = self + .build_transaction() + .isolation_level(tokio_postgres::IsolationLevel::Serializable) + .start() + .await + .context("begin transaction")?; + + let notifications = t + .query( + "SELECT notification_id, idx, user_id + FROM notifications + WHERE user_id = $1 + ORDER BY idx ASC NULLS LAST;", + &[&user_id], + ) + .await + .context("failed to get initial ordering")?; + + let mut notifications = notifications + .into_iter() + .map(|n| n.get(0)) + .collect::>(); + + if notifications.get(from).is_none() { + anyhow::bail!( + "`from` index not present, must be less than {}", + notifications.len() + ); + } + + if notifications.get(to).is_none() { + anyhow::bail!( + "`to` index not present, must be less than {}", + notifications.len() + ); + } + + if from < to { + notifications[from..=to].rotate_left(1); + } else if to < from { + notifications[to..=from].rotate_right(1); + } + + for (idx, id) in notifications.into_iter().enumerate() { + t.execute( + "UPDATE notifications SET idx = $2 + WHERE notification_id = $1", + &[&id, &(idx as i32)], + ) + .await + .context("update notification id")?; + } + + if let Err(e) = t.commit().await { + if e.code().map_or(false, |c| { + *c == tokio_postgres::error::SqlState::T_R_SERIALIZATION_FAILURE + }) { + trace!("serialization failure, restarting index movement"); + continue; + } else { + return Err(e).context("transaction commit failure"); + } + } else { + break; + } + } + + Ok(()) + } + + async fn insert_job( + &mut self, + name: &str, + scheduled_at: &DateTime, + metadata: &serde_json::Value, + ) -> Result<()> { + tracing::trace!("insert_job(name={})", name); + + self.conn() + .execute( + "INSERT INTO jobs (name, scheduled_at, metadata) VALUES ($1, $2, $3) + ON CONFLICT (name, scheduled_at) DO UPDATE SET metadata = EXCLUDED.metadata", + &[&name, &scheduled_at, &metadata], + ) + .await + .context("Inserting job")?; + + Ok(()) + } + + async fn delete_job(&mut self, id: &Uuid) -> Result<()> { + tracing::trace!("delete_job(id={})", id); + + self.conn() + .execute("DELETE FROM jobs WHERE id = $1", &[id]) + .await + .context("Deleting job")?; + + Ok(()) + } + + async fn update_job_error_message(&mut self, id: &Uuid, message: &str) -> Result<()> { + tracing::trace!("update_job_error_message(id={})", id); + + self.conn() + .execute( + "UPDATE jobs SET error_message = $2 WHERE id = $1", + &[&id, &message], + ) + .await + .context("Updating job error message")?; + + Ok(()) + } + + async fn update_job_executed_at(&mut self, id: &Uuid) -> Result<()> { + tracing::trace!("update_job_executed_at(id={})", id); + + self.conn() + .execute("UPDATE jobs SET executed_at = now() WHERE id = $1", &[&id]) + .await + .context("Updating job executed at")?; + + Ok(()) + } + + async fn get_job_by_name_and_scheduled_at( + &mut self, + name: &str, + scheduled_at: &DateTime, + ) -> Result { + tracing::trace!( + "get_job_by_name_and_scheduled_at(name={}, scheduled_at={})", + name, + scheduled_at + ); + + let job = self + .conn() + .query_one( + "SELECT * FROM jobs WHERE name = $1 AND scheduled_at = $2", + &[&name, &scheduled_at], + ) + .await + .context("Select job by name and scheduled at")?; + + deserialize_job(&job) + } + + async fn get_jobs_to_execute(&mut self) -> Result> { + let jobs = self + .conn() + .query( + "SELECT * FROM jobs WHERE scheduled_at <= now() + AND (error_message IS NULL OR executed_at <= now() - INTERVAL '60 minutes')", + &[], + ) + .await + .context("Getting jobs data")?; + + let mut data = Vec::with_capacity(jobs.len()); + for job in jobs { + let serialized_job = deserialize_job(&job); + data.push(serialized_job.unwrap()); + } + + Ok(data) + } + + async fn lock_and_load_issue_data( + &mut self, + repo: &str, + issue_number: i32, + key: &str, + ) -> Result<(Box, Option)> { + let transaction = self.conn_mut().transaction().await?; + transaction + .execute("LOCK TABLE issue_data", &[]) + .await + .context("locking issue data")?; + let data = transaction + .query_opt( + "SELECT data FROM issue_data WHERE \ + repo = $1 AND issue_number = $2 AND key = $3", + &[&repo, &issue_number, &key], + ) + .await + .context("selecting issue data")? + .map(|row| row.get::>(0).0); + Ok((Box::new(PostgresTransaction { conn: transaction }), data)) + } + + async fn save_issue_data( + &mut self, + repo: &str, + issue_number: i32, + key: &str, + data: &serde_json::Value, + ) -> Result<()> { + self.conn() + .execute( + "INSERT INTO issue_data (repo, issue_number, key, data) \ + VALUES ($1, $2, $3, $4) \ + ON CONFLICT (repo, issue_number, key) DO UPDATE SET data=EXCLUDED.data", + &[&repo, &issue_number, &key, &Json(&data)], + ) + .await + .context("inserting issue data")?; + Ok(()) + } +} + +fn deserialize_job(row: &tokio_postgres::row::Row) -> Result { + let id: Uuid = row.try_get(0)?; + let name: String = row.try_get(1)?; + let scheduled_at: DateTime = row.try_get(2)?; + let metadata: serde_json::Value = row.try_get(3)?; + let executed_at: Option> = row.try_get(4)?; + let error_message: Option = row.try_get(5)?; + + Ok(Job { + id, + name, + scheduled_at, + metadata, + executed_at, + error_message, + }) +} diff --git a/src/db/rustc_commits.rs b/src/db/rustc_commits.rs deleted file mode 100644 index b272d44c..00000000 --- a/src/db/rustc_commits.rs +++ /dev/null @@ -1,79 +0,0 @@ -use anyhow::Context as _; -use chrono::{DateTime, FixedOffset}; -use tokio_postgres::Client as DbClient; - -/// A bors merge commit. -#[derive(Debug, serde::Serialize)] -pub struct Commit { - pub sha: String, - pub parent_sha: String, - pub time: DateTime, - pub pr: Option, -} - -pub async fn record_commit(db: &DbClient, commit: Commit) -> anyhow::Result<()> { - tracing::trace!("record_commit(sha={})", commit.sha); - let pr = commit.pr.expect("commit has pr"); - db.execute( - "INSERT INTO rustc_commits (sha, parent_sha, time, pr) VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING", - &[&commit.sha, &commit.parent_sha, &commit.time, &(pr as i32)], - ) - .await - .context("inserting commit")?; - Ok(()) -} - -pub async fn has_commit(db: &DbClient, sha: &str) -> bool { - !db.query("SELECT 1 FROM rustc_commits WHERE sha = $1", &[&sha]) - .await - .unwrap() - .is_empty() -} - -pub async fn get_missing_commits(db: &DbClient) -> Vec { - let missing = db - .query( - " - SELECT parent_sha - FROM rustc_commits - WHERE parent_sha NOT IN ( - SELECT sha - FROM rustc_commits - )", - &[], - ) - .await - .unwrap(); - missing.into_iter().map(|row| row.get(0)).collect() -} - -pub async fn get_commits_with_artifacts(db: &DbClient) -> anyhow::Result> { - let commits = db - .query( - " - select sha, parent_sha, time, pr - from rustc_commits - where time >= current_date - interval '168 days' - order by time desc;", - &[], - ) - .await - .context("Getting commit data")?; - - let mut data = Vec::with_capacity(commits.len()); - for commit in commits { - let sha: String = commit.get(0); - let parent_sha: String = commit.get(1); - let time: DateTime = commit.get(2); - let pr: Option = commit.get(3); - - data.push(Commit { - sha, - parent_sha, - time, - pr: pr.map(|n| n as u32), - }); - } - - Ok(data) -} diff --git a/src/db/sqlite.rs b/src/db/sqlite.rs new file mode 100644 index 00000000..77267ac1 --- /dev/null +++ b/src/db/sqlite.rs @@ -0,0 +1,742 @@ +use super::{Commit, Identifier, Notification, NotificationData}; +use crate::db::{Connection, ConnectionManager, Job, ManagedConnection, Transaction}; +use anyhow::{Context, Result}; +use chrono::DateTime; +use chrono::Utc; +use rusqlite::params; +use std::path::PathBuf; +use std::sync::Mutex; +use std::sync::Once; +use uuid::Uuid; + +pub struct SqliteTransaction<'a> { + conn: &'a mut SqliteConnection, + finished: bool, +} + +#[async_trait::async_trait] +impl<'a> Transaction for SqliteTransaction<'a> { + async fn commit(mut self: Box) -> Result<(), anyhow::Error> { + self.finished = true; + Ok(self.conn.raw().execute_batch("COMMIT")?) + } + + async fn finish(mut self: Box) -> Result<(), anyhow::Error> { + self.finished = true; + Ok(self.conn.raw().execute_batch("ROLLBACK")?) + } + fn conn(&mut self) -> &mut dyn Connection { + &mut *self.conn + } + fn conn_ref(&self) -> &dyn Connection { + &*self.conn + } +} + +impl std::ops::Deref for SqliteTransaction<'_> { + type Target = dyn Connection; + fn deref(&self) -> &Self::Target { + &*self.conn + } +} + +impl std::ops::DerefMut for SqliteTransaction<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.conn + } +} + +impl Drop for SqliteTransaction<'_> { + fn drop(&mut self) { + if !self.finished { + self.conn.raw().execute_batch("ROLLBACK").unwrap(); + } + } +} + +pub struct Sqlite(PathBuf, Once); + +impl Sqlite { + pub fn new(path: PathBuf) -> Self { + if let Some(parent) = path.parent() { + if !parent.exists() { + std::fs::create_dir_all(parent).unwrap(); + } + } + Sqlite(path, Once::new()) + } +} + +struct Migration { + /// One or more SQL statements, each terminated by a semicolon. + sql: &'static str, + + /// If false, indicates that foreign key checking should be delayed until after execution of + /// the migration SQL, and foreign key `ON UPDATE` and `ON DELETE` actions disabled completely. + foreign_key_constraints_enabled: bool, +} + +impl Migration { + /// Returns a `Migration` with foreign key constraints enabled during execution. + const fn new(sql: &'static str) -> Migration { + Migration { + sql, + foreign_key_constraints_enabled: true, + } + } + + /// Returns a `Migration` with foreign key checking delayed until after execution, and foreign + /// key `ON UPDATE` and `ON DELETE` actions disabled completely. + /// + /// SQLite has limited `ALTER TABLE` capabilities, so some schema alterations require the + /// approach of replacing a table with a new one having the desired schema. Because there might + /// be other tables with foreign key constraints on the table, these constraints need to be + /// disabled during execution of such migration SQL, and reenabled after. Otherwise, dropping + /// the old table may trigger `ON DELETE` actions in the referencing tables. See [SQLite + /// documentation](https://www.sqlite.org/lang_altertable.html) for more information. + #[allow(dead_code)] + const fn without_foreign_key_constraints(sql: &'static str) -> Migration { + Migration { + sql, + foreign_key_constraints_enabled: false, + } + } + + fn execute(&self, conn: &mut rusqlite::Connection, migration_id: i32) { + if self.foreign_key_constraints_enabled { + let tx = conn.transaction().unwrap(); + tx.execute_batch(&self.sql).unwrap(); + tx.pragma_update(None, "user_version", &migration_id) + .unwrap(); + tx.commit().unwrap(); + return; + } + + // The following steps are reproduced from https://www.sqlite.org/lang_altertable.html, + // from the section titled, "Making Other Kinds Of Table Schema Changes". + + // 1. If foreign key constraints are enabled, disable them using PRAGMA foreign_keys=OFF. + conn.pragma_update(None, "foreign_keys", &"OFF").unwrap(); + + // 2. Start a transaction. + let tx = conn.transaction().unwrap(); + + // The migration SQL is responsible for steps 3 through 9. + + // 3. Remember the format of all indexes, triggers, and views associated with table X. + // This information will be needed in step 8 below. One way to do this is to run a + // query like the following: SELECT type, sql FROM sqlite_schema WHERE tbl_name='X'. + // + // 4. Use CREATE TABLE to construct a new table "new_X" that is in the desired revised + // format of table X. Make sure that the name "new_X" does not collide with any + // existing table name, of course. + // + // 5. Transfer content from X into new_X using a statement like: INSERT INTO new_X SELECT + // ... FROM X. + // + // 6. Drop the old table X: DROP TABLE X. + // + // 7. Change the name of new_X to X using: ALTER TABLE new_X RENAME TO X. + // + // 8. Use CREATE INDEX, CREATE TRIGGER, and CREATE VIEW to reconstruct indexes, triggers, + // and views associated with table X. Perhaps use the old format of the triggers, + // indexes, and views saved from step 3 above as a guide, making changes as appropriate + // for the alteration. + // + // 9. If any views refer to table X in a way that is affected by the schema change, then + // drop those views using DROP VIEW and recreate them with whatever changes are + // necessary to accommodate the schema change using CREATE VIEW. + tx.execute_batch(&self.sql).unwrap(); + + // 10. If foreign key constraints were originally enabled then run PRAGMA foreign_key_check + // to verify that the schema change did not break any foreign key constraints. + tx.pragma_query(None, "foreign_key_check", |row| { + let table: String = row.get_unwrap(0); + let row_id: Option = row.get_unwrap(1); + let foreign_table: String = row.get_unwrap(2); + let fk_idx: i64 = row.get_unwrap(3); + + tx.query_row::<(), _, _>( + "select * from pragma_foreign_key_list(?) where id = ?", + params![&table, &fk_idx], + |row| { + let col: String = row.get_unwrap(3); + let foreign_col: String = row.get_unwrap(4); + panic!( + "Foreign key violation encountered during migration\n\ + table: {},\n\ + column: {},\n\ + row_id: {:?},\n\ + foreign table: {},\n\ + foreign column: {}\n\ + migration ID: {}\n", + table, col, row_id, foreign_table, foreign_col, migration_id, + ); + }, + ) + .unwrap(); + Ok(()) + }) + .unwrap(); + + tx.pragma_update(None, "user_version", &migration_id) + .unwrap(); + + // 11. Commit the transaction started in step 2. + tx.commit().unwrap(); + + // 12. If foreign keys constraints were originally enabled, reenable them now. + conn.pragma_update(None, "foreign_keys", &"ON").unwrap(); + } +} + +static MIGRATIONS: &[Migration] = &[ + Migration::new(""), + Migration::new( + r#" +CREATE TABLE notifications ( + notification_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + user_id BIGINT, + origin_url TEXT NOT NULL, + origin_html TEXT, + time TEXT NOT NULL, + short_description TEXT, + team_name TEXT, + idx INTEGER, + metadata TEXT +); + "#, + ), + Migration::new( + r#" +CREATE TABLE users ( + user_id BIGINT PRIMARY KEY, + username TEXT NOT NULL +); + "#, + ), + Migration::new( + r#" +CREATE TABLE rustc_commits ( + sha TEXT PRIMARY KEY, + parent_sha TEXT NOT NULL, + time TEXT NOT NULL, + pr INTEGER +); + "#, + ), + Migration::new( + r#" +CREATE TABLE issue_data ( + repo TEXT, + issue_number INTEGER, + key TEXT, + data JSONB, + PRIMARY KEY (repo, issue_number, key) +); + "#, + ), + Migration::new( + r#" +CREATE TABLE jobs ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + scheduled_at TIMESTAMP WITH TIME ZONE NOT NULL, + metadata JSONB, + executed_at TIMESTAMP WITH TIME ZONE, + error_message TEXT +); + "#, + ), + Migration::new( + r#" +CREATE UNIQUE INDEX jobs_name_scheduled_at_unique_index + ON jobs ( + name, scheduled_at + ); + "#, + ), +]; + +#[async_trait::async_trait] +impl ConnectionManager for Sqlite { + type Connection = Mutex; + async fn open(&self) -> Self::Connection { + let mut conn = rusqlite::Connection::open(&self.0).unwrap(); + conn.pragma_update(None, "cache_size", &-128000).unwrap(); + conn.pragma_update(None, "journal_mode", &"WAL").unwrap(); + conn.pragma_update(None, "foreign_keys", &"ON").unwrap(); + + self.1.call_once(|| { + let version: i32 = conn + .query_row( + "select user_version from pragma_user_version;", + params![], + |row| row.get(0), + ) + .unwrap(); + for mid in (version as usize + 1)..MIGRATIONS.len() { + MIGRATIONS[mid].execute(&mut conn, mid as i32); + } + }); + + Mutex::new(conn) + } + async fn is_valid(&self, conn: &mut Self::Connection) -> bool { + conn.get_mut() + .unwrap_or_else(|e| e.into_inner()) + .execute_batch("") + .is_ok() + } +} + +pub struct SqliteConnection { + conn: ManagedConnection>, +} + +#[async_trait::async_trait] +impl Connection for SqliteConnection { + async fn transaction(&mut self) -> Box { + Box::new(self.raw_transaction()) + } + + async fn record_username(&mut self, user_id: i64, username: String) -> Result<()> { + self.raw().execute( + "INSERT INTO users (user_id, username) VALUES (?, ?) ON CONFLICT DO NOTHING", + params![user_id, username], + )?; + Ok(()) + } + + async fn record_ping(&mut self, notification: &Notification) -> Result<()> { + self.raw().execute( + "INSERT INTO notifications + (user_id, origin_url, origin_html, time, short_description, team_name, idx) + VALUES ( + ?, ?, ?, ?, ?, ?, + (SELECT ifnull(max(notifications.idx), 0) + 1 from notifications + where notifications.user_id = ?) + )", + params![ + notification.user_id, + notification.origin_url, + notification.origin_html, + notification.time, + notification.short_description, + notification.team_name, + notification.user_id, + ], + )?; + Ok(()) + } + + async fn get_missing_commits(&mut self) -> Result> { + let commits = self + .raw() + .prepare( + " + SELECT parent_sha + FROM rustc_commits + WHERE parent_sha NOT IN ( + SELECT sha + FROM rustc_commits + )", + )? + .query_map([], |row| row.get(0))? + .collect::>()?; + Ok(commits) + } + + async fn record_commit(&mut self, commit: &Commit) -> Result<()> { + let pr = commit.pr.expect("commit has pr"); + // let time = commit.time.timestamp(); + self.raw().execute( + "INSERT INTO rustc_commits (sha, parent_sha, time, pr) \ + VALUES (?, ?, ?, ?) ON CONFLICT DO NOTHING", + params![commit.sha, commit.parent_sha, commit.time, pr], + )?; + Ok(()) + } + + async fn has_commit(&mut self, sha: &str) -> Result { + Ok(self + .raw() + .prepare("SELECT 1 FROM rustc_commits WHERE sha = ?")? + .query([sha])? + .next()? + .is_some()) + } + + async fn get_commits_with_artifacts(&mut self) -> Result> { + let commits = self + .raw() + .prepare( + "SELECT sha, parent_sha, time, pr + FROM rustc_commits + WHERE time >= datetime('now', '-168 days') + ORDER BY time DESC;", + )? + .query_map([], |row| { + let c = Commit { + sha: row.get(0)?, + parent_sha: row.get(1)?, + time: row.get(2)?, + pr: row.get(3)?, + }; + Ok(c) + })? + .collect::>()?; + Ok(commits) + } + + async fn get_notifications(&mut self, username: &str) -> Result> { + let notifications = self + .raw() + .prepare( + "SELECT username, origin_url, origin_html, time, short_description, idx, metadata + FROM notifications + JOIN users ON notifications.user_id = users.user_id + WHERE username = ? + ORDER BY notifications.idx ASC NULLS LAST;", + )? + .query_map([username], |row| { + let n = NotificationData { + origin_url: row.get(1)?, + origin_text: row.get(2)?, + time: row.get(3)?, + short_description: row.get(4)?, + metadata: row.get(6)?, + }; + Ok(n) + })? + .collect::>()?; + Ok(notifications) + } + + async fn delete_ping( + &mut self, + user_id: i64, + identifier: Identifier<'_>, + ) -> Result> { + match identifier { + Identifier::Url(origin_url) => { + let rows = self + .raw() + .prepare( + "DELETE FROM notifications WHERE user_id = ? and origin_url = ? + RETURNING origin_html, time, short_description, metadata", + )? + .query_map(params![user_id, origin_url], |row| { + let n = NotificationData { + origin_url: origin_url.to_string(), + origin_text: row.get(0)?, + time: row.get(1)?, + short_description: row.get(2)?, + metadata: row.get(3)?, + }; + Ok(n) + })? + .collect::>()?; + Ok(rows) + } + Identifier::Index(idx) => { + let deleted_notifications: Vec<_> = self + .raw() + .prepare( + "DELETE FROM notifications WHERE notification_id = ( + SELECT notification_id FROM notifications + WHERE user_id = ? + ORDER BY idx ASC NULLS LAST + LIMIT 1 OFFSET ? + ) + RETURNING origin_url, origin_html, time, short_description, metadata", + )? + .query_map(params![user_id, idx.get() - 1], |row| { + let n = NotificationData { + origin_url: row.get(0)?, + origin_text: row.get(1)?, + time: row.get(2)?, + short_description: row.get(3)?, + metadata: row.get(4)?, + }; + Ok(n) + })? + .collect::>()?; + if deleted_notifications.is_empty() { + anyhow::bail!("No such notification with index {}", idx.get()); + } + return Ok(deleted_notifications); + } + Identifier::All => { + let rows = self + .raw() + .prepare( + "DELETE FROM notifications WHERE user_id = ? + RETURNING origin_url, origin_html, time, short_description, metadata", + )? + .query_map([&user_id], |row| { + let n = NotificationData { + origin_url: row.get(0)?, + origin_text: row.get(1)?, + time: row.get(2)?, + short_description: row.get(3)?, + metadata: row.get(4)?, + }; + Ok(n) + })? + .collect::>()?; + Ok(rows) + } + } + } + + async fn add_metadata( + &mut self, + user_id: i64, + idx: usize, + metadata: Option<&str>, + ) -> Result<()> { + let t = self.raw().transaction()?; + + let notifications = t + .prepare( + "SELECT notification_id + FROM notifications + WHERE user_id = ? + ORDER BY idx ASC NULLS LAST", + )? + .query_map([user_id], |row| row.get(0)) + .context("failed to get initial ordering")? + .collect::, rusqlite::Error>>()?; + + match notifications.get(idx) { + None => anyhow::bail!( + "index not present, must be less than {}", + notifications.len() + ), + Some(id) => { + t.prepare( + "UPDATE notifications SET metadata = ? + WHERE notification_id = ?", + )? + .execute(params![metadata, id]) + .context("update notification id")?; + } + } + t.commit()?; + + Ok(()) + } + + async fn move_indices(&mut self, user_id: i64, from: usize, to: usize) -> Result<()> { + let t = self.raw().transaction()?; + + let mut notifications = t + .prepare( + "SELECT notification_id + FROM notifications + WHERE user_id = ? + ORDER BY idx ASC NULLS LAST;", + )? + .query_map([user_id], |row| row.get(0)) + .context("failed to get initial ordering")? + .collect::, rusqlite::Error>>()?; + + if notifications.get(from).is_none() { + anyhow::bail!( + "`from` index not present, must be less than {}", + notifications.len() + ); + } + + if notifications.get(to).is_none() { + anyhow::bail!( + "`to` index not present, must be less than {}", + notifications.len() + ); + } + + if from < to { + notifications[from..=to].rotate_left(1); + } else if to < from { + notifications[to..=from].rotate_right(1); + } + + for (idx, id) in notifications.into_iter().enumerate() { + t.prepare( + "UPDATE notifications SET idx = ? + WHERE notification_id = ?", + )? + .execute(params![idx, id]) + .context("update notification id")?; + } + t.commit()?; + + Ok(()) + } + + async fn insert_job( + &mut self, + name: &str, + scheduled_at: &DateTime, + metadata: &serde_json::Value, + ) -> Result<()> { + tracing::trace!("insert_job(name={})", name); + + let id = Uuid::new_v4(); + self.raw() + .execute( + "INSERT INTO jobs (id, name, scheduled_at, metadata) VALUES (?, ?, ?, ?) + ON CONFLICT (name, scheduled_at) DO UPDATE SET metadata = EXCLUDED.metadata", + params![id, name, scheduled_at, metadata], + ) + .context("Inserting job")?; + + Ok(()) + } + + async fn delete_job(&mut self, id: &Uuid) -> Result<()> { + tracing::trace!("delete_job(id={})", id); + + self.raw() + .execute("DELETE FROM jobs WHERE id = ?", [id]) + .context("Deleting job")?; + + Ok(()) + } + + async fn update_job_error_message(&mut self, id: &Uuid, message: &str) -> Result<()> { + tracing::trace!("update_job_error_message(id={})", id); + + self.raw() + .execute( + "UPDATE jobs SET error_message = ? WHERE id = ?", + params![message, id], + ) + .context("Updating job error message")?; + + Ok(()) + } + + async fn update_job_executed_at(&mut self, id: &Uuid) -> Result<()> { + tracing::trace!("update_job_executed_at(id={})", id); + + self.raw() + .execute( + "UPDATE jobs SET executed_at = datetime('now') WHERE id = ?", + [id], + ) + .context("Updating job executed at")?; + + Ok(()) + } + + async fn get_job_by_name_and_scheduled_at( + &mut self, + name: &str, + scheduled_at: &DateTime, + ) -> Result { + tracing::trace!( + "get_job_by_name_and_scheduled_at(name={}, scheduled_at={})", + name, + scheduled_at + ); + + let job = self + .raw() + .query_row( + "SELECT * FROM jobs WHERE name = ? AND scheduled_at = ?", + params![name, scheduled_at], + |row| deserialize_job(row), + ) + .context("Select job by name and scheduled at")?; + Ok(job) + } + + async fn get_jobs_to_execute(&mut self) -> Result> { + let jobs = self + .raw() + .prepare( + "SELECT * FROM jobs WHERE scheduled_at <= datetime('now') + AND (error_message IS NULL OR executed_at <= datetime('now', '-60 minutes'))", + )? + .query_map([], |row| deserialize_job(row))? + .collect::>()?; + Ok(jobs) + } + + async fn lock_and_load_issue_data( + &mut self, + repo: &str, + issue_number: i32, + key: &str, + ) -> Result<(Box, Option)> { + let transaction = self.raw_transaction(); + let data = match transaction + .conn + .raw() + .prepare( + "SELECT data FROM issue_data WHERE \ + repo = ? AND issue_number = ? AND key = ?", + )? + .query_row(params![repo, issue_number, key], |row| row.get(0)) + { + Err(rusqlite::Error::QueryReturnedNoRows) => None, + Err(e) => return Err(e.into()), + Ok(d) => Some(d), + }; + Ok((Box::new(transaction), data)) + } + + async fn save_issue_data( + &mut self, + repo: &str, + issue_number: i32, + key: &str, + data: &serde_json::Value, + ) -> Result<()> { + self.raw() + .execute( + "INSERT INTO issue_data (repo, issue_number, key, data) \ + VALUES (?, ?, ?, ?) \ + ON CONFLICT (repo, issue_number, key) DO UPDATE SET data=EXCLUDED.data", + params![repo, issue_number, key, data], + ) + .context("inserting issue data")?; + Ok(()) + } +} + +fn assert_sync() {} + +impl SqliteConnection { + pub fn new(conn: ManagedConnection>) -> Self { + assert_sync::(); + Self { conn } + } + + pub fn raw(&mut self) -> &mut rusqlite::Connection { + self.conn.get_mut().unwrap_or_else(|e| e.into_inner()) + } + pub fn raw_ref(&self) -> std::sync::MutexGuard { + self.conn.lock().unwrap_or_else(|e| e.into_inner()) + } + fn raw_transaction(&mut self) -> SqliteTransaction<'_> { + self.raw().execute_batch("BEGIN DEFERRED").unwrap(); + SqliteTransaction { + conn: self, + finished: false, + } + } +} + +fn deserialize_job(row: &rusqlite::Row<'_>) -> std::result::Result { + Ok(Job { + id: row.get(0)?, + name: row.get(1)?, + scheduled_at: row.get(2)?, + metadata: row.get(3)?, + executed_at: row.get(4)?, + error_message: row.get(5)?, + }) +} diff --git a/src/handlers.rs b/src/handlers.rs index da39943b..5578b63b 100644 --- a/src/handlers.rs +++ b/src/handlers.rs @@ -279,7 +279,7 @@ command_handlers! { pub struct Context { pub github: GithubClient, - pub db: crate::db::ClientPool, + pub db: crate::db::Pool, pub username: String, pub octocrab: Octocrab, } diff --git a/src/handlers/mentions.rs b/src/handlers/mentions.rs index 1c888bea..0e77b0f4 100644 --- a/src/handlers/mentions.rs +++ b/src/handlers/mentions.rs @@ -90,9 +90,11 @@ pub(super) async fn handle_input( event: &IssuesEvent, input: MentionsInput, ) -> anyhow::Result<()> { - let mut client = ctx.db.get().await; + let mut connection = ctx.db.connection().await; + let repo = event.issue.repository().to_string(); + let issue_number = event.issue.number as i32; let mut state: IssueData<'_, MentionState> = - IssueData::load(&mut client, &event.issue, MENTIONS_KEY).await?; + IssueData::load(&mut *connection, repo, issue_number, MENTIONS_KEY).await?; // Build the message to post to the issue. let mut result = String::new(); for to_mention in &input.paths { diff --git a/src/handlers/no_merges.rs b/src/handlers/no_merges.rs index b9ed89ff..14f2b41e 100644 --- a/src/handlers/no_merges.rs +++ b/src/handlers/no_merges.rs @@ -77,9 +77,11 @@ pub(super) async fn handle_input( event: &IssuesEvent, input: NoMergesInput, ) -> anyhow::Result<()> { - let mut client = ctx.db.get().await; - let mut state: IssueData<'_, NoMergesState> = - IssueData::load(&mut client, &event.issue, NO_MERGES_KEY).await?; + let mut connection = ctx.db.connection().await; + let repo = event.issue.repository().to_string(); + let issue_number = event.issue.number as i32; + let mut state: IssueData = + IssueData::load(&mut *connection, repo, issue_number, NO_MERGES_KEY).await?; let since_last_posted = if state.data.mentioned_merge_commits.is_empty() { "" diff --git a/src/handlers/notification.rs b/src/handlers/notification.rs index 718b2b24..87fbd080 100644 --- a/src/handlers/notification.rs +++ b/src/handlers/notification.rs @@ -4,8 +4,8 @@ //! //! Parsing is done in the `parser::command::ping` module. -use crate::db::notifications; use crate::{ + db::notifications::Notification, github::{self, Event}, handlers::Context, }; @@ -93,33 +93,32 @@ pub async fn handle(ctx: &Context, event: &Event) -> anyhow::Result<()> { } }; - let client = ctx.db.get().await; + let mut connection = ctx.db.connection().await; for user in users { if !users_notified.insert(user.id.unwrap()) { // Skip users already associated with this event. continue; } - if let Err(err) = notifications::record_username(&client, user.id.unwrap(), user.login) + if let Err(err) = connection + .record_username(user.id.unwrap(), user.login) .await .context("failed to record username") { log::error!("record username: {:?}", err); } - if let Err(err) = notifications::record_ping( - &client, - ¬ifications::Notification { + if let Err(err) = connection + .record_ping(&Notification { user_id: user.id.unwrap(), origin_url: event.html_url().unwrap().to_owned(), origin_html: body.to_owned(), time: event.time().unwrap(), short_description: Some(short_description.clone()), team_name: team_name.clone(), - }, - ) - .await - .context("failed to record ping") + }) + .await + .context("failed to record ping") { log::error!("record ping: {:?}", err); } diff --git a/src/handlers/rustc_commits.rs b/src/handlers/rustc_commits.rs index 4724bb2b..91c59386 100644 --- a/src/handlers/rustc_commits.rs +++ b/src/handlers/rustc_commits.rs @@ -1,6 +1,5 @@ use crate::db::jobs::JobSchedule; -use crate::db::rustc_commits; -use crate::db::rustc_commits::get_missing_commits; +use crate::db::Commit; use crate::{ github::{self, Event}, handlers::Context, @@ -87,7 +86,7 @@ async fn synchronize_commits(ctx: &Context, sha: &str, pr: u32) { } pub async fn synchronize_commits_inner(ctx: &Context, starter: Option<(String, u32)>) { - let db = ctx.db.get().await; + let mut connection = ctx.db.connection().await; // List of roots to be resolved. Each root and its parents will be recursively resolved // until an existing commit is found. @@ -96,14 +95,15 @@ pub async fn synchronize_commits_inner(ctx: &Context, starter: Option<(String, u to_be_resolved.push_back((sha.to_string(), Some(pr))); } to_be_resolved.extend( - get_missing_commits(&db) + connection + .get_missing_commits() .await + .unwrap() .into_iter() .map(|c| (c, None::)), ); log::info!("synchronize_commits for {:?}", to_be_resolved); - let db = ctx.db.get().await; while let Some((sha, mut pr)) = to_be_resolved.pop_front() { let mut gc = match ctx.github.rust_commit(&sha).await { Some(c) => c, @@ -132,19 +132,17 @@ pub async fn synchronize_commits_inner(ctx: &Context, starter: Option<(String, u } }; - let res = rustc_commits::record_commit( - &db, - rustc_commits::Commit { + let res = connection + .record_commit(&Commit { sha: gc.sha, parent_sha: parent_sha.clone(), time: gc.commit.author.date, pr: Some(pr), - }, - ) - .await; + }) + .await; match res { Ok(()) => { - if !rustc_commits::has_commit(&db, &parent_sha).await { + if !connection.has_commit(&parent_sha).await.unwrap() { to_be_resolved.push_back((parent_sha, None)) } } diff --git a/src/main.rs b/src/main.rs index a1438c6b..8c026bc8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -82,7 +82,8 @@ async fn serve_req( .unwrap()); } if req.uri.path() == "/bors-commit-list" { - let res = db::rustc_commits::get_commits_with_artifacts(&*ctx.db.get().await).await; + let mut connection = ctx.db.connection().await; + let res = connection.get_commits_with_artifacts().await; let res = match res { Ok(r) => r, Err(e) => { @@ -102,10 +103,11 @@ async fn serve_req( if let Some(query) = req.uri.query() { let user = url::form_urlencoded::parse(query.as_bytes()).find(|(k, _)| k == "user"); if let Some((_, name)) = user { + let mut connection = ctx.db.connection().await; return Ok(Response::builder() .status(StatusCode::OK) .body(Body::from( - notification_listing::render(&ctx.db.get().await, &*name).await, + notification_listing::render(&mut *connection, &*name).await, )) .unwrap()); } @@ -234,10 +236,7 @@ async fn serve_req( } async fn run_server(addr: SocketAddr) -> anyhow::Result<()> { - let pool = db::ClientPool::new(); - db::run_migrations(&*pool.get().await) - .await - .context("database migrations")?; + let pool = db::Pool::new_from_env(); let gh = github::GithubClient::new_from_env(); let oc = octocrab::OctocrabBuilder::new() @@ -307,13 +306,14 @@ fn spawn_job_scheduler() { task::spawn(async move { loop { let res = task::spawn(async move { - let pool = db::ClientPool::new(); + let pool = db::Pool::new_from_env(); let mut interval = time::interval(time::Duration::from_secs(JOB_SCHEDULING_CADENCE_IN_SECS)); loop { interval.tick().await; - db::schedule_jobs(&*pool.get().await, jobs()) + let mut connection = pool.connection().await; + db::jobs::schedule_jobs(&mut *connection, jobs()) .await .context("database schedule jobs") .unwrap(); @@ -323,7 +323,8 @@ fn spawn_job_scheduler() { match res.await { Err(err) if err.is_panic() => { /* handle panic in above task, re-launching */ - tracing::trace!("schedule_jobs task died (error={})", err); + tracing::error!("schedule_jobs task died (error={})", err); + tokio::time::sleep(std::time::Duration::new(5, 0)).await; } _ => unreachable!(), } @@ -342,13 +343,14 @@ fn spawn_job_runner(ctx: Arc) { loop { let ctx = ctx.clone(); let res = task::spawn(async move { - let pool = db::ClientPool::new(); + let pool = db::Pool::new_from_env(); let mut interval = time::interval(time::Duration::from_secs(JOB_PROCESSING_CADENCE_IN_SECS)); loop { interval.tick().await; - db::run_scheduled_jobs(&ctx, &*pool.get().await) + let mut connection = pool.connection().await; + db::jobs::run_scheduled_jobs(&ctx, &mut *connection) .await .context("run database scheduled jobs") .unwrap(); @@ -358,7 +360,8 @@ fn spawn_job_runner(ctx: Arc) { match res.await { Err(err) if err.is_panic() => { /* handle panic in above task, re-launching */ - tracing::trace!("run_scheduled_jobs task died (error={})", err); + tracing::error!("run_scheduled_jobs task died (error={})", err); + tokio::time::sleep(std::time::Duration::new(5, 0)).await; } _ => unreachable!(), } diff --git a/src/notification_listing.rs b/src/notification_listing.rs index 033fc9b3..d827a8b2 100644 --- a/src/notification_listing.rs +++ b/src/notification_listing.rs @@ -1,7 +1,5 @@ -use crate::db::notifications::get_notifications; - -pub async fn render(db: &crate::db::PooledClient, user: &str) -> String { - let notifications = match get_notifications(db, user).await { +pub async fn render(connection: &mut dyn crate::db::Connection, user: &str) -> String { + let notifications = match connection.get_notifications(user).await { Ok(n) => n, Err(e) => { return format!("{:?}", e.context("getting notifications")); diff --git a/src/zulip.rs b/src/zulip.rs index 5f943982..c2ef621b 100644 --- a/src/zulip.rs +++ b/src/zulip.rs @@ -1,5 +1,4 @@ -use crate::db::notifications::add_metadata; -use crate::db::notifications::{self, delete_ping, move_indices, record_ping, Identifier}; +use crate::db::notifications::{Identifier, Notification}; use crate::github::{self, GithubClient}; use crate::handlers::Context; use anyhow::Context as _; @@ -548,8 +547,8 @@ async fn acknowledge( } else { Identifier::Url(filter) }; - let mut db = ctx.db.get().await; - match delete_ping(&mut *db, gh_id, ident).await { + let mut connection = ctx.db.connection().await; + match connection.delete_ping(gh_id, ident).await { Ok(deleted) => { let resp = if deleted.is_empty() { format!( @@ -603,18 +602,17 @@ async fn add_notification( assert_eq!(description.pop(), Some(' ')); // pop trailing space Some(description) }; - match record_ping( - &*ctx.db.get().await, - ¬ifications::Notification { + let mut connection = ctx.db.connection().await; + match connection + .record_ping(&Notification { user_id: gh_id, origin_url: url.to_owned(), origin_html: String::new(), short_description: description, time: chrono::Utc::now().into(), team_name: None, - }, - ) - .await + }) + .await { Ok(()) => Ok(serde_json::to_string(&Response { content: "Created!", @@ -652,8 +650,11 @@ async fn add_meta_notification( assert_eq!(description.pop(), Some(' ')); // pop trailing space Some(description) }; - let mut db = ctx.db.get().await; - match add_metadata(&mut db, gh_id, idx, description.as_deref()).await { + let mut connection = ctx.db.connection().await; + match connection + .add_metadata(gh_id, idx, description.as_deref()) + .await + { Ok(()) => Ok(serde_json::to_string(&Response { content: "Added metadata!", }) @@ -688,7 +689,8 @@ async fn move_notification( .context("to index")? .checked_sub(1) .ok_or_else(|| anyhow::anyhow!("1-based indexes"))?; - match move_indices(&mut *ctx.db.get().await, gh_id, from, to).await { + let mut connection = ctx.db.connection().await; + match connection.move_indices(gh_id, from, to).await { Ok(()) => Ok(serde_json::to_string(&Response { // to 1-base indices content: &format!("Moved {} to {}.", from + 1, to + 1), diff --git a/tests/db/issue_data.rs b/tests/db/issue_data.rs new file mode 100644 index 00000000..26e47099 --- /dev/null +++ b/tests/db/issue_data.rs @@ -0,0 +1,30 @@ +use super::run_test; +use serde::{Deserialize, Serialize}; +use triagebot::db::issue_data::IssueData; + +#[derive(Serialize, Deserialize, Default, Debug)] +struct MyData { + f1: String, + f2: u32, +} + +#[test] +fn issue_data() { + run_test(|mut connection| async move { + let repo = "rust-lang/rust".to_string(); + let mut id: IssueData = + IssueData::load(&mut *connection, repo.clone(), 1234, "test") + .await + .unwrap(); + assert_eq!(id.data.f1, ""); + assert_eq!(id.data.f2, 0); + id.data.f1 = "new data".to_string(); + id.data.f2 = 1; + id.save().await.unwrap(); + let id: IssueData = IssueData::load(&mut *connection, repo.clone(), 1234, "test") + .await + .unwrap(); + assert_eq!(id.data.f1, "new data"); + assert_eq!(id.data.f2, 1); + }); +} diff --git a/tests/db/jobs.rs b/tests/db/jobs.rs new file mode 100644 index 00000000..824b1212 --- /dev/null +++ b/tests/db/jobs.rs @@ -0,0 +1,112 @@ +use super::run_test; +use serde_json::json; + +#[test] +fn jobs() { + run_test(|mut connection| async move { + // Create some jobs and check that ones scheduled in the past are returned. + let past = chrono::Utc::now() - chrono::Duration::minutes(5); + let future = chrono::Utc::now() + chrono::Duration::hours(1); + connection + .insert_job("sample_job1", &past, &json! {{"foo": 123}}) + .await + .unwrap(); + connection + .insert_job("sample_job2", &past, &json! {{}}) + .await + .unwrap(); + connection + .insert_job("sample_job1", &future, &json! {{}}) + .await + .unwrap(); + let jobs = connection.get_jobs_to_execute().await.unwrap(); + assert_eq!(jobs.len(), 2); + assert_eq!(jobs[0].name, "sample_job1"); + assert_eq!(jobs[0].scheduled_at, past); + assert_eq!(jobs[0].metadata, json! {{"foo": 123}}); + assert_eq!(jobs[0].executed_at, None); + assert_eq!(jobs[0].error_message, None); + + assert_eq!(jobs[1].name, "sample_job2"); + assert_eq!(jobs[1].scheduled_at, past); + assert_eq!(jobs[1].metadata, json! {{}}); + assert_eq!(jobs[1].executed_at, None); + assert_eq!(jobs[1].error_message, None); + + // Get job by name + let job = connection + .get_job_by_name_and_scheduled_at("sample_job1", &future) + .await + .unwrap(); + assert_eq!(job.metadata, json! {{}}); + assert_eq!(job.error_message, None); + + // Update error message + connection + .update_job_error_message(&job.id, "an error") + .await + .unwrap(); + let job = connection + .get_job_by_name_and_scheduled_at("sample_job1", &future) + .await + .unwrap(); + assert_eq!(job.error_message.as_deref(), Some("an error")); + + // Delete job + let job = connection + .get_job_by_name_and_scheduled_at("sample_job1", &past) + .await + .unwrap(); + connection.delete_job(&job.id).await.unwrap(); + let jobs = connection.get_jobs_to_execute().await.unwrap(); + assert_eq!(jobs.len(), 1); + assert_eq!(jobs[0].name, "sample_job2"); + }); +} + +#[test] +fn on_conflict() { + // Verify that inserting a job with different data updates the data. + run_test(|mut connection| async move { + let past = chrono::Utc::now() - chrono::Duration::minutes(5); + connection + .insert_job("sample_job1", &past, &json! {{"foo": 123}}) + .await + .unwrap(); + connection + .insert_job("sample_job1", &past, &json! {{"foo": 456}}) + .await + .unwrap(); + let job = connection + .get_job_by_name_and_scheduled_at("sample_job1", &past) + .await + .unwrap(); + assert_eq!(job.metadata, json! {{"foo": 456}}); + }); +} + +#[test] +fn update_job_executed_at() { + run_test(|mut connection| async move { + let now = chrono::Utc::now(); + let past = now - chrono::Duration::minutes(5); + connection + .insert_job("sample_job1", &past, &json! {{"foo": 123}}) + .await + .unwrap(); + let jobs = connection.get_jobs_to_execute().await.unwrap(); + assert_eq!(jobs.len(), 1); + assert_eq!(jobs[0].executed_at, None); + connection + .update_job_executed_at(&jobs[0].id) + .await + .unwrap(); + let jobs = connection.get_jobs_to_execute().await.unwrap(); + assert_eq!(jobs.len(), 1); + let executed_at = jobs[0].executed_at.expect("executed_at should be set"); + // The timestamp should be approximately "now". + if executed_at - now > chrono::Duration::minutes(1) { + panic!("executed_at timestamp unexpected {executed_at:?} vs {now:?}"); + } + }); +} diff --git a/tests/db/mod.rs b/tests/db/mod.rs new file mode 100644 index 00000000..ef2dffd0 --- /dev/null +++ b/tests/db/mod.rs @@ -0,0 +1,208 @@ +//! Tests for the database API. +//! +//! These tests help verify the database interaction. The [`run_test`] +//! function helps set up the database and gives your callback a connection to +//! interact with. The general form of a test is: +//! +//! ```rust +//! #[test] +//! fn example() { +//! run_test(|mut connection| async move { +//! // Call methods on `connection` and verify its behavior. +//! }); +//! } +//! ``` +//! +//! The `run_test` function will run your test on both SQLite and Postgres (if +//! it is installed). + +use futures::Future; +use std::path::{Path, PathBuf}; +use std::process::Command; +use triagebot::db::{Connection, Pool}; + +mod issue_data; +mod jobs; +mod notification; +mod rustc_commits; + +struct PgContext { + db_dir: PathBuf, + pool: Pool, +} + +impl PgContext { + fn new(db_dir: PathBuf) -> PgContext { + let database_url = postgres_database_url(&db_dir); + let pool = Pool::open(&database_url); + PgContext { db_dir, pool } + } +} + +impl Drop for PgContext { + fn drop(&mut self) { + stop_postgres(&self.db_dir); + } +} + +struct SqliteContext { + pool: Pool, +} + +impl SqliteContext { + fn new() -> SqliteContext { + let db_path = super::test_dir().join("triagebot.sqlite3"); + let pool = Pool::open(db_path.to_str().unwrap()); + SqliteContext { pool } + } +} + +fn run_test(f: F) +where + F: Fn(Box) -> Fut + Send + Sync + 'static, + Fut: Future + Send, +{ + // Only run postgres if postgres can be found or on CI. + if let Some(db_dir) = setup_postgres() { + eprintln!("testing Postgres"); + let ctx = PgContext::new(db_dir); + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() + .block_on(async { f(ctx.pool.connection().await).await }); + } else if std::env::var_os("CI").is_some() { + panic!("postgres must be installed in CI"); + } + + eprintln!("\n\ntesting Sqlite"); + let ctx = SqliteContext::new(); + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() + .block_on(async { f(ctx.pool.connection().await).await }); +} + +pub fn postgres_database_url(db_dir: &PathBuf) -> String { + format!( + "postgres:///triagebot?user=triagebot&host={}", + db_dir.display() + ) +} + +pub fn setup_postgres() -> Option { + let pg_dir = find_postgres()?; + // Set up a directory where this test can store all its stuff. + let test_dir = super::test_dir(); + let db_dir = test_dir.join("db"); + + std::fs::create_dir(&db_dir).unwrap(); + let db_dir_str = db_dir.to_str().unwrap(); + run_command( + &pg_dir.join("initdb"), + &["--auth=trust", "--username=triagebot", "-D", db_dir_str], + &db_dir, + ); + run_command( + &pg_dir.join("pg_ctl"), + &[ + // -h '' tells it to not listen on TCP + // -k tells it where to place the unix-domain socket + "-o", + &format!("-h '' -k {db_dir_str}"), + // -D is the data dir where everything is stored + "-D", + db_dir_str, + // -l enables logging to a file instead of stdout + "-l", + db_dir.join("postgres.log").to_str().unwrap(), + "start", + ], + &db_dir, + ); + run_command( + &pg_dir.join("createdb"), + &["--user", "triagebot", "-h", db_dir_str, "triagebot"], + &db_dir, + ); + Some(db_dir) +} + +pub fn stop_postgres(db_dir: &Path) { + // Shut down postgres. + let pg_dir = find_postgres().unwrap(); + match Command::new(pg_dir.join("pg_ctl")) + .args(&["-D", db_dir.to_str().unwrap(), "stop"]) + .output() + { + Ok(output) => { + if !output.status.success() { + eprintln!( + "failed to stop postgres:\n\ + ---stdout\n\ + {}\n\ + ---stderr\n\ + {}\n\ + ", + std::str::from_utf8(&output.stdout).unwrap(), + std::str::from_utf8(&output.stderr).unwrap() + ); + } + } + Err(e) => eprintln!("could not run pg_ctl to stop: {e}"), + } +} + +/// Finds the root for PostgreSQL commands. +/// +/// For various reasons, some Linux distros hide some postgres commands and +/// don't put them on PATH, making them difficult to access. +fn find_postgres() -> Option { + // Check if on PATH first. + if let Ok(o) = Command::new("initdb").arg("-V").output() { + if o.status.success() { + return Some(PathBuf::new()); + } + } + if let Ok(dirs) = std::fs::read_dir("/usr/lib/postgresql") { + let mut versions: Vec<_> = dirs + .filter_map(|entry| { + let entry = entry.unwrap(); + // Versions are generally of the form 9.3 or 14, but this + // might be broken if other forms are used. + if let Ok(n) = entry.file_name().to_str().unwrap().parse::() { + Some((n, entry.path())) + } else { + None + } + }) + .collect(); + if !versions.is_empty() { + versions.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); + return Some(versions.last().unwrap().1.join("bin")); + } + } + None +} + +fn run_command(command: &Path, args: &[&str], cwd: &Path) { + eprintln!("running {command:?}: {args:?}"); + let output = Command::new(command) + .args(args) + .current_dir(cwd) + .output() + .unwrap_or_else(|e| panic!("`{command:?}` failed to run: {e}")); + if !output.status.success() { + panic!( + "{command:?} failed:\n\ + ---stdout\n\ + {}\n\ + ---stderr\n\ + {}\n\ + ", + std::str::from_utf8(&output.stdout).unwrap(), + std::str::from_utf8(&output.stderr).unwrap() + ); + } +} diff --git a/tests/db/notification.rs b/tests/db/notification.rs new file mode 100644 index 00000000..b54e39ae --- /dev/null +++ b/tests/db/notification.rs @@ -0,0 +1,260 @@ +use super::run_test; +use std::num::NonZeroUsize; +use triagebot::db::notifications::{Identifier, Notification}; + +#[test] +fn notification() { + run_test(|mut connection| async move { + let now = chrono::Utc::now(); + connection + .record_username(43198, "ehuss".to_string()) + .await + .unwrap(); + connection + .record_username(14314532, "weihanglo".to_string()) + .await + .unwrap(); + connection + .record_ping(&Notification { + user_id: 43198, + origin_url: "https://github.com/rust-lang/rust/issues/1".to_string(), + origin_html: "This comment mentions @ehuss.".to_string(), + short_description: Some("Comment on some issue".to_string()), + time: now.into(), + team_name: None, + }) + .await + .unwrap(); + + connection + .record_ping(&Notification { + user_id: 43198, + origin_url: "https://github.com/rust-lang/rust/issues/2".to_string(), + origin_html: "This comment mentions @rust-lang/cargo.".to_string(), + short_description: Some("Comment on some issue".to_string()), + time: now.into(), + team_name: Some("cargo".to_string()), + }) + .await + .unwrap(); + connection + .record_ping(&Notification { + user_id: 14314532, + origin_url: "https://github.com/rust-lang/rust/issues/2".to_string(), + origin_html: "This comment mentions @rust-lang/cargo.".to_string(), + short_description: Some("Comment on some issue".to_string()), + time: now.into(), + team_name: Some("cargo".to_string()), + }) + .await + .unwrap(); + + let notifications = connection.get_notifications("ehuss").await.unwrap(); + assert_eq!(notifications.len(), 2); + assert_eq!( + notifications[0].origin_url, + "https://github.com/rust-lang/rust/issues/1" + ); + assert_eq!( + notifications[0].origin_text, + "This comment mentions @ehuss." + ); + assert_eq!( + notifications[0].short_description.as_deref(), + Some("Comment on some issue") + ); + assert_eq!(notifications[0].time, now); + assert_eq!(notifications[0].metadata, None); + + assert_eq!( + notifications[1].origin_url, + "https://github.com/rust-lang/rust/issues/2" + ); + assert_eq!( + notifications[1].origin_text, + "This comment mentions @rust-lang/cargo." + ); + assert_eq!( + notifications[1].short_description.as_deref(), + Some("Comment on some issue") + ); + assert_eq!(notifications[1].time, now); + assert_eq!(notifications[1].metadata, None); + + let notifications = connection.get_notifications("weihanglo").await.unwrap(); + assert_eq!(notifications.len(), 1); + assert_eq!( + notifications[0].origin_url, + "https://github.com/rust-lang/rust/issues/2" + ); + assert_eq!( + notifications[0].origin_text, + "This comment mentions @rust-lang/cargo." + ); + assert_eq!( + notifications[0].short_description.as_deref(), + Some("Comment on some issue") + ); + assert_eq!(notifications[0].time, now); + assert_eq!(notifications[0].metadata, None); + + let notifications = connection.get_notifications("octocat").await.unwrap(); + assert_eq!(notifications.len(), 0); + }); +} + +#[test] +fn delete_ping() { + run_test(|mut connection| async move { + connection + .record_username(43198, "ehuss".to_string()) + .await + .unwrap(); + let now = chrono::Utc::now(); + for x in 1..4 { + connection + .record_ping(&Notification { + user_id: 43198, + origin_url: x.to_string(), + origin_html: "@ehuss {n}".to_string(), + short_description: Some("Comment on some issue".to_string()), + time: now.into(), + team_name: None, + }) + .await + .unwrap(); + } + + let notifications = connection.get_notifications("ehuss").await.unwrap(); + assert_eq!(notifications.len(), 3); + assert_eq!(notifications[0].origin_url, "1"); + assert_eq!(notifications[1].origin_url, "2"); + assert_eq!(notifications[2].origin_url, "3"); + + match connection + .delete_ping(43198, Identifier::Index(NonZeroUsize::new(5).unwrap())) + .await + { + Err(e) => assert_eq!(e.to_string(), "No such notification with index 5"), + Ok(deleted) => panic!("did not expect success {deleted:?}"), + } + + let deleted = connection + .delete_ping(43198, Identifier::Index(NonZeroUsize::new(2).unwrap())) + .await + .unwrap(); + assert_eq!(deleted.len(), 1); + assert_eq!(deleted[0].origin_url, "2"); + let notifications = connection.get_notifications("ehuss").await.unwrap(); + assert_eq!(notifications.len(), 2); + assert_eq!(notifications[0].origin_url, "1"); + assert_eq!(notifications[1].origin_url, "3"); + + let deleted = connection + .delete_ping(43198, Identifier::Url("1")) + .await + .unwrap(); + assert_eq!(deleted.len(), 1); + assert_eq!(deleted[0].origin_url, "1"); + let notifications = connection.get_notifications("ehuss").await.unwrap(); + assert_eq!(notifications.len(), 1); + assert_eq!(notifications[0].origin_url, "3"); + + for x in 4..6 { + connection + .record_ping(&Notification { + user_id: 43198, + origin_url: x.to_string(), + origin_html: "@ehuss {n}".to_string(), + short_description: Some("Comment on some issue".to_string()), + time: now.into(), + team_name: None, + }) + .await + .unwrap(); + } + + let notifications = connection.get_notifications("ehuss").await.unwrap(); + assert_eq!(notifications.len(), 3); + assert_eq!(notifications[0].origin_url, "3"); + assert_eq!(notifications[1].origin_url, "4"); + assert_eq!(notifications[2].origin_url, "5"); + + let deleted = connection + .delete_ping(43198, Identifier::Index(NonZeroUsize::new(2).unwrap())) + .await + .unwrap(); + assert_eq!(deleted.len(), 1); + assert_eq!(deleted[0].origin_url, "4"); + + let deleted = connection + .delete_ping(43198, Identifier::All) + .await + .unwrap(); + assert_eq!(deleted.len(), 2); + assert_eq!(deleted[0].origin_url, "3"); + assert_eq!(deleted[1].origin_url, "5"); + + let notifications = connection.get_notifications("ehuss").await.unwrap(); + assert_eq!(notifications.len(), 0); + }); +} + +#[test] +fn meta_notification() { + run_test(|mut connection| async move { + let now = chrono::Utc::now(); + connection + .record_username(43198, "ehuss".to_string()) + .await + .unwrap(); + connection + .record_ping(&Notification { + user_id: 43198, + origin_url: "1".to_string(), + origin_html: "This comment mentions @ehuss.".to_string(), + short_description: Some("Comment on some issue".to_string()), + time: now.into(), + team_name: None, + }) + .await + .unwrap(); + connection + .add_metadata(43198, 0, Some("metadata 1")) + .await + .unwrap(); + let notifications = connection.get_notifications("ehuss").await.unwrap(); + assert_eq!(notifications.len(), 1); + assert_eq!(notifications[0].metadata.as_deref(), Some("metadata 1")); + }); +} + +#[test] +fn move_indices() { + run_test(|mut connection| async move { + let now = chrono::Utc::now(); + connection + .record_username(43198, "ehuss".to_string()) + .await + .unwrap(); + for x in 1..4 { + connection + .record_ping(&Notification { + user_id: 43198, + origin_url: x.to_string(), + origin_html: "@ehuss {n}".to_string(), + short_description: Some("Comment on some issue".to_string()), + time: now.into(), + team_name: None, + }) + .await + .unwrap(); + } + connection.move_indices(43198, 1, 0).await.unwrap(); + let notifications = connection.get_notifications("ehuss").await.unwrap(); + assert_eq!(notifications.len(), 3); + assert_eq!(notifications[0].origin_url, "2"); + assert_eq!(notifications[1].origin_url, "1"); + assert_eq!(notifications[2].origin_url, "3"); + }); +} diff --git a/tests/db/rustc_commits.rs b/tests/db/rustc_commits.rs new file mode 100644 index 00000000..b1b307ab --- /dev/null +++ b/tests/db/rustc_commits.rs @@ -0,0 +1,86 @@ +use super::run_test; +use triagebot::db::Commit; + +#[test] +fn rustc_commits() { + run_test(|mut connection| async move { + // Using current time since `get_commits_with_artifacts` is relative to the current time. + let now = chrono::offset::Utc::now(); + connection + .record_commit(&Commit { + sha: "eebdfb55fce148676c24555505aebf648123b2de".to_string(), + parent_sha: "73f40197ecabf77ed59028af61739404eb60dd2e".to_string(), + time: now.into(), + pr: Some(108228), + }) + .await + .unwrap(); + + // A little older to ensure sorting is consistent. + let now3 = now - chrono::Duration::hours(3); + connection + .record_commit(&Commit { + sha: "73f40197ecabf77ed59028af61739404eb60dd2e".to_string(), + parent_sha: "fcdbd1c07f0b6c8e7d8bbd727c6ca69a1af8c7e9".to_string(), + time: now3.into(), + pr: Some(107772), + }) + .await + .unwrap(); + + // In the distant past, won't show up in get_commits_with_artifacts. + connection + .record_commit(&Commit { + sha: "26904687275a55864f32f3a7ba87b7711d063fd5".to_string(), + parent_sha: "3b348d932aa5c9884310d025cf7c516023fd0d9a".to_string(), + time: "2022-02-19T23:25:06Z".parse().unwrap(), + pr: Some(92911), + }) + .await + .unwrap(); + + assert!(connection + .has_commit("eebdfb55fce148676c24555505aebf648123b2de") + .await + .unwrap()); + assert!(connection + .has_commit("73f40197ecabf77ed59028af61739404eb60dd2e") + .await + .unwrap()); + assert!(connection + .has_commit("26904687275a55864f32f3a7ba87b7711d063fd5") + .await + .unwrap()); + assert!(!connection + .has_commit("fcdbd1c07f0b6c8e7d8bbd727c6ca69a1af8c7e9") + .await + .unwrap()); + + let missing = connection.get_missing_commits().await.unwrap(); + assert_eq!( + &missing[..], + [ + "fcdbd1c07f0b6c8e7d8bbd727c6ca69a1af8c7e9", + "3b348d932aa5c9884310d025cf7c516023fd0d9a" + ] + ); + + let commits = connection.get_commits_with_artifacts().await.unwrap(); + assert_eq!(commits.len(), 2); + assert_eq!(commits[0].sha, "eebdfb55fce148676c24555505aebf648123b2de"); + assert_eq!( + commits[0].parent_sha, + "73f40197ecabf77ed59028af61739404eb60dd2e" + ); + assert_eq!(commits[0].time, now); + assert_eq!(commits[0].pr, Some(108228)); + + assert_eq!(commits[1].sha, "73f40197ecabf77ed59028af61739404eb60dd2e"); + assert_eq!( + commits[1].parent_sha, + "fcdbd1c07f0b6c8e7d8bbd727c6ca69a1af8c7e9" + ); + assert_eq!(commits[1].time, now3); + assert_eq!(commits[1].pr, Some(107772)); + }); +} diff --git a/tests/server_test/mod.rs b/tests/server_test/mod.rs index 21f84042..f03da5e4 100644 --- a/tests/server_test/mod.rs +++ b/tests/server_test/mod.rs @@ -27,8 +27,8 @@ //! ``` //! //! Look at `README.md` for instructions for running triagebot against the -//! live GitHub site. You'll need to have webhook forwarding and Postgres -//! running in the background. +//! live GitHub site. You'll need to have webhook forwarding running in the +//! background. //! //! 3. Perform the action you want to test on GitHub. For example, post a //! comment with `@rustbot ready` to test the "ready" command. @@ -50,6 +50,13 @@ //! //! with the name of your test. //! +//! ## Databases +//! +//! By default, the server tests will use Postgres if it is installed. If it +//! doesn't appear to be installed, then it will use SQLite instead. If you +//! want to force it to use SQLite, you can set the +//! TRIAGEBOT_TEST_FORCE_SQLITE environment variable. +//! //! ## Scheduled Jobs //! //! Scheduled jobs get automatically disabled when recording or running tests @@ -63,17 +70,14 @@ mod shortcut; use super::{HttpServer, HttpServerHandle}; use std::io::Read; use std::net::{SocketAddr, TcpListener}; -use std::path::{Path, PathBuf}; +use std::path::PathBuf; use std::process::{Child, Command, Stdio}; -use std::sync::atomic::{AtomicU16, AtomicU32}; +use std::sync::atomic::AtomicU16; use std::sync::{Arc, Mutex}; use std::thread; use std::time::Duration; use triagebot::test_record::Activity; -/// Counter used to give each test a unique sandbox directory in the -/// `target/tmp` directory. -static TEST_COUNTER: AtomicU32 = AtomicU32::new(1); /// The webhook secret used to validate that the webhook events are coming /// from the expected source. const WEBHOOK_SECRET: &str = "secret"; @@ -89,7 +93,9 @@ struct ServerTestCtx { /// Stderr received from triagebot, used for debugging. stderr: Arc>>, /// Directory for the temporary Postgres database. - db_dir: PathBuf, + /// + /// `None` if using sqlite. + db_dir: Option, /// The address for sending webhooks into the triagebot binary. triagebot_addr: SocketAddr, /// The handle to the mock server which simulates GitHub. @@ -134,17 +140,25 @@ fn run_test(test_name: &str) { } fn build(activities: Vec) -> ServerTestCtx { - // Set up a directory where this test can store all its stuff. - let tmp_dir = Path::new(env!("CARGO_TARGET_TMPDIR")).join("local"); - let test_num = TEST_COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst); - let test_dir = tmp_dir.join(format!("t{test_num}")); - if test_dir.exists() { - std::fs::remove_dir_all(&test_dir).unwrap(); - } - std::fs::create_dir_all(&test_dir).unwrap(); - - let db_dir = test_dir.join("db"); - setup_postgres(&db_dir); + let db_sqlite = || { + crate::test_dir() + .join("triagebot.sqlite3") + .to_str() + .unwrap() + .to_string() + }; + let (db_dir, database_url) = if std::env::var_os("TRIAGEBOT_TEST_FORCE_SQLITE").is_some() { + (None, db_sqlite()) + } else { + match crate::db::setup_postgres() { + Some(db_dir) => { + let database_url = crate::db::postgres_database_url(&db_dir); + (Some(db_dir), database_url) + } + None if std::env::var_os("CI").is_some() => panic!("expected postgres in CI"), + None => (None, db_sqlite()), + } + }; let server = HttpServer::new(activities); let triagebot_port = next_triagebot_port(); @@ -154,13 +168,7 @@ fn build(activities: Vec) -> ServerTestCtx { "ghp_123456789012345678901234567890123456", ) .env("GITHUB_WEBHOOK_SECRET", WEBHOOK_SECRET) - .env( - "DATABASE_URL", - format!( - "postgres:///triagebot?user=triagebot&host={}", - db_dir.display() - ), - ) + .env("DATABASE_URL", database_url) .env("PORT", triagebot_port.to_string()) .env("GITHUB_API_URL", format!("http://{}", server.addr)) .env( @@ -240,29 +248,9 @@ impl ServerTestCtx { impl Drop for ServerTestCtx { fn drop(&mut self) { - // Shut down postgres. - let pg_dir = find_postgres(); - match Command::new(pg_dir.join("pg_ctl")) - .args(&["-D", self.db_dir.to_str().unwrap(), "stop"]) - .output() - { - Ok(output) => { - if !output.status.success() { - eprintln!( - "failed to stop postgres:\n\ - ---stdout\n\ - {}\n\ - ---stderr\n\ - {}\n\ - ", - std::str::from_utf8(&output.stdout).unwrap(), - std::str::from_utf8(&output.stderr).unwrap() - ); - } - } - Err(e) => eprintln!("could not run pg_ctl to stop: {e}"), + if let Some(db_dir) = &self.db_dir { + crate::db::stop_postgres(db_dir); } - // Shut down triagebot. let _ = self.child.kill(); // Display triagebot's output for debugging. @@ -279,98 +267,6 @@ impl Drop for ServerTestCtx { } } -fn run_command(command: &Path, args: &[&str], cwd: &Path) { - eprintln!("running {command:?}: {args:?}"); - let output = Command::new(command) - .args(args) - .current_dir(cwd) - .output() - .unwrap_or_else(|e| panic!("`{command:?}` failed to run: {e}")); - if !output.status.success() { - panic!( - "{command:?} failed:\n\ - ---stdout\n\ - {}\n\ - ---stderr\n\ - {}\n\ - ", - std::str::from_utf8(&output.stdout).unwrap(), - std::str::from_utf8(&output.stderr).unwrap() - ); - } -} - -fn setup_postgres(db_dir: &Path) { - std::fs::create_dir(&db_dir).unwrap(); - let db_dir_str = db_dir.to_str().unwrap(); - let pg_dir = find_postgres(); - run_command( - &pg_dir.join("initdb"), - &["--auth=trust", "--username=triagebot", "-D", db_dir_str], - db_dir, - ); - run_command( - &pg_dir.join("pg_ctl"), - &[ - // -h '' tells it to not listen on TCP - // -k tells it where to place the unix-domain socket - "-o", - &format!("-h '' -k {db_dir_str}"), - // -D is the data dir where everything is stored - "-D", - db_dir_str, - // -l enables logging to a file instead of stdout - "-l", - db_dir.join("postgres.log").to_str().unwrap(), - "start", - ], - db_dir, - ); - run_command( - &pg_dir.join("createdb"), - &["--user", "triagebot", "-h", db_dir_str, "triagebot"], - db_dir, - ); -} - -/// Finds the root for PostgreSQL commands. -/// -/// For various reasons, some Linux distros hide some postgres commands and -/// don't put them on PATH, making them difficult to access. -fn find_postgres() -> PathBuf { - // Check if on PATH first. - if let Ok(o) = Command::new("initdb").arg("-V").output() { - if o.status.success() { - return PathBuf::new(); - } - } - if let Ok(dirs) = std::fs::read_dir("/usr/lib/postgresql") { - let mut versions: Vec<_> = dirs - .filter_map(|entry| { - let entry = entry.unwrap(); - // Versions are generally of the form 9.3 or 14, but this - // might be broken if other forms are used. - if let Ok(n) = entry.file_name().to_str().unwrap().parse::() { - Some((n, entry.path())) - } else { - None - } - }) - .collect(); - if !versions.is_empty() { - versions.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); - return versions.last().unwrap().1.join("bin"); - } - } - panic!( - "Could not find PostgreSQL binaries.\n\ - Make sure to install PostgreSQL.\n\ - If PostgreSQL is installed, update this function to match where they \ - are located on your system.\n\ - Or, add them to your PATH." - ); -} - /// Returns a free port for the next triagebot process to use. fn next_triagebot_port() -> u16 { static NEXT_TCP_PORT: AtomicU16 = AtomicU16::new(50000); diff --git a/tests/testsuite.rs b/tests/testsuite.rs index 01fdafaf..0e7d38f1 100644 --- a/tests/testsuite.rs +++ b/tests/testsuite.rs @@ -12,6 +12,7 @@ //! requests that would normally go to external sites like //! https://api.github.com. +mod db; mod github_client; mod server_test; @@ -19,12 +20,42 @@ use std::collections::HashMap; use std::io::{BufRead, BufReader, Read, Write}; use std::net::TcpStream; use std::net::{SocketAddr, TcpListener}; -use std::path::PathBuf; -use std::sync::atomic::{AtomicBool, Ordering}; +use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; use std::sync::mpsc; +use std::sync::Mutex; use triagebot::test_record::{self, Activity}; use url::Url; +pub fn test_dir() -> PathBuf { + thread_local! { + static TEST_ID: Mutex> = Mutex::new(None); + } + static NEXT_ID: AtomicU32 = AtomicU32::new(0); + let path_from_id = |id| { + let tmp_dir = Path::new(env!("CARGO_TARGET_TMPDIR")).join("testsuite"); + tmp_dir.join(format!("t{id}")) + }; + + let this_id = TEST_ID.with(|n| { + let mut v = n.lock().unwrap(); + match *v { + None => { + let test_id = NEXT_ID.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + let test_dir = path_from_id(test_id); + if test_dir.exists() { + std::fs::remove_dir_all(&test_dir).unwrap(); + } + std::fs::create_dir_all(&test_dir).unwrap(); + *v = Some(test_id); + test_id + } + Some(id) => id, + } + }); + path_from_id(this_id) +} + /// A request received on the HTTP server. #[derive(Clone, Debug)] pub struct Request {