From 5214475f7c8ebe7ab57e1b45af075ab1dc9fe879 Mon Sep 17 00:00:00 2001 From: Fredrik Ekre Date: Tue, 27 Aug 2024 11:53:38 +0200 Subject: [PATCH] Disable explicit return inside unknown macros --- src/Runic.jl | 10 +++++++++- src/chisels.jl | 47 +++++++++++++++++++++++++++++++++++++++++++++++ src/runestone.jl | 3 +++ test/runtests.jl | 11 ++++++++++- 4 files changed, 69 insertions(+), 2 deletions(-) diff --git a/src/Runic.jl b/src/Runic.jl index aa9384f..7fe1b48 100644 --- a/src/Runic.jl +++ b/src/Runic.jl @@ -138,6 +138,7 @@ mutable struct Context next_sibling::Union{Node, Nothing} # parent::Union{Node, Nothing} lineage_kinds::Vector{JuliaSyntax.Kind} + lineage_macros::Vector{String} end function Context( @@ -168,11 +169,12 @@ function Context( call_depth = 0 prev_sibling = next_sibling = nothing lineage_kinds = JuliaSyntax.Kind[] + lineage_macros = String[] format_on = true return Context( src_str, src_tree, src_io, fmt_io, fmt_tree, quiet, verbose, assert, debug, check, diff, filemode, indent_level, call_depth, format_on, prev_sibling, next_sibling, - lineage_kinds + lineage_kinds, lineage_macros ) end @@ -307,6 +309,9 @@ function format_node_with_kids!(ctx::Context, node::Node) ctx.prev_sibling = nothing ctx.next_sibling = nothing push!(ctx.lineage_kinds, kind(node)) + if kind(node) === K"macrocall" + push!(ctx.lineage_macros, macrocall_name(ctx, node)) + end # The new node parts. `kids′` aliases `kids` and only copied below if any of the # nodes change ("copy-on-write"). @@ -382,6 +387,9 @@ function format_node_with_kids!(ctx::Context, node::Node) ctx.prev_sibling = prev_sibling ctx.next_sibling = next_sibling pop!(ctx.lineage_kinds) + if kind(node) === K"macrocall" + pop!(ctx.lineage_macros) + end ctx.call_depth -= 1 # Return a new node if any of the kids changed if any_kid_changed diff --git a/src/chisels.jl b/src/chisels.jl index 6dad995..6f12c7e 100644 --- a/src/chisels.jl +++ b/src/chisels.jl @@ -746,6 +746,53 @@ function kmatch(kids, kinds, i = firstindex(kids)) return true end +# Extract the macro name as written in the source. +function macrocall_name(ctx, node) + @assert kind(node) === K"macrocall" + kids = verified_kids(node) + pred = x -> kind(x) in KSet"MacroName StringMacroName CmdMacroName core_@cmd" + mkind = kind(first_leaf_predicate(node, pred)::Node) + if kmatch(kids, KSet"@ MacroName") + p = position(ctx.fmt_io) + bytes = read(ctx.fmt_io, span(kids[1]) + span(kids[2])) + seek(ctx.fmt_io, p) + return String(bytes) + elseif kmatch(kids, KSet".") || kmatch(kids, KSet"CmdMacroName") || + kmatch(kids, KSet"StringMacroName") + bytes = read_bytes(ctx, kids[1]) + if mkind === K"CmdMacroName" + append!(bytes, "_cmd") + elseif mkind === K"StringMacroName" + append!(bytes, "_str") + end + return String(bytes) + elseif kmatch(kids, KSet"core_@cmd") + bytes = read_bytes(ctx, kids[1]) + @assert length(bytes) == 0 + return "core_@cmd" + else + # Don't bother looking in more complex expressions, just return unknown + return "" + end +end + +# Inserting `return` modifies the AST in a way that is visible to macros.. In general it is +# never safe to change the AST inside a macro, but we make an exception for some common +# "known" macros in order to be able to format functions that e.g. have an `@inline` +# annotation in front. +const MACROS_SAFE_TO_INSERT_RETURN = let set = Set{String}() + for m in ("inline", "noinline", "propagate_inbounds", "generated", "eval", "assume_effects") + push!(set, "@$m", "Base.@$m", "@Base.$m") + end + set +end +function safe_to_insert_return(ctx, node) + for m in ctx.lineage_macros + m in MACROS_SAFE_TO_INSERT_RETURN || return false + end + return true +end + ########################## # Utilities for IOBuffer # ########################## diff --git a/src/runestone.jl b/src/runestone.jl index ac0ba82..05f42b6 100644 --- a/src/runestone.jl +++ b/src/runestone.jl @@ -3513,6 +3513,9 @@ function explicit_return(ctx::Context, node::Node) if !(!is_leaf(node) && kind(node) in KSet"function macro do") return nothing end + if !safe_to_insert_return(ctx, node) + return nothing + end kids = verified_kids(node) pos = position(ctx.fmt_io) block_idx = findlast(x -> kind(x) === K"block", verified_kids(node)) diff --git a/test/runtests.jl b/test/runtests.jl index f334e37..afc6605 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1261,6 +1261,15 @@ end for r in ("throw(ArgumentError())", "error(\"foo\")", "rethrow()") @test format_string("$f\n $r\nend") == "$f\n $r\nend" end + # Safe/known macros + @test format_string("@inline $f\n x\nend") == + "@inline $f\n return x\nend" + @test format_string("Base.@noinline $f\n x\nend") == + "Base.@noinline $f\n return x\nend" + @test format_string("@Base.eval $f\n x\nend") == + "@Base.eval $f\n return x\nend" + # Unsafe/unknown macros + @test format_string("@kernel $f\n x\nend") == "@kernel $f\n x\nend" # `for` and `while` append `return nothing` to the end for r in ("for i in I\n end", "while i in I\n end") @test format_string("$f\n $r\nend") == "$f\n $r\n return nothing\nend" @@ -1374,7 +1383,7 @@ if Sys.isunix() && isdir(share_julia) @warn "JuliaSyntax.ParseError for $path" err @test_broken false else - @error "Error when formatting file $path" + @error "Error when formatting file $path" err @test false end end