Skip to content

Commit

Permalink
Implemented WideDimTable
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshuaBillson committed Sep 9, 2023
1 parent 31d22ad commit af9381a
Showing 1 changed file with 204 additions and 15 deletions.
219 changes: 204 additions & 15 deletions src/tables.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
# Tables.jl interface
"""
AbstractDimTable <: Tables.AbstractColumns
Abstract supertype for dim tables
"""
abstract type AbstractDimTable <: Tables.AbstractColumns end

# Tables.jl interface for AbstractDimStack and AbstractDimArray

DimTableSources = Union{AbstractDimStack,AbstractDimArray}

Tables.istable(::Type{<:DimTableSources}) = true
Tables.columnaccess(::Type{<:DimTableSources}) = true
Tables.columns(x::DimTableSources) = DimTable(x)
Tables.columns(x::DimTableSources) = WideDimTable(x)

Tables.columnnames(A::AbstractDimArray) = _colnames(DimStack(A))
Tables.columnnames(s::AbstractDimStack) = _colnames(s)
Expand All @@ -20,6 +27,10 @@ Tables.schema(s::AbstractDimStack) = Tables.schema(DimTable(s))
@inline Tables.getcolumn(t::DimTableSources, dim::DimOrDimType) =
Tables.getcolumn(t, dimnum(t, dim))


# DimColumn


"""
DimColumn{T,D<:Dimension} <: AbstractVector{T}
Expand Down Expand Up @@ -56,8 +67,6 @@ end
dim(c::DimColumn) = getfield(c, :dim)
dimstride(c::DimColumn) = getfield(c, :dimstride)

# Simple Array interface

Base.length(c::DimColumn) = getfield(c, :length)
@inline function Base.getindex(c::DimColumn, i::Int)
Base.@boundscheck checkbounds(c, i)
Expand All @@ -70,6 +79,36 @@ Base.axes(c::DimColumn) = (Base.OneTo(length(c)),)
Base.vec(c::DimColumn{T}) where T = [c[i] for i in eachindex(c)]
Base.Array(c::DimColumn) = vec(c)


# MergedDimColumn


struct MergedDimColumn{T,DS} <: AbstractVector{T}
colname::Symbol
dimcols::DS
end
function MergedDimColumn(dims::DS, name::Symbol) where DS
MergedDimColumn{Tuple{map(eltype, dims)...},DS}(name, dims)
end

colname(c::MergedDimColumn) = getfield(c, :colname)
dimcols(c::MergedDimColumn) = getfield(c, :dimcols)

Base.length(c::MergedDimColumn) = length(first(dimcols(c)))
@inline function Base.getindex(c::MergedDimColumn{T}, i::Int) where T
return map(x -> x[i], dimcols(c))
end
Base.getindex(c::MergedDimColumn, ::Colon) = vec(c)
Base.getindex(c::MergedDimColumn, A::AbstractArray) = [c[i] for i in A]
Base.size(c::MergedDimColumn) = (length(c),)
Base.axes(c::MergedDimColumn) = (Base.OneTo(length(c)),)
Base.vec(c::MergedDimColumn{T}) where T = [c[i] for i in eachindex(c)]
Base.Array(c::MergedDimColumn) = vec(c)


# DimArrayColumn


struct DimArrayColumn{T,A<:AbstractDimArray{T},DS,DL,L} <: AbstractVector{T}
data::A
dimstrides::DS
Expand All @@ -89,8 +128,6 @@ Base.parent(c::DimArrayColumn) = getfield(c, :data)
dimstrides(c::DimArrayColumn) = getfield(c, :dimstrides)
dimlengths(c::DimArrayColumn) = getfield(c, :dimlengths)

# Simple Array interface

Base.length(c::DimArrayColumn) = getfield(c, :length)
@inline function Base.getindex(c::DimArrayColumn, i::Int)
Base.@boundscheck checkbounds(c, i)
Expand All @@ -107,12 +144,9 @@ Base.axes(c::DimArrayColumn) = (Base.OneTo(length(c)),)
Base.vec(c::DimArrayColumn{T}) where T = [c[i] for i in eachindex(c)]
Base.Array(c::DimArrayColumn) = vec(c)

"""
AbstractDimTable <: Tables.AbstractColumns

