Skip to content

Commit

Permalink
refactor: simplify resolve api
Browse files Browse the repository at this point in the history
  • Loading branch information
baszalmstra committed Sep 28, 2023
1 parent d5153f1 commit 25acf0f
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 78 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions crates/rattler_installs_packages/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ rust-version.workspace = true
default = ["native-tls"]
native-tls = ['reqwest/native-tls']
rustls-tls = ['reqwest/rustls-tls']
resolvo-pypi = []

[dependencies]
async-trait = "0.1.73"
Expand Down Expand Up @@ -50,7 +49,7 @@ tokio-util = { version = "0.7.9", features = ["compat"] }
tracing = { version = "0.1.37", default-features = false, features = ["attributes"] }
url = { version = "2.4.1", features = ["serde"] }
zip = "0.6.6"
resolvo = "0.1.0"
resolvo = { version = "0.1.0", optional = true }

[dev-dependencies]
criterion = "0.3"
Expand Down
7 changes: 5 additions & 2 deletions crates/rattler_installs_packages/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ mod seek_slice;
mod specifier;
mod utils;

#[cfg(feature = "resolvo-pypi")]
pub mod resolvo_pypi;
#[cfg(feature = "resolvo")]
mod resolve;

#[cfg(feature = "resolvo")]
pub use resolve::resolve;

pub use file_store::{CacheKey, FileStore};
pub use package_database::PackageDb;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use crate::{
CompareOp, Extra, NormalizedPackageName, PackageDb, Requirement, Specifier, Specifiers,
Version, Wheel,
CompareOp, Extra, NormalizedPackageName, PackageDb, PackageName, PackageRequirement,
Requirement, Specifier, Specifiers, Version, Wheel,
};
use resolvo::{
Candidates, Dependencies, DependencyProvider, NameId, Pool, SolvableId, SolverCache, VersionSet,
Candidates, DefaultSolvableDisplay, Dependencies, DependencyProvider, NameId, Pool, SolvableId,
Solver, SolverCache, VersionSet,
};
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::fmt::{Display, Formatter};
use tokio::runtime::Handle;
use tokio::task;
Expand Down Expand Up @@ -81,21 +82,21 @@ impl Display for PypiPackageName {
}
}

pub struct PypiDependencyProvider {
pub struct PypiDependencyProvider<'db> {
pool: Pool<PypiVersionSet, PypiPackageName>,
package_db: PackageDb,
package_db: &'db PackageDb,
}

