Skip to content

Commit

Permalink
lazy loading of the JET integration
Browse files Browse the repository at this point in the history
JET sometimes ends up being incompatible with the latest version of
Julia, and it can also prevent SnoopCompile.jl from loading.
To get around this, this commit makes the JET integration lazy-loaded,
so that it does not prevent SnoopCompile.jl from loading.
  • Loading branch information
aviatesk committed Apr 10, 2024
1 parent fd09eef commit 043959a
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 102 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
*.jl.*.cov
*.jl.mem
Manifest.toml
Manifest-*.toml
11 changes: 8 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Cthulhu = "f68482b8-f384-11e8-15f7-abe071a5a75f"
FlameGraphs = "08572546-2f56-4bcf-ba4e-bab62c3a3f89"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand All @@ -18,11 +17,16 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
SnoopCompileCore = "e2b509da-e806-4183-be48-004708413034"
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"

[weakdeps]
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"

[extensions]
JETExt = "JET"

[compat]
AbstractTrees = "0.3, 0.4"
Cthulhu = "1.5, 2"
FlameGraphs = "0.2, 1"
JET = "0.0, 0.4, 0.5, 0.6, 0.7, 0.8"
OrderedCollections = "1"
Requires = "1"
SnoopCompileCore = "~2.10.0"
Expand All @@ -33,11 +37,12 @@ julia = "1"
ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FixedPointNumbers = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
MethodAnalysis = "85b6ec6f-f7df-4429-9514-a64bcd9ee824"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["ColorTypes", "PrettyTables", "Documenter", "FixedPointNumbers", "MethodAnalysis", "Pkg", "Random", "Test"]
test = ["ColorTypes", "PrettyTables", "Documenter", "FixedPointNumbers", "JET", "MethodAnalysis", "Pkg", "Random", "Test"]
51 changes: 28 additions & 23 deletions docs/src/jet.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,16 @@ The key reason is that SnoopCompile is a dynamic analyzer, and is capable of bri
As always, you need to do the data collection in a fresh session where the calls have not previously been inferred.
After restarting Julia, we can do this:

```
```julia
julia> using SnoopCompile

julia> using JET # this is necessary to enable the integration

julia> list = Any[1,2,3];

julia> lc = Any[list]; # "hide" `list` inside a Vector{Any}

julia> callsum(listcontainer) = sum(listcontainer[1])
callsum (generic function with 1 method)
julia> callsum(listcontainer) = sum(listcontainer[1]);

julia> tinf = @snoopi_deep callsum(lc)
InferenceTimingNode: 0.039239/0.046793 on Core.Compiler.Timings.ROOT() with 2 direct children
Expand All @@ -109,28 +110,32 @@ julia> tinf.children
InferenceTimingNode: 0.000196/0.006685 on sum(::Vector{Any}) with 1 direct children

