Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix segfault on return type #2117

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions lib/EnzymeCore/src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ end
function has_frule_from_sig(@nospecialize(TT);
world::UInt=Base.get_world_counter(),
method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
caller::Union{Nothing,Core.MethodInstance}=nothing)
caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing)
ft, tt = _annotate_tt(TT)
TT = Tuple{<:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...}
return isapplicable(forward, TT; world, method_table, caller)
Expand All @@ -180,7 +180,7 @@ end
function has_rrule_from_sig(@nospecialize(TT);
world::UInt=Base.get_world_counter(),
method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
caller::Union{Nothing,Core.MethodInstance}=nothing)
caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing)
ft, tt = _annotate_tt(TT)
TT = Tuple{<:RevConfig, <:Annotation{ft}, Type{<:Annotation}, tt...}
return isapplicable(augmented_primal, TT; world, method_table, caller)
Expand All @@ -192,7 +192,7 @@ end
function isapplicable(@nospecialize(f), @nospecialize(TT);
world::UInt=Base.get_world_counter(),
method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
caller::Union{Nothing,Core.MethodInstance}=nothing)
caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing)
tt = Base.to_tuple_type(TT)
sig = Base.signature_type(f, tt)
mt = ccall(:jl_method_table_for, Any, (Any,), sig)
Expand All @@ -208,18 +208,36 @@ function isapplicable(@nospecialize(f), @nospecialize(TT);
matches = result
end
fullmatch = Core.Compiler._any(match::Core.MethodMatch->match.fully_covers, matches)
if caller !== nothing
fullmatch || add_mt_backedge!(caller, mt, sig)
if !fullmatch
if caller isa Core.MethodInstance
add_mt_backedge!(caller, mt, sig)
elseif caller isa Core.Compiler.MethodLookupResult
for j = 1:Core.Compiler.length(caller)
cmatch = Core.Compiler.getindex(caller, j)::Core.MethodMatch
cspec = Core.Compiler.specialize_method(cmatch)::Core.MethodInstance
add_mt_backedge!(cspec, mt, sig)
end
end
end
if Core.Compiler.isempty(matches)
return false
else
if caller !== nothing
if caller isa Core.MethodInstance
for i = 1:Core.Compiler.length(matches)
match = Core.Compiler.getindex(matches, i)::Core.MethodMatch
edge = Core.Compiler.specialize_method(match)::Core.MethodInstance
add_backedge!(caller, edge, sig)
end
elseif caller isa Core.Compiler.MethodLookupResult
for j = 1:Core.Compiler.length(caller)
cmatch = Core.Compiler.getindex(caller, j)::Core.MethodMatch
cspec = Core.Compiler.specialize_method(cmatch)::Core.MethodInstance
for i = 1:Core.Compiler.length(matches)
match = Core.Compiler.getindex(matches, i)::Core.MethodMatch
edge = Core.Compiler.specialize_method(match)::Core.MethodInstance
add_backedge!(cspec, edge, sig)
end
end
end
return true
end
Expand All @@ -245,7 +263,7 @@ function inactive end
function is_inactive_from_sig(@nospecialize(TT);
world::UInt=Base.get_world_counter(),
method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
caller::Union{Nothing,Core.MethodInstance}=nothing)
caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing)
return isapplicable(inactive, TT; world, method_table, caller)
end

Expand All @@ -260,7 +278,7 @@ function inactive_noinl end
function is_inactive_noinl_from_sig(@nospecialize(TT);
world::UInt=Base.get_world_counter(),
method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
caller::Union{Nothing,Core.MethodInstance}=nothing)
caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing)
return isapplicable(inactive_noinl, TT; world, method_table, caller)
end

Expand All @@ -275,7 +293,7 @@ function noalias end
function noalias_from_sig(@nospecialize(TT);
world::UInt=Base.get_world_counter(),
method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
caller::Union{Nothing,Core.MethodInstance}=nothing)
caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing)
return isapplicable(noalias, TT; world, method_table, caller)
end

Expand Down
Loading
Loading