diff --git a/src/array/darray.jl b/src/array/darray.jl index 6207ee245..fd898f614 100644 --- a/src/array/darray.jl +++ b/src/array/darray.jl @@ -30,7 +30,6 @@ chunks(a::ArrayDomain{N}) where {N} = DomainBlocks( ntuple(i->first(indexes(a)[i]), Val(N)), map(x->[length(x)], indexes(a))) (==)(a::ArrayDomain, b::ArrayDomain) = indexes(a) == indexes(b) -Base.getindex(arr::AbstractArray, d::ArrayDomain) = arr[indexes(d)...] function intersect(a::ArrayDomain, b::ArrayDomain) if a === b @@ -452,7 +451,8 @@ function stage(ctx::Context, d::Distribute) cs = map(d.domainchunks) do c # TODO: fix hashing #hash = uhash(c, Base.hash(Distribute, Base.hash(d.data))) - Dagger.@spawn identity(d.data[c]) + data_view = d.data[indexes(c)...] + Dagger.@spawn identity(data_view) end end return DArray(eltype(d.data), diff --git a/src/array/mul.jl b/src/array/mul.jl index 4cdc7a525..b46fc449e 100644 --- a/src/array/mul.jl +++ b/src/array/mul.jl @@ -30,6 +30,7 @@ function LinearAlgebra.generic_matmatmul!( return gemm_dagger!(C, transA, transB, A, B, _add) end end +# FIXME: Mixed-precision methods function _repartition_matmatmul(C, A, B, transA::Char, transB::Char) partA = A.partitioning.blocksize partB = B.partitioning.blocksize @@ -124,28 +125,26 @@ function gemm_dagger!( # A: NoTrans / B: NoTrans for k in range(1, Ant) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn LinearAlgebra.generic_matmatmul!( + InOut(Cc[m, n]), transA, transB, - alpha, In(Ac[m, k]), In(Bc[k, n]), - mzone, - InOut(Cc[m, n]), + LinearAlgebra.MulAddMul(alpha, mzone), ) end else # A: NoTrans / B: [Conj]Trans for k in range(1, Ant) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn LinearAlgebra.generic_matmatmul!( + InOut(Cc[m, n]), transA, transB, - alpha, In(Ac[m, k]), In(Bc[n, k]), - mzone, - InOut(Cc[m, n]), + LinearAlgebra.MulAddMul(alpha, mzone), ) end end @@ -154,28 +153,26 @@ function gemm_dagger!( # A: [Conj]Trans / B: NoTrans for k in range(1, Amt) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn LinearAlgebra.generic_matmatmul!( + InOut(Cc[m, n]), transA, transB, - alpha, In(Ac[k, m]), In(Bc[k, n]), - mzone, - InOut(Cc[m, n]), + LinearAlgebra.MulAddMul(alpha, mzone), ) end else # A: [Conj]Trans / B: [Conj]Trans for k in range(1, Amt) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn LinearAlgebra.generic_matmatmul!( + InOut(Cc[m, n]), transA, transB, - alpha, In(Ac[k, m]), In(Bc[n, k]), - mzone, - InOut(Cc[m, n]), + LinearAlgebra.MulAddMul(alpha, mzone), ) end end diff --git a/test/array/linalg/matmul.jl b/test/array/linalg/matmul.jl index e15f74329..f36b340e1 100644 --- a/test/array/linalg/matmul.jl +++ b/test/array/linalg/matmul.jl @@ -29,36 +29,71 @@ function test_gemm!(T, szA, szB, partA, partB) DA = distribute(A, partA) DB = distribute(B, partB) + SA = sprand(T, szA..., 0.1) + SB = sprand(T, szA..., 0.1) + + DSA = distribute(SA, partA) + DSB = distribute(SB, partB) + ## Out-of-place gemm # No transA, No transB + # Dense DC = DA * DB C = A * B @test collect(DC) ≈ C + # Sparse + DSC = DSA * DSB + SC = SA * SB + @test collect(DSC) ≈ SC if szA == szB # No transA, transB + # Dense DC = DA * DB' C = A * B' @test collect(DC) ≈ C + # Sparse + DSC = DSA * DSB' + SC = SA * SB' + @test collect(DSC) ≈ SC # transA, No transB + # Dense DC = DA' * DB C = A' * B @test collect(DC) ≈ C + # Sparse + DSC = DSA' * DSB + SC = SA' * SB + @test collect(DSC) ≈ SC end # transA, transB + # Dense DC = DA' * DB' C = A' * B' @test collect(DC) ≈ C + #= Sparse + DSC = DSA' * DSB' + SC = SA' * SB' + @test collect(DSC) ≈ SC + =# ## In-place gemm # No transA, No transB + # Dense C = zeros(T, szC...) DC = distribute(C, partC) mul!(C, A, B) mul!(DC, DA, DB) @test collect(DC) ≈ C + #= Sparse + SC = zeros(T, szC...) + DSC = distribute(SC, partC) + mul!(SC, SA, SB) + mul!(DSC, DSA, DSB) + @test collect(DSC) ≈ SC + =# if szA == szB # No transA, transB