From 1e8d41d884fbbdbd8aa86f94400fdf8f7b6b64a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Riedemann?= <38795484+longemen3000@users.noreply.github.com> Date: Tue, 8 Oct 2024 09:14:45 -0300 Subject: [PATCH] add Adapt Ext (#36) * add adapt ext * remove `import Adapt` outside conditional code --- Project.toml | 5 ++++- ext/TransducersAdaptExt.jl | 34 ++++++++++++++++++++++++++++++++++ src/Transducers.jl | 3 ++- src/core.jl | 2 -- src/library.jl | 12 ------------ src/partitionby.jl | 5 ----- 6 files changed, 40 insertions(+), 21 deletions(-) create mode 100644 ext/TransducersAdaptExt.jl diff --git a/Project.toml b/Project.toml index 385ee8fc..4bc35c38 100644 --- a/Project.toml +++ b/Project.toml @@ -22,6 +22,7 @@ SplittablesBase = "171d559e-b47b-412a-8079-5efa626c420e" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [weakdeps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" @@ -29,6 +30,7 @@ OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338" Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" [extensions] +TransducersAdaptExt = "Adapt" TransducersBlockArraysExt = "BlockArrays" TransducersDataFramesExt = "DataFrames" TransducersLazyArraysExt = "LazyArrays" @@ -60,6 +62,7 @@ Tables = "0.2, 1.0" julia = "1.6" [extras] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" @@ -88,4 +91,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" [targets] -test = ["Aqua", "BlockArrays", "Compat", "DataFrames", "DataTools", "Dates", "Distributed", "Documenter", "Folds", "InteractiveUtils", "LazyArrays", "LiterateTest", "LoadAllPackages", "Maybe", "OnlineStats", "OnlineStatsBase", "PerformanceTestTools", "Pkg", "Random", "Referenceables", "SparseArrays", "StaticArrays", "Statistics", "StructArrays", "Test", "TypedTables"] +test = ["Adapt", "Aqua", "BlockArrays", "Compat", "DataFrames", "DataTools", "Dates", "Distributed", "Documenter", "Folds", "InteractiveUtils", "LazyArrays", "LiterateTest", "LoadAllPackages", "Maybe", "OnlineStats", "OnlineStatsBase", "PerformanceTestTools", "Pkg", "Random", "Referenceables", "SparseArrays", "StaticArrays", "Statistics", "StructArrays", "Test", "TypedTables"] diff --git a/ext/TransducersAdaptExt.jl b/ext/TransducersAdaptExt.jl new file mode 100644 index 00000000..1f78ccd4 --- /dev/null +++ b/ext/TransducersAdaptExt.jl @@ -0,0 +1,34 @@ +module TransducersAdaptExt + +if isdefined(Base,:get_extension) + import Transducers + import Adapt +else + import ..Transducers + import ..Adapt +end + +Adapt.adapt_structure(to, rf::R) where {R <: Transducers.Reduction} = + Transducers.Reduction(Adapt.adapt(to, Transducers.xform(rf)), Adapt.adapt(to, Transducers.inner(rf))) + +Adapt.adapt_structure(to, xf::Transducers.Map) = Transducers.Map(Adapt.adapt(to, xf.f)) + +Adapt.adapt_structure(to, xf::Transducers.MapSplat) = Transducers.MapSplat(Adapt.adapt(to, xf.f)) + +Adapt.adapt_structure(to, xf::Transducers.Filter) = Transducers.Filter(Adapt.adapt(to, xf.pred)) + +Adapt.adapt_structure(to, xf::Transducers.GetIndex{inbounds}) where {inbounds} = + Transducers.GetIndex{inbounds}(Adapt.adapt(to, xf.array)) + +Adapt.adapt_structure(to, xf::Transducers.SetIndex{inbounds}) where {inbounds} = + Transducers.SetIndex{inbounds}(Adapt.adapt(to, xf.array)) + +Adapt.adapt_structure(to, xf::Transducers.ReducePartitionBy) = Transducers.ReducePartitionBy( + Adapt.adapt(to, xf.f), + Adapt.adapt(to, xf.rf), + Adapt.adapt(to, xf.init), +) +end #module + + + diff --git a/src/Transducers.jl b/src/Transducers.jl index 9da2ca79..6e8d29f1 100644 --- a/src/Transducers.jl +++ b/src/Transducers.jl @@ -80,7 +80,6 @@ export AdHocFoldable, using Base.Broadcast: Broadcasted using Base: tail -import Adapt import Accessors import Tables using ArgCheck @@ -156,6 +155,8 @@ const OSNonZeroNObsError = ArgumentError( if !isdefined(Base,:get_extension) using Requires + import Adapt + include("../ext/TransducersAdaptExt.jl") function __init__() @require BlockArrays="8e7c35d0-a365-5155-bbbb-fb81a777f24e" include("../ext/TransducersBlockArraysExt.jl") @require LazyArrays="5078a376-72f3-5289-bfd5-ec5146d43c02" include("../ext/TransducersLazyArraysExt.jl") diff --git a/src/core.jl b/src/core.jl index 829fa8a4..2a203a99 100644 --- a/src/core.jl +++ b/src/core.jl @@ -335,8 +335,6 @@ Transducer(rf::Reduction) = # `Reduction` to `AbstractReduction`. Reduction(::IdentityTransducer, inner) = ensurerf(inner) -Adapt.adapt_structure(to, rf::R) where {R <: Reduction} = - Reduction(Adapt.adapt(to, xform(rf)), Adapt.adapt(to, inner(rf))) """ Transducers.R_{X} diff --git a/src/library.jl b/src/library.jl index 8562faa5..d55f184a 100644 --- a/src/library.jl +++ b/src/library.jl @@ -53,8 +53,6 @@ OutputSize(::Type{<:Map}) = SizeStable() isexpansive(::Map) = false @inline next(rf::R_{Map}, result, input) = next(inner(rf), result, xform(rf).f(input)) -Adapt.adapt_structure(to, xf::Map) = Map(Adapt.adapt(to, xf.f)) - """ MapSplat(f) @@ -83,8 +81,6 @@ isexpansive(::MapSplat) = false @inline next(rf::R_{MapSplat}, result, input) = next(inner(rf), result, xform(rf).f(input...)) -Adapt.adapt_structure(to, xf::MapSplat) = MapSplat(Adapt.adapt(to, xf.f)) - # https://clojure.github.io/clojure/clojure.core-api.html#clojure.core/replace # https://clojuredocs.org/clojure.core/replace """ @@ -290,8 +286,6 @@ end @inline next(rf::R_{Filter}, result, input) = xform(rf).pred(input) ? next(inner(rf), result, input) : result -Adapt.adapt_structure(to, xf::Filter) = Filter(Adapt.adapt(to, xf.pred)) - """ NotA(T) @@ -1562,9 +1556,6 @@ Base.:(==)(xf1::GetIndex{inbounds,A}, xf2::GetIndex{inbounds,A}) where {inbounds,A} = xf1.array == xf2.array -Adapt.adapt_structure(to, xf::GetIndex{inbounds}) where {inbounds} = - GetIndex{inbounds}(Adapt.adapt(to, xf.array)) - """ SetIndex(array) SetIndex{inbounds}(array) @@ -1610,9 +1601,6 @@ Base.:(==)(xf1::SetIndex{inbounds,A}, xf2::SetIndex{inbounds,A}) where {inbounds,A} = xf1.array == xf2.array -Adapt.adapt_structure(to, xf::SetIndex{inbounds}) where {inbounds} = - SetIndex{inbounds}(Adapt.adapt(to, xf.array)) - """ Inject(iterator) diff --git a/src/partitionby.jl b/src/partitionby.jl index 761510c3..c8f3bb4c 100644 --- a/src/partitionby.jl +++ b/src/partitionby.jl @@ -47,11 +47,6 @@ struct ReducePartitionBy{F,RF,Init} <: Transducer end ReducePartitionBy(f, rf) = ReducePartitionBy(f, rf, Init) -Adapt.adapt_structure(to, xf::ReducePartitionBy) = ReducePartitionBy( - Adapt.adapt(to, xf.f), - Adapt.adapt(to, xf.rf), - Adapt.adapt(to, xf.init), -) struct PartitionChunk{K,V} kr::K