Skip to content

Commit

Permalink
make DataBatches generic
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander-Barth committed May 21, 2024
1 parent 8242715 commit d1b1fcb
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
21 changes: 16 additions & 5 deletions src/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ end


#mutable struct NCData{T,N} <: AbstractVector{Tuple{Array{T,N},Array{T,N}}}
mutable struct NCData{T,N #=,TA=#}
mutable struct NCData{T,N #=,TA=#} <: AbstractDataset{T,N}
lon::Vector{T}
lat::Vector{T}
time::Vector{DateTime}
Expand Down Expand Up @@ -590,7 +590,7 @@ function getxy!(dd::NCData{T,4},ind::Integer,xin::AbstractArray{T2,4},xtrue::Abs
end

nobs(dd::NCData) = length(dd.time)
function getobs(dd::NCData{T},index::Int) where T
function getobs(dd::AbstractDataset{T},index::Int) where T
data = (zeros(T,sizex(dd)),zeros(T,sizey(dd)))
return getobs!(dd,data,index)
end
Expand Down Expand Up @@ -769,14 +769,25 @@ struct DataBatches{Atype,T,N,Tdata,Tbatch}
batch::Tbatch
end

function Base.length(d::DataBatches)
b =
if d.batch isa Tuple
d.batch[1]
else
d.batch
end

return ceil(Int,nobs(d.data)/(size(b)[end]))
end

function Random.shuffle!(d::DataBatches)
randperm!(d.perm)
return d
end

function DataBatches(Atype,data::NCData{T,N},batchsize) where {T,N}
function DataBatches(Atype,data::AbstractDataset{T,N},batchsize; shuffle=data.train) where {T,N}
perm =
if data.train
if shuffle
randperm(nobs(data))
else
1:(nobs(data))
Expand All @@ -802,7 +813,7 @@ function DataBatches(Atype,data::NCData{T,N},batchsize) where {T,N}
end

function Base.iterate(d::DataBatches{Atype,T,N},index = 0) where {Atype,T,N}
bs = index+1 : min(index + d.batchsize,length(d.data))
bs = index+1 : min(index + d.batchsize,nobs(d.data))

if length(bs) == 0
return nothing
Expand Down
3 changes: 2 additions & 1 deletion src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ struct ModelVector2_1
directionobs
end


abstract type AbstractDataset{T,N}
end

0 comments on commit d1b1fcb

Please sign in to comment.