From 0b3955a78b3ba5f2d8982f871753a063f549e783 Mon Sep 17 00:00:00 2001 From: Alexis Montoison <35051714+amontoison@users.noreply.github.com> Date: Mon, 2 Dec 2024 22:05:13 -0600 Subject: [PATCH] [oneMKL] Fix gesvd! (#485) --- lib/mkl/wrappers_lapack.jl | 23 ++++++++++++----------- test/onemkl.jl | 10 ++++++++++ 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/lib/mkl/wrappers_lapack.jl b/lib/mkl/wrappers_lapack.jl index f6a64d0b..177c2510 100644 --- a/lib/mkl/wrappers_lapack.jl +++ b/lib/mkl/wrappers_lapack.jl @@ -304,30 +304,31 @@ for (bname, fname, elty, relty) in ((:onemklSgesvd_scratchpad_size, :onemklSgesv jobvt::Char, A::oneStridedMatrix{$elty}) m, n = size(A) + k = min(m, n) lda = max(1, stride(A, 2)) U = if jobu === 'A' oneMatrix{$elty}(undef, m, m) - elseif jobu == 'S' || jobu === 'O' - oneMatrix{$elty}(undef, m, min(m, n)) - elseif jobu === 'N' - oneMatrix{$elty}(undef, 0, 0) # Equivalence of CU_NULL? + elseif jobu === 'S' + oneMatrix{$elty}(undef, m, k) + elseif jobu === 'N' || jobu === 'O' + ZE_NULL else error("jobu must be one of 'A', 'S', 'O', or 'N'") end - ldu = U == oneMatrix{$elty}(undef, 0, 0) ? 1 : max(1, stride(U, 2)) - S = oneVector{$relty}(undef, min(m, n)) + ldu = U == ZE_NULL ? 1 : max(1, stride(U, 2)) + S = oneVector{$relty}(undef, k) Vt = if jobvt === 'A' oneMatrix{$elty}(undef, n, n) - elseif jobvt === 'S' || jobvt === 'O' - oneMatrix{$elty}(undef, min(m, n), n) - elseif jobvt === 'N' - oneMatrix{$elty}(undef, 0, 0) + elseif jobvt === 'S' + oneMatrix{$elty}(undef, k, n) + elseif jobvt === 'N' || jobvt === 'O' + ZE_NULL else error("jobvt must be one of 'A', 'S', 'O', or 'N'") end - ldvt = Vt == oneArray{$elty}(undef, 0, 0) ? 1 : max(1, stride(Vt, 2)) + ldvt = Vt == ZE_NULL ? 1 : max(1, stride(Vt, 2)) queue = global_queue(context(A), device()) scratchpad_size = $bname(sycl_queue(queue), jobu, jobvt, m, n, lda, ldu, ldvt) diff --git a/test/onemkl.jl b/test/onemkl.jl index 19f06926..50651584 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -1421,6 +1421,16 @@ end d_A = oneMatrix(A) U, Σ, Vt = oneMKL.gesvd!('A', 'A', d_A) @test A ≈ collect(U[:,1:n] * Diagonal(Σ) * Vt) + + for jobu in ('A', 'S', 'N', 'O') + for jobvt in ('A', 'S', 'N', 'O') + (jobu == 'A') && (jobvt == 'A') && continue + (jobu == 'O') && (jobvt == 'O') && continue + d_A = oneMatrix(A) + U2, Σ2, Vt2 = oneMKL.gesvd!(jobu, jobvt, d_A) + @test Σ ≈ Σ2 + end + end end @testset "syevd! -- heevd!" begin