Skip to content

Commit

Permalink
Added tests and par_fold_hashmap lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Cameron-Low committed Mar 12, 2024
1 parent 6d628e1 commit 1cb81a8
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 6 deletions.
20 changes: 20 additions & 0 deletions lints/fold/ui/main.fixed
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// run-rustfix
fn main() {
warn_fold_simple();
warn_fold_vec();
warn_fold_hashmap();
get_upload_file_total_size();
}

Expand All @@ -13,6 +15,24 @@ fn warn_fold_simple() {
println!("Sum: {}", sum);
}

fn warn_fold_vec() {
let mut data = vec![];
let numbers = vec![1, 2, 3, 4, 5];
data = numbers.iter().fold(data, |mut data, &num| { data.push(num * 3); data });

println!("Data: {:?}", data);
}

fn warn_fold_hashmap() {
use std::collections::HashMap;

let mut data = HashMap::new();
let numbers = vec![1, 2, 3, 4, 5];
data = numbers.iter().fold(data, |mut data, &num| { data.insert(num, num.to_string()); data });

println!("Data: {:?}", data);
}

fn get_upload_file_total_size() -> u64 {
let some_num = vec![0; 10];
let mut file_total_size = 0;
Expand Down
24 changes: 24 additions & 0 deletions lints/fold/ui/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// run-rustfix
fn main() {
warn_fold_simple();
warn_fold_vec();
warn_fold_hashmap();
get_upload_file_total_size();
}

Expand All @@ -15,6 +17,28 @@ fn warn_fold_simple() {
println!("Sum: {}", sum);
}

fn warn_fold_vec() {
let mut data = vec![];
let numbers = vec![1, 2, 3, 4, 5];
numbers.iter().for_each(|&num| {
data.push(num * 3);
});

println!("Data: {:?}", data);
}

fn warn_fold_hashmap() {
use std::collections::HashMap;

let mut data = HashMap::new();
let numbers = vec![1, 2, 3, 4, 5];
numbers.iter().for_each(|&num| {
data.insert(num, num.to_string());
});

println!("Data: {:?}", data);
}

fn get_upload_file_total_size() -> u64 {
let some_num = vec![0; 10];
let mut file_total_size = 0;
Expand Down
28 changes: 25 additions & 3 deletions lints/fold/ui/main.stderr
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
error: implicit fold
--> $DIR/main.rs:11:5
--> $DIR/main.rs:13:5
|
LL | / numbers.iter().for_each(|&num| {
LL | | sum += num;
Expand All @@ -10,13 +10,35 @@ LL | | });
= help: to override `-D warnings` add `#[allow(fold_simple)]`

error: implicit fold
--> $DIR/main.rs:21:5
--> $DIR/main.rs:23:5
|
LL | / numbers.iter().for_each(|&num| {
LL | | data.push(num * 3);
LL | | });
| |______^ help: try using `fold` instead: `data = numbers.iter().fold(data, |mut data, &num| { data.push(num * 3); data })`
|
= note: `-D fold-vec` implied by `-D warnings`
= help: to override `-D warnings` add `#[allow(fold_vec)]`

error: implicit fold
--> $DIR/main.rs:35:5
|
LL | / numbers.iter().for_each(|&num| {
LL | | data.insert(num, num.to_string());
LL | | });
| |______^ help: try using `fold` instead: `data = numbers.iter().fold(data, |mut data, &num| { data.insert(num, num.to_string()); data })`
|
= note: `-D fold-hashmap` implied by `-D warnings`
= help: to override `-D warnings` add `#[allow(fold_hashmap)]`

error: implicit fold
--> $DIR/main.rs:45:5
|
LL | / (0..some_num.len()).into_iter().for_each(|_| {
LL | | let (_, upload_size) = (true, 99);
LL | | file_total_size += upload_size;
LL | | });
| |______^ help: try using `fold` instead: `file_total_size += (0..some_num.len()).into_iter().map(|_| {let (_, upload_size) = (true, 99); upload_size}).fold(0, |mut file_total_size, v| { file_total_size += v; file_total_size })`

error: aborting due to 2 previous errors
error: aborting due to 4 previous errors

81 changes: 81 additions & 0 deletions lints/par_fold/src/hashmap.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
use rustc_errors::Applicability;
use rustc_hir::{Expr, ExprKind, StmtKind};
use rustc_lint::{LateContext, LateLintPass, LintContext};
use rustc_session::{declare_lint, declare_lint_pass};
use rustc_span::{sym, Symbol};
use utils::span_to_snippet_macro;

declare_lint! {
pub WARN_PAR_FOLD_HASHMAP,
Warn,
"suggest using parallel fold"
}

declare_lint_pass!(ParFoldHashMap => [WARN_PAR_FOLD_HASHMAP]);
impl<'tcx> LateLintPass<'tcx> for ParFoldHashMap {
fn check_expr(&mut self, cx: &LateContext<'tcx>, expr: &'tcx Expr<'_>) {
if let ExprKind::MethodCall(path, recv, args, _span) = &expr.kind
&& path.ident.name == Symbol::intern("fold")
{
assert_eq!(args.len(), 2);
let id_expr = args[0];
let op_expr = args[1];

// Check the penultimate statement of the fold for a `c.push(v)`
// Quite a specific target, can we be more general?
let ExprKind::Closure(op_cls) = op_expr.kind else {
return;
};
let hir_map = cx.tcx.hir();
let cls_body = hir_map.body(op_cls.body);

let Ok(StmtKind::Semi(fold_op)) =
utils::get_penult_stmt(cls_body.value).map(|s| s.kind)
else {
return;
};

let ExprKind::MethodCall(path, _, _, _) = fold_op.kind else {
return;
};
if path.ident.name != Symbol::intern("insert") {
return;
}

// Check that this method is on a hashmap
let base_ty = cx
.tcx
.typeck(expr.hir_id.owner.def_id)
.node_type(id_expr.hir_id);
let Some(adt) = base_ty.ty_adt_def() else {
return;
};
if !cx.tcx.is_diagnostic_item(sym::HashMap, adt.did()) {
return;
}

// Assume that if we make it here, we can apply the pattern.
let src_map = cx.sess().source_map();
let cls_snip = span_to_snippet_macro(src_map, op_expr.span);
let recv_snip = span_to_snippet_macro(src_map, recv.span);
let id_snip = span_to_snippet_macro(src_map, id_expr.span);

let fold_snip = format!("fold(|| HashMap::new(), {cls_snip})");
let reduce_snip = "reduce(|| HashMap::new(), |mut a, b| { a.extend(b); a })";
let mut extend_snip =
format!("{{ {id_snip}.extend({recv_snip}.{fold_snip}.{reduce_snip}); {id_snip} }}");
extend_snip = extend_snip.replace(".iter()", ".par_iter()");
extend_snip = extend_snip.replace(".iter_mut()", ".par_iter_mut()");
extend_snip = extend_snip.replace(".into_iter()", ".into_par_iter()");

cx.span_lint(WARN_PAR_FOLD_HASHMAP, expr.span, "sequential fold", |diag| {
diag.span_suggestion(
expr.span,
"try using a parallel fold on the iterator",
extend_snip,
Applicability::MachineApplicable,
);
});
}
}
}
2 changes: 2 additions & 0 deletions lints/par_fold/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ extern crate rustc_span;

mod par_fold_simple;
mod vec;
mod hashmap;

#[allow(clippy::no_mangle_with_rust_abi)]
#[cfg_attr(not(feature = "rlib"), no_mangle)]
pub fn register_lints(_sess: &rustc_session::Session, lint_store: &mut rustc_lint::LintStore) {
lint_store.register_late_pass(|_| Box::new(par_fold_simple::ParFoldSimple));
lint_store.register_late_pass(|_| Box::new(vec::ParFoldVec));
lint_store.register_late_pass(|_| Box::new(hashmap::ParFoldHashMap));
}

#[test]
Expand Down
26 changes: 26 additions & 0 deletions lints/par_fold/ui/par_fold_simple.fixed
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use rayon::prelude::*;

fn main() {
warn_fold_simple();
warn_fold_vec();
warn_fold_hashmap();
}

fn warn_fold_simple() {
Expand All @@ -17,3 +19,27 @@ fn warn_fold_simple() {

println!("Sum: {}", sum);
}

fn warn_fold_vec() {
let mut data = vec![];
let numbers = vec![1, 2, 3, 4, 5];
data = { data.extend(numbers.par_iter().fold(|| Vec::new(), |mut data, &num| {
data.push(num * 3);
data
}).reduce(|| Vec::new(), |mut a, b| { a.extend(b); a })); data };

println!("Data: {:?}", data);
}

fn warn_fold_hashmap() {
use std::collections::HashMap;

let mut data = HashMap::new();
let numbers = vec![1, 2, 3, 4, 5];
data = { data.extend(numbers.par_iter().fold(|| HashMap::new(), |mut data, &num| {
data.insert(num, num.to_string());
data
}).reduce(|| HashMap::new(), |mut a, b| { a.extend(b); a })); data };

println!("Data: {:?}", data);
}
26 changes: 26 additions & 0 deletions lints/par_fold/ui/par_fold_simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use rayon::prelude::*;

fn main() {
warn_fold_simple();
warn_fold_vec();
warn_fold_hashmap();
}

fn warn_fold_simple() {
Expand All @@ -17,3 +19,27 @@ fn warn_fold_simple() {

println!("Sum: {}", sum);
}

fn warn_fold_vec() {
let mut data = vec![];
let numbers = vec![1, 2, 3, 4, 5];
data = numbers.iter().fold(data, |mut data, &num| {
data.push(num * 3);
data
});

println!("Data: {:?}", data);
}

fn warn_fold_hashmap() {
use std::collections::HashMap;

let mut data = HashMap::new();
let numbers = vec![1, 2, 3, 4, 5];
data = numbers.iter().fold(data, |mut data, &num| {
data.insert(num, num.to_string());
data
});

println!("Data: {:?}", data);
}
44 changes: 42 additions & 2 deletions lints/par_fold/ui/par_fold_simple.stderr
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
error: sequential fold
--> $DIR/par_fold_simple.rs:13:12
--> $DIR/par_fold_simple.rs:15:12
|
LL | sum += numbers.iter().map(|&num| num).fold(0, |mut sum, v| {
| ____________^
Expand All @@ -15,5 +15,45 @@ help: try using a parallel fold on the iterator
LL | sum += numbers.par_iter().map(|&num| num).reduce(|| 0, |mut sum, v| {
| ~~~~~~~~ ~~~~~~ ~~~~

error: aborting due to 1 previous error
error: sequential fold
--> $DIR/par_fold_simple.rs:26:12
|
LL | data = numbers.iter().fold(data, |mut data, &num| {
| ____________^
LL | | data.push(num * 3);
LL | | data
LL | | });
| |______^
|
= note: `-D warn-par-fold-vec` implied by `-D warnings`
= help: to override `-D warnings` add `#[allow(warn_par_fold_vec)]`
help: try using a parallel fold on the iterator
|
LL ~ data = { data.extend(numbers.par_iter().fold(|| Vec::new(), |mut data, &num| {
LL + data.push(num * 3);
LL + data
LL ~ }).reduce(|| Vec::new(), |mut a, b| { a.extend(b); a })); data };
|

error: sequential fold
--> $DIR/par_fold_simple.rs:39:12
|
LL | data = numbers.iter().fold(data, |mut data, &num| {
| ____________^
LL | | data.insert(num, num.to_string());
LL | | data
LL | | });
| |______^
|
= note: `-D warn-par-fold-hashmap` implied by `-D warnings`
= help: to override `-D warnings` add `#[allow(warn_par_fold_hashmap)]`
help: try using a parallel fold on the iterator
|
LL ~ data = { data.extend(numbers.par_iter().fold(|| HashMap::new(), |mut data, &num| {
LL + data.insert(num, num.to_string());
LL + data
LL ~ }).reduce(|| HashMap::new(), |mut a, b| { a.extend(b); a })); data };
|

error: aborting due to 3 previous errors

2 changes: 1 addition & 1 deletion rust-toolchain
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[toolchain]
channel = "nightly-2024-02-22"
components = ["llvm-tools-preview", "rustc-dev"]
components = ["llvm-tools-preview", "rustc-dev"]

0 comments on commit 1cb81a8

Please sign in to comment.