Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement PythonCall extension to convert xarray objects #882

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ on:
defaults:
run:
shell: bash

permissions:
actions: write
contents: read

jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }}
Expand All @@ -32,12 +37,12 @@ jobs:
arch: x64
allow_failure: true
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: julia-actions/julia-buildpkg@latest
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-runtest@latest
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v1
Expand Down
3 changes: 3 additions & 0 deletions CondaPkg.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[deps]
xarray = ""
numpy = ""
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ AlgebraOfGraphics = "cbdf2221-f076-402e-a563-3d30da359d67"
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
DiskArrays = "3c3547ce-8d99-4f5e-a174-61eb10b00ae3"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[extensions]
DimensionalDataAlgebraOfGraphicsExt = "AlgebraOfGraphics"
DimensionalDataCategoricalArraysExt = "CategoricalArrays"
DimensionalDataDiskArraysExt = "DiskArrays"
DimensionalDataMakie = "Makie"
DimensionalDataPythonCall = "PythonCall"
DimensionalDataStatsBase = "StatsBase"

[compat]
Expand Down Expand Up @@ -69,6 +71,7 @@ Makie = "0.20, 0.21"
OffsetArrays = "1"
Plots = "1"
PrecompileTools = "1"
PythonCall = "0.9"
Random = "1"
RecipesBase = "0.7, 0.8, 1"
SafeTestsets = "0.1"
Expand Down Expand Up @@ -103,6 +106,7 @@ JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand All @@ -111,4 +115,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["AlgebraOfGraphics", "Aqua", "ArrayInterface", "BenchmarkTools", "CategoricalArrays", "ColorTypes", "Combinatorics", "CoordinateTransformations", "DataFrames", "DiskArrays", "Distributions", "Documenter", "GPUArrays", "ImageFiltering", "ImageTransformations", "JLArrays", "CairoMakie", "OffsetArrays", "Plots", "Random", "SafeTestsets", "StatsBase", "StatsPlots", "Test", "Unitful"]
test = ["AlgebraOfGraphics", "Aqua", "ArrayInterface", "BenchmarkTools", "CategoricalArrays", "ColorTypes", "Combinatorics", "CoordinateTransformations", "DataFrames", "DiskArrays", "Distributions", "Documenter", "GPUArrays", "ImageFiltering", "ImageTransformations", "JLArrays", "CairoMakie", "OffsetArrays", "Plots", "PythonCall", "Random", "SafeTestsets", "StatsBase", "StatsPlots", "Test", "Unitful"]
2 changes: 2 additions & 0 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ const navTemp = {
{ text: 'Tables and DataFrames', link: '/tables' },
{ text: 'CUDA and GPUs', link: '/cuda' },
{ text: 'DiskArrays', link: '/diskarrays' },
{ text: 'Xarray', link: '/xarray' },
{ text: 'Extending DimensionalData', link: '/extending_dd' },
],
},
Expand Down Expand Up @@ -96,6 +97,7 @@ export default defineConfig({
{ text: 'Tables and DataFrames', link: '/tables' },
{ text: 'CUDA and GPUs', link: '/cuda' },
{ text: 'DiskArrays', link: '/diskarrays' },
{ text: 'Xarray', link: '/xarray' },
{ text: 'Extending DimensionalData', link: '/extending_dd' },
],
},
Expand Down
27 changes: 27 additions & 0 deletions docs/src/xarray.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Xarray and PythonCall.jl

