Skip to content

Commit

Permalink
deal with closure in chain methods
Browse files Browse the repository at this point in the history
  • Loading branch information
lucarlig committed Mar 25, 2024
1 parent 53ab0e4 commit 8dd4768
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 95 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
[package]
name = "mate"
version = "0.1.0"
authors = ["Cameron Low <[email protected]>", "Luca Carlig <[email protected]"]
version = "0.1.1"
authors = [
"Cameron Low <[email protected]>",
"Luca Carlig <[email protected]",
]
description = "library of lints for automatic parallelization"
edition = "2021"
publish = false
Expand Down
2 changes: 1 addition & 1 deletion lints/par_iter/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "par_iter"
version = "0.1.0"
version = "0.1.1"
authors = ["authors go here"]
description = "description goes here"
edition = "2021"
Expand Down
20 changes: 13 additions & 7 deletions lints/par_iter/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ impl<'tcx> LateLintPass<'tcx> for ParIter {
if !par_iter_traits.is_empty() && is_type_valid(cx, ty) {
// TODO: issue with into_par_iter() need to check directly with
// parallel iterator
// let mut implemented_methods: Vec<&AssocItems> = Vec::new();

let mut allowed_methods: FxHashSet<&str> =
["into_iter", "iter", "iter_mut", "map_or"]
Expand All @@ -74,17 +73,24 @@ impl<'tcx> LateLintPass<'tcx> for ParIter {
let mut top_expr = *recv;

while let Some(parent_expr) = get_parent_expr(cx, top_expr) {
if let hir::ExprKind::MethodCall(method_name, _, _, _) = parent_expr.kind {
if !allowed_methods.contains(method_name.ident.as_str()) {
return;
match parent_expr.kind {
hir::ExprKind::MethodCall(method_name, _, _, _) => {
if !allowed_methods.contains(method_name.ident.as_str()) {
return;
}
top_expr = parent_expr;
}
hir::ExprKind::Closure(_) => {
top_expr = parent_expr;
}
_ => {
break;
}
top_expr = parent_expr;
} else {
break;
}
}

let ty: Ty<'_> = cx.typeck_results().expr_ty(top_expr);

// TODO: find a way to deal with iterators returns
if check_trait_impl(cx, ty, sym::Iterator) {
return;
Expand Down
37 changes: 36 additions & 1 deletion lints/par_iter/ui/main.fixed
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use core::ascii;
use futures::io::{self, AsyncWrite, IoSlice};
use futures::task::{Context, Poll};
use rayon::prelude::*;
use std::collections::LinkedList;
use std::collections::{HashMap, HashSet, LinkedList};
use std::ops::Range;
use std::pin::Pin;
use std::rc::Rc;
Expand Down Expand Up @@ -40,6 +40,23 @@ struct ApplicationState {

struct MyWriter;

#[derive(Hash, Eq, PartialEq, Clone)]
struct Id(String);

struct Cmd {
args: HashMap<Id, Arg>,
}

impl Cmd {
fn find(&self, key: &Id) -> Option<&Arg> {
self.args.get(key)
}
}

struct Arg {
requires: Vec<(String, Id)>,
}

fn main() {}

// should parallelize
Expand Down Expand Up @@ -388,3 +405,21 @@ impl AsyncWrite for MyWriter {
self.poll_write(cx, buf)
}
}

//should parallelize
fn nested_pars() {
let used_filtered: HashSet<Id> = HashSet::new();
let conflicting_keys: HashSet<Id> = HashSet::new();
let cmd = Cmd {
args: HashMap::new(),
};

let required: Vec<Id> = used_filtered
.par_iter()
.filter_map(|key| cmd.find(key))
.flat_map(|arg| arg.requires.par_iter().map(|item| &item.1))
.filter(|key| !used_filtered.contains(key) && !conflicting_keys.contains(key))
.chain(used_filtered.par_iter())
.cloned()
.collect();
}
37 changes: 36 additions & 1 deletion lints/par_iter/ui/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use core::ascii;
use futures::io::{self, AsyncWrite, IoSlice};
use futures::task::{Context, Poll};
use rayon::prelude::*;
use std::collections::LinkedList;
use std::collections::{HashMap, HashSet, LinkedList};
use std::ops::Range;
use std::pin::Pin;
use std::rc::Rc;
Expand Down Expand Up @@ -40,6 +40,23 @@ struct ApplicationState {

struct MyWriter;

#[derive(Hash, Eq, PartialEq, Clone)]
struct Id(String);

struct Cmd {
args: HashMap<Id, Arg>,
}

impl Cmd {
fn find(&self, key: &Id) -> Option<&Arg> {
self.args.get(key)
}
}

struct Arg {
requires: Vec<(String, Id)>,
}

fn main() {}

// should parallelize
Expand Down Expand Up @@ -388,3 +405,21 @@ impl AsyncWrite for MyWriter {
self.poll_write(cx, buf)
}
}

//should parallelize
fn nested_pars() {
let used_filtered: HashSet<Id> = HashSet::new();
let conflicting_keys: HashSet<Id> = HashSet::new();
let cmd = Cmd {
args: HashMap::new(),
};

let required: Vec<Id> = used_filtered
.iter()
.filter_map(|key| cmd.find(key))
.flat_map(|arg| arg.requires.iter().map(|item| &item.1))
.filter(|key| !used_filtered.contains(key) && !conflicting_keys.contains(key))
.chain(used_filtered.iter())
.cloned()
.collect();
}
44 changes: 35 additions & 9 deletions lints/par_iter/ui/main.stderr
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
warning: found iterator that can be parallelized
--> $DIR/main.rs:47:5
--> $DIR/main.rs:64:5
|
LL | (0..100).into_iter().for_each(|x| println!("{:?}", x));
| ^^^^^^^^^^^^^^^^^^^^ help: try using a parallel iterator: `(0..100).into_par_iter()`
|
= note: `#[warn(par_iter)]` on by default

warning: found iterator that can be parallelized
--> $DIR/main.rs:71:5
--> $DIR/main.rs:88:5
|
LL | / (0..100)
LL | | .into_iter()
Expand All @@ -20,37 +20,37 @@ LL + .into_par_iter()
|

warning: found iterator that can be parallelized
--> $DIR/main.rs:106:5
--> $DIR/main.rs:123:5
|
LL | list.into_iter().for_each(|x| println!("{:?}", x));
| ^^^^^^^^^^^^^^^^ help: try using a parallel iterator: `list.into_par_iter()`

warning: found iterator that can be parallelized
--> $DIR/main.rs:122:5
--> $DIR/main.rs:139:5
|
LL | (0..10).into_iter().for_each(|x| {
| ^^^^^^^^^^^^^^^^^^^ help: try using a parallel iterator: `(0..10).into_par_iter()`

warning: found iterator that can be parallelized
--> $DIR/main.rs:205:5
--> $DIR/main.rs:222:5
|
LL | data.iter()
| ^^^^^^^^^^^ help: try using a parallel iterator: `data.par_iter()`

warning: found iterator that can be parallelized
--> $DIR/main.rs:232:5
--> $DIR/main.rs:249:5
|
LL | numbers.iter().enumerate().for_each(|t| {
| ^^^^^^^^^^^^^^ help: try using a parallel iterator: `numbers.par_iter()`

warning: found iterator that can be parallelized
--> $DIR/main.rs:328:30
--> $DIR/main.rs:345:30
|
LL | let names: Vec<String> = people.iter().map(|p| p.name.clone()).collect();
| ^^^^^^^^^^^^^ help: try using a parallel iterator: `people.par_iter()`

warning: found iterator that can be parallelized
--> $DIR/main.rs:384:19
--> $DIR/main.rs:401:19
|
LL | let buf = bufs
| ___________________^
Expand All @@ -63,5 +63,31 @@ LL ~ let buf = bufs
LL + .par_iter()
|

warning: 8 warnings emitted
warning: found iterator that can be parallelized
--> $DIR/main.rs:417:29
|
LL | let required: Vec<Id> = used_filtered
| _____________________________^
LL | | .iter()
| |_______________^
|
help: try using a parallel iterator
|
LL ~ let required: Vec<Id> = used_filtered
LL + .par_iter()
|

warning: found iterator that can be parallelized
--> $DIR/main.rs:420:25
|
LL | .flat_map(|arg| arg.requires.iter().map(|item| &item.1))
| ^^^^^^^^^^^^^^^^^^^ help: try using a parallel iterator: `arg.requires.par_iter()`

warning: found iterator that can be parallelized
--> $DIR/main.rs:422:16
|
LL | .chain(used_filtered.iter())
| ^^^^^^^^^^^^^^^^^^^^ help: try using a parallel iterator: `used_filtered.par_iter()`

warning: 11 warnings emitted

36 changes: 0 additions & 36 deletions lints/par_iter/ui/main2.fixed
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

use core::ascii;
use rayon::prelude::*;
use std::collections::LinkedList;
use std::ops::Range;
use std::rc::Rc;

Expand Down Expand Up @@ -100,38 +99,3 @@ fn main() {}
// println!("{:?}", doubled_numbers);
// }
//

// #[derive(Hash, Eq, PartialEq, Clone)]
// struct Id(String);

// struct Cmd {
// args: HashMap<Id, Arg>,
// }

// impl Cmd {
// fn find(&self, key: &Id) -> Option<&Arg> {
// self.args.get(key)
// }
// }

// struct Arg {
// requires: Vec<(String, Id)>,
// }

// //should parallelize
// fn nested_pars() {
// let used_filtered: HashSet<Id> = HashSet::new();
// let conflicting_keys: HashSet<Id> = HashSet::new();
// let cmd = Cmd {
// args: HashMap::new(),
// };

// let required: Vec<Id> = used_filtered
// .iter()
// .filter_map(|key| cmd.find(key))
// .flat_map(|arg| arg.requires.iter().map(|item| &item.1))
// .filter(|key| !used_filtered.contains(key) && !conflicting_keys.contains(key))
// .chain(used_filtered.iter())
// .cloned()
// .collect();
// }
36 changes: 0 additions & 36 deletions lints/par_iter/ui/main2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

use core::ascii;
use rayon::prelude::*;
use std::collections::LinkedList;
use std::ops::Range;
use std::rc::Rc;

Expand Down Expand Up @@ -100,38 +99,3 @@ fn main() {}
// println!("{:?}", doubled_numbers);
// }
//

// #[derive(Hash, Eq, PartialEq, Clone)]
// struct Id(String);

// struct Cmd {
// args: HashMap<Id, Arg>,
// }

// impl Cmd {
// fn find(&self, key: &Id) -> Option<&Arg> {
// self.args.get(key)
// }
// }

// struct Arg {
// requires: Vec<(String, Id)>,
// }

// //should parallelize
// fn nested_pars() {
// let used_filtered: HashSet<Id> = HashSet::new();
// let conflicting_keys: HashSet<Id> = HashSet::new();
// let cmd = Cmd {
// args: HashMap::new(),
// };

// let required: Vec<Id> = used_filtered
// .iter()
// .filter_map(|key| cmd.find(key))
// .flat_map(|arg| arg.requires.iter().map(|item| &item.1))
// .filter(|key| !used_filtered.contains(key) && !conflicting_keys.contains(key))
// .chain(used_filtered.iter())
// .cloned()
// .collect();
// }

0 comments on commit 8dd4768

Please sign in to comment.