diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 4156ca0..5a03348 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -24,10 +24,12 @@ jobs: - 'nightly' os: - ubuntu-latest + - macOS-latest + - windows-latest arch: - x64 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} diff --git a/Project.toml b/Project.toml index 9362895..258ce6b 100644 --- a/Project.toml +++ b/Project.toml @@ -4,10 +4,17 @@ authors = ["pevnak and contributors"] version = "1.0.0" [deps] +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +DLFP8Types = "f4c16678-4a16-415b-82ef-ed337c5d6c7c" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" +Mmap = "a63ad114-7e13-5084-954f-fe012c677804" [compat] +BFloat16s = "0.5" +DLFP8Types = "0.1" JSON3 = "1" +MappedArrays = "0.4" julia = "1.6" [extras] diff --git a/README.md b/README.md index bd723f5..3cd1304 100644 --- a/README.md +++ b/README.md @@ -3,18 +3,14 @@ [![Build Status](https://github.com/FluxML/SafeTensors.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/FluxML/SafeTensors.jl/actions/workflows/CI.yml?query=branch%3Amain) -This packages loads data stored in [safetensor format](https://huggingface.co/docs/safetensors/index). +This packages loads data stored in [safetensor format](https://huggingface.co/docs/safetensors/index). Since Python is row-major and Julia is column-major, the dimensions are permuted such the tensor has the same shape as in python, but everything is correctly ordered. This includes a performance penalty in sense that we cannot be completely copy-free. -The list of dependencies is kept minimal to `JSON3` for parsing the header. - -The package does not allow to save the data. - The main function is `load_safetensors` which returns a `Dict{String,V}` where keys are names of tensors and values are tensors. An example from `runtests` is as follows ```julia julia> using SafeTensors -julia> d = load_safetensors("model.safetensors") +julia> d = load_safetensors("test/model.safetensors") Dict{String, Array} with 27 entries: "int32_357" => Int32[0 7 … 21 28; 35 42 … 56 63; 70 77 … 91 98;;; 1 8 … 22 29… "uint8_3" => UInt8[0x00, 0x01, 0x02] @@ -45,9 +41,73 @@ Dict{String, Array} with 27 entries: "float64_3" => [0.0, 1.0, 2.0] ``` -It is also possible to load just header using unexported `load_header` as +It can also perform a lazy loading with `SafeTensors.deserialize("model.safetensors")` which `mmap` the file and return a `Dict`-like object: ```julia -julia> d = SafeTensors.load_header("model.safetensors") +julia> tensors = SafeTensors.deserialize("test/model.safetensors"; mmap = true #= default to `true`=#); + +julia> tensors["float32_35"] +3×5 mappedarray(ltoh, PermutedDimsArray(reshape(reinterpret(Float32, view(::Vector{UInt8}, 0x0000000000000ef5:0x0000000000000f30)), 5, 3), (2, 1))) with eltype Float32: + 0.0 1.0 2.0 3.0 4.0 + 5.0 6.0 7.0 8.0 9.0 + 10.0 11.0 12.0 13.0 14.0 +``` + +Serialization is also supported: + +```julia +julia> using Random, BFloat16s + +julia> weights = Dict("W"=>randn(BFloat16, 3, 5), "b"=>rand(BFloat16, 3)) +Dict{String, Array{BFloat16}} with 2 entries: + "W" => [0.617188 0.695312 … 0.390625 -2.0; -0.65625 -0.617188 … 0.652344 0.244141; 0.226562 2.70312 … -0.174805 -0.7773… + "b" => [0.111816, 0.566406, 0.283203] + +julia> f = tempname(); + +julia> SafeTensors.serialize(f, weights) + +julia> loaded = SafeTensors.deserialize(f); + +julia> loaded["W"] ≈ weights["W"] +true + +julia> SafeTensors.serialize(f, weights, Dict("Package"=>"SafeTensors.jl", "version"=>"1")) + +julia> loaded = SafeTensors.deserialize(f); + +julia> loaded.metadata +Dict{String, String} with 2 entries: + "Package" => "SafeTensors.jl" + "version" => "1" ``` +Working with gpu: +```julia +julia> loaded["W"] +3×5 mappedarray(ltoh, PermutedDimsArray(reshape(reinterpret(BFloat16, view(::Vector{UInt8}, 0x00000000000000b9:0x00000000000000d6)), 5, 3), (2, 1))) with eltype BFloat16: + 0.542969 0.201172 1.38281 -0.255859 -1.55469 + 0.172852 -0.949219 0.0561523 -1.34375 -0.206055 + -0.0854492 1.17969 -0.265625 -0.871094 2.25 + +julia> using CUDA; CUDA.allowscalar(false) +julia> CuArray(loaded["W"]) +3×5 CuArray{BFloat16, 2, CUDA.Mem.DeviceBuffer}: + 0.542969 0.201172 1.38281 -0.255859 -1.55469 + 0.172852 -0.949219 0.0561523 -1.34375 -0.206055 + -0.0854492 1.17969 -0.265625 -0.871094 2.25 + +julia> gpu_weights = Dict("W"=>CuArray(loaded["W"]), "b"=>CuArray(loaded["b"])) +Dict{String, CuArray{BFloat16, N, CUDA.Mem.DeviceBuffer} where N} with 2 entries: + "W" => [0.542969 0.201172 … -0.255859 -1.55469; 0.172852 -0.949219 … -1.34375 -0.206055; -0.0854492 1.17969 … -0.871094… + "b" => BFloat16[0.871094, 0.773438, 0.703125] + +julia> f = tempname(); + +julia> SafeTensors.serialize(f, gpu_weights) + +julia> SafeTensors.deserialize(f) +SafeTensors.SafeTensor{SubArray{UInt8, 1, Vector{UInt8}, Tuple{UnitRange{UInt64}}, true}} with 2 entries: + "W" => BFloat16[0.542969 0.201172 … -0.255859 -1.55469; 0.172852 -0.949219 … -1.34375 -0.206055; -0.0854492 1.17969 … -… + "b" => BFloat16[0.871094, 0.773438, 0.703125] +``` diff --git a/src/SafeTensors.jl b/src/SafeTensors.jl index 84bb241..59af07e 100644 --- a/src/SafeTensors.jl +++ b/src/SafeTensors.jl @@ -1,129 +1,355 @@ module SafeTensors +using Base: Checked +using Mmap + +using DLFP8Types +using BFloat16s using JSON3 +using JSON3.StructTypes -""" - _gettype(s) - _gettype(s, name) +using MappedArrays: mappedarray - Julia type of the tensor from the string name -""" -function _gettype(s::AbstractString, name) - s == "F16" && return(Float16) - s == "F32" && return(Float32) - s == "F64" && return(Float64) - s == "B" && return(Bool) - s == "BOOL" && return(Bool) - s == "U8" && return(UInt8) - s == "I8" && return(Int8) - s == "I16" && return(Int16) - s == "I32" && return(Int32) - s == "I64" && return(Int64) - s == "BF16" && error("BFloat16 is not supported") - error("unknown type ",s," of the tensor ", name) -end - -_byteoftype(::Type{T}) where {T<:Union{Bool, UInt8, Int8}} = 1 -_byteoftype(::Type{T}) where {T<:Union{Int16, Float16}} = 2 -_byteoftype(::Type{T}) where {T<:Union{Int32, Float32}} = 4 -_byteoftype(::Type{T}) where {T<:Union{Int64, Float64}} = 8 +Base.@enum Dtype::UInt8 begin + # Boolan type + BOOL + # Unsigned byte + U8 + # Signed byte + I8 + # FP8 _ + F8_E5M2 + # FP8 _ + F8_E4M3 + # Signed integer (16-bit) + I16 + # Unsigned integer (16-bit) + U16 + # Half-precision floating point + F16 + # Brain floating point + BF16 + # Signed integer (32-bit) + I32 + # Unsigned integer (32-bit) + U32 + # Floating point (32-bit) + F32 + # Floating point (64-bit) + F64 + # Signed integer (64-bit) + I64 + # Unsigned integer (64-bit) + U64 +end -""" - readtensor!(fio::IO, header::Dict, name::Symbol, header_length; seek_to_start = true) - readtensor!(fio::IO, T, shape, start, stop, name="", header_length; seek_to_start = true) +const typemap = Dict( + BOOL => Bool, + U8 => UInt8, + I8 => Int8, + F8_E5M2 => Float8_E5M2, + F8_E4M3 => Float8_E4M3FN, + I16 => Int16, + U16 => UInt16, + F16 => Float16, + BF16 => BFloat16, + I32 => Int32, + U32 => UInt32, + F32 => Float32, + F64 => Float64, + I64 => Int64, + U64 => UInt64, +) - reads tensor `name` from the file `fio`. - `seek_to_start = true` means that seek(fio, start) will be called to ensure that reading - starts from correct position -""" -function readtensor!(fio::IO, header::JSON3.Object, name::Symbol, header_length; seek_to_start = true) - entry = header[name] - T = _gettype(entry[:dtype], name) - start = entry[:data_offsets][1] + header_length - stop = entry[:data_offsets][2] + header_length - shape = tuple(entry[:shape]...) - readtensor!(fio, T, shape, start, stop, name; seek_to_start) -end - -function readtensor!(fio::IO, T::Type, shape::NTuple{N,<:Integer}, start::Integer, stop::Integer, name=""; seek_to_start = true) where {N} - seek_to_start && seek(fio, start) - n = stop - start - if _byteoftype(T)*prod(shape) != n - s = isempty(name) ? "" : "of tensor "*name - error("length of the stored data",s," does not corresponds to shape of the tensor") - end - x = Array{T,length(shape)}(undef, reverse(shape)...) - read!(fio, x) - if length(shape) > 1 - x = permutedims(x, length(shape):-1:1) - end - return(x) -end - -function names_without_metadata(header) - filter(s -> s !== Symbol("__metadata__"), collect(keys(header))) +tag2type(tag::Dtype) = typemap[tag] +tag2name(tag::Dtype) = Symbol(tag) + +let nametagmap = Dict(v => Dtype(k) for (k, v) in Base.Enums.namemap(Dtype)), + typetagmap = Dict(reverse(kv) for kv in typemap) + global function name2tag(name) + tag = get(nametagmap, Symbol(name), nothing) + isnothing(tag) && error("Unknown Dtype: $name") + return tag + end + global function type2tag(@nospecialize T) + tag = get(typetagmap, T, nothing) + isnothing(tag) && error("Unsupproted data type: $T") + return tag + end end -""" - starts_of_tensors(header) +StructTypes.StructType(::Type{Dtype}) = StructTypes.CustomStruct() +StructTypes.lower(x::Dtype) = tag2name(x) +StructTypes.lowertype(::Type{Dtype}) = Symbol +StructTypes.construct(::Type{Dtype}, x::Symbol) = name2tag(x) - return a sorted list of pairs (name_of_tensor, start) -""" -function starts_of_tensors(header) - ks = names_without_metadata(header) - starts = map(ks) do k - k => header[k][:data_offsets][1] - end - sort!(starts, lt = (i,j) -> i[2] < j[2]) - return(starts) +struct TensorInfo + dtype::Dtype + shape::Tuple{Vararg{UInt}} + data_offsets::NTuple{2, UInt} # rust zero-based offsets, need +1,+0 when used as index end +StructTypes.StructType(::Type{TensorInfo}) = StructTypes.CustomStruct() +StructTypes.lower(x::TensorInfo) = (; dtype = x.dtype, shape = x.shape, data_offsets = x.data_offsets) +StructTypes.lowertype(::Type{TensorInfo}) = @NamedTuple{dtype::Dtype, shape::Vector{UInt}, data_offsets::NTuple{2, UInt}} +StructTypes.construct(::Type{TensorInfo}, x::NamedTuple) = TensorInfo(x.dtype, Tuple(x.shape), x.data_offsets) -""" - is_continuous(header, starts = starts_of_tensors(header)) +struct HashMetadata <: AbstractDict{String, Union{Dict{String, String}, TensorInfo}} + metadata::Union{Dict{String, String}, Nothing} + tensors::Dict{String, TensorInfo} +end +Base.length(m::HashMetadata) = length(m.tensors) + !isnothing(m.metadata) +Base.iterate(m::HashMetadata) = isnothing(m.metadata) ? iterate(m, nothing) : (("__metadata__" => m.metadata), nothing) +Base.iterate(m::HashMetadata, state) = isnothing(state) ? iterate(m.tensors) : iterate(m.tensors, state) +function StructTypes.construct(::Type{HashMetadata}, x::Dict{String, Union{Dict{String, String}, TensorInfo}}) + metadata = get(x, "__metadata__", nothing); delete!(x, "__metadata__") + tensors = Dict{String, TensorInfo}(x) + return HashMetadata(metadata, tensors) +end - return true if tensors in header are correctly aligned and can be read sequentially (which they should) -""" -function is_continuous(header, starts = starts_of_tensors(header)) - i = 0 - for (k, start) in starts - start != i && return(false) - i = header[k]["data_offsets"][2] - end - return(true) +struct Metadata <: AbstractDict{String, TensorInfo} + metadata::Union{Dict{String, String}, Nothing} + tensors::Vector{TensorInfo} + index_map::Dict{String, UInt} +end +function Metadata( + metadata::Union{AbstractDict{String, String}, Nothing}, + tensors::AbstractVector{Pair{String, TensorInfo}} +) + index_map = Dict{String, UInt}(); sizehint!(index_map, length(tensors)) + tensors = map(enumerate(tensors)) do (index, (k, tensor)) + index_map[k] = index + return tensor + end + return Metadata(metadata, tensors, index_map) +end +Base.length(x::Metadata) = length(x.tensors) +function Base.iterate(x::Metadata, s...) + it = iterate(x.index_map, s...) + isnothing(it) && return nothing + (name, index), state = it + tensor = @inbounds x.tensors[index] + return (name => tensor), state +end +function Base.getindex(x::Metadata, name) + index = x.index_map[name] + return @inbounds x.tensors[index] +end + +StructTypes.StructType(::Type{Metadata}) = StructTypes.CustomStruct() +function StructTypes.lower(x::Metadata) + metadata = x.metadata + tensors = Dict{String, TensorInfo}(); sizehint!(tensors, length(x.tensors)) + @inbounds for (name, index) in x.index_map + tensors[name] = x.tensors[index] + end + return HashMetadata(metadata, tensors) +end +StructTypes.lowertype(::Type{Metadata}) = HashMetadata +function StructTypes.construct(::Type{Metadata}, x::HashMetadata) + metadata = x.metadata + tensors = sort!(collect(x.tensors); by = pair -> last(pair).data_offsets) + return Metadata(metadata, tensors) +end + +function validate(metadata::Metadata) + start = 0 + for (i, info) in enumerate(metadata.tensors) + s, e = info.data_offsets + if s != start || e < s + tensor_name = something(findfirst(==(i), metadata.index_map), "no_tensor") + error("Invalid Offset: `$tensor_name`") + end + start = e + nelements = reduce(Checked.checked_mul, info.shape; init = one(UInt)) + nbytes = Checked.checked_mul(nelements, sizeof(tag2type(info.dtype))) + if e - s != nbytes + error("Tensor Invalid Info") + end + end + return start +end + +struct SafeTensor{D} <: AbstractDict{String, AbstractArray} + metadata::Metadata + data::D end +getmetadata(x::SafeTensor) = getfield(x, :metadata) +Base.getproperty(x::SafeTensor, sym::Symbol) = sym == :metadata ? getmetadata(x).metadata : getfield(x, sym) +Base.length(x::SafeTensor) = length(getmetadata(x)) +function Base.iterate(x::SafeTensor, s...) + it = iterate(getmetadata(x), s...) + isnothing(it) && return nothing + ((name, info), state) = it + tensor = _tensorslice(x.data, info) + return (name => tensor), state +end +function Base.getindex(x::SafeTensor, name) + info = getmetadata(x)[name] + return _tensorslice(x.data, info) +end + +_from_le(x) = mappedarray(ltoh, x) +function _changemaj(x, shape::NTuple{N}) where N + perm = ntuple(i->N+1-i, Val(N)) + return PermutedDimsArray(x, perm) +end +function _tensorslice(data, info) + T = tag2type(info.dtype) + shape = Int.(info.shape) + start, stop = info.data_offsets + tensor = @inbounds _changemaj(Base.ReshapedArray(reinterpret(T, @view(data[start+0x1:stop])), reverse(shape), ()), shape) + return _from_le(tensor) +end + +const MAX_HEADER_SIZE = 100_000_000 +function read_metadata(buf::AbstractVector{UInt8}) + buffer_len = length(buf) + buffer_len < 8 && error("Header Too Small") + n = ltoh(@inbounds reinterpret(UInt64, @view(buf[1:8]))[1]) + n > min(MAX_HEADER_SIZE, typemax(Int)) && error("Header Too Large") + stop = Checked.checked_add(UInt(n), 0x8) + stop > buffer_len && error("Invalid Header Length") + metadata = @inbounds JSON3.read(@view(buf[9:Int(stop)]), Metadata) + buffer_end = validate(metadata) + buffer_end + 8 + n != buffer_len && error("Metadata Incomplete Buffer") + return (n, metadata) +end +function deserialize(buffer::AbstractVector{UInt8}) + n, metadata = read_metadata(buffer) + data = @inbounds @view buffer[n+9:end] + return SafeTensor(metadata, data) +end """ - header, header_length = load_header(fio::IO) - header = load_header(filename::AbstractString) + deserialize(file::AbstractString; mmap = true) - loads the header of a stream containing safetensor +Deserialize the lazy [`SafeTensor`](@ref) object. """ -function load_header(fio::IO) - seek(fio, 0) - n = read(fio, Int64) # first read the length of the header - s = read(fio, n) # then read the header - header = JSON3.read(s) - return(header, 8 + n) +function deserialize(file::AbstractString; mmap = true) + if mmap + open(io->deserialize(Mmap.mmap(io, Vector{UInt8})), file) + else + deserialize(read(file)) + end +end + +function prepare( + data::AbstractDict{String, <:AbstractArray}, + data_info::Union{AbstractDict{String, String}, Nothing}, +) + len = length(data) + tensors = Vector{valtype(data)}(undef, len) + hmetadata = Dict{String, TensorInfo}(); sizehint!(hmetadata, len) + data = sort!(collect(data); by = kv -> (type2tag(eltype(last(kv))), first(kv))) + offset = zero(UInt) + for (i, (name, tensor)) in enumerate(data) + dtype = type2tag(eltype(tensor)) + shape = size(tensor) + n = length(reinterpret(UInt8, tensor)) % UInt + noffset = offset + n + info = TensorInfo(dtype, shape, (offset, noffset)) + offset = noffset + hmetadata[name] = info + @inbounds tensors[i] = tensor + end + metadata = HashMetadata(data_info, hmetadata) + metadata_buf = IOBuffer() + JSON3.write(metadata_buf, metadata) + extra = 8 - mod1(metadata_buf.size, 8) + foreach(_->write(metadata_buf, ' '), 1:extra) + n = UInt64(metadata_buf.size) + header_bytes = take!(metadata_buf) + return (n, header_bytes, offset), tensors +end + +function _prepare(data, data_info) + ((n, header_bytes, offset), tensors) = prepare(data, data_info) + header_size = length(header_bytes) % UInt + 0x8 + expected_size = header_size + offset + return expected_size, header_size, n, header_bytes, tensors end -function load_header(filename::AbstractString) - open(first ∘ load_header,filename,"r") +@static if Base.ENDIAN_BOM == 0x04030201 + _to_le(x) = x +else + _to_le(x) = mappedarray(htol, x) end -function load_safetensors(fio::IO, header, tensors, header_length; seek_to_start = true) - Dict(map(k -> String(k) => readtensor!(fio, header, k, header_length; seek_to_start), tensors)) +function _serialize_write(io::IO, expected_size, header_size, n, header_bytes, tensors) + ws = zero(UInt) + ws += write(io, htol(n)) + ws += write(io, header_bytes) + @assert ws == header_size + for tensor in tensors + _tensor = _to_le(collect(_changemaj(tensor, size(tensor)))) + ws += write(io, _tensor) + end + @assert ws == expected_size + return end -function load_safetensors(filename::AbstractString) - open(filename,"r") do fio - header, header_length = load_header(fio) - starts = starts_of_tensors(header) - seek_to_start = !is_continuous(header, starts) - tensors = first.(starts) - load_safetensors(fio, header, tensors, header_length; seek_to_start) - end +function _serialize_copyto!(buf::AbstractVector{UInt8}, expected_size, header_size, n, header_bytes, tensors) + @assert length(buf) == expected_size + copyto!(buf, 0x1, reinterpret(UInt8, [htol(n)]), 0x1, 0x8) + copyto!(buf, 0x9, header_bytes, 0x1, header_size - 0x8) + pos = header_size + 0x1 + for tensor in tensors + _tensor = _to_le(collect(_changemaj(tensor, size(tensor)))) + _tensor = reinterpret(UInt8, _tensor) + len = UInt(length(_tensor)) + copyto!(buf, pos, _tensor, 0x1, len) + pos += len + end + @assert expected_size == pos - 0x1 + return +end + +serialize(buf::AbstractVector{UInt8}, data::AbstractDict{String, <:AbstractArray}, data_info::Union{AbstractDict{String, String}, Nothing} = nothing) = _serialize_copyto!(buf, _prepare(data, data_info)...) +serialize(io::IO, data::AbstractDict{String, <:AbstractArray}, data_info::Union{AbstractDict{String, String}, Nothing} = nothing) = _serialize_write(io, _prepare(data, data_info)...) + +""" + serialize( + file::AbstractString, + data::AbstractDict{String, <:AbstractArray}, + data_info::Union{AbstractDict{String, String}, Nothing} = nothing; + mmap = true, + ) + +Serialize the `Dict` of tensors (`data`) into `file`. Optionally, some extra information can be provided as a + `Dict{String, String}` (`data_info`). +""" +function serialize( + file::AbstractString, + data::AbstractDict{String, <:AbstractArray}, + data_info::Union{AbstractDict{String, String}, Nothing} = nothing; + mmap = true, +) + if mmap + open(file, "w+") do io + expected_size, header_size, n, header_bytes, tensors = _prepare(data, data_info) + buf = Mmap.mmap(io, Vector{UInt8}, expected_size) + _serialize_copyto!(buf, expected_size, header_size, n, header_bytes, tensors) + Mmap.sync!(buf) + end + else + open(io->serialize(io, data, data_info), file, "w+") + end +end + +""" + load_safetensors(filename::AbstractString; mmap = true) + +Eagerly load the tensors in `filename`. +""" +function load_safetensors(filename::AbstractString; mmap = true) + safetensor = deserialize(filename; mmap) + tensors = Dict{String, Array}(); sizehint!(tensors, length(safetensor)) + for (name, tensor) in safetensor + tensors[name] = collect(tensor) + end + return tensors end export load_safetensors diff --git a/test/gendata.py b/test/gendata.py index aceb20b..187a079 100644 --- a/test/gendata.py +++ b/test/gendata.py @@ -20,4 +20,37 @@ tensors[key] = tensor save_file(tensors, "./model.safetensors") -loaded = load_file("./model.safetensors") \ No newline at end of file +loaded = load_file("./model.safetensors") + +import torch +from safetensors.torch import save_file as th_save_file +from safetensors.torch import load_file as th_load_file + +th_dtypes = [ + torch.bool, torch.uint8, torch.int8, torch.float8_e5m2, torch.float8_e4m3fn, + torch.int16, #torch.uint16, + torch.float16, torch.bfloat16, + torch.int32, #torch.uint32, + torch.float32, + torch.int64, #torch.uint64, + torch.float64, +] +th_tensors = {} +for (shape, dtype) in product(shapes, th_dtypes): + name = dtype.__repr__().split('.')[1].lower() + s = ''.join(map(str,shape)) + key = f"{name}_{s}" + if key.startswith("bool"): + tensor = torch.randint(0, 1, shape, dtype=dtype) + elif key.startswith("int") or key.startswith("uint"): + tensor = torch.randint(0, 100, shape, dtype=dtype) + elif key.startswith("float8"): + tensor = torch.randn(shape, dtype=torch.float32).type(dtype) + else: + tensor = torch.randn(shape, dtype=dtype).type(dtype) + tensors[key] = tensor + +th_save_file(tensors, "./torch.safetensors") +loaded = th_load_file("./torch.safetensors") +th_save_file(tensors, "./torch_metadata.safetensors", {"test":"metadata", "version":"2.2"}) +loaded = th_load_file("./torch_metadata.safetensors") diff --git a/test/runtests.jl b/test/runtests.jl index f6808c9..a70d8b6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,5 @@ using SafeTensors +using JSON3 using Test @@ -65,4 +66,129 @@ end t, s = type_and_shape(k) @test check_tensor(t, s, d[k]) end + + @testset "rust test" begin + data = reshape(Float32[0, 3, 1, 4, 2, 5], (1,2,3)) + io = IOBuffer() + SafeTensors.serialize(io, Dict("attn.0"=>data)) + out = take!(io) + @test out == UInt8[ + 64, 0, 0, 0, 0, 0, 0, 0, 123, 34, 97, 116, 116, 110, 46, 48, 34, 58, 123, 34, 100, + 116, 121, 112, 101, 34, 58, 34, 70, 51, 50, 34, 44, 34, 115, 104, 97, 112, 101, 34, + 58, 91, 49, 44, 50, 44, 51, 93, 44, 34, 100, 97, 116, 97, 95, 111, 102, 102, 115, + 101, 116, 115, 34, 58, 91, 48, 44, 50, 52, 93, 125, 125, 0, 0, 0, 0, 0, 0, 128, 63, + 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64, 0, 0, 160, 64 + ] + @test SafeTensors.deserialize(out)["attn.0"] == data + data = reshape(data, (1,1,2,3)) + SafeTensors.serialize(io, Dict("attn0"=>data)) + out = take!(io) + @test out == UInt8[ + 72, 0, 0, 0, 0, 0, 0, 0, 123, 34, 97, 116, 116, 110, 48, 34, 58, 123, 34, 100, 116, + 121, 112, 101, 34, 58, 34, 70, 51, 50, 34, 44, 34, 115, 104, 97, 112, 101, 34, 58, + 91, 49, 44, 49, 44, 50, 44, 51, 93, 44, 34, 100, 97, 116, 97, 95, 111, 102, 102, + 115, 101, 116, 115, 34, 58, 91, 48, 44, 50, 52, 93, 125, 125, 32, 32, 32, 32, 32, + 32, 32, 0, 0, 0, 0, 0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64, 0, 0, + 160, 64 + ] + @test SafeTensors.deserialize(out)["attn0"] == data + data = reshape(data, (1,2,3)) + SafeTensors.serialize(io, Dict("attn.0"=>data)) + out = take!(io) + parsed = SafeTensors.deserialize(out) + out_buf = vec(parsed["attn.0"][:, 1, :]) + @test reinterpret(UInt8, out_buf) == UInt8[0,0,0,0,0,0,128,63,0,0,0,64] + @test out_buf == Float32[0,1,2] + out_buf = vec(parsed["attn.0"][:, :, 1]) + @test reinterpret(UInt8, out_buf) == UInt8[0,0,0,0,0,0,64,64] + @test out_buf == Float32[0,3] + serialized = b"8\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[],\"data_offsets\":[0,4]}}\x00\x00\x00\x00" + loaded = SafeTensors.deserialize(serialized) + @test collect(keys(loaded)) == ["test"] + tensor = loaded["test"] + @test size(tensor) == () + @test eltype(tensor) == Int32 + @test iszero(tensor[]) + serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + loaded = SafeTensors.deserialize(serialized) + @test length(loaded) == 1 + @test collect(keys(loaded)) == ["test"] + tensor = loaded["test"] + @test size(tensor) == (2,2) + @test eltype(tensor) == Int32 + @test iszero(tensor) + tensors = Dict{String, SafeTensors.TensorInfo}() + dtype = SafeTensors.F32 + shape = (2,2) + data_offsets = (0, 16) + for i = 1:10 + tensors["weight_$(i-1)"] = SafeTensors.TensorInfo(dtype, shape, data_offsets) + end + metadata = SafeTensors.HashMetadata(nothing, tensors) + serialized = codeunits(JSON3.write(metadata)) + n = length(serialized) + file = tempname() + open(file, "w+") do io + write(io, n) + write(io, serialized) + write(io, b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0") + end + reloaded = read(file) + @test_throws ErrorException("Invalid Offset: `weight_0`") SafeTensors.deserialize(reloaded) + serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00extra_bogus_data_for_polyglot_file" + @test_throws ErrorException("Metadata Incomplete Buffer") SafeTensors.deserialize(serialized) + serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"; + @test_throws ErrorException("Metadata Incomplete Buffer") SafeTensors.deserialize(serialized)s + serialized = b"<\x00\x00\x00\x00\xff\xff\xff{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"; + @test_throws ErrorException("Header Too Large") SafeTensors.deserialize(serialized) + serialized = b"" + @test_throws ErrorException("Header Too Small") SafeTensors.deserialize(serialized) + serialized = b"<\x00\x00\x00\x00\x00\x00\x00" + @test_throws ErrorException("Invalid Header Length") SafeTensors.deserialize(serialized) + if VERSION < v"1.8" + serialized = b"\x01\x00\x00\x00\x00\x00\x00\x00\xff" + @test_throws ArgumentError SafeTensors.deserialize(serialized) + serialized = b"\x01\x00\x00\x00\x00\x00\x00\x00{" + @test_throws ArgumentError SafeTensors.deserialize(serialized) + else + serialized = b"\x01\x00\x00\x00\x00\x00\x00\x00\xff" + @test_throws "ArgumentError: invalid JSON" SafeTensors.deserialize(serialized) + serialized = b"\x01\x00\x00\x00\x00\x00\x00\x00{" + @test_throws "ArgumentError: invalid JSON" SafeTensors.deserialize(serialized) + end + serialized = b"\x06\x00\x00\x00\x00\x00\x00\x00{}\x0D\x20\x09\x0A" + @test iszero(length(SafeTensors.deserialize(serialized))) + serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,0],\"data_offsets\":[0, 0]}}" + loaded = SafeTensors.deserialize(serialized) + @test collect(keys(loaded)) == ["test"] + tensor = loaded["test"] + @test size(tensor) == (2,0) + @test eltype(tensor) == Int32 + @test isempty(tensor) + serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0, 4]}}" + @test_throws ErrorException("Tensor Invalid Info") SafeTensors.deserialize(serialized) + serialized = b"O\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,18446744073709551614],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + @test_throws OverflowError SafeTensors.deserialize(serialized) + serialized = b"N\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,9223372036854775807],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + @test_throws OverflowError SafeTensors.deserialize(serialized) + end + + @testset "torch" begin + for thfile in ("torch", "torch_metadata") + file = joinpath(@__DIR__, "$(thfile).safetensors") + jl_bytes = [] + for use_mmap in (true, false) + torch_tensors = SafeTensors.deserialize(file; mmap = use_mmap) + tfile = tempname() + SafeTensors.serialize(tfile, Dict(torch_tensors), torch_tensors.metadata; mmap = use_mmap) + jl_tensors = SafeTensors.deserialize(tfile; mmap = use_mmap) + push!(jl_bytes, read(tfile)) + @test jl_tensors.metadata == torch_tensors.metadata + for (name, tensor) in torch_tensors + @test collect(jl_tensors[name]) == collect(tensor) + end + end + jl_bytes[1] == jl_bytes[2] + end + end end diff --git a/test/torch.safetensors b/test/torch.safetensors new file mode 100644 index 0000000..d191810 Binary files /dev/null and b/test/torch.safetensors differ diff --git a/test/torch_metadata.safetensors b/test/torch_metadata.safetensors new file mode 100644 index 0000000..0b92720 Binary files /dev/null and b/test/torch_metadata.safetensors differ