Skip to content

Commit

Permalink
simplify resolve api (#21)
Browse files Browse the repository at this point in the history
Simpify resolve API for users. Hides all the details of how the provider
is created and stuff. Caching is provider through the `PackageDb`.
  • Loading branch information
baszalmstra authored Sep 28, 2023
1 parent 4fd4511 commit 861e6c9
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 84 deletions.
7 changes: 4 additions & 3 deletions .github/workflows/rust-compile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rust-lang/setup-rust-toolchain@v1
- run: |
for package in $(cargo metadata --no-deps --format-version=1 | jq -r '.packages[] | .name'); do
cargo rustdoc -p "$package" --all-features -- -D warnings -W unreachable-pub
done
cargo rustdoc -p rattler_installs_packages --all-features -- -D warnings -W unreachable-pub
cargo rustdoc -p index --all-features -- -D warnings -W unreachable-pub
cargo rustdoc -p rip_bin --bin rip --all-features -- -D warnings -W unreachable-pub
cargo rustdoc -p rip_bin --lib --all-features -- -D warnings -W unreachable-pub
format_and_lint:
name: Format and Lint
Expand Down
49 changes: 48 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions crates/rattler_installs_packages/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ rust-version.workspace = true
default = ["native-tls"]
native-tls = ['reqwest/native-tls']
rustls-tls = ['reqwest/rustls-tls']
resolvo-pypi = []

[dependencies]
async-trait = "0.1.73"
Expand Down Expand Up @@ -50,14 +49,16 @@ tokio-util = { version = "0.7.9", features = ["compat"] }
tracing = { version = "0.1.37", default-features = false, features = ["attributes"] }
url = { version = "2.4.1", features = ["serde"] }
zip = "0.6.6"
resolvo = "0.1.0"
resolvo = { version = "0.1.0", optional = true }
dirs = "5.0.1"

[dev-dependencies]
criterion = "0.3"
insta = { version = "1.32.0", features = ["ron"] }
miette = { version = "5.10.0", features = ["fancy"] }
once_cell = "1.18.0"
tokio = { version = "1.32.0", features = ["rt", "macros"] }
tokio = { version = "1.32.0", features = ["rt", "macros", "rt-multi-thread"] }
tokio-test = "0.4.3"

[[bench]]
name = "html"
Expand Down
7 changes: 5 additions & 2 deletions crates/rattler_installs_packages/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ mod seek_slice;
mod specifier;
mod utils;

#[cfg(feature = "resolvo-pypi")]
pub mod resolvo_pypi;
#[cfg(feature = "resolvo")]
mod resolve;

#[cfg(feature = "resolvo")]
pub use resolve::resolve;

pub use file_store::{CacheKey, FileStore};
pub use package_database::PackageDb;
Expand Down
10 changes: 10 additions & 0 deletions crates/rattler_installs_packages/src/requirement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,16 @@ impl Deref for PackageRequirement {
#[derive(Debug, Clone, PartialEq, Eq, DeserializeFromStr, SerializeDisplay)]
pub struct UserRequirement(Requirement);

impl UserRequirement {
pub fn into_inner(self) -> Requirement {
self.0
}

pub fn as_inner(&self) -> &Requirement {
&self.0
}
}

impl Display for UserRequirement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use crate::{
CompareOp, Extra, NormalizedPackageName, PackageDb, Requirement, Specifier, Specifiers,
Version, Wheel,
CompareOp, Extra, NormalizedPackageName, PackageDb, PackageName, Requirement, Specifier,
Specifiers, UserRequirement, Version, Wheel,
};
use resolvo::{
Candidates, Dependencies, DependencyProvider, NameId, Pool, SolvableId, SolverCache, VersionSet,
Candidates, DefaultSolvableDisplay, Dependencies, DependencyProvider, NameId, Pool, SolvableId,
Solver, SolverCache, VersionSet,
};
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::fmt::{Display, Formatter};
use tokio::runtime::Handle;
use tokio::task;
Expand Down Expand Up @@ -81,21 +82,21 @@ impl Display for PypiPackageName {
}
}

pub struct PypiDependencyProvider {
pub struct PypiDependencyProvider<'db> {
pool: Pool<PypiVersionSet, PypiPackageName>,
package_db: PackageDb,
package_db: &'db PackageDb,
}

impl PypiDependencyProvider {
pub fn new(package_db: PackageDb) -> Self {
impl<'db> PypiDependencyProvider<'db> {
pub fn new(package_db: &'db PackageDb) -> Self {
Self {
pool: Pool::new(),
package_db,
}
}
}

impl DependencyProvider<PypiVersionSet, PypiPackageName> for PypiDependencyProvider {
impl<'db> DependencyProvider<PypiVersionSet, PypiPackageName> for PypiDependencyProvider<'db> {
fn pool(&self) -> &Pool<PypiVersionSet, PypiPackageName> {
&self.pool
}
Expand Down Expand Up @@ -296,3 +297,75 @@ impl DependencyProvider<PypiVersionSet, PypiPackageName> for PypiDependencyProvi
dependencies
}
}

/// Resolves an environment that contains the given requirements and all dependencies of those
/// requirements.
pub async fn resolve(
package_db: &PackageDb,
requirements: impl IntoIterator<Item = &UserRequirement>,
) -> Result<HashMap<PackageName, (Version, HashSet<Extra>)>, String> {
// Construct a provider
let provider = PypiDependencyProvider::new(package_db);
let pool = provider.pool();

let requirements = requirements.into_iter();

// Construct the root requirements from the requirements requested by the user.
let requirement_count = requirements.size_hint();
let mut root_requirements =
Vec::with_capacity(requirement_count.1.unwrap_or(requirement_count.0));
for Requirement {
name,
specifiers,
extras,
..
} in requirements.map(UserRequirement::as_inner)
{
let dependency_package_name =
pool.intern_package_name(PypiPackageName::Base(name.clone().into()));
let version_set_id =
pool.intern_version_set(dependency_package_name, specifiers.clone().into());
root_requirements.push(version_set_id);

for extra in extras {
let dependency_package_name = pool
.intern_package_name(PypiPackageName::Extra(name.clone().into(), extra.clone()));
let version_set_id =
pool.intern_version_set(dependency_package_name, specifiers.clone().into());
root_requirements.push(version_set_id);
}
}

// Invoke the solver to get a solution to the requirements
let mut solver = Solver::new(provider);
let result = solver.solve(root_requirements);

match result {
Ok(solvables) => {
let mut result = HashMap::default();
for solvable in solvables {
let pool = solver.pool();
let solvable = pool.resolve_solvable(solvable);
let name = pool.resolve_package_name(solvable.name_id());
let version = solvable.inner();
match name {
PypiPackageName::Base(name) => {
result
.entry(name.clone().into())
.or_insert((version.0.clone(), HashSet::new()));
}
PypiPackageName::Extra(name, extra) => {
let (_, extras) = result
.entry(name.clone().into())
.or_insert((version.0.clone(), HashSet::new()));
extras.insert(extra.clone());
}
}
}
Ok(result)
}
Err(e) => Err(e
.display_user_friendly(&solver, &DefaultSolvableDisplay)
.to_string()),
}
}
11 changes: 6 additions & 5 deletions crates/rip_bin/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ homepage.workspace = true
repository.workspace = true
license.workspace = true
readme.workspace = true
default-run = "rip_bin"
default-run = "rip"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[[bin]]
name = "rip"
path = "src/main.rs"

[features]
default = ["native-tls"]
Expand All @@ -26,12 +28,11 @@ indexmap = "2.0.0"
indicatif = "0.17.6"
itertools = "0.11.0"
miette = { version = "5.10.0", features = ["fancy"] }
rattler_installs_packages = { path = "../rattler_installs_packages", default-features = false, features = ["resolvo-pypi"] }
resolvo = "0.1.0"
rattler_installs_packages = { path = "../rattler_installs_packages", default-features = false, features = ["resolvo"] }
tabwriter = { version = "1.2.1", features = ["ansi_formatting"] }
tokio = { version = "1.29.1", features = ["rt", "macros", "rt-multi-thread"] }
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["env-filter"]}
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
url = "2.4.0"
rand = "0.8.4"
serde = "1.0.188"
Expand Down
Loading

0 comments on commit 861e6c9

Please sign in to comment.