From e85b999a81e852784dadf1fff72ebc34fae25baf Mon Sep 17 00:00:00 2001 From: Jakub Wronowski Date: Wed, 14 Feb 2024 20:44:54 +0100 Subject: [PATCH 1/2] multidimensional fnt --- Manifest.toml | 176 +++++++++++++++++++++++++++++++ Project.toml | 3 + src/NumberTheoreticTransforms.jl | 2 + src/fnt.jl | 61 +++++------ test/fnt.jl | 46 +++++++- 5 files changed, 254 insertions(+), 34 deletions(-) create mode 100644 Manifest.toml diff --git a/Manifest.toml b/Manifest.toml new file mode 100644 index 0000000..bc176e5 --- /dev/null +++ b/Manifest.toml @@ -0,0 +1,176 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.10.0" +manifest_format = "2.0" +project_hash = "4ae2d6d0e9c3229ff0a83387a28e10dc57c90a74" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.CodeTracking]] +deps = ["InteractiveUtils", "UUIDs"] +git-tree-sha1 = "c0216e792f518b39b22212127d4a84dc31e4e386" +uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2" +version = "1.3.5" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[deps.Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[deps.JuliaInterpreter]] +deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"] +git-tree-sha1 = "04663b9e1eb0d0eabf76a6d0752e0dac83d53b36" +uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" +version = "0.9.28" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.4" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "8.4.0+0" + +[[deps.LibGit2]] +deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.6.4+0" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.11.0+1" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[deps.LoweredCodeUtils]] +deps = ["JuliaInterpreter"] +git-tree-sha1 = "20ce1091ba18bcdae71ad9b71ee2367796ba6c48" +uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b" +version = "2.4.4" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.2+1" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2023.1.10" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.OrderedCollections]] +git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.6.3" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.10.0" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Random]] +deps = ["SHA"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.0" + +[[deps.Revise]] +deps = ["CodeTracking", "Distributed", "FileWatching", "JuliaInterpreter", "LibGit2", "LoweredCodeUtils", "OrderedCollections", "Pkg", "REPL", "Requires", "UUIDs", "Unicode"] +git-tree-sha1 = "3fe4e5b9cdbb9bbc851c57b149e516acc07f8f72" +uuid = "295af30f-e4ad-537b-8983-00126c2a3abe" +version = "3.5.13" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.3" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.13+1" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.52.0+1" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+2" diff --git a/Project.toml b/Project.toml index 488cffb..c90f6d6 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,9 @@ uuid = "8497c1d1-af72-4391-8d22-bdd566511a1c" authors = ["Jakub Wronowski "] version = "1.0.0" +[deps] +Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" + [compat] julia = "1" diff --git a/src/NumberTheoreticTransforms.jl b/src/NumberTheoreticTransforms.jl index f8e43fb..6382891 100644 --- a/src/NumberTheoreticTransforms.jl +++ b/src/NumberTheoreticTransforms.jl @@ -1,5 +1,7 @@ module NumberTheoreticTransforms +using Revise + include("ntt.jl") include("fnt.jl") diff --git a/src/fnt.jl b/src/fnt.jl index 83011cf..ee7bede 100644 --- a/src/fnt.jl +++ b/src/fnt.jl @@ -50,7 +50,7 @@ end Order input to perform radix-2 structured calculation. It sorts array by bit-reversed 0-based sample index. """ -function radix2sort!(data::Array{T, 1}) where {T<:Integer} +function radix2sort!(data::AbstractArray{T, 1}) where {T<:Integer} N = length(data) @assert ispow2(N) @@ -71,7 +71,16 @@ function radix2sort!(data::Array{T, 1}) where {T<:Integer} return data end -function fnt!(x::Array{T, 1}, g::T, q::T) where {T<:Integer} +function radix2sort!(data::AbstractArray{T, N}) where {T<:Integer, N} + for d in 1:N + other_dims = tuple(filter(!=(d), 1:N)...) + for s in eachslice(data, dims = other_dims) + radix2sort!(s) + end + end +end + +function fnt!(x::AbstractArray{T, 1}, g::T, q::T) where {T<:Integer} N = length(x) @assert ispow2(N) @assert isfermat(q) @@ -104,26 +113,6 @@ function fnt!(x::Array{T, 1}, g::T, q::T) where {T<:Integer} return x end -""" - fnt!(x, g, q) - -In-place version of `fnt`. That means it will store result in the `x` array. -""" -function fnt!(x::Array{T,2}, g::T, q::T) where {T<:Integer} - N, M = size(x) - @assert N == M #TODO: make it work for N != M (need different g for each dim) - - for n in 1:N - x[n, :] = fnt!(x[n, :], g, q) - end - - for m in 1:M - x[:, m] = fnt!(x[:, m], g, q) - end - - return x -end - """ fnt(x, g, q) @@ -131,7 +120,7 @@ The Fermat Number Transform returns the same result as `ntt` function using more performant algorithm. When `q` has \$ 2^{2^t}+1 \$ form the calculation can be performed with O(N*log(N)) operation instead of O(N^2) for `ntt`. """ -function fnt(x::Array{T}, g::T, q::T) where {T<:Integer} +function fnt(x::AbstractArray{T}, g::T, q::T) where {T<:Integer} return fnt!(copy(x), g, q) end @@ -140,7 +129,7 @@ end In-place version of `ifnt`. That means it will store result in the `y` array. """ -function ifnt!(y::Array{T,1}, g::T, q::T) where {T<:Integer} +function ifnt!(y::AbstractArray{T,1}, g::T, q::T) where {T<:Integer} N = length(y) inv_N = invmod(N, q) inv_g = invmod(g, q) @@ -153,17 +142,25 @@ function ifnt!(y::Array{T,1}, g::T, q::T) where {T<:Integer} return x end -function ifnt!(y::Array{T,2}, g::T, q::T) where {T<:Integer} - N, M = size(y) - - for m in 1:M - y[:, m] = ifnt!(y[:, m], g, q) +function fnt!(y::AbstractArray{T, N}, g::T, q::T) where {T<:Integer, N} + for d in 1:N + other_dims = tuple(filter(!=(d), 1:N)...) + for s in eachslice(y, dims = other_dims) + fnt!(s, g, q) + end end - for n in 1:N - y[n, :] = ifnt!(y[n, :], g, q) + return y +end + +function ifnt!(y::AbstractArray{T, N}, g::T, q::T) where {T<:Integer, N} + for d in 1:N + other_dims = tuple(filter(!=(d), 1:N)...) + for s in eachslice(y, dims = other_dims) + ifnt!(s, g, q) + end end - + return y end diff --git a/test/fnt.jl b/test/fnt.jl index 4d5bead..7c4944a 100644 --- a/test/fnt.jl +++ b/test/fnt.jl @@ -1,9 +1,7 @@ @testset "FNT 1D" begin for t in BigInt.(1:13) - @show t x = [1:2^(t+1);] .|> BigInt - @show length(x) g = 2 |> BigInt q = 2^2^t + 1 |> BigInt @time @test ifnt(fnt(x, g, q), g, q) == x @@ -122,4 +120,48 @@ end x = mod.(rand(0:limit, 1000), (q-1)^2) @test mod.(x, q) == modfermat.(x, q) end +end + +@testset "multidimensional fnt timings" begin + + for i in 1:2 + (g, q, n) = (2255, 65537, 65536) + + rnd = rand(0:16, n); + x = copy(rnd); + y1, tm1 = @timed fnt(x, g, q); + cmplx1 = n * log2(n); + + + (g, q, n) = (2256, 65537, 256); + x = reshape(copy(rnd), (n, n)); + y2, tm2 = @timed fnt(x, g, q); + cmplx2 = 2*n* n*log2(n); + + n = 65536^(1/3); + cmplx3 = 3*n * 2*n* n*log2(n); + + (g, q, n) = (1024, 65537, 16); + x = reshape(copy(rnd), (n, n, n, n)); + y4, tm4 = @timed fnt(x, g, q); + + (g, q, n) = (256, 65537, 4); + x = reshape(copy(rnd), (n, n, n, n, n, n, n, n)); + y8, tm8 = @timed fnt(x, g, q); # 0.03 s + + @show tm1 tm2 tm4 tm8 + end + + (g, q, n) = (233, 65537, 4096) + rnd = rand(0:16, n, n); + x = copy(rnd); + y1, tm1 = @timed fnt(x, g, q); + cmplx1 = n * log2(n); + + (g, q, n) = (255, 65537, 64) + rnd = rand(0:16, n, n, n, n); + x = copy(rnd); + y2, tm2 = @timed fnt(x, g, q); + + @show tm1 tm2 end \ No newline at end of file From ab3fe4c615bb5521926f6f4bca03ada840a7616a Mon Sep 17 00:00:00 2001 From: Jakub Wronowski Date: Wed, 14 Feb 2024 20:51:04 +0100 Subject: [PATCH 2/2] useful scripts --- scripts/image-deconv-big.jl | 24 +++++++++++++++++ scripts/image-deconv.jl | 52 +++++++++++++++++++++++++++++++++++++ scripts/inverse-matrix.jl | 31 ++++++++++++++++++++++ src/utils.jl | 19 ++++++++++++++ 4 files changed, 126 insertions(+) create mode 100644 scripts/image-deconv-big.jl create mode 100644 scripts/image-deconv.jl create mode 100644 scripts/inverse-matrix.jl create mode 100644 src/utils.jl diff --git a/scripts/image-deconv-big.jl b/scripts/image-deconv-big.jl new file mode 100644 index 0000000..c6050f2 --- /dev/null +++ b/scripts/image-deconv-big.jl @@ -0,0 +1,24 @@ +using Images, TestImages, Colors, ImageView +using ZernikePolynomials, NumberTheoreticTransforms + +image_float = channelview(testimage("cameraman")) +image = map(x -> x.:i, image_float) .|> BigInt + +blur_float = evaluateZernike(LinRange(-41,41,512), [12, 4, 0], [1.0, -1.0, 2.0], index=:OSA) +blur_float ./= (sum(blur_float)) +blur = blur_float .|> Normed{UInt8, 8} .|> x -> x.:i .|> BigInt +blur = blur[begin + 224:begin + 287,begin + 224:begin + 287] +blur = circshift(blur, (32, 32)) + +(g, q, n) = BigInt.((2, 4294967297, 64)) +image = image[257:256+n, 257:256+n] +image = image[256+32+begin:256-32+begin+127, 256+64+begin:256+begin+127] + +X = fnt(image, g, q) +H = fnt(blur, g, q) +Y = mod.(X .* H, q) +y = ifnt(Y, g, q) +blurred_image = y .>> 8 + +imshow(UInt8.(image)) +imshow(UInt8.(blurred_image)) diff --git a/scripts/image-deconv.jl b/scripts/image-deconv.jl new file mode 100644 index 0000000..980ee1e --- /dev/null +++ b/scripts/image-deconv.jl @@ -0,0 +1,52 @@ +using Images, TestImages, Colors, ImageView +using ZernikePolynomials, NumberTheoreticTransforms + +image_float = channelview(testimage("cameraman")) +image = map(x -> x.:i, image_float) .|> Int64 + +# lens abberation blur model +blur_float = evaluateZernike(LinRange(-41,41,512), [12, 4, 0], [1.0, -1.0, 2.0], index=:OSA) +blur_float ./= (sum(blur_float)) +blur = blur_float .|> Normed{UInt8, 8} .|> x -> x.:i .|> Int64 +blur = circshift(blur, (256, 256)) + +# 2D convolution with FNT +t = 4 +(g, q) = (314, 2^2^t+1) # g for N = 512 found with scripts/find-ntt.jl +X = fnt(image, g, q) +H = fnt(blur, g, q) +Y = mod.(X .* H, q) +y = ifnt(Y, g, q) +blurred_image = y .>> 8 + +imshow(image) +imshow(blurred_image) + +(g, q, n) = (255, 65537, 64) +image = image[257:256+n, 257:256+n] + +blur_float = evaluateZernike(LinRange(-41,41,512), [12, 4, 0], [1.0, -1.0, 2.0], index=:OSA) +blur_float ./= (sum(blur_float)) +blur = blur_float .|> Normed{UInt8, 8} .|> x -> x.:i .|> Int64 +blur = blur[begin + 224:begin + 287,begin + 224:begin + 287] +blur = circshift(blur, (32, 32)) + +X = fnt(image, g, q) +H = fnt(blur, g, q) +Y = mod.(X .* H, q) +y = ifnt(Y, g, q) +blurred_image = y + +imshow(image) +imshow(blurred_image) + + +(g, q, n) = (8, 4294967297, 64) +X = fnt(image, g, q) +H = fnt(blur, g, q) +Y = mod.(X .* H, q) +y = ifnt(Y, g, q) +blurred_image = y .>> 24 +imshow(image) +imshow(blur) +imshow(blurred_image) \ No newline at end of file diff --git a/scripts/inverse-matrix.jl b/scripts/inverse-matrix.jl new file mode 100644 index 0000000..57663f4 --- /dev/null +++ b/scripts/inverse-matrix.jl @@ -0,0 +1,31 @@ + + +using InvertedIndices # for cleaner code, you can remove this if you really want to. +function cofactor(A::AbstractMatrix, T = Int64) + ax = axes(A) + out = similar(A, T, ax) + for col in ax[1] + for row in ax[2] + out[col, row] = (T(-1))^(T(col) + T(row)) * LinearAlgebra.det_bareiss(A[Not(col), Not(row)]) + end + end + return out +end + +# mod 17 +q = 17 +T = [1 1 1 1; 1 4 16 13; 1 16 1 16; 1 13 16 4] + +d = LinearAlgebra.det_bareiss(T) +d_inv = invmod(d, q) + +A = mod.(cofactor(T, Int64), q) +T_inv = mod.(d_inv * A, q) + +@assert mod.(T * T_inv, q) == I(4) + + +using Mods + +T = [1 1 1 1; 1 4 16 13; 1 16 1 16; 1 13 16 4] .|> Mod{17} +inv(T) diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..67ede4e --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,19 @@ + +import LinearAlgebra: det +function det(matrix::Matrix{Z64{m}}) where m + LinearAlgebra.det_bareiss(matrix) +end + +function adjugate(matrix::Matrix{Z64{m}}) where m + out = similar(matrix) + for (col, row) in Iterators.product(axes(matrix)...) + d = det(matrix[vcat(begin:col-1, col+1:end), vcat(begin:row-1, row+1:end),]) + res = Z{m}(powermod(m-1, col+row, m)) * d + out[col, row] = res + end + return out +end + +function inv(matrix::Matrix{Z64{m}}) where m + inv(det(matrix)) * adjugate(matrix) +end