From 4a303692e716e5746cb0c42c1c2095cd00621c36 Mon Sep 17 00:00:00 2001 From: MikaelSlevinsky Date: Tue, 9 Feb 2021 21:12:12 -0600 Subject: [PATCH] add const char TRANS to spinsph2fourier --- examples/spinweighted.c | 4 +- src/drivers.c | 82 +++++++++++++++++++++++++++++------------ src/fasttransforms.h | 4 +- test/test_drivers.c | 11 ++++-- test/test_fftw.c | 8 ++-- 5 files changed, 73 insertions(+), 36 deletions(-) diff --git a/examples/spinweighted.c b/examples/spinweighted.c index e25ae206..9a3577a0 100644 --- a/examples/spinweighted.c +++ b/examples/spinweighted.c @@ -71,7 +71,7 @@ int main(void) { ft_spinsphere_fftw_plan * PA = ft_plan_spinsph_analysis(N, M, 0); ft_execute_spinsph_analysis(PA, F, N, M); - ft_execute_fourier2spinsph(P, F, N, M); + ft_execute_fourier2spinsph('N', P, F, N, M); printf("Its spin-0 spherical harmonic coefficients are:\n\n"); @@ -109,7 +109,7 @@ int main(void) { PA = ft_plan_spinsph_analysis(N, M, 1); ft_execute_spinsph_analysis(PA, F, N, M); - ft_execute_fourier2spinsph(P, F, N, M); + ft_execute_fourier2spinsph('N', P, F, N, M); printmat("U¹sampling", FMT, (double *) F, 2*N, M); printf("\n"); diff --git a/src/drivers.c b/src/drivers.c index 9c5bea68..5b07850c 100644 --- a/src/drivers.c +++ b/src/drivers.c @@ -1146,36 +1146,70 @@ ft_spin_harmonic_plan * ft_plan_spinsph2fourier(const int n, const int s) { return P; } -void ft_execute_spinsph2fourier(const ft_spin_harmonic_plan * P, ft_complex * A, const int N, const int M) { - ft_execute_spinsph_hi2lo(P->SRP, A, P->B, M); +void ft_execute_spinsph2fourier(const char TRANS, const ft_spin_harmonic_plan * P, ft_complex * A, const int N, const int M) { ft_complex alpha = {1.0, 0.0}; - if (P->s%2 == 0) { - cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+3)/4, &alpha, P->P1, N, A, 4*N); - cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+2)/4, &alpha, P->P2, N, A+N, 4*N); - cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+1)/4, &alpha, P->P2, N, A+2*N, 4*N); - cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, M/4, &alpha, P->P1, N, A+3*N, 4*N); + if (TRANS == 'N') { + ft_execute_spinsph_hi2lo(P->SRP, A, P->B, M); + if (P->s%2 == 0) { + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+3)/4, &alpha, P->P1, N, A, 4*N); + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+2)/4, &alpha, P->P2, N, A+N, 4*N); + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+1)/4, &alpha, P->P2, N, A+2*N, 4*N); + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, M/4, &alpha, P->P1, N, A+3*N, 4*N); + } + else { + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+3)/4, &alpha, P->P2, N, A, 4*N); + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+2)/4, &alpha, P->P1, N, A+N, 4*N); + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+1)/4, &alpha, P->P1, N, A+2*N, 4*N); + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, M/4, &alpha, P->P2, N, A+3*N, 4*N); + } } - else { - cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+3)/4, &alpha, P->P2, N, A, 4*N); - cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+2)/4, &alpha, P->P1, N, A+N, 4*N); - cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+1)/4, &alpha, P->P1, N, A+2*N, 4*N); - cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, M/4, &alpha, P->P2, N, A+3*N, 4*N); + else if (TRANS == 'T') { + if (P->s%2 == 0) { + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, (M+3)/4, &alpha, P->P1, N, A, 4*N); + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, (M+2)/4, &alpha, P->P2, N, A+N, 4*N); + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, (M+1)/4, &alpha, P->P2, N, A+2*N, 4*N); + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, M/4, &alpha, P->P1, N, A+3*N, 4*N); + } + else { + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, (M+3)/4, &alpha, P->P2, N, A, 4*N); + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, (M+2)/4, &alpha, P->P1, N, A+N, 4*N); + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, (M+1)/4, &alpha, P->P1, N, A+2*N, 4*N); + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, M/4, &alpha, P->P2, N, A+3*N, 4*N); + } + ft_execute_spinsph_lo2hi(P->SRP, A, P->B, M); } } -void ft_execute_fourier2spinsph(const ft_spin_harmonic_plan * P, ft_complex * A, const int N, const int M) { +void ft_execute_fourier2spinsph(const char TRANS, const ft_spin_harmonic_plan * P, ft_complex * A, const int N, const int M) { ft_complex alpha = {1.0, 0.0}; - if (P->s%2 == 0) { - cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+3)/4, &alpha, P->P1inv, N, A, 4*N); - cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+2)/4, &alpha, P->P2inv, N, A+N, 4*N); - cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+1)/4, &alpha, P->P2inv, N, A+2*N, 4*N); - cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, M/4, &alpha, P->P1inv, N, A+3*N, 4*N); + if (TRANS == 'N') { + if (P->s%2 == 0) { + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+3)/4, &alpha, P->P1inv, N, A, 4*N); + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+2)/4, &alpha, P->P2inv, N, A+N, 4*N); + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+1)/4, &alpha, P->P2inv, N, A+2*N, 4*N); + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, M/4, &alpha, P->P1inv, N, A+3*N, 4*N); + } + else { + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+3)/4, &alpha, P->P2inv, N, A, 4*N); + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+2)/4, &alpha, P->P1inv, N, A+N, 4*N); + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+1)/4, &alpha, P->P1inv, N, A+2*N, 4*N); + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, M/4, &alpha, P->P2inv, N, A+3*N, 4*N); + } + ft_execute_spinsph_lo2hi(P->SRP, A, P->B, M); } - else { - cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+3)/4, &alpha, P->P2inv, N, A, 4*N); - cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+2)/4, &alpha, P->P1inv, N, A+N, 4*N); - cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+1)/4, &alpha, P->P1inv, N, A+2*N, 4*N); - cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, M/4, &alpha, P->P2inv, N, A+3*N, 4*N); + else if (TRANS == 'T') { + ft_execute_spinsph_hi2lo(P->SRP, A, P->B, M); + if (P->s%2 == 0) { + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, (M+3)/4, &alpha, P->P1inv, N, A, 4*N); + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, (M+2)/4, &alpha, P->P2inv, N, A+N, 4*N); + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, (M+1)/4, &alpha, P->P2inv, N, A+2*N, 4*N); + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, M/4, &alpha, P->P1inv, N, A+3*N, 4*N); + } + else { + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, (M+3)/4, &alpha, P->P2inv, N, A, 4*N); + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, (M+2)/4, &alpha, P->P1inv, N, A+N, 4*N); + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, (M+1)/4, &alpha, P->P1inv, N, A+2*N, 4*N); + cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, M/4, &alpha, P->P2inv, N, A+3*N, 4*N); + } } - ft_execute_spinsph_lo2hi(P->SRP, A, P->B, M); } diff --git a/src/fasttransforms.h b/src/fasttransforms.h index d01b457e..cd7a7160 100644 --- a/src/fasttransforms.h +++ b/src/fasttransforms.h @@ -403,9 +403,9 @@ void ft_destroy_spin_harmonic_plan(ft_spin_harmonic_plan * P); ft_spin_harmonic_plan * ft_plan_spinsph2fourier(const int n, const int s); /// Transform a spin-weighted spherical harmonic expansion to a bivariate Fourier series. -void ft_execute_spinsph2fourier(const ft_spin_harmonic_plan * P, ft_complex * A, const int N, const int M); +void ft_execute_spinsph2fourier(const char TRANS, const ft_spin_harmonic_plan * P, ft_complex * A, const int N, const int M); /// Transform a bivariate Fourier series to a spin-weighted spherical harmonic expansion. -void ft_execute_fourier2spinsph(const ft_spin_harmonic_plan * P, ft_complex * A, const int N, const int M); +void ft_execute_fourier2spinsph(const char TRANS, const ft_spin_harmonic_plan * P, ft_complex * A, const int N, const int M); int ft_fftw_init_threads(void); diff --git a/test/test_drivers.c b/test/test_drivers.c index 1c4df31e..a6baf251 100644 --- a/test/test_drivers.c +++ b/test/test_drivers.c @@ -1453,8 +1453,11 @@ int main(int argc, const char * argv[]) { B = (double *) BC; SP = ft_plan_spinsph2fourier(N, S); - ft_execute_spinsph2fourier(SP, AC, N, M); - ft_execute_fourier2spinsph(SP, AC, N, M); + ft_execute_spinsph2fourier('N', SP, AC, N, M); + ft_execute_fourier2spinsph('N', SP, AC, N, M); + + ft_execute_spinsph2fourier('T', SP, AC, N, M); + ft_execute_fourier2spinsph('T', SP, AC, N, M); err = ft_norm_2arg(A, B, 2*N*M)/ft_norm_1arg(B, 2*N*M); printf("ϵ_2 \t\t\t (N×M, S) = (%5ix%5i,%3i): \t |%20.2e ", N, M, S, err); @@ -1481,10 +1484,10 @@ int main(int argc, const char * argv[]) { ft_complex * AC = spinsphrand(N, M, S); SP = ft_plan_spinsph2fourier(N, S); - FT_TIME(ft_execute_spinsph2fourier(SP, AC, N, M), start, end, NTIMES) + FT_TIME(ft_execute_spinsph2fourier('N', SP, AC, N, M), start, end, NTIMES) printf("%d %.6f", S, elapsed(&start, &end, NTIMES)); - FT_TIME(ft_execute_fourier2spinsph(SP, AC, N, M), start, end, NTIMES) + FT_TIME(ft_execute_fourier2spinsph('N', SP, AC, N, M), start, end, NTIMES) printf(" %.6f\n", elapsed(&start, &end, NTIMES)); free(AC); diff --git a/test/test_fftw.c b/test/test_fftw.c index 7682e4e8..df358d19 100644 --- a/test/test_fftw.c +++ b/test/test_fftw.c @@ -405,10 +405,10 @@ int main(int argc, const char * argv[]) { US = ft_plan_spinsph_synthesis(N, M, S); UA = ft_plan_spinsph_analysis(N, M, S); - ft_execute_spinsph2fourier(SP, AC, N, M); + ft_execute_spinsph2fourier('N', SP, AC, N, M); ft_execute_spinsph_synthesis(US, AC, N, M); ft_execute_spinsph_analysis(UA, AC, N, M); - ft_execute_fourier2spinsph(SP, AC, N, M); + ft_execute_fourier2spinsph('N', SP, AC, N, M); err = ft_norm_2arg(A, B, 2*N*M)/ft_norm_1arg(B, 2*N*M); printf("ϵ_2 \t\t\t (N×M, S) = (%5ix%5i,%3i): \t |%20.2e ", N, M, S, err); @@ -438,10 +438,10 @@ int main(int argc, const char * argv[]) { US = ft_plan_spinsph_synthesis(N, M, S); UA = ft_plan_spinsph_analysis(N, M, S); - FT_TIME({ft_execute_spinsph2fourier(SP, AC, N, M); ft_execute_spinsph_synthesis(US, AC, N, M);}, start, end, NTIMES) + FT_TIME({ft_execute_spinsph2fourier('N', SP, AC, N, M); ft_execute_spinsph_synthesis(US, AC, N, M);}, start, end, NTIMES) printf("%d %.6f", N, elapsed(&start, &end, NTIMES)); - FT_TIME({ft_execute_spinsph_analysis(UA, AC, N, M); ft_execute_fourier2spinsph(SP, AC, N, M);}, start, end, NTIMES) + FT_TIME({ft_execute_spinsph_analysis(UA, AC, N, M); ft_execute_fourier2spinsph('N', SP, AC, N, M);}, start, end, NTIMES) printf(" %.6f\n", elapsed(&start, &end, NTIMES)); free(AC);