julia> report_callees(inference_triggers(tinf))
1-element Vector{Pair{InferenceTrigger, JET.JETCallResult{JET.JETAnalyzer{JET.BasicPass{typeof(JET.basic_function_filter)}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}}:
1-element Vector{Pair{InferenceTrigger, JET.JETCallResult{JET.JETAnalyzer{JET.BasicPass}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}}}}:
Inference triggered to call sum(::Vector{Any}) from callsum (./REPL[5]:1) with specialization callsum(::Vector{Any}) => ═════ 1 possible error found ═════
@ reducedim.jl:889 Base.#sum#732(Base.:, Base.pairs(Core.NamedTuple()), #self#, a)
│┌ @ reducedim.jl:889 Base._sum(a, dims)
││┌ @ reducedim.jl:893 Base.#_sum#734(Base.pairs(Core.NamedTuple()), #self#, a, _3)
│││┌ @ reducedim.jl:893 Base._sum(Base.identity, a, Base.:)
││││┌ @ reducedim.jl:894 Base.#_sum#735(Base.pairs(Core.NamedTuple()), #self#, f, a, _4)
│││││┌ @ reducedim.jl:894 Base.mapreduce(f, Base.add_sum, a)
││││││┌ @ reducedim.jl:322 Base.#mapreduce#725(Base.:, Base._InitialValue(), #self#, f, op, A)
│││││││┌ @ reducedim.jl:322 Base._mapreduce_dim(f, op, init, A, dims)
││││││││┌ @ reducedim.jl:330 Base._mapreduce(f, op, Base.IndexStyle(A), A)
│││││││││┌ @ reduce.jl:402 Base.mapreduce_empty_iter(f, op, A, Base.IteratorEltype(A))
││││││││││┌ @ reduce.jl:353 Base.reduce_empty_iter(Base.MappingRF(f, op), itr, ItrEltype)
│││││││││││┌ @ reduce.jl:357 Base.reduce_empty(op, Base.eltype(itr))
││││││││││││┌ @ reduce.jl:331 Base.mapreduce_empty(Base.getproperty(op, :f), Base.getproperty(op, :rf), _)
│││││││││││││┌ @ reduce.jl:345 Base.reduce_empty(op, T)
││││││││││││││┌ @ reduce.jl:322 Base.reduce_empty(Base.+, _)
│││││││││││││││┌ @ reduce.jl:313 Base.zero(_)
││││││││││││││││┌ @ missing.jl:106 Base.throw(Base.MethodError(Base.zero, Core.tuple(Base.Any)))
│││││││││││││││││ MethodError: no method matching zero(::Type{Any})
││││││││││││││││└──────────────────
sum(a::Vector{Any}) @ Base ./reducedim.jl:1010
│┌ sum(a::Vector{Any}; dims::Colon, kw::@Kwargs{}) @ Base ./reducedim.jl:1010
││┌ _sum(a::Vector{Any}, ::Colon) @ Base ./reducedim.jl:1014
│││┌ _sum(a::Vector{Any}, ::Colon; kw::@Kwargs{}) @ Base ./reducedim.jl:1014
││││┌ _sum(f::typeof(identity), a::Vector{Any}, ::Colon) @ Base ./reducedim.jl:1015
│││││┌ _sum(f::typeof(identity), a::Vector{Any}, ::Colon; kw::@Kwargs{}) @ Base ./reducedim.jl:1015
││││││┌ mapreduce(f::typeof(identity), op::typeof(Base.add_sum), A::Vector{Any}) @ Base ./reducedim.jl:357
│││││││┌ mapreduce(f::typeof(identity), op::typeof(Base.add_sum), A::Vector{Any}; dims::Colon, init::Base._InitialValue) @ Base ./reducedim.jl:357
││││││││┌ _mapreduce_dim(f::typeof(identity), op::typeof(Base.add_sum), ::Base._InitialValue, A::Vector{Any}, ::Colon) @ Base ./reducedim.jl:365
│││││││││┌ _mapreduce(f::typeof(identity), op::typeof(Base.add_sum), ::IndexLinear, A::Vector{Any}) @ Base ./reduce.jl:432
││││││││││┌ mapreduce_empty_iter(f::typeof(identity), op::typeof(Base.add_sum), itr::Vector{Any}, ItrEltype::Base.HasEltype) @ Base ./reduce.jl:380
│││││││││││┌ reduce_empty_iter(op::Base.MappingRF{typeof(identity), typeof(Base.add_sum)}, itr::Vector{Any}, ::Base.HasEltype) @ Base ./reduce.jl:384
││││││││││││┌ reduce_empty(op::Base.MappingRF{typeof(identity), typeof(Base.add_sum)}, ::Type{Any}) @ Base ./reduce.jl:361
│││││││││││││┌ mapreduce_empty(::typeof(identity), op::typeof(Base.add_sum), T::Type{Any}) @ Base ./reduce.jl:372
││││││││││││││┌ reduce_empty(::typeof(Base.add_sum), ::Type{Any}) @ Base ./reduce.jl:352
│││││││││││││││┌ reduce_empty(::typeof(+), ::Type{Any}) @ Base ./reduce.jl:343
││││││││││││││││┌ zero(::Type{Any}) @ Base ./missing.jl:106
│││││││││││││││││ MethodError: no method matching zero(::Type{Any}): Base.throw(Base.MethodError(zero, tuple(Base.Any)::Tuple{DataType})::MethodError)
││││││││││││││││└────────────────────
```

Because SnoopCompile collected the runtime-dispatched `sum` call, we can pass it to JET.
`report_callees` filters those calls which generate JET reports, allowing you to focus on potential errors.

!!! note
JET integration is enabled only if JET.jl has been loaded into your main session.
This is why there's the `using JET` statement included in the example given.
86 changes: 86 additions & 0 deletions ext/JETExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
module JETExt

@static if isdefined(Base, :get_extension)
import SnoopCompile: report_callee, report_caller, report_callees
using SnoopCompile: SnoopCompile, InferenceTrigger, callerinstance
using SnoopCompile.Cthulhu: specTypes
using JET: report_call, get_reports
else
import ..SnoopCompile: report_callee, report_caller, report_callees
using ..SnoopCompile: SnoopCompile, InferenceTrigger, callerinstance
using ..SnoopCompile.Cthulhu: specTypes
using ..JET: report_call, get_reports
end

"""
report_callee(itrig::InferenceTrigger)
Return the `JET.report_call` for the callee in `itrig`.
"""
SnoopCompile.report_callee(itrig::InferenceTrigger; jetconfigs...) = report_call(specTypes(itrig); jetconfigs...)

"""
report_caller(itrig::InferenceTrigger)
Return the `JET.report_call` for the caller in `itrig`.
"""
SnoopCompile.report_caller(itrig::InferenceTrigger; jetconfigs...) = report_call(specTypes(callerinstance(itrig)); jetconfigs...)

"""
report_callees(itrigs)
Filter `itrigs` for those with a non-passing `JET` report, returning the list of `itrig => report` pairs.
# Examples
```jldoctest jetfib; setup=(using SnoopCompile, JET), filter=[r"\\d direct children", r"[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?/[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?"]
julia> fib(n::Integer) = n ≤ 2 ? n : fib(n-1) + fib(n-2);
julia> function fib(str::String)
n = length(str)
return fib(m) # error is here
end
fib (generic function with 2 methods)
julia> fib(::Dict) = 0; fib(::Vector) = 0;
julia> list = [5, "hello"];
julia> mapfib(list) = map(fib, list)
mapfib (generic function with 1 method)
julia> tinf = @snoopi_deep try mapfib(list) catch end
InferenceTimingNode: 0.049825/0.071476 on Core.Compiler.Timings.ROOT() with 5 direct children
julia> @report_call mapfib(list)
No errors detected
```
JET did not catch the error because the call to `fib` is hidden behind runtime dispatch.
However, when captured by `@snoopi_deep`, we get
```jldoctest jetfib; filter=[r"@ .*", r"REPL\\[\\d+\\]|none"]
julia> report_callees(inference_triggers(tinf))
1-element Vector{Pair{InferenceTrigger, JET.JETCallResult{JET.JETAnalyzer{JET.BasicPass{typeof(JET.basic_function_filter)}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}}:
Inference triggered to call fib(::String) from iterate (./generator.jl:47) inlined into Base.collect_to!(::Vector{Int64}, ::Base.Generator{Vector{Any}, typeof(fib)}, ::Int64, ::Int64) (./array.jl:782) => ═════ 1 possible error found ═════
┌ @ none:3 fib(m)
│ variable `m` is not defined
└──────────
```
"""
function SnoopCompile.report_callees(itrigs; jetconfigs...)
function rr(itrig)
rpt = try
report_callee(itrig; jetconfigs...)
catch err
@warn "skipping $itrig due to report_callee error" exception=err
nothing
end
return itrig => rpt
end
hasreport((itrig, report)) = report !== nothing && !isempty(get_reports(report))

return [itrigrpt for itrigrpt in map(rr, itrigs) if hasreport(itrigrpt)]
end

end # module JETExt
3 changes: 3 additions & 0 deletions src/SnoopCompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ function __init__()
if isdefined(SnoopCompileCore, Symbol("@snoopr"))
@require PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" include("report_invalidations.jl")
end
if isdefined(SnoopCompile, :report_callee) && !isdefined(Base, :get_extension)
@require JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" include("../ext/JETExt.jl")
end
return nothing
end

Expand Down
77 changes: 5 additions & 72 deletions src/parcel_snoopi_deep.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ using Core.Compiler.Timings: InferenceFrameInfo
using SnoopCompileCore: InferenceTiming, InferenceTimingNode, inclusive, exclusive
using Profile
using Cthulhu
using JET

const InferenceNode = Union{InferenceFrameInfo,InferenceTiming,InferenceTimingNode}

Expand Down Expand Up @@ -897,76 +896,10 @@ Cthulhu.specTypes(itrig::InferenceTrigger) = Cthulhu.specTypes(Cthulhu.instance(
Cthulhu.backedges(itrig::InferenceTrigger) = (itrig.callerframes,)
Cthulhu.nextnode(itrig::InferenceTrigger, edge) = (ret = callingframe(itrig); return isempty(ret.callerframes) ? nothing : ret)

"""
report_callee(itrig::InferenceTrigger)
Return the `JET.report_call` for the callee in `itrig`.
"""
report_callee(itrig::InferenceTrigger; jetconfigs...) = report_call(Cthulhu.specTypes(itrig); jetconfigs...)

"""
report_caller(itrig::InferenceTrigger)
Return the `JET.report_call` for the caller in `itrig`.
"""
report_caller(itrig::InferenceTrigger; jetconfigs...) = report_call(Cthulhu.specTypes(callerinstance(itrig)); jetconfigs...)

"""
report_callees(itrigs)
Filter `itrigs` for those with a non-passing `JET` report, returning the list of `itrig => report` pairs.
# Examples
```jldoctest jetfib; setup=(using SnoopCompile, JET), filter=[r"\\d direct children", r"[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?/[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?"]
julia> fib(n::Integer) = n ≤ 2 ? n : fib(n-1) + fib(n-2);
julia> function fib(str::String)
n = length(str)
return fib(m) # error is here
end
fib (generic function with 2 methods)
julia> fib(::Dict) = 0; fib(::Vector) = 0;
julia> list = [5, "hello"];
julia> mapfib(list) = map(fib, list)
mapfib (generic function with 1 method)
julia> tinf = @snoopi_deep try mapfib(list) catch end
InferenceTimingNode: 0.049825/0.071476 on Core.Compiler.Timings.ROOT() with 5 direct children
julia> @report_call mapfib(list)
No errors detected
```
JET did not catch the error because the call to `fib` is hidden behind runtime dispatch.
However, when captured by `@snoopi_deep`, we get
```jldoctest jetfib; filter=[r"@ .*", r"REPL\\[\\d+\\]|none"]
julia> report_callees(inference_triggers(tinf))
1-element Vector{Pair{InferenceTrigger, JET.JETCallResult{JET.JETAnalyzer{JET.BasicPass{typeof(JET.basic_function_filter)}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}}:
Inference triggered to call fib(::String) from iterate (./generator.jl:47) inlined into Base.collect_to!(::Vector{Int64}, ::Base.Generator{Vector{Any}, typeof(fib)}, ::Int64, ::Int64) (./array.jl:782) => ═════ 1 possible error found ═════
┌ @ none:3 fib(m)
│ variable `m` is not defined
└──────────
```
"""
function report_callees(itrigs; jetconfigs...)
function rr(itrig)
rpt = try
report_callee(itrig; jetconfigs...)
catch err
@warn "skipping $itrig due to report_callee error" exception=err
nothing
end
return itrig => rpt
end
hasreport((itrig, report)) = report !== nothing && !isempty(JET.get_reports(report))

return [itrigrpt for itrigrpt in map(rr, itrigs) if hasreport(itrigrpt)]
end
# JET integrations are implemented lazily
function report_callee end
function report_caller end
function report_callees end

filtermod(mod::Module, itrigs::AbstractVector{InferenceTrigger}) = filter(==(mod) callermodule, itrigs)

Expand Down Expand Up @@ -1581,7 +1514,7 @@ function unwrapconst(@nospecialize(arg))
return arg.val
elseif isa(arg, Core.PartialStruct)
return arg.typ
elseif isa(arg, Core.Compiler.MaybeUndef)
elseif @static isdefined(Core.Compiler, :MaybeUndef) ? isa(arg, Core.Compiler.MaybeUndef) : false
return arg.typ
end
return arg
Expand Down
10 changes: 6 additions & 4 deletions test/snoopi_deep.jl
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,8 @@ end
end

if Base.VERSION >= v"1.7"
using JET: report_call, get_reports

@testset "JET integration" begin
function mysum(c) # vendor a simple version of `sum`
isempty(c) && return zero(eltype(c))
Expand All @@ -973,12 +975,12 @@ if Base.VERSION >= v"1.7"

cc = Any[Any[1,2,3]]
tinf = @snoopi_deep call_mysum(cc)
rpt = SnoopCompile.JET.@report_call call_mysum(cc)
@test isempty(SnoopCompile.JET.get_reports(rpt))
rpt = report_call(call_mysum, (Vector{Any},))
@test isempty(get_reports(rpt))
itrigs = inference_triggers(tinf)
irpts = report_callees(itrigs)
@test only(irpts).first == last(itrigs)
@test !isempty(SnoopCompile.JET.get_reports(only(irpts).second))
@test isempty(SnoopCompile.JET.get_reports(report_caller(itrigs[end])))
@test !isempty(get_reports(only(irpts).second))
@test isempty(get_reports(report_caller(itrigs[end])))
end
end

0 comments on commit 043959a

Please sign in to comment.