From 043959a629881049ca02c8a6f3f915fbc784fb9d Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Wed, 10 Apr 2024 18:51:02 +0900 Subject: [PATCH] lazy loading of the JET integration JET sometimes ends up being incompatible with the latest version of Julia, and it can also prevent SnoopCompile.jl from loading. To get around this, this commit makes the JET integration lazy-loaded, so that it does not prevent SnoopCompile.jl from loading. --- .gitignore | 1 + Project.toml | 11 +++-- docs/src/jet.md | 51 ++++++++++++----------- ext/JETExt.jl | 86 +++++++++++++++++++++++++++++++++++++++ src/SnoopCompile.jl | 3 ++ src/parcel_snoopi_deep.jl | 77 +++-------------------------------- test/snoopi_deep.jl | 10 +++-- 7 files changed, 137 insertions(+), 102 deletions(-) create mode 100644 ext/JETExt.jl diff --git a/.gitignore b/.gitignore index 3f02ca741..327cf35ca 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ *.jl.*.cov *.jl.mem Manifest.toml +Manifest-*.toml diff --git a/Project.toml b/Project.toml index 4c63a86b0..35f0f0bd0 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,6 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Cthulhu = "f68482b8-f384-11e8-15f7-abe071a5a75f" FlameGraphs = "08572546-2f56-4bcf-ba4e-bab62c3a3f89" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -18,11 +17,16 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" SnoopCompileCore = "e2b509da-e806-4183-be48-004708413034" YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" +[weakdeps] +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" + +[extensions] +JETExt = "JET" + [compat] AbstractTrees = "0.3, 0.4" Cthulhu = "1.5, 2" FlameGraphs = "0.2, 1" -JET = "0.0, 0.4, 0.5, 0.6, 0.7, 0.8" OrderedCollections = "1" Requires = "1" SnoopCompileCore = "~2.10.0" @@ -33,6 +37,7 @@ julia = "1" ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FixedPointNumbers = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" MethodAnalysis = "85b6ec6f-f7df-4429-9514-a64bcd9ee824" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" @@ -40,4 +45,4 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["ColorTypes", "PrettyTables", "Documenter", "FixedPointNumbers", "MethodAnalysis", "Pkg", "Random", "Test"] +test = ["ColorTypes", "PrettyTables", "Documenter", "FixedPointNumbers", "JET", "MethodAnalysis", "Pkg", "Random", "Test"] diff --git a/docs/src/jet.md b/docs/src/jet.md index c065dfc1a..ba0a09c79 100644 --- a/docs/src/jet.md +++ b/docs/src/jet.md @@ -90,15 +90,16 @@ The key reason is that SnoopCompile is a dynamic analyzer, and is capable of bri As always, you need to do the data collection in a fresh session where the calls have not previously been inferred. After restarting Julia, we can do this: -``` +```julia julia> using SnoopCompile +julia> using JET # this is necessary to enable the integration + julia> list = Any[1,2,3]; julia> lc = Any[list]; # "hide" `list` inside a Vector{Any} -julia> callsum(listcontainer) = sum(listcontainer[1]) -callsum (generic function with 1 method) +julia> callsum(listcontainer) = sum(listcontainer[1]); julia> tinf = @snoopi_deep callsum(lc) InferenceTimingNode: 0.039239/0.046793 on Core.Compiler.Timings.ROOT() with 2 direct children @@ -109,28 +110,32 @@ julia> tinf.children InferenceTimingNode: 0.000196/0.006685 on sum(::Vector{Any}) with 1 direct children julia> report_callees(inference_triggers(tinf)) -1-element Vector{Pair{InferenceTrigger, JET.JETCallResult{JET.JETAnalyzer{JET.BasicPass{typeof(JET.basic_function_filter)}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}}: +1-element Vector{Pair{InferenceTrigger, JET.JETCallResult{JET.JETAnalyzer{JET.BasicPass}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}}}}: Inference triggered to call sum(::Vector{Any}) from callsum (./REPL[5]:1) with specialization callsum(::Vector{Any}) => ═════ 1 possible error found ═════ -┌ @ reducedim.jl:889 Base.#sum#732(Base.:, Base.pairs(Core.NamedTuple()), #self#, a) -│┌ @ reducedim.jl:889 Base._sum(a, dims) -││┌ @ reducedim.jl:893 Base.#_sum#734(Base.pairs(Core.NamedTuple()), #self#, a, _3) -│││┌ @ reducedim.jl:893 Base._sum(Base.identity, a, Base.:) -││││┌ @ reducedim.jl:894 Base.#_sum#735(Base.pairs(Core.NamedTuple()), #self#, f, a, _4) -│││││┌ @ reducedim.jl:894 Base.mapreduce(f, Base.add_sum, a) -││││││┌ @ reducedim.jl:322 Base.#mapreduce#725(Base.:, Base._InitialValue(), #self#, f, op, A) -│││││││┌ @ reducedim.jl:322 Base._mapreduce_dim(f, op, init, A, dims) -││││││││┌ @ reducedim.jl:330 Base._mapreduce(f, op, Base.IndexStyle(A), A) -│││││││││┌ @ reduce.jl:402 Base.mapreduce_empty_iter(f, op, A, Base.IteratorEltype(A)) -││││││││││┌ @ reduce.jl:353 Base.reduce_empty_iter(Base.MappingRF(f, op), itr, ItrEltype) -│││││││││││┌ @ reduce.jl:357 Base.reduce_empty(op, Base.eltype(itr)) -││││││││││││┌ @ reduce.jl:331 Base.mapreduce_empty(Base.getproperty(op, :f), Base.getproperty(op, :rf), _) -│││││││││││││┌ @ reduce.jl:345 Base.reduce_empty(op, T) -││││││││││││││┌ @ reduce.jl:322 Base.reduce_empty(Base.+, _) -│││││││││││││││┌ @ reduce.jl:313 Base.zero(_) -││││││││││││││││┌ @ missing.jl:106 Base.throw(Base.MethodError(Base.zero, Core.tuple(Base.Any))) -│││││││││││││││││ MethodError: no method matching zero(::Type{Any}) -││││││││││││││││└────────────────── +┌ sum(a::Vector{Any}) @ Base ./reducedim.jl:1010 +│┌ sum(a::Vector{Any}; dims::Colon, kw::@Kwargs{}) @ Base ./reducedim.jl:1010 +││┌ _sum(a::Vector{Any}, ::Colon) @ Base ./reducedim.jl:1014 +│││┌ _sum(a::Vector{Any}, ::Colon; kw::@Kwargs{}) @ Base ./reducedim.jl:1014 +││││┌ _sum(f::typeof(identity), a::Vector{Any}, ::Colon) @ Base ./reducedim.jl:1015 +│││││┌ _sum(f::typeof(identity), a::Vector{Any}, ::Colon; kw::@Kwargs{}) @ Base ./reducedim.jl:1015 +││││││┌ mapreduce(f::typeof(identity), op::typeof(Base.add_sum), A::Vector{Any}) @ Base ./reducedim.jl:357 +│││││││┌ mapreduce(f::typeof(identity), op::typeof(Base.add_sum), A::Vector{Any}; dims::Colon, init::Base._InitialValue) @ Base ./reducedim.jl:357 +││││││││┌ _mapreduce_dim(f::typeof(identity), op::typeof(Base.add_sum), ::Base._InitialValue, A::Vector{Any}, ::Colon) @ Base ./reducedim.jl:365 +│││││││││┌ _mapreduce(f::typeof(identity), op::typeof(Base.add_sum), ::IndexLinear, A::Vector{Any}) @ Base ./reduce.jl:432 +││││││││││┌ mapreduce_empty_iter(f::typeof(identity), op::typeof(Base.add_sum), itr::Vector{Any}, ItrEltype::Base.HasEltype) @ Base ./reduce.jl:380 +│││││││││││┌ reduce_empty_iter(op::Base.MappingRF{typeof(identity), typeof(Base.add_sum)}, itr::Vector{Any}, ::Base.HasEltype) @ Base ./reduce.jl:384 +││││││││││││┌ reduce_empty(op::Base.MappingRF{typeof(identity), typeof(Base.add_sum)}, ::Type{Any}) @ Base ./reduce.jl:361 +│││││││││││││┌ mapreduce_empty(::typeof(identity), op::typeof(Base.add_sum), T::Type{Any}) @ Base ./reduce.jl:372 +││││││││││││││┌ reduce_empty(::typeof(Base.add_sum), ::Type{Any}) @ Base ./reduce.jl:352 +│││││││││││││││┌ reduce_empty(::typeof(+), ::Type{Any}) @ Base ./reduce.jl:343 +││││││││││││││││┌ zero(::Type{Any}) @ Base ./missing.jl:106 +│││││││││││││││││ MethodError: no method matching zero(::Type{Any}): Base.throw(Base.MethodError(zero, tuple(Base.Any)::Tuple{DataType})::MethodError) +││││││││││││││││└──────────────────── ``` Because SnoopCompile collected the runtime-dispatched `sum` call, we can pass it to JET. `report_callees` filters those calls which generate JET reports, allowing you to focus on potential errors. + +!!! note + JET integration is enabled only if JET.jl has been loaded into your main session. + This is why there's the `using JET` statement included in the example given. diff --git a/ext/JETExt.jl b/ext/JETExt.jl new file mode 100644 index 000000000..09ccd2d1a --- /dev/null +++ b/ext/JETExt.jl @@ -0,0 +1,86 @@ +module JETExt + +@static if isdefined(Base, :get_extension) + import SnoopCompile: report_callee, report_caller, report_callees + using SnoopCompile: SnoopCompile, InferenceTrigger, callerinstance + using SnoopCompile.Cthulhu: specTypes + using JET: report_call, get_reports +else + import ..SnoopCompile: report_callee, report_caller, report_callees + using ..SnoopCompile: SnoopCompile, InferenceTrigger, callerinstance + using ..SnoopCompile.Cthulhu: specTypes + using ..JET: report_call, get_reports +end + +""" + report_callee(itrig::InferenceTrigger) + +Return the `JET.report_call` for the callee in `itrig`. +""" +SnoopCompile.report_callee(itrig::InferenceTrigger; jetconfigs...) = report_call(specTypes(itrig); jetconfigs...) + +""" + report_caller(itrig::InferenceTrigger) + +Return the `JET.report_call` for the caller in `itrig`. +""" +SnoopCompile.report_caller(itrig::InferenceTrigger; jetconfigs...) = report_call(specTypes(callerinstance(itrig)); jetconfigs...) + +""" + report_callees(itrigs) + +Filter `itrigs` for those with a non-passing `JET` report, returning the list of `itrig => report` pairs. + +# Examples + +```jldoctest jetfib; setup=(using SnoopCompile, JET), filter=[r"\\d direct children", r"[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?/[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?"] +julia> fib(n::Integer) = n ≤ 2 ? n : fib(n-1) + fib(n-2); + +julia> function fib(str::String) + n = length(str) + return fib(m) # error is here + end +fib (generic function with 2 methods) + +julia> fib(::Dict) = 0; fib(::Vector) = 0; + +julia> list = [5, "hello"]; + +julia> mapfib(list) = map(fib, list) +mapfib (generic function with 1 method) + +julia> tinf = @snoopi_deep try mapfib(list) catch end +InferenceTimingNode: 0.049825/0.071476 on Core.Compiler.Timings.ROOT() with 5 direct children + +julia> @report_call mapfib(list) +No errors detected +``` + +JET did not catch the error because the call to `fib` is hidden behind runtime dispatch. +However, when captured by `@snoopi_deep`, we get + +```jldoctest jetfib; filter=[r"@ .*", r"REPL\\[\\d+\\]|none"] +julia> report_callees(inference_triggers(tinf)) +1-element Vector{Pair{InferenceTrigger, JET.JETCallResult{JET.JETAnalyzer{JET.BasicPass{typeof(JET.basic_function_filter)}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}}: + Inference triggered to call fib(::String) from iterate (./generator.jl:47) inlined into Base.collect_to!(::Vector{Int64}, ::Base.Generator{Vector{Any}, typeof(fib)}, ::Int64, ::Int64) (./array.jl:782) => ═════ 1 possible error found ═════ +┌ @ none:3 fib(m) +│ variable `m` is not defined +└────────── +``` +""" +function SnoopCompile.report_callees(itrigs; jetconfigs...) + function rr(itrig) + rpt = try + report_callee(itrig; jetconfigs...) + catch err + @warn "skipping $itrig due to report_callee error" exception=err + nothing + end + return itrig => rpt + end + hasreport((itrig, report)) = report !== nothing && !isempty(get_reports(report)) + + return [itrigrpt for itrigrpt in map(rr, itrigs) if hasreport(itrigrpt)] +end + +end # module JETExt diff --git a/src/SnoopCompile.jl b/src/SnoopCompile.jl index ed83b6b75..32011555e 100644 --- a/src/SnoopCompile.jl +++ b/src/SnoopCompile.jl @@ -128,6 +128,9 @@ function __init__() if isdefined(SnoopCompileCore, Symbol("@snoopr")) @require PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" include("report_invalidations.jl") end + if isdefined(SnoopCompile, :report_callee) && !isdefined(Base, :get_extension) + @require JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" include("../ext/JETExt.jl") + end return nothing end diff --git a/src/parcel_snoopi_deep.jl b/src/parcel_snoopi_deep.jl index 20a9c2c04..49f5bc8bd 100644 --- a/src/parcel_snoopi_deep.jl +++ b/src/parcel_snoopi_deep.jl @@ -7,7 +7,6 @@ using Core.Compiler.Timings: InferenceFrameInfo using SnoopCompileCore: InferenceTiming, InferenceTimingNode, inclusive, exclusive using Profile using Cthulhu -using JET const InferenceNode = Union{InferenceFrameInfo,InferenceTiming,InferenceTimingNode} @@ -897,76 +896,10 @@ Cthulhu.specTypes(itrig::InferenceTrigger) = Cthulhu.specTypes(Cthulhu.instance( Cthulhu.backedges(itrig::InferenceTrigger) = (itrig.callerframes,) Cthulhu.nextnode(itrig::InferenceTrigger, edge) = (ret = callingframe(itrig); return isempty(ret.callerframes) ? nothing : ret) -""" - report_callee(itrig::InferenceTrigger) - -Return the `JET.report_call` for the callee in `itrig`. -""" -report_callee(itrig::InferenceTrigger; jetconfigs...) = report_call(Cthulhu.specTypes(itrig); jetconfigs...) - -""" - report_caller(itrig::InferenceTrigger) - -Return the `JET.report_call` for the caller in `itrig`. -""" -report_caller(itrig::InferenceTrigger; jetconfigs...) = report_call(Cthulhu.specTypes(callerinstance(itrig)); jetconfigs...) - -""" - report_callees(itrigs) - -Filter `itrigs` for those with a non-passing `JET` report, returning the list of `itrig => report` pairs. - -# Examples - -```jldoctest jetfib; setup=(using SnoopCompile, JET), filter=[r"\\d direct children", r"[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?/[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?"] -julia> fib(n::Integer) = n ≤ 2 ? n : fib(n-1) + fib(n-2); - -julia> function fib(str::String) - n = length(str) - return fib(m) # error is here - end -fib (generic function with 2 methods) - -julia> fib(::Dict) = 0; fib(::Vector) = 0; - -julia> list = [5, "hello"]; - -julia> mapfib(list) = map(fib, list) -mapfib (generic function with 1 method) - -julia> tinf = @snoopi_deep try mapfib(list) catch end -InferenceTimingNode: 0.049825/0.071476 on Core.Compiler.Timings.ROOT() with 5 direct children - -julia> @report_call mapfib(list) -No errors detected -``` - -JET did not catch the error because the call to `fib` is hidden behind runtime dispatch. -However, when captured by `@snoopi_deep`, we get - -```jldoctest jetfib; filter=[r"@ .*", r"REPL\\[\\d+\\]|none"] -julia> report_callees(inference_triggers(tinf)) -1-element Vector{Pair{InferenceTrigger, JET.JETCallResult{JET.JETAnalyzer{JET.BasicPass{typeof(JET.basic_function_filter)}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}}: - Inference triggered to call fib(::String) from iterate (./generator.jl:47) inlined into Base.collect_to!(::Vector{Int64}, ::Base.Generator{Vector{Any}, typeof(fib)}, ::Int64, ::Int64) (./array.jl:782) => ═════ 1 possible error found ═════ -┌ @ none:3 fib(m) -│ variable `m` is not defined -└────────── -``` -""" -function report_callees(itrigs; jetconfigs...) - function rr(itrig) - rpt = try - report_callee(itrig; jetconfigs...) - catch err - @warn "skipping $itrig due to report_callee error" exception=err - nothing - end - return itrig => rpt - end - hasreport((itrig, report)) = report !== nothing && !isempty(JET.get_reports(report)) - - return [itrigrpt for itrigrpt in map(rr, itrigs) if hasreport(itrigrpt)] -end +# JET integrations are implemented lazily +function report_callee end +function report_caller end +function report_callees end filtermod(mod::Module, itrigs::AbstractVector{InferenceTrigger}) = filter(==(mod) ∘ callermodule, itrigs) @@ -1581,7 +1514,7 @@ function unwrapconst(@nospecialize(arg)) return arg.val elseif isa(arg, Core.PartialStruct) return arg.typ - elseif isa(arg, Core.Compiler.MaybeUndef) + elseif @static isdefined(Core.Compiler, :MaybeUndef) ? isa(arg, Core.Compiler.MaybeUndef) : false return arg.typ end return arg diff --git a/test/snoopi_deep.jl b/test/snoopi_deep.jl index e3d023c93..6e4ae3ba9 100644 --- a/test/snoopi_deep.jl +++ b/test/snoopi_deep.jl @@ -960,6 +960,8 @@ end end if Base.VERSION >= v"1.7" + using JET: report_call, get_reports + @testset "JET integration" begin function mysum(c) # vendor a simple version of `sum` isempty(c) && return zero(eltype(c)) @@ -973,12 +975,12 @@ if Base.VERSION >= v"1.7" cc = Any[Any[1,2,3]] tinf = @snoopi_deep call_mysum(cc) - rpt = SnoopCompile.JET.@report_call call_mysum(cc) - @test isempty(SnoopCompile.JET.get_reports(rpt)) + rpt = report_call(call_mysum, (Vector{Any},)) + @test isempty(get_reports(rpt)) itrigs = inference_triggers(tinf) irpts = report_callees(itrigs) @test only(irpts).first == last(itrigs) - @test !isempty(SnoopCompile.JET.get_reports(only(irpts).second)) - @test isempty(SnoopCompile.JET.get_reports(report_caller(itrigs[end]))) + @test !isempty(get_reports(only(irpts).second)) + @test isempty(get_reports(report_caller(itrigs[end]))) end end