Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize ArbitraryMotion #396

Closed
wants to merge 9 commits into from
2 changes: 1 addition & 1 deletion KomaMRIBase/src/KomaMRIBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Reexport
#Datatypes
using Parameters
#Simulation
using Interpolations
@reexport using Interpolations
#Reconstruction
using MRIBase
@reexport using MRIBase:
Expand Down
6 changes: 5 additions & 1 deletion KomaMRIBase/src/datatypes/Phantom.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ Base.:(≈)(obj1::Phantom, obj2::Phantom) = reduce(&, [getfield(obj1, field
Base.:(==)(m1::MotionModel, m2::MotionModel) = false
Base.:(≈)(m1::MotionModel, m2::MotionModel) = false

"""Separate object spins in a sub-group"""
"""
obj = obj[p]

Separate object spins in a sub-group
"""
Base.getindex(obj::Phantom, p::Union{AbstractRange,AbstractVector,Colon}) = begin
fields = []
for field in Iterators.filter(x -> !(x == :name), fieldnames(Phantom))
Expand Down
118 changes: 69 additions & 49 deletions KomaMRIBase/src/datatypes/phantom/motion/ArbitraryMotion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,26 @@
# Interpolator{T,Degree,ETPType},
# Degree = Linear,Cubic....
# ETPType = Periodic, Flat...
const LinearInterpolator = Interpolations.Extrapolation{

const Interpolator = Interpolations.Extrapolation{
T,
1,
Interpolations.GriddedInterpolation{T,1,V,Gridded{Linear{Throw{OnGrid}}},Tuple{V}},
Gridded{Linear{Throw{OnGrid}}},
Periodic{Nothing},
} where {T<:Real,V<:AbstractVector{T}}
N,
Interpolations.GriddedInterpolation{
T,
N,
V,
Itp,
K
},
Itp,
Interpolations.Periodic{Nothing}
} where {
T<:Real,
N,
V<:AbstractArray{T},
K<:Tuple{Vararg{AbstractVector{T}}},
Itp<:Tuple{Vararg{Union{Interpolations.Gridded{Linear{Throw{OnGrid}}}, Interpolations.NoInterp}}}
}

"""
motion = ArbitraryMotion(period_durations, dx, dy, dz)
Expand Down Expand Up @@ -43,39 +56,11 @@
)
```
"""
struct ArbitraryMotion{T<:Real,V<:AbstractVector{T}} <: MotionModel{T}
struct ArbitraryMotion{T} <: MotionModel{T}
period_durations::Vector{T}
dx::Array{T,2}
dy::Array{T,2}
dz::Array{T,2}
ux::Vector{LinearInterpolator{T,V}}
uy::Vector{LinearInterpolator{T,V}}
uz::Vector{LinearInterpolator{T,V}}
end

function ArbitraryMotion(
period_durations::AbstractVector{T},
dx::AbstractArray{T,2},
dy::AbstractArray{T,2},
dz::AbstractArray{T,2},
) where {T<:Real}
@warn "Note that ArbitraryMotion is under development so it is not optimized so far" maxlog = 1
Ns = size(dx)[1]
num_pieces = size(dx)[2] + 1
limits = times(period_durations, num_pieces)

#! format: off
Δ = zeros(Ns,length(limits),4)
Δ[:,:,1] = hcat(repeat(hcat(zeros(Ns,1),dx),1,length(period_durations)),zeros(Ns,1))
Δ[:,:,2] = hcat(repeat(hcat(zeros(Ns,1),dy),1,length(period_durations)),zeros(Ns,1))
Δ[:,:,3] = hcat(repeat(hcat(zeros(Ns,1),dz),1,length(period_durations)),zeros(Ns,1))

etpx = [extrapolate(interpolate((limits,), Δ[i,:,1], Gridded(Linear())), Periodic()) for i in 1:Ns]
etpy = [extrapolate(interpolate((limits,), Δ[i,:,2], Gridded(Linear())), Periodic()) for i in 1:Ns]
etpz = [extrapolate(interpolate((limits,), Δ[i,:,3], Gridded(Linear())), Periodic()) for i in 1:Ns]
#! format: on

return ArbitraryMotion(period_durations, dx, dy, dz, etpx, etpy, etpz)
dx::Matrix{T}
dy::Matrix{T}
dz::Matrix{T}
end

function Base.getindex(
Expand All @@ -85,8 +70,6 @@
for field in fieldnames(ArbitraryMotion)
if field in (:dx, :dy, :dz)
push!(fields, getfield(motion, field)[p, :])
elseif field in (:ux, :uy, :uz)
push!(fields, getfield(motion, field)[p])
else
push!(fields, getfield(motion, field))
end
Expand Down Expand Up @@ -133,16 +116,53 @@
return limits
end

# TODO: Calculate interpolation functions "on the fly"

function get_itp_functions(motion::ArbitraryMotion{T}, Ns::Int) where {T<:Real}
dx = hcat(repeat(hcat(zeros(T ,Ns, 1), motion.dx), 1, length(motion.period_durations)), zeros(T ,Ns, 1))
dy = hcat(repeat(hcat(zeros(T ,Ns, 1), motion.dy), 1, length(motion.period_durations)), zeros(T ,Ns, 1))
dz = hcat(repeat(hcat(zeros(T ,Ns, 1), motion.dz), 1, length(motion.period_durations)), zeros(T ,Ns, 1))
if Ns > 1
nodes = ([i*one(T) for i=1:Ns], times(motion))
itpx = extrapolate(interpolate(nodes, dx, (NoInterp(), Gridded(Linear()))), Periodic())
itpy = extrapolate(interpolate(nodes, dy, (NoInterp(), Gridded(Linear()))), Periodic())
itpz = extrapolate(interpolate(nodes, dz, (NoInterp(), Gridded(Linear()))), Periodic())

Check warning on line 128 in KomaMRIBase/src/datatypes/phantom/motion/ArbitraryMotion.jl

View check run for this annotation

Codecov / codecov/patch

KomaMRIBase/src/datatypes/phantom/motion/ArbitraryMotion.jl#L125-L128

Added lines #L125 - L128 were not covered by tests
else
nodes = (times(motion), )
itpx = extrapolate(interpolate(nodes, dx[:], (Gridded(Linear()), )), Periodic())
itpy = extrapolate(interpolate(nodes, dy[:], (Gridded(Linear()), )), Periodic())
itpz = extrapolate(interpolate(nodes, dz[:], (Gridded(Linear()), )), Periodic())
end
return itpx, itpy, itpz
end

function get_itp_results(
itpx::Interpolator{T},
itpy::Interpolator{T},
itpz::Interpolator{T},
t::AbstractArray{T},
Ns::Int
) where {T<:Real}
if Ns > 1
id = similar(t, Ns)
id .= (1:Ns)

Check warning on line 147 in KomaMRIBase/src/datatypes/phantom/motion/ArbitraryMotion.jl

View check run for this annotation

Codecov / codecov/patch

KomaMRIBase/src/datatypes/phantom/motion/ArbitraryMotion.jl#L146-L147

Added lines #L146 - L147 were not covered by tests
# Grid
idx = 1*id .+ 0*t # spin id
t = 0*id .+ 1*t # time instants
return itpx.(idx, t), itpy.(idx, t), itpz.(idx, t)

Check warning on line 151 in KomaMRIBase/src/datatypes/phantom/motion/ArbitraryMotion.jl

View check run for this annotation

Codecov / codecov/patch

KomaMRIBase/src/datatypes/phantom/motion/ArbitraryMotion.jl#L149-L151

Added lines #L149 - L151 were not covered by tests
else
return itpx.(t), itpy.(t), itpz.(t)
end
end

function get_spin_coords(
motion::ArbitraryMotion{T},
x::AbstractVector{T},
y::AbstractVector{T},
z::AbstractVector{T},
x::Vector{T},
y::Vector{T},
z::Vector{T},
t::AbstractArray{T},
) where {T<:Real}
xt = x .+ reduce(vcat, [etp.(t) for etp in motion.ux])
yt = y .+ reduce(vcat, [etp.(t) for etp in motion.uy])
zt = z .+ reduce(vcat, [etp.(t) for etp in motion.uz])
return xt, yt, zt
end
Ns = size(motion.dx)[1]
itp = get_itp_functions(motion, Ns)
ux, uy, uz = get_itp_results(itp..., t, Ns)
return x .+ ux, y .+ uy, z .+ uz
end
4 changes: 4 additions & 0 deletions KomaMRICore/src/KomaMRICore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ using CUDA

# KomaMRIBase
@reexport using KomaMRIBase
@reexport import KomaMRIBase.get_spin_coords # This should not be necessary, but it is

# Rawdata
include("rawdata/ISMRMRD.jl")
# Datatypes
include("datatypes/Spinor.jl")
include("other/DiffusionModel.jl")
# Simulator
include("simulation/GPUArbitraryMotion.jl")
include("simulation/GPUFunctions.jl")
include("simulation/SimulatorCore.jl")

Expand All @@ -26,6 +28,8 @@ export signal_to_raw_data
# Simulator
export Mag
export simulate, simulate_slice_profile
# Spin coordinates
export get_spin_coords
# Spinors
export Spinor, Rx, Ry, Rz, Q, Un

Expand Down
17 changes: 17 additions & 0 deletions KomaMRICore/src/simulation/GPUArbitraryMotion.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
function get_spin_coords(

Check warning on line 1 in KomaMRICore/src/simulation/GPUArbitraryMotion.jl

View check run for this annotation

Codecov / codecov/patch

KomaMRICore/src/simulation/GPUArbitraryMotion.jl#L1

Added line #L1 was not covered by tests
motion::ArbitraryMotion{T},
x::AbstractVector{T},
y::AbstractVector{T},
z::AbstractVector{T},
t::AbstractArray{T},
) where {T<:Real}
Ns = size(motion.dx)[1]
itpx, itpy, itpz = KomaMRIBase.get_itp_functions(motion, Ns)

Check warning on line 9 in KomaMRICore/src/simulation/GPUArbitraryMotion.jl

View check run for this annotation

Codecov / codecov/patch

KomaMRICore/src/simulation/GPUArbitraryMotion.jl#L8-L9

Added lines #L8 - L9 were not covered by tests
# To GPU
itpx = adapt(CuArray{T}, itpx)
itpy = adapt(CuArray{T}, itpy)
itpz = adapt(CuArray{T}, itpz) # Problem: too many CPU -> GPU transfers
ux, uy, uz = KomaMRIBase.get_itp_results(itpx, itpy, itpz, t, Ns)
return x .+ ux, y .+ uy, z .+ uz

Check warning on line 15 in KomaMRICore/src/simulation/GPUArbitraryMotion.jl

View check run for this annotation

Codecov / codecov/patch

KomaMRICore/src/simulation/GPUArbitraryMotion.jl#L11-L15

Added lines #L11 - L15 were not covered by tests
end

31 changes: 5 additions & 26 deletions KomaMRICore/src/simulation/GPUFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,8 @@
# GPU adaptor
struct KomaCUDAAdaptor end
adapt_storage(to::KomaCUDAAdaptor, x) = CUDA.cu(x)
adapt_storage(to::KomaCUDAAdaptor, x::NoMotion) = NoMotion{Float32}()
adapt_storage(to::KomaCUDAAdaptor, x::SimpleMotion) = f32(x)
function adapt_storage(to::KomaCUDAAdaptor, x::ArbitraryMotion)
fields = []
for field in fieldnames(ArbitraryMotion)
if field in (:ux, :uy, :uz)
push!(fields, adapt(KomaCUDAAdaptor(), getfield(x, field)))
else
push!(fields, f32(getfield(x, field)))
end
end
return ArbitraryMotion(fields...)
end
function adapt_storage(
to::KomaCUDAAdaptor, x::Vector{LinearInterpolator{T,V}}
) where {T<:Real,V<:AbstractVector{T}}
return CUDA.cu.(x)
end
adapt_storage(to::KomaCUDAAdaptor, x::MotionModel) = f32(x) # Motion models are not passed to GPU

Check warning on line 48 in KomaMRICore/src/simulation/GPUFunctions.jl

View check run for this annotation

Codecov / codecov/patch

KomaMRICore/src/simulation/GPUFunctions.jl#L48

Added line #L48 was not covered by tests


"""
gpu(x)
Expand Down Expand Up @@ -113,15 +97,10 @@
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Complex}) = convert.(Complex{T}, xs)
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Bool}) = xs
adapt_storage(T::Type{<:Real}, xs::SimpleMotion) = SimpleMotion(paramtype(T, xs.types))

adapt_storage(T::Type{<:Real}, xs::NoMotion) = NoMotion{T}()
function adapt_storage(T::Type{<:Real}, xs::ArbitraryMotion)
fields = []
for field in fieldnames(ArbitraryMotion)
push!(fields, paramtype(T, getfield(xs, field)))
end
return ArbitraryMotion(fields...)
end
adapt_storage(T::Type{<:Real}, xs::SimpleMotion) = SimpleMotion(paramtype(T, xs.types))
adapt_storage(T::Type{<:Real}, xs::ArbitraryMotion) = ArbitraryMotion( (paramtype.(Ref(T), getfield.(Ref(xs), fieldnames(ArbitraryMotion))))... )

"""
f32(m)
Expand Down
4 changes: 2 additions & 2 deletions KomaMRIPlots/src/ui/DisplayFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -792,8 +792,8 @@ function plot_image(
image;
height=600,
width=nothing,
zmin=minimum(abs.(image[:])),
zmax=maximum(abs.(image[:])),
zmin=minimum(image[:]),
zmax=maximum(image[:]),
darkmode=false,
title="",
)
Expand Down
Loading