From 25acf0f7f4b8bc32e9e5668ef53d64768d01ed41 Mon Sep 17 00:00:00 2001 From: Bas Zalmstra Date: Thu, 28 Sep 2023 11:33:44 +0200 Subject: [PATCH] refactor: simplify resolve api --- Cargo.lock | 1 - crates/rattler_installs_packages/Cargo.toml | 3 +- crates/rattler_installs_packages/src/lib.rs | 7 +- .../src/{resolvo_pypi.rs => resolve.rs} | 91 +++++++++++++++++-- crates/rip_bin/Cargo.toml | 11 ++- crates/rip_bin/src/main.rs | 77 ++++------------ 6 files changed, 112 insertions(+), 78 deletions(-) rename crates/rattler_installs_packages/src/{resolvo_pypi.rs => resolve.rs} (73%) diff --git a/Cargo.lock b/Cargo.lock index 413dd567..676b797f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1995,7 +1995,6 @@ dependencies = [ "miette", "rand", "rattler_installs_packages", - "resolvo", "serde", "serde_json", "tabwriter", diff --git a/crates/rattler_installs_packages/Cargo.toml b/crates/rattler_installs_packages/Cargo.toml index e69cc431..aac6ea24 100644 --- a/crates/rattler_installs_packages/Cargo.toml +++ b/crates/rattler_installs_packages/Cargo.toml @@ -15,7 +15,6 @@ rust-version.workspace = true default = ["native-tls"] native-tls = ['reqwest/native-tls'] rustls-tls = ['reqwest/rustls-tls'] -resolvo-pypi = [] [dependencies] async-trait = "0.1.73" @@ -50,7 +49,7 @@ tokio-util = { version = "0.7.9", features = ["compat"] } tracing = { version = "0.1.37", default-features = false, features = ["attributes"] } url = { version = "2.4.1", features = ["serde"] } zip = "0.6.6" -resolvo = "0.1.0" +resolvo = { version = "0.1.0", optional = true } [dev-dependencies] criterion = "0.3" diff --git a/crates/rattler_installs_packages/src/lib.rs b/crates/rattler_installs_packages/src/lib.rs index 8c225d7a..1ad24422 100644 --- a/crates/rattler_installs_packages/src/lib.rs +++ b/crates/rattler_installs_packages/src/lib.rs @@ -15,8 +15,11 @@ mod seek_slice; mod specifier; mod utils; -#[cfg(feature = "resolvo-pypi")] -pub mod resolvo_pypi; +#[cfg(feature = "resolvo")] +mod resolve; + +#[cfg(feature = "resolvo")] +pub use resolve::resolve; pub use file_store::{CacheKey, FileStore}; pub use package_database::PackageDb; diff --git a/crates/rattler_installs_packages/src/resolvo_pypi.rs b/crates/rattler_installs_packages/src/resolve.rs similarity index 73% rename from crates/rattler_installs_packages/src/resolvo_pypi.rs rename to crates/rattler_installs_packages/src/resolve.rs index cf579c38..85a1fca0 100644 --- a/crates/rattler_installs_packages/src/resolvo_pypi.rs +++ b/crates/rattler_installs_packages/src/resolve.rs @@ -1,11 +1,12 @@ use crate::{ - CompareOp, Extra, NormalizedPackageName, PackageDb, Requirement, Specifier, Specifiers, - Version, Wheel, + CompareOp, Extra, NormalizedPackageName, PackageDb, PackageName, PackageRequirement, + Requirement, Specifier, Specifiers, Version, Wheel, }; use resolvo::{ - Candidates, Dependencies, DependencyProvider, NameId, Pool, SolvableId, SolverCache, VersionSet, + Candidates, DefaultSolvableDisplay, Dependencies, DependencyProvider, NameId, Pool, SolvableId, + Solver, SolverCache, VersionSet, }; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::fmt::{Display, Formatter}; use tokio::runtime::Handle; use tokio::task; @@ -81,13 +82,13 @@ impl Display for PypiPackageName { } } -pub struct PypiDependencyProvider { +pub struct PypiDependencyProvider<'db> { pool: Pool, - package_db: PackageDb, + package_db: &'db PackageDb, } -impl PypiDependencyProvider { - pub fn new(package_db: PackageDb) -> Self { +impl<'db> PypiDependencyProvider<'db> { + pub fn new(package_db: &'db PackageDb) -> Self { Self { pool: Pool::new(), package_db, @@ -95,7 +96,7 @@ impl PypiDependencyProvider { } } -impl DependencyProvider for PypiDependencyProvider { +impl<'db> DependencyProvider for PypiDependencyProvider<'db> { fn pool(&self) -> &Pool { &self.pool } @@ -296,3 +297,75 @@ impl DependencyProvider for PypiDependencyProvi dependencies } } + +/// Resolves an environment that contains the given requirements and all dependencies of those +/// requirements. +pub async fn resolve( + package_db: &PackageDb, + requirements: impl IntoIterator, +) -> Result)>, String> { + // Construct a provider + let provider = PypiDependencyProvider::new(package_db); + let pool = provider.pool(); + + let requirements = requirements.into_iter(); + + // Construct the root requirements from the requirements requested by the user. + let requirement_count = requirements.size_hint(); + let mut root_requirements = + Vec::with_capacity(requirement_count.1.unwrap_or(requirement_count.0)); + for Requirement { + name, + specifiers, + extras, + .. + } in requirements.map(PackageRequirement::as_inner) + { + let dependency_package_name = + pool.intern_package_name(PypiPackageName::Base(name.clone().into())); + let version_set_id = + pool.intern_version_set(dependency_package_name, specifiers.clone().into()); + root_requirements.push(version_set_id); + + for extra in extras { + let dependency_package_name = pool + .intern_package_name(PypiPackageName::Extra(name.clone().into(), extra.clone())); + let version_set_id = + pool.intern_version_set(dependency_package_name, specifiers.clone().into()); + root_requirements.push(version_set_id); + } + } + + // Invoke the solver to get a solution to the requirements + let mut solver = Solver::new(provider); + let result = solver.solve(root_requirements); + + match result { + Ok(solvables) => { + let mut result = HashMap::default(); + for solvable in solvables { + let pool = solver.pool(); + let solvable = pool.resolve_solvable(solvable); + let name = pool.resolve_package_name(solvable.name_id()); + let version = solvable.inner(); + match name { + PypiPackageName::Base(name) => { + result + .entry(name.clone().into()) + .or_insert((version.0.clone(), HashSet::new())); + } + PypiPackageName::Extra(name, extra) => { + let (_, extras) = result + .entry(name.clone().into()) + .or_insert((version.0.clone(), HashSet::new())); + extras.insert(extra.clone()); + } + } + } + Ok(result) + } + Err(e) => Err(e + .display_user_friendly(&solver, &DefaultSolvableDisplay) + .to_string()), + } +} diff --git a/crates/rip_bin/Cargo.toml b/crates/rip_bin/Cargo.toml index 9d516bcd..13af1089 100644 --- a/crates/rip_bin/Cargo.toml +++ b/crates/rip_bin/Cargo.toml @@ -9,9 +9,11 @@ homepage.workspace = true repository.workspace = true license.workspace = true readme.workspace = true -default-run = "rip_bin" +default-run = "rip" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[[bin]] +name = "rip" +path = "src/main.rs" [features] default = ["native-tls"] @@ -26,12 +28,11 @@ indexmap = "2.0.0" indicatif = "0.17.6" itertools = "0.11.0" miette = { version = "5.10.0", features = ["fancy"] } -rattler_installs_packages = { path = "../rattler_installs_packages", default-features = false, features = ["resolvo-pypi"] } -resolvo = "0.1.0" +rattler_installs_packages = { path = "../rattler_installs_packages", default-features = false, features = ["resolvo"] } tabwriter = { version = "1.2.1", features = ["ansi_formatting"] } tokio = { version = "1.29.1", features = ["rt", "macros", "rt-multi-thread"] } tracing = "0.1.37" -tracing-subscriber = { version = "0.3.17", features = ["env-filter"]} +tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } url = "2.4.0" rand = "0.8.4" serde = "1.0.188" diff --git a/crates/rip_bin/src/main.rs b/crates/rip_bin/src/main.rs index bb3db76f..8b886b57 100644 --- a/crates/rip_bin/src/main.rs +++ b/crates/rip_bin/src/main.rs @@ -2,16 +2,12 @@ use rip_bin::{global_multi_progress, IndicatifWriter}; use std::io::Write; use clap::Parser; +use itertools::Itertools; use miette::IntoDiagnostic; -use resolvo::{DefaultSolvableDisplay, DependencyProvider, Solver}; use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; use url::Url; -use rattler_installs_packages::{ - normalize_index_url, - resolvo_pypi::{PypiDependencyProvider, PypiPackageName}, -}; -use rattler_installs_packages::{PackageRequirement, Requirement}; +use rattler_installs_packages::{normalize_index_url, resolve, PackageRequirement}; #[derive(Parser)] #[command(author, version, about, long_about = None)] @@ -48,56 +44,10 @@ async fn actual_main() -> miette::Result<()> { ) .into_diagnostic()?; - let provider = PypiDependencyProvider::new(package_db); - - // Create a task to solve the specs passed on the command line. - let mut root_requirements = Vec::with_capacity(args.specs.len()); - for Requirement { - name, - specifiers, - extras, - .. - } in args.specs.iter().map(PackageRequirement::as_inner) - { - let dependency_package_name = provider - .pool() - .intern_package_name(PypiPackageName::Base(name.clone().into())); - let version_set_id = provider - .pool() - .intern_version_set(dependency_package_name, specifiers.clone().into()); - root_requirements.push(version_set_id); - - for extra in extras { - let dependency_package_name = provider - .pool() - .intern_package_name(PypiPackageName::Extra(name.clone().into(), extra.clone())); - let version_set_id = provider - .pool() - .intern_version_set(dependency_package_name, specifiers.clone().into()); - root_requirements.push(version_set_id); - } - } - - // Solve the jobs - let mut solver = Solver::new(provider); - let result = solver.solve(root_requirements); - let artifacts = match result { - Err(e) => { - eprintln!( - "Could not solve:\n{}", - e.display_user_friendly(&solver, &DefaultSolvableDisplay) - ); - return Ok(()); - } - Ok(transaction) => transaction - .into_iter() - .map(|result| { - let pool = solver.pool(); - let solvable = pool.resolve_solvable(result); - let name = pool.resolve_package_name(solvable.name_id()); - (name.clone(), solvable.inner().0.clone()) - }) - .collect::>(), + // Solve the environment + let blueprint = match resolve(&package_db, &args.specs).await { + Ok(blueprint) => blueprint, + Err(err) => miette::bail!("Could not solve for the requested requirements:\n{err}"), }; // Output the selected versions @@ -115,10 +65,19 @@ async fn actual_main() -> miette::Result<()> { console::style("Version").bold() ) .into_diagnostic()?; - for (name, artifact) in artifacts { - writeln!(tabbed_stdout, "{name}\t{artifact}").into_diagnostic()?; + for (name, (version, extras)) in blueprint.into_iter().sorted_by(|(a, _), (b, _)| a.cmp(b)) { + write!(tabbed_stdout, "{name}", name = name.as_str()).into_diagnostic()?; + if !extras.is_empty() { + write!( + tabbed_stdout, + "[{}]", + extras.iter().map(|e| e.as_str()).join(",") + ) + .into_diagnostic()?; + } + writeln!(tabbed_stdout, "\t{version}").into_diagnostic()?; } - tabbed_stdout.flush().unwrap(); + tabbed_stdout.flush().into_diagnostic()?; Ok(()) }