From 3558ef1cc6a3be16d765b8f9c9816d7010a60db1 Mon Sep 17 00:00:00 2001
From: Mark Rousskov <mark.simulacrum@gmail.com>
Date: Sun, 8 Dec 2024 15:25:10 -0500
Subject: [PATCH] Add github-action PR open/closer

---
 Cargo.lock                        | 17 ++++----
 Cargo.toml                        |  1 +
 src/config.rs                     |  7 ++++
 src/github.rs                     | 27 ++++++++++++-
 src/handlers.rs                   | 11 ++++++
 src/handlers/bot_pull_requests.rs | 66 +++++++++++++++++++++++++++++++
 6 files changed, 120 insertions(+), 9 deletions(-)
 create mode 100644 src/handlers/bot_pull_requests.rs

diff --git a/Cargo.lock b/Cargo.lock
index 7680e03b..50dd6747 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1,6 +1,6 @@
 # This file is automatically @generated by Cargo.
 # It is not intended for manual editing.
-version = 3
+version = 4
 
 [[package]]
 name = "Inflector"
@@ -877,9 +877,9 @@ dependencies = [
 
 [[package]]
 name = "hashbrown"
-version = "0.14.3"
+version = "0.15.2"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604"
+checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289"
 
 [[package]]
 name = "heck"
@@ -1094,12 +1094,12 @@ dependencies = [
 
 [[package]]
 name = "indexmap"
-version = "2.1.0"
+version = "2.7.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f"
+checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f"
 dependencies = [
  "equivalent",
- "hashbrown 0.14.3",
+ "hashbrown 0.15.2",
  "serde",
 ]
 
@@ -1968,7 +1968,7 @@ name = "rust_team_data"
 version = "1.0.0"
 source = "git+https://github.com/rust-lang/team#1ff0fa95e5ead9fbbb4be3975cac8ede35b9d3d5"
 dependencies = [
- "indexmap 2.1.0",
+ "indexmap 2.7.0",
  "serde",
 ]
 
@@ -2641,7 +2641,7 @@ version = "0.21.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "d34d383cd00a163b4a5b85053df514d45bc330f6de7737edfe0a93311d1eaa03"
 dependencies = [
- "indexmap 2.1.0",
+ "indexmap 2.7.0",
  "serde",
  "serde_spanned",
  "toml_datetime",
@@ -2777,6 +2777,7 @@ dependencies = [
  "hex",
  "hyper",
  "ignore",
+ "indexmap 2.7.0",
  "itertools",
  "lazy_static",
  "native-tls",
diff --git a/Cargo.toml b/Cargo.toml
index fcf96ffa..3d6876dc 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -47,6 +47,7 @@ postgres-types = { version = "0.2.4", features = ["derive"] }
 cron = { version = "0.12.0" }
 bytes = "1.1.0"
 structopt = "0.3.26"
+indexmap = "2.7.0"
 
 [dependencies.serde]
 version = "1"
diff --git a/src/config.rs b/src/config.rs
index 5057baf1..c7679084 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -46,6 +46,7 @@ pub(crate) struct Config {
     pub(crate) pr_tracking: Option<ReviewPrefsConfig>,
     pub(crate) transfer: Option<TransferConfig>,
     pub(crate) merge_conflicts: Option<MergeConflictConfig>,
+    pub(crate) bot_pull_requests: Option<BotPullRequests>,
 }
 
 #[derive(PartialEq, Eq, Debug, serde::Deserialize)]
@@ -363,6 +364,11 @@ pub(crate) struct MergeConflictConfig {
     pub unless: HashSet<String>,
 }
 
+#[derive(PartialEq, Eq, Debug, serde::Deserialize)]
+#[serde(rename_all = "kebab-case")]
+#[serde(deny_unknown_fields)]
+pub(crate) struct BotPullRequests {}
+
 fn get_cached_config(repo: &str) -> Option<Result<Arc<Config>, ConfigurationError>> {
     let cache = CONFIG_CACHE.read().unwrap();
     cache.get(repo).and_then(|(config, fetch_time)| {
@@ -541,6 +547,7 @@ mod tests {
                 pr_tracking: None,
                 transfer: None,
                 merge_conflicts: None,
+                bot_pull_requests: None,
             }
         );
     }
diff --git a/src/github.rs b/src/github.rs
index e05e0f31..ea7b247c 100644
--- a/src/github.rs
+++ b/src/github.rs
@@ -189,6 +189,31 @@ impl GithubClient {
         .await
         .context("failed to create issue")
     }
+
+    pub(crate) async fn set_pr_status(
+        &self,
+        repo: &IssueRepository,
+        number: u64,
+        status: PrStatus,
+    ) -> anyhow::Result<()> {
+        #[derive(serde::Serialize)]
+        struct Update {
+            status: PrStatus,
+        }
+        let url = format!("{}/pulls/{number}", repo.url(&self));
+        self.send_req(self.post(&url).json(&Update { status }))
+            .await
+            .context("failed to update pr state")?;
+        Ok(())
+    }
+}
+
+#[derive(Debug, serde::Serialize)]
+pub(crate) enum PrStatus {
+    #[serde(rename = "open")]
+    Open,
+    #[serde(rename = "closed")]
+    Closed,
 }
 
 #[derive(Debug, serde::Deserialize)]
@@ -463,7 +488,7 @@ impl fmt::Display for AssignmentError {
 
 impl std::error::Error for AssignmentError {}
 
-#[derive(Debug, Clone, PartialEq, Eq)]
+#[derive(Hash, Debug, Clone, PartialEq, Eq)]
 pub struct IssueRepository {
     pub organization: String,
     pub repository: String,
diff --git a/src/handlers.rs b/src/handlers.rs
index 759c5441..9218a210 100644
--- a/src/handlers.rs
+++ b/src/handlers.rs
@@ -25,6 +25,7 @@ impl fmt::Display for HandlerError {
 
 mod assign;
 mod autolabel;
+mod bot_pull_requests;
 mod close;
 pub mod docs_update;
 mod github_releases;
@@ -117,6 +118,16 @@ pub async fn handle(ctx: &Context, event: &Event) -> Vec<HandlerError> {
         );
     }
 
+    if config.as_ref().is_ok_and(|c| c.bot_pull_requests.is_some()) {
+        if let Err(e) = bot_pull_requests::handle(ctx, event).await {
+            log::error!(
+                "failed to process event {:?} with bot_pull_requests handler: {:?}",
+                event,
+                e
+            )
+        }
+    }
+
     if let Some(config) = config
         .as_ref()
         .ok()
diff --git a/src/handlers/bot_pull_requests.rs b/src/handlers/bot_pull_requests.rs
new file mode 100644
index 00000000..94f1fa52
--- /dev/null
+++ b/src/handlers/bot_pull_requests.rs
@@ -0,0 +1,66 @@
+use indexmap::IndexSet;
+use std::sync::atomic::AtomicBool;
+use std::sync::{LazyLock, Mutex};
+
+use crate::github::{IssueRepository, IssuesAction, PrStatus};
+use crate::{github::Event, handlers::Context};
+
+pub(crate) async fn handle(ctx: &Context, event: &Event) -> anyhow::Result<()> {
+    let Event::Issue(event) = event else {
+        return Ok(());
+    };
+    if event.action != IssuesAction::Opened {
+        return Ok(());
+    }
+    if !event.issue.is_pr() {
+        return Ok(());
+    }
+
+    // avoid acting on our own open events, otherwise we'll infinitely loop
+    if event.sender.login == ctx.username {
+        return Ok(());
+    }
+
+    // If it's not the github-actions bot, we don't expect this handler to be needed. Skip the
+    // event.
+    if event.sender.login != "github-actions" {
+        return Ok(());
+    }
+
+    if DISABLE.load(std::sync::atomic::Ordering::Relaxed) {
+        tracing::warn!("skipping bot_pull_requests handler due to previous disable",);
+        return Ok(());
+    }
+
+    // Sanity check that our logic above doesn't cause us to act on PRs in a loop, by
+    // tracking a window of PRs we've acted on. We can probably drop this if we don't see problems
+    // in the first few days/weeks of deployment.
+    {
+        let mut touched = TOUCHED_PRS.lock().unwrap();
+        if !touched.insert((event.issue.repository().clone(), event.issue.number)) {
+            tracing::warn!("touching same PR twice despite username check: {:?}", event);
+            DISABLE.store(true, std::sync::atomic::Ordering::Relaxed);
+            return Ok(());
+        }
+        if touched.len() > 300 {
+            touched.drain(..150);
+        }
+    }
+
+    ctx.github
+        .set_pr_status(
+            event.issue.repository(),
+            event.issue.number,
+            PrStatus::Closed,
+        )
+        .await?;
+    ctx.github
+        .set_pr_status(event.issue.repository(), event.issue.number, PrStatus::Open)
+        .await?;
+
+    Ok(())
+}
+
+static TOUCHED_PRS: LazyLock<Mutex<IndexSet<(IssueRepository, u64)>>> =
+    LazyLock::new(|| std::sync::Mutex::new(IndexSet::new()));
+static DISABLE: AtomicBool = AtomicBool::new(false);