diff --git a/Cargo.lock b/Cargo.lock index 6042830b..074e77ac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -545,6 +545,15 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" +[[package]] +name = "concurrent-queue" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d16048cd947b08fa32c24458a22f5dc5e835264f689f4f5653210c69fd107363" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "configparser" version = "3.0.4" @@ -837,6 +846,17 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "event-listener" +version = "5.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b72557800024fabbaa2449dd4bf24e37b93702d457a4d4f2b0dd1f0f039f20c1" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + [[package]] name = "fastrand" version = "2.0.1" @@ -1522,15 +1542,6 @@ dependencies = [ "either", ] -[[package]] -name = "itertools" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" -dependencies = [ - "either", -] - [[package]] name = "itertools" version = "0.12.1" @@ -2466,14 +2477,17 @@ dependencies = [ [[package]] name = "resolvo" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acd163bc7df01195423c83a7a391fecf319ff41d3de899694a9ccb698e790b29" +checksum = "2016584c3fd9df0fd859a7dcbc7fafdc7fdd2d87b53a576e8e63e62fad140e33" dependencies = [ "bitvec", "elsa", - "itertools 0.11.0", + "event-listener", + "futures", + "itertools 0.12.1", "petgraph", + "tokio", "tracing", ] diff --git a/crates/rattler_installs_packages/Cargo.toml b/crates/rattler_installs_packages/Cargo.toml index 177121e2..95dfe4f6 100644 --- a/crates/rattler_installs_packages/Cargo.toml +++ b/crates/rattler_installs_packages/Cargo.toml @@ -58,7 +58,7 @@ tokio-util = { version = "0.7.10", features = ["compat"] } tracing = { version = "0.1.40", default-features = false, features = ["attributes"] } url = { version = "2.5.0", features = ["serde"] } zip = "0.6.6" -resolvo = { version = "0.3.0", default-features = false } +resolvo = { version = "0.4.0", default-features = false, features = ["tokio"] } pathdiff = "0.2.1" async_zip = { version = "0.0.16", features = ["tokio", "deflate"] } tar = "0.4.40" diff --git a/crates/rattler_installs_packages/src/artifacts/sdist.rs b/crates/rattler_installs_packages/src/artifacts/sdist.rs index 08643ace..566aaac3 100644 --- a/crates/rattler_installs_packages/src/artifacts/sdist.rs +++ b/crates/rattler_installs_packages/src/artifacts/sdist.rs @@ -646,7 +646,7 @@ mod tests { .available_artifacts(ArtifactRequest::DirectUrl { name: norm_name.into(), url: url.clone(), - wheel_builder: &wheel_builder, + wheel_builder: Arc::new(wheel_builder), }) .await .unwrap(); @@ -679,7 +679,7 @@ mod tests { .available_artifacts(ArtifactRequest::DirectUrl { name: norm_name.into(), url: url.clone(), - wheel_builder: &wheel_builder, + wheel_builder: Arc::new(wheel_builder), }) .await .unwrap(); @@ -714,7 +714,7 @@ mod tests { .available_artifacts(ArtifactRequest::DirectUrl { name: norm_name.into(), url: url.clone(), - wheel_builder: &wheel_builder, + wheel_builder: Arc::new(wheel_builder), }) .await .unwrap(); @@ -745,7 +745,7 @@ mod tests { .available_artifacts(ArtifactRequest::DirectUrl { name: norm_name.into(), url: url.clone(), - wheel_builder: &wheel_builder, + wheel_builder: Arc::new(wheel_builder), }) .await .unwrap(); @@ -789,7 +789,7 @@ mod tests { .available_artifacts(ArtifactRequest::DirectUrl { name: norm_name.into(), url: url.clone(), - wheel_builder: &wheel_builder, + wheel_builder: Arc::new(wheel_builder), }) .await .unwrap(); @@ -820,7 +820,7 @@ mod tests { .available_artifacts(ArtifactRequest::DirectUrl { name: norm_name.into(), url: url.clone(), - wheel_builder: &wheel_builder, + wheel_builder: Arc::new(wheel_builder), }) .await .unwrap(); diff --git a/crates/rattler_installs_packages/src/index/package_database.rs b/crates/rattler_installs_packages/src/index/package_database.rs index 4aa381ce..bd411476 100644 --- a/crates/rattler_installs_packages/src/index/package_database.rs +++ b/crates/rattler_installs_packages/src/index/package_database.rs @@ -27,6 +27,8 @@ use std::borrow::Borrow; use std::path::PathBuf; +use itertools::Itertools; +use std::ops::Deref; use std::sync::Arc; use std::{fmt::Display, io::Read, path::Path}; @@ -54,7 +56,7 @@ pub struct PackageDb { } /// Type of request to get from the `available_artifacts` function. -pub enum ArtifactRequest<'wb> { +pub enum ArtifactRequest { /// Get the available artifacts from the index. FromIndex(NormalizedPackageName), /// Get the artifact from a direct URL. @@ -64,7 +66,7 @@ pub enum ArtifactRequest<'wb> { /// The URL of the artifact url: Url, /// The wheel builder to use to build the artifact if its an SDist or STree - wheel_builder: &'wb WheelBuilder, + wheel_builder: Arc, }, } @@ -113,7 +115,7 @@ impl PackageDb { /// Downloads and caches information about available artifacts of a package from the index. pub async fn available_artifacts<'wb>( &self, - request: ArtifactRequest<'wb>, + request: ArtifactRequest, ) -> miette::Result<&IndexMap>>> { match request { ArtifactRequest::FromIndex(p) => { @@ -124,8 +126,11 @@ impl PackageDb { let http = self.http.clone(); let index_urls = self.sources.index_url(&p); - let request_iter = stream::iter(&index_urls) + let urls = index_urls + .into_iter() .map(|url| url.join(&format!("{}/", p.as_str())).expect("invalid url")) + .collect_vec(); + let request_iter = stream::iter(urls) .map(|url| fetch_simple_api(&http, url)) .buffer_unordered(10) .filter_map(|result| async { result.transpose() }); @@ -165,7 +170,7 @@ impl PackageDb { url, wheel_builder, } => { - self.get_artifact_by_direct_url(name, url, wheel_builder) + self.get_artifact_by_direct_url(name, url, wheel_builder.deref()) .await } } diff --git a/crates/rattler_installs_packages/src/resolve/dependency_provider.rs b/crates/rattler_installs_packages/src/resolve/dependency_provider.rs index 5912764a..3def4270 100644 --- a/crates/rattler_installs_packages/src/resolve/dependency_provider.rs +++ b/crates/rattler_installs_packages/src/resolve/dependency_provider.rs @@ -1,14 +1,17 @@ -use crate::artifacts::SDist; -use crate::artifacts::Wheel; -use crate::index::{ArtifactRequest, PackageDb}; -use crate::python_env::WheelTags; -use crate::resolve::solve_options::SDistResolution; -use crate::resolve::solve_options::{PreReleaseResolution, ResolveOptions}; -use crate::resolve::PinnedPackage; -use crate::types::{ - ArtifactFromBytes, ArtifactInfo, ArtifactName, Extra, NormalizedPackageName, PackageName, +use super::{ + pypi_version_types::PypiPackageName, + solve_options::{PreReleaseResolution, ResolveOptions, SDistResolution}, + PinnedPackage, PypiVersion, PypiVersionSet, +}; +use crate::{ + artifacts::{SDist, Wheel}, + index::{ArtifactRequest, PackageDb}, + python_env::WheelTags, + types::{ + ArtifactFromBytes, ArtifactInfo, ArtifactName, Extra, NormalizedPackageName, PackageName, + }, + wheel_builder::WheelBuilder, }; -use crate::wheel_builder::WheelBuilder; use elsa::FrozenMap; use itertools::Itertools; use miette::{Diagnostic, IntoDiagnostic, MietteDiagnostic}; @@ -19,22 +22,15 @@ use resolvo::{ Candidates, Dependencies, DependencyProvider, KnownDependencies, NameId, Pool, SolvableId, SolverCache, }; -use std::any::Any; -use std::borrow::Borrow; -use std::cmp::Ordering; -use std::collections::HashMap; - -use crate::resolve::pypi_version_types::{PypiPackageName, PypiVersion, PypiVersionSet}; -use std::str::FromStr; -use std::sync::Arc; +use std::{ + any::Any, borrow::Borrow, cmp::Ordering, collections::HashMap, rc::Rc, str::FromStr, sync::Arc, +}; use thiserror::Error; -use tokio::runtime::Handle; -use tokio::task; use url::Url; /// This is a [`DependencyProvider`] for PyPI packages pub(crate) struct PypiDependencyProvider { - pub pool: Pool, + pub pool: Rc>, package_db: Arc, wheel_builder: Arc, markers: Arc, @@ -77,7 +73,7 @@ impl PypiDependencyProvider { ); Ok(Self { - pool, + pool: Rc::new(pool), package_db, wheel_builder, markers, @@ -233,8 +229,8 @@ pub(crate) enum MetadataError { } impl<'p> DependencyProvider for &'p PypiDependencyProvider { - fn pool(&self) -> &Pool { - &self.pool + fn pool(&self) -> Rc> { + self.pool.clone() } fn should_cancel_with_value(&self) -> Option> { @@ -245,9 +241,9 @@ impl<'p> DependencyProvider for &'p PypiDepende .map(|s| Box::new(s.clone()) as Box) } - fn sort_candidates( + async fn sort_candidates( &self, - solver: &SolverCache, + _: &SolverCache, solvables: &mut [SolvableId], ) { solvables.sort_by(|&a, &b| { @@ -272,8 +268,8 @@ impl<'p> DependencyProvider for &'p PypiDepende } } - let solvable_a = solver.pool().resolve_solvable(a); - let solvable_b = solver.pool().resolve_solvable(b); + let solvable_a = self.pool.resolve_solvable(a); + let solvable_b = self.pool.resolve_solvable(b); match (&solvable_a.inner(), &solvable_b.inner()) { // Sort Urls alphabetically @@ -293,7 +289,7 @@ impl<'p> DependencyProvider for &'p PypiDepende }) } - fn get_candidates(&self, name: NameId) -> Option { + async fn get_candidates(&self, name: NameId) -> Option { let package_name = self.pool.resolve_package_name(name); tracing::info!("collecting {}", package_name); @@ -304,14 +300,18 @@ impl<'p> DependencyProvider for &'p PypiDepende ArtifactRequest::DirectUrl { name: package_name.base().clone(), url: Url::from_str(url).expect("cannot parse back url"), - wheel_builder: &self.wheel_builder, + wheel_builder: self.wheel_builder.clone(), } } else { ArtifactRequest::FromIndex(package_name.base().clone()) }; - let result = task::block_in_place(move || { - Handle::current().block_on(self.package_db.available_artifacts(request)) - }); + + let result: Result<_, miette::Report> = tokio::spawn({ + let package_db = self.package_db.clone(); + async move { Ok(package_db.available_artifacts(request).await?.clone()) } + }) + .await + .expect("cancelled"); let artifacts = match result { Ok(artifacts) => artifacts, @@ -427,7 +427,7 @@ impl<'p> DependencyProvider for &'p PypiDepende Some(candidates) } - fn get_dependencies(&self, solvable_id: SolvableId) -> Dependencies { + async fn get_dependencies(&self, solvable_id: SolvableId) -> Dependencies { let solvable = self.pool.resolve_solvable(solvable_id); let package_name = self.pool.resolve_package_name(solvable.name_id()); let package_version = solvable.inner(); @@ -498,13 +498,23 @@ impl<'p> DependencyProvider for &'p PypiDepende return Dependencies::Unknown(error); } - // Retrieve the metadata for the artifacts - let result = task::block_in_place(|| { - Handle::current().block_on( - self.package_db - .get_metadata(artifacts, Some(&self.wheel_builder)), - ) - }); + let result: miette::Result<_> = tokio::spawn({ + let package_db = self.package_db.clone(); + let wheel_builder = self.wheel_builder.clone(); + let artifacts = artifacts.to_vec(); + async move { + if let Some((ai, metadata)) = package_db + .get_metadata(&artifacts, Some(&wheel_builder)) + .await? + { + Ok(Some((ai.clone(), metadata))) + } else { + Ok(None) + } + } + }) + .await + .expect("cancelled"); let metadata = match result { // We have retrieved a value without error diff --git a/crates/rattler_installs_packages/src/resolve/solve.rs b/crates/rattler_installs_packages/src/resolve/solve.rs index d8b75b87..8b5f482f 100644 --- a/crates/rattler_installs_packages/src/resolve/solve.rs +++ b/crates/rattler_installs_packages/src/resolve/solve.rs @@ -15,6 +15,7 @@ use url::Url; use crate::resolve::pypi_version_types::{PypiPackageName, PypiVersionSet}; use crate::resolve::solve_options::ResolveOptions; use std::collections::HashSet; +use std::convert::identity; use std::ops::Deref; use std::sync::Arc; @@ -61,6 +62,40 @@ pub async fn resolve( favored_packages: HashMap, options: ResolveOptions, env_variables: HashMap, +) -> miette::Result> { + let requirements: Vec<_> = requirements.into_iter().cloned().collect(); + tokio::task::spawn_blocking(move || { + resolve_inner( + package_db, + &requirements, + env_markers, + compatible_tags, + locked_packages, + favored_packages, + options, + env_variables, + ) + }) + .await + .map_or_else( + |e| match e.try_into_panic() { + Ok(panic) => std::panic::resume_unwind(panic), + Err(_) => Err(miette::miette!("the operation was cancelled")), + }, + identity, + ) +} + +#[allow(clippy::too_many_arguments)] +fn resolve_inner<'r>( + package_db: Arc, + requirements: impl IntoIterator, + env_markers: Arc, + compatible_tags: Option>, + locked_packages: HashMap, + favored_packages: HashMap, + options: ResolveOptions, + env_variables: HashMap, ) -> miette::Result> { // Construct the pool let pool = Pool::new(); @@ -120,7 +155,7 @@ pub async fn resolve( )?; // Invoke the solver to get a solution to the requirements - let mut solver = Solver::new(&provider); + let mut solver = Solver::new(&provider).with_runtime(tokio::runtime::Handle::current()); let solvables = match solver.solve(root_requirements) { Ok(solvables) => solvables, Err(e) => { @@ -128,7 +163,11 @@ pub async fn resolve( UnsolvableOrCancelled::Unsolvable(problem) => Err(miette::miette!( "{}", problem - .display_user_friendly(&solver, &DefaultSolvableDisplay) + .display_user_friendly( + &solver, + solver.pool.clone(), + &DefaultSolvableDisplay + ) .to_string() .trim() )), @@ -142,9 +181,8 @@ pub async fn resolve( }; let mut result: HashMap = HashMap::new(); for solvable_id in solvables { - let pool = solver.pool(); - let solvable = pool.resolve_solvable(solvable_id); - let name = pool.resolve_package_name(solvable.name_id()); + let solvable = solver.pool.resolve_solvable(solvable_id); + let name = solver.pool.resolve_package_name(solvable.name_id()); let version = solvable.inner(); let artifacts: Vec<_> = provider diff --git a/crates/rattler_installs_packages/src/wheel_builder/mod.rs b/crates/rattler_installs_packages/src/wheel_builder/mod.rs index 22f68b9f..471a5466 100644 --- a/crates/rattler_installs_packages/src/wheel_builder/mod.rs +++ b/crates/rattler_installs_packages/src/wheel_builder/mod.rs @@ -294,6 +294,7 @@ mod tests { use crate::resolve::solve_options::ResolveOptions; use crate::wheel_builder::wheel_cache::WheelCacheKey; use crate::wheel_builder::WheelBuilder; + use futures::future::TryJoinAll; use reqwest::Client; use reqwest_middleware::ClientWithMiddleware; use std::collections::HashMap; @@ -392,4 +393,52 @@ mod tests { // Check if the build env is there assert!(path.exists()); } + + // Skipped for now will fix this in a later PR + #[tokio::test(flavor = "multi_thread")] + #[ignore] + pub async fn build_sdist_metadata_concurrently() { + let path = + Path::new(env!("CARGO_MANIFEST_DIR")).join("../../test-data/sdists/rich-13.6.0.tar.gz"); + + let package_db = get_package_db(); + let env_markers = Arc::new(Pep508EnvMakers::from_env().await.unwrap().0); + + let wheel_builder = Arc::new( + WheelBuilder::new( + package_db.0, + env_markers, + None, + ResolveOptions::default(), + Default::default(), + ) + .unwrap(), + ); + + let mut handles = vec![]; + + for _ in 0..10 { + let sdist = SDist::from_path(&path, &"rich".parse().unwrap()).unwrap(); + let wheel_builder = wheel_builder.clone(); + handles.push(tokio::spawn(async move { + wheel_builder.get_sdist_metadata(&sdist).await + })); + } + + let result = handles.into_iter().collect::>().await; + match result { + Ok(results) => { + for result in results { + assert!( + result.is_ok(), + "error during concurrent wheel build: {:?}", + result.err() + ); + } + } + Err(e) => { + panic!("Failed to build wheels concurrently: {}", e); + } + } + } } diff --git a/rust-toolchain b/rust-toolchain index cc31fcd4..07cde984 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -1.72 +1.75