Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

oneMKL - L1: nrm2 support #2

Open
wants to merge 25 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions deps/src/onemkl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,32 @@ extern "C" int onemklZgemm(syclQueue_t device_queue, onemklTranspose transA,
return 0;
}

extern "C" void onemklDnrm2(syclQueue_t device_queue, int64_t n, const double *x,
int64_t incx, double *result) {
auto status = oneapi::mkl::blas::column_major::nrm2(device_queue->val, n, x, incx, result);
status.wait();
}

extern "C" void onemklSnrm2(syclQueue_t device_queue, int64_t n, const float *x,
int64_t incx, float *result) {
auto status = oneapi::mkl::blas::column_major::nrm2(device_queue->val, n, x, incx, result);
status.wait();
}

extern "C" void onemklCnrm2(syclQueue_t device_queue, int64_t n, const float _Complex *x,
int64_t incx, float *result) {
auto status = oneapi::mkl::blas::column_major::nrm2(device_queue->val, n,
reinterpret_cast<const std::complex<float> *>(x), incx, result);
status.wait();
}

extern "C" void onemklZnrm2(syclQueue_t device_queue, int64_t n, const double _Complex *x,
int64_t incx, double *result) {
auto status = oneapi::mkl::blas::column_major::nrm2(device_queue->val, n,
reinterpret_cast<const std::complex<double> *>(x), incx, result);
status.wait();
}

extern "C" void onemklDcopy(syclQueue_t device_queue, int64_t n, const double *x,
int64_t incx, double *y, int64_t incy) {
oneapi::mkl::blas::column_major::copy(device_queue->val, n, x, incx, y, incy);
Expand Down
10 changes: 10 additions & 0 deletions deps/src/onemkl.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@ int onemklZgemm(syclQueue_t device_queue, onemklTranspose transA,
const double _Complex *B, int64_t ldb, double _Complex beta,
double _Complex *C, int64_t ldc);

// Supported Level-1: Nrm2
void onemklDnrm2(syclQueue_t device_queue, int64_t n, const double *x,
int64_t incx, double *result);
void onemklSnrm2(syclQueue_t device_queue, int64_t n, const float *x,
int64_t incx, float *result);
void onemklCnrm2(syclQueue_t device_queue, int64_t n, const float _Complex *x,
int64_t incx, float *result);
void onemklZnrm2(syclQueue_t device_queue, int64_t n, const double _Complex *x,
int64_t incx, double *result);

void onemklDcopy(syclQueue_t device_queue, int64_t n, const double *x,
int64_t incx, double *y, int64_t incy);
void onemklScopy(syclQueue_t device_queue, int64_t n, const float *x,
Expand Down
27 changes: 26 additions & 1 deletion lib/mkl/libonemkl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,31 @@ function onemklZgemm(device_queue, transA, transB, m, n, k, alpha, A, lda, B, ld
C::ZePtr{ComplexF64}, ldc::Int64)::Cint
end

function onemklDnrm2(device_queue, n, x, incx, result)
@ccall liboneapi_support.onemklDnrm2(device_queue::syclQueue_t,
n::Int64, x::ZePtr{Cdouble}, incx::Int64,
result::RefOrZeRef{Cdouble})::Cvoid
end

function onemklSnrm2(device_queue, n, x, incx, result)
@ccall liboneapi_support.onemklSnrm2(device_queue::syclQueue_t,
n::Int64, x::ZePtr{Cfloat}, incx::Int64,
result::RefOrZeRef{Cfloat})::Cvoid
end

function onemklCnrm2(device_queue, n, x, incx, result)
@ccall liboneapi_support.onemklCnrm2(device_queue::syclQueue_t,
n::Int64, x::ZePtr{ComplexF32}, incx::Int64,
result::RefOrZeRef{Cfloat})::Cvoid
end

function onemklZnrm2(device_queue, n, x, incx, result)
@ccall liboneapi_support.onemklZnrm2(device_queue::syclQueue_t,
n::Int64, x::ZePtr{ComplexF64}, incx::Int64,
result::RefOrZeRef{Cdouble})::Cvoid
end


function onemklDcopy(device_queue, n, x, incx, y, incy)
@ccall liboneapi_support.onemklDcopy(device_queue::syclQueue_t, n::Int64,
x::ZePtr{Cdouble}, incx::Int64,
Expand Down Expand Up @@ -104,4 +129,4 @@ end
function onemklZamin(device_queue, n, x, incx, result)
@ccall liboneapi_support.onemklZamin(device_queue::syclQueue_t, n::Int64,
x::ZePtr{ComplexF64}, incx::Int64, result::ZePtr{Int64})::Cvoid
end
end
2 changes: 2 additions & 0 deletions lib/mkl/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ function gemm_dispatch!(C::oneStridedVecOrMat, A, B, alpha::Number=true, beta::N
end
end

LinearAlgebra.norm(x::oneStridedVecOrMat{<:onemklFloat}) = oneMKL.nrm2(length(x), x)

for NT in (Number, Real)
# NOTE: alpha/beta also ::Real to avoid ambiguities with certain Base methods
@eval begin
Expand Down
18 changes: 17 additions & 1 deletion lib/mkl/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,23 @@ function Base.convert(::Type{onemklTranspose}, trans::Char)
end
end


# level 1
## nrm2
for (fname, elty, ret_type) in
((:onemklDnrm2, :Float64,:Float64),
(:onemklSnrm2, :Float32,:Float32),
(:onemklCnrm2, :ComplexF32,:Float32),
(:onemklZnrm2, :ComplexF64,:Float64))
@eval begin
function nrm2(n::Integer, x::oneStridedArray{$elty})
queue = global_queue(context(x), device(x))
result = oneArray{$ret_type}([0]);
$fname(sycl_queue(queue), n, x, stride(x,1), result)
res = Array(result)
return res[1]
end
end
end

#
# BLAS
Expand Down
17 changes: 12 additions & 5 deletions test/onemkl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,17 @@ m = 20
oneMKL.copy!(m,A,B)
@test Array(A) == Array(B)

# testing oneMKL max and min
a = convert.(T, [1.0, 2.0, -0.8, 5.0, 3.0])
ca = oneArray(a)
@test BLAS.iamax(a) == oneMKL.iamax(ca)
@test oneMKL.iamin(ca) == 3
@testset "nrm2" begin
# Test nrm2 primitive
@test testf(norm, rand(T,m))
end

@testset "iamax/iamin" begin
# testing oneMKL max and min
a = convert.(T, [1.0, 2.0, -0.8, 5.0, 3.0])
ca = oneArray(a)
@test BLAS.iamax(a) == oneMKL.iamax(ca)
@test oneMKL.iamin(ca) == 3
end
end # level 1 testset
end