impl PypiDependencyProvider {
pub fn new(package_db: PackageDb) -> Self {
impl<'db> PypiDependencyProvider<'db> {
pub fn new(package_db: &'db PackageDb) -> Self {
Self {
pool: Pool::new(),
package_db,
}
}
}

impl DependencyProvider<PypiVersionSet, PypiPackageName> for PypiDependencyProvider {
impl<'db> DependencyProvider<PypiVersionSet, PypiPackageName> for PypiDependencyProvider<'db> {
fn pool(&self) -> &Pool<PypiVersionSet, PypiPackageName> {
&self.pool
}
Expand Down Expand Up @@ -296,3 +297,75 @@ impl DependencyProvider<PypiVersionSet, PypiPackageName> for PypiDependencyProvi
dependencies
}
}

/// Resolves an environment that contains the given requirements and all dependencies of those
/// requirements.
pub async fn resolve(
package_db: &PackageDb,
requirements: impl IntoIterator<Item = &PackageRequirement>,
) -> Result<HashMap<PackageName, (Version, HashSet<Extra>)>, String> {
// Construct a provider
let provider = PypiDependencyProvider::new(package_db);
let pool = provider.pool();

let requirements = requirements.into_iter();

// Construct the root requirements from the requirements requested by the user.
let requirement_count = requirements.size_hint();
let mut root_requirements =
Vec::with_capacity(requirement_count.1.unwrap_or(requirement_count.0));
for Requirement {
name,
specifiers,
extras,
..
} in requirements.map(PackageRequirement::as_inner)
{
let dependency_package_name =
pool.intern_package_name(PypiPackageName::Base(name.clone().into()));
let version_set_id =
pool.intern_version_set(dependency_package_name, specifiers.clone().into());
root_requirements.push(version_set_id);

for extra in extras {
let dependency_package_name = pool
.intern_package_name(PypiPackageName::Extra(name.clone().into(), extra.clone()));
let version_set_id =
pool.intern_version_set(dependency_package_name, specifiers.clone().into());
root_requirements.push(version_set_id);
}
}

// Invoke the solver to get a solution to the requirements
let mut solver = Solver::new(provider);
let result = solver.solve(root_requirements);

match result {
Ok(solvables) => {
let mut result = HashMap::default();
for solvable in solvables {
let pool = solver.pool();
let solvable = pool.resolve_solvable(solvable);
let name = pool.resolve_package_name(solvable.name_id());
let version = solvable.inner();
match name {
PypiPackageName::Base(name) => {
result
.entry(name.clone().into())
.or_insert((version.0.clone(), HashSet::new()));
}
PypiPackageName::Extra(name, extra) => {
let (_, extras) = result
.entry(name.clone().into())
.or_insert((version.0.clone(), HashSet::new()));
extras.insert(extra.clone());
}
}
}
Ok(result)
}
Err(e) => Err(e
.display_user_friendly(&solver, &DefaultSolvableDisplay)
.to_string()),
}
}
11 changes: 6 additions & 5 deletions crates/rip_bin/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ homepage.workspace = true
repository.workspace = true
license.workspace = true
readme.workspace = true
default-run = "rip_bin"
default-run = "rip"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[[bin]]
name = "rip"
path = "src/main.rs"

[features]
default = ["native-tls"]
Expand All @@ -26,12 +28,11 @@ indexmap = "2.0.0"
indicatif = "0.17.6"
itertools = "0.11.0"
miette = { version = "5.10.0", features = ["fancy"] }
rattler_installs_packages = { path = "../rattler_installs_packages", default-features = false, features = ["resolvo-pypi"] }
resolvo = "0.1.0"
rattler_installs_packages = { path = "../rattler_installs_packages", default-features = false, features = ["resolvo"] }
tabwriter = { version = "1.2.1", features = ["ansi_formatting"] }
tokio = { version = "1.29.1", features = ["rt", "macros", "rt-multi-thread"] }
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["env-filter"]}
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
url = "2.4.0"
rand = "0.8.4"
serde = "1.0.188"
Expand Down
77 changes: 18 additions & 59 deletions crates/rip_bin/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,12 @@ use rip_bin::{global_multi_progress, IndicatifWriter};
use std::io::Write;

use clap::Parser;
use itertools::Itertools;
use miette::IntoDiagnostic;
use resolvo::{DefaultSolvableDisplay, DependencyProvider, Solver};
use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
use url::Url;

use rattler_installs_packages::{
normalize_index_url,
resolvo_pypi::{PypiDependencyProvider, PypiPackageName},
};
use rattler_installs_packages::{PackageRequirement, Requirement};
use rattler_installs_packages::{normalize_index_url, resolve, PackageRequirement};

#[derive(Parser)]
#[command(author, version, about, long_about = None)]
Expand Down Expand Up @@ -48,56 +44,10 @@ async fn actual_main() -> miette::Result<()> {
)
.into_diagnostic()?;

let provider = PypiDependencyProvider::new(package_db);

// Create a task to solve the specs passed on the command line.
let mut root_requirements = Vec::with_capacity(args.specs.len());
for Requirement {
name,
specifiers,
extras,
..
} in args.specs.iter().map(PackageRequirement::as_inner)
{
let dependency_package_name = provider
.pool()
.intern_package_name(PypiPackageName::Base(name.clone().into()));
let version_set_id = provider
.pool()
.intern_version_set(dependency_package_name, specifiers.clone().into());
root_requirements.push(version_set_id);

for extra in extras {
let dependency_package_name = provider
.pool()
.intern_package_name(PypiPackageName::Extra(name.clone().into(), extra.clone()));
let version_set_id = provider
.pool()
.intern_version_set(dependency_package_name, specifiers.clone().into());
root_requirements.push(version_set_id);
}
}

// Solve the jobs
let mut solver = Solver::new(provider);
let result = solver.solve(root_requirements);
let artifacts = match result {
Err(e) => {
eprintln!(
"Could not solve:\n{}",
e.display_user_friendly(&solver, &DefaultSolvableDisplay)
);
return Ok(());
}
Ok(transaction) => transaction
.into_iter()
.map(|result| {
let pool = solver.pool();
let solvable = pool.resolve_solvable(result);
let name = pool.resolve_package_name(solvable.name_id());
(name.clone(), solvable.inner().0.clone())
})
.collect::<Vec<_>>(),
// Solve the environment
let blueprint = match resolve(&package_db, &args.specs).await {
Ok(blueprint) => blueprint,
Err(err) => miette::bail!("Could not solve for the requested requirements:\n{err}"),
};

// Output the selected versions
Expand All @@ -115,10 +65,19 @@ async fn actual_main() -> miette::Result<()> {
console::style("Version").bold()
)
.into_diagnostic()?;
for (name, artifact) in artifacts {
writeln!(tabbed_stdout, "{name}\t{artifact}").into_diagnostic()?;
for (name, (version, extras)) in blueprint.into_iter().sorted_by(|(a, _), (b, _)| a.cmp(b)) {
write!(tabbed_stdout, "{name}", name = name.as_str()).into_diagnostic()?;
if !extras.is_empty() {
write!(
tabbed_stdout,
"[{}]",
extras.iter().map(|e| e.as_str()).join(",")
)
.into_diagnostic()?;
}
writeln!(tabbed_stdout, "\t{version}").into_diagnostic()?;
}
tabbed_stdout.flush().unwrap();
tabbed_stdout.flush().into_diagnostic()?;

Ok(())
}
Expand Down

0 comments on commit 25acf0f

Please sign in to comment.