In the Python ecosystem [Xarray](https://xarray.dev) is by far the most popular
package for working with multidimensional labelled arrays. The main data
structures it provides are:
- [DataArray](https://docs.xarray.dev/en/stable/user-guide/data-structures.html#dataarray),
analagous to `DimArray`.
- [Dataset](https://docs.xarray.dev/en/stable/user-guide/data-structures.html#dataset),
analagous to `DimStack`.

DimensionalData integrates with
[PythonCall.jl](https://juliapy.github.io/PythonCall.jl/stable/) to allow
converting these Xarray types to their DimensionalData equivalent:
```julia
import PythonCall: pyconvert

my_dimarray = pyconvert(DimArray, my_dataarray)

my_dimstack = pyconvert(DimStack, my_dataset)
```

Note that:
- The current implementation will make a copy of the underlying arrays.
- Python stores arrays in row-major order whereas Julia stores them in
column-major order, hence the dimensions on a converted `DimArray` will be in
reverse order from the original `DataArray`. This is done to ensure that the
'fast axis' to iterate over is the same dimension in both Julia and Python.
99 changes: 99 additions & 0 deletions ext/DimensionalDataPythonCall.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
module DimensionalDataPythonCall

using DimensionalData
import PythonCall
import PythonCall: Py, pyis, pyconvert, pytype, pybuiltins
import DimensionalData.Lookups: NoLookup

function dtype2type(dtype::String)
if dtype == "float16"
Float16
elseif dtype == "float32"
Float32
elseif dtype == "float64"
Float64
elseif dtype == "int8"
Int8
elseif dtype == "int16"
Int16
elseif dtype == "int32"
Int32
elseif dtype == "int64"
Int64
elseif dtype == "uint8"
UInt8
elseif dtype == "uint16"
UInt16
elseif dtype == "uint32"
UInt32
elseif dtype == "uint64"
UInt64
elseif dtype == "bool"
Bool
else
error("Unsupported dtype: '$dtype'")
end
end

function PythonCall.pyconvert(::Type{DimArray}, x::Py, d=nothing)
x_pytype = string(pytype(x).__name__)
if x_pytype != "DataArray"
if isnothing(d)
throw(ArgumentError("Cannot convert $(pytype(x)) to a DimArray, it must be an xarray.DataArray"))
else
return d
end
end

# Transpose here so that the fast axis remains the same in the Julia array
data_npy = x.data.T
data_type = dtype2type(string(data_npy.dtype.name))
data_ndim = pyconvert(Int, data_npy.ndim)
data = pyconvert(Array{data_type, data_ndim}, data_npy)

dim_names = Symbol.(collect(x.dims))
coord_names = Symbol.(collect(x.coords.keys()))
lookups_dict = Dict{Symbol, Any}()
for dim in dim_names
if dim in coord_names
coord = getproperty(x, dim).data
coord_type = dtype2type(string(coord.dtype.name))
coord_ndim = pyconvert(Int, coord.ndim)

lookups_dict[dim] = pyconvert(Array{coord_type, coord_ndim}, coord)
else
lookups_dict[dim] = NoLookup()
end
end

lookups = NamedTuple(lookups_dict)

metadata = pyconvert(Dict, x.attrs)

array_name = pyis(x.name, pybuiltins.None) ? nothing : string(x.name)

return DimArray(data, lookups; name=array_name, metadata)
end

function PythonCall.pyconvert(::Type{DimStack}, x::Py, d=nothing)
x_pytype = string(pytype(x).__name__)
if x_pytype != "Dataset"
if isnothing(d)
throw(ArgumentError("Cannot convert $(x) to a DimStack, it must be an xarray.Dataset"))
else
return d
end
end

variable_names = Symbol.(collect(x.data_vars.keys()))
arrays = Dict{Symbol, DimArray}()
for name in variable_names
arrays[name] = pyconvert(DimArray, getproperty(x, name))
end

metadata = pyconvert(Dict, x.attrs)

return DimStack(NamedTuple(arrays); metadata)
end

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ using DimensionalData, Test, Aqua, SafeTestsets
@time @safetestset "adapt" begin include("adapt.jl") end
@time @safetestset "ecosystem" begin include("ecosystem.jl") end
@time @safetestset "categorical" begin include("categorical.jl") end
@time @safetestset "xarray" begin include("xarray.jl") end
if Sys.islinux()
# Unfortunately this can hang on other platforms.
# Maybe ram use of all the plots on the small CI machine? idk
Expand Down
54 changes: 54 additions & 0 deletions test/xarray.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
ENV["JULIA_CONDAPKG_ENV"] = "@dimensionaldata-tests"

using DimensionalData, Test, PythonCall
import DimensionalData.Dimensions: NoLookup, NoMetadata


xr = pyimport("xarray")
np = pyimport("numpy")

data = rand(10, 5)
times = sort(rand(10))
x = xr.DataArray(data,
dims=("time", "length"),
coords=Dict("time" => times),
name="data",
attrs=Dict("motor" => "hexapod",
"pos" => 0.48,
"foo" => np.array([1, 2, 3])))

data2 = rand(10, 2)
x2 = xr.DataArray(data2,
dims=("time", "mass"),
coords=Dict("time" => times),
name="data2",
attrs=Dict("motor" => "delay",
"pos" => 0.48))

@testset "DataArray to DimArray" begin
y = pyconvert(DimArray, x)
@test name(y) == "data"
@test name.(dims(y)) == (:length, :time)
@test lookup(y, :time) == times
@test_broken lookup(y, :length) == NoLookup()
@test metadata(y) == Dict("motor" => "hexapod",
"pos" => 0.48,
"foo" => [1, 2, 3])

@test_throws ArgumentError pyconvert(DimArray, xr)
@test pyconvert(DimArray, xr, 42) == 42
end

@testset "Dataset to DimStack" begin
dataset = xr.Dataset(Dict("x" => x, "x2" => x2),
attrs=Dict("source" => "interwebs"))
z = pyconvert(DimStack, dataset)

@test name(z) == (:x2, :x)
@test name.(dims(z)) == (:mass, :time, :length)
@test lookup(z, :time) == times
@test metadata(z) == Dict("source" => "interwebs")

@test_throws ArgumentError pyconvert(DimStack, x)
@test pyconvert(DimStack, x, 42) == 42
end
Loading