Abstract supertype for dim tables
"""
abstract type AbstractDimTable <: Tables.AbstractColumns end
# DimTable


"""
DimTable <: AbstractDimTable
Expand Down Expand Up @@ -163,8 +197,6 @@ for func in (:dims, :val, :index, :lookup, :metadata, :order, :sampling, :span,

end

# Tables interface

Tables.istable(::DimTable) = true
Tables.columnaccess(::Type{<:DimTable}) = true
Tables.columns(t::DimTable) = t
Expand Down Expand Up @@ -207,16 +239,173 @@ function _colnames(s::AbstractDimStack)
end


# WideDimTable


"""
WideDimTable <: AbstractDimTable
WideDimTable(A::AbstractDimArray)
Construct a Tables.jl/TableTraits.jl compatible object out of an `AbstractDimArray`.
This table will have a column for the array data and columns for each
`Dimension` index, as a [`DimColumn`]. These are lazy, and generated
as required.
Column names are converted from the dimension types using
[`DimensionalData.dim2key`](@ref). This means type `Ti` becomes the
column name `:Ti`, and `Dim{:custom}` becomes `:custom`.
To get dimension columns, you can index with `Dimension` (`X()`) or
`Dimension` type (`X`) as well as the regular `Int` or `Symbol`.
"""
struct WideDimTable{DS} <: AbstractDimTable
colnames::Vector{Symbol}
dimcolumns::DS
dimarraycolumns::Vector{DimArrayColumn}
end

function WideDimTable(s::AbstractDimStack; mergedims=false)
dims_ = dims(s)
dimcolumns = collect(map(d -> DimColumn(d, dims_), dims_))
dimarraycolumns = collect(map(A -> DimArrayColumn(A, dims_), s))

if mergedims
dimcol = MergedDimColumn(Tuple(dimcolumns), :geometry)
keys = vcat([:geometry], collect(_colnames(s))[length(dims_)+1:end])
return WideDimTable(keys, [dimcol], dimarraycolumns)
else
keys = collect(_colnames(s))
return WideDimTable(keys, dimcolumns, dimarraycolumns)
end
end

function WideDimTable(xs::Vararg{AbstractDimArray}; layernames=[Symbol("layer_$i") for i in eachindex(xs)], mergedims=false)
# Construct DimColumns
dims_ = dims(first(xs))
dimcolumns = map(d -> DimColumn(d, dims_), dims_)
dimnames = collect(map(dim2key, dims_))

# Construct DimArrayColumns
dimarraycolumns = collect(map(A -> DimArrayColumn(A, dims_), xs))

# Merge DimColumns
if mergedims
colnames = vcat([:geometry], layernames)
dimcol = MergedDimColumn(Tuple(dimcolumns), :geometry)
return WideDimTable{typeof(dimcol)}(colnames, dimcol, dimarraycolumns)
else
colnames = vcat(dimnames, layernames)
return WideDimTable{typeof(dimcolumns)}(colnames, dimcolumns, dimarraycolumns)
end
end

function WideDimTable(x::AbstractDimArray; layersfrom=nothing, mergedims=false)
if (layersfrom <: Dimension) && (any(isa.(dims(x), layersfrom)))
nlayers = size(x, layersfrom)
layers = [(@view x[layersfrom(i)]) for i in 1:nlayers]
layernames = Symbol.(["$(dim2key(layersfrom))_$i" for i in 1:nlayers])
return WideDimTable(layers..., layernames=layernames, mergedims=mergedims)
else
# Construct DimColumns
dims_ = dims(x)
dimcolumns = map(d -> DimColumn(d, dims_), dims_)
dimnames = collect(map(dim2key, dims_))

# Construct DimArrayColumn
dimarraycolumn = DimArrayColumn(x, dims_)

# Merge DimColumns
if mergedims
colnames = vcat([:geometry], [:value])
dimcol = MergedDimColumn(Tuple(dimcolumns), :geometry)
return WideDimTable{typeof(dimcol)}(colnames, dimcol, [dimarraycolumn])
else
return WideDimTable{typeof(dimcolumns)}(vcat(dimnames, [:value]), dimcolumns, [dimarraycolumn])
end
end
end

dimcolumns(t::WideDimTable) = getfield(t, :dimcolumns)
dimarraycolumns(t::WideDimTable) = getfield(t, :dimarraycolumns)
dims(t::WideDimTable) = dims(parent(t))

Base.parent(t::WideDimTable) = getfield(t, :colnames)

for func in (:dims, :val, :index, :lookup, :metadata, :order, :sampling, :span, :bounds,
:locus, :name, :label, :units)
@eval $func(t::WideDimTable, args...) = $func(parent(t), args...)

end

Tables.istable(::WideDimTable) = true
Tables.columnaccess(::Type{<:WideDimTable}) = true
Tables.columns(t::WideDimTable) = t
Tables.columnnames(c::WideDimTable) = parent(c)

function Tables.schema(t::WideDimTable)
colnames = parent(t)
types = vcat([map(eltype, dimcolumns(t))...], [map(eltype, dimarraycolumns(t))...])
Tables.Schema(colnames, types)
end

function Tables.schema(t::WideDimTable{<:MergedDimColumn})
colnames = parent(t)
types = vcat([eltype(dimcolumns(t))], [map(eltype, dimarraycolumns(t))...])
Tables.Schema(colnames, types)
end

@inline function Tables.getcolumn(t::WideDimTable, key::Symbol)
keys = parent(t)
i = findfirst(==(key), keys)
n_dimcols = length(dimcolumns(t))
if i <= n_dimcols
return dimcolumns(t)[i]
else
return dimarraycolumns(t)[i - n_dimcols]
end
end

@inline function Tables.getcolumn(t::WideDimTable{<:MergedDimColumn}, key::Symbol)
keys = parent(t)
i = findfirst(==(key), keys)
if i == 1
return dimcolumns(t)
else
return dimarraycolumns(t)[i - 1]
end
n_dimcols = length(dimcolumns(t))
i = findfirst(==(key), keys)
if i <= n_dimcols
return dimcolumns(t)[i]
else
return dimarraycolumns(t)[i - n_dimcols]
end
end

@inline function Tables.getcolumn(t::WideDimTable, ::Type{T}, i::Int, key::Symbol) where T
Tables.getcolumn(t, key)
end


# TableTraits.jl interface


function IteratorInterfaceExtensions.getiterator(x::DimTableSources)
return Tables.datavaluerows(Tables.columntable(x))
return Tables.datavaluerows(Tables.dictcolumntable(x))
end
IteratorInterfaceExtensions.isiterable(::DimTableSources) = true
TableTraits.isiterabletable(::DimTableSources) = true

function IteratorInterfaceExtensions.getiterator(t::DimTable)
return Tables.datavaluerows(Tables.columntable(t))
return Tables.datavaluerows(Tables.dictcolumntable(t))
end
IteratorInterfaceExtensions.isiterable(::DimTable) = true
TableTraits.isiterabletable(::DimTable) = true
function IteratorInterfaceExtensions.getiterator(t::WideDimTable)
return Tables.datavaluerows(Tables.dictcolumntable(t))
end
IteratorInterfaceExtensions.isiterable(::WideDimTable) = true
TableTraits.isiterabletable(::WideDimTable) = true

0 comments on commit af9381a

Please sign in to comment.