Skip to content

Commit

Permalink
add const char TRANS to spinsph2fourier
Browse files Browse the repository at this point in the history
  • Loading branch information
MikaelSlevinsky committed Feb 10, 2021
1 parent 5dd9239 commit 4a30369
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 36 deletions.
4 changes: 2 additions & 2 deletions examples/spinweighted.c
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ int main(void) {
ft_spinsphere_fftw_plan * PA = ft_plan_spinsph_analysis(N, M, 0);

ft_execute_spinsph_analysis(PA, F, N, M);
ft_execute_fourier2spinsph(P, F, N, M);
ft_execute_fourier2spinsph('N', P, F, N, M);

printf("Its spin-0 spherical harmonic coefficients are:\n\n");

Expand Down Expand Up @@ -109,7 +109,7 @@ int main(void) {
PA = ft_plan_spinsph_analysis(N, M, 1);

ft_execute_spinsph_analysis(PA, F, N, M);
ft_execute_fourier2spinsph(P, F, N, M);
ft_execute_fourier2spinsph('N', P, F, N, M);

printmat("U¹sampling", FMT, (double *) F, 2*N, M);
printf("\n");
Expand Down
82 changes: 58 additions & 24 deletions src/drivers.c
Original file line number Diff line number Diff line change
Expand Up @@ -1146,36 +1146,70 @@ ft_spin_harmonic_plan * ft_plan_spinsph2fourier(const int n, const int s) {
return P;
}

void ft_execute_spinsph2fourier(const ft_spin_harmonic_plan * P, ft_complex * A, const int N, const int M) {
ft_execute_spinsph_hi2lo(P->SRP, A, P->B, M);
void ft_execute_spinsph2fourier(const char TRANS, const ft_spin_harmonic_plan * P, ft_complex * A, const int N, const int M) {
ft_complex alpha = {1.0, 0.0};
if (P->s%2 == 0) {
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+3)/4, &alpha, P->P1, N, A, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+2)/4, &alpha, P->P2, N, A+N, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+1)/4, &alpha, P->P2, N, A+2*N, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, M/4, &alpha, P->P1, N, A+3*N, 4*N);
if (TRANS == 'N') {
ft_execute_spinsph_hi2lo(P->SRP, A, P->B, M);
if (P->s%2 == 0) {
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+3)/4, &alpha, P->P1, N, A, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+2)/4, &alpha, P->P2, N, A+N, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+1)/4, &alpha, P->P2, N, A+2*N, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, M/4, &alpha, P->P1, N, A+3*N, 4*N);
}
else {
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+3)/4, &alpha, P->P2, N, A, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+2)/4, &alpha, P->P1, N, A+N, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+1)/4, &alpha, P->P1, N, A+2*N, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, M/4, &alpha, P->P2, N, A+3*N, 4*N);
}
}
else {
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+3)/4, &alpha, P->P2, N, A, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+2)/4, &alpha, P->P1, N, A+N, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+1)/4, &alpha, P->P1, N, A+2*N, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, M/4, &alpha, P->P2, N, A+3*N, 4*N);
else if (TRANS == 'T') {
if (P->s%2 == 0) {
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, (M+3)/4, &alpha, P->P1, N, A, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, (M+2)/4, &alpha, P->P2, N, A+N, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, (M+1)/4, &alpha, P->P2, N, A+2*N, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, M/4, &alpha, P->P1, N, A+3*N, 4*N);
}
else {
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, (M+3)/4, &alpha, P->P2, N, A, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, (M+2)/4, &alpha, P->P1, N, A+N, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, (M+1)/4, &alpha, P->P1, N, A+2*N, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, M/4, &alpha, P->P2, N, A+3*N, 4*N);
}
ft_execute_spinsph_lo2hi(P->SRP, A, P->B, M);
}
}

void ft_execute_fourier2spinsph(const ft_spin_harmonic_plan * P, ft_complex * A, const int N, const int M) {
void ft_execute_fourier2spinsph(const char TRANS, const ft_spin_harmonic_plan * P, ft_complex * A, const int N, const int M) {
ft_complex alpha = {1.0, 0.0};
if (P->s%2 == 0) {
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+3)/4, &alpha, P->P1inv, N, A, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+2)/4, &alpha, P->P2inv, N, A+N, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+1)/4, &alpha, P->P2inv, N, A+2*N, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, M/4, &alpha, P->P1inv, N, A+3*N, 4*N);
if (TRANS == 'N') {
if (P->s%2 == 0) {
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+3)/4, &alpha, P->P1inv, N, A, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+2)/4, &alpha, P->P2inv, N, A+N, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+1)/4, &alpha, P->P2inv, N, A+2*N, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, M/4, &alpha, P->P1inv, N, A+3*N, 4*N);
}
else {
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+3)/4, &alpha, P->P2inv, N, A, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+2)/4, &alpha, P->P1inv, N, A+N, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+1)/4, &alpha, P->P1inv, N, A+2*N, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, M/4, &alpha, P->P2inv, N, A+3*N, 4*N);
}
ft_execute_spinsph_lo2hi(P->SRP, A, P->B, M);
}
else {
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+3)/4, &alpha, P->P2inv, N, A, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+2)/4, &alpha, P->P1inv, N, A+N, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, (M+1)/4, &alpha, P->P1inv, N, A+2*N, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, M/4, &alpha, P->P2inv, N, A+3*N, 4*N);
else if (TRANS == 'T') {
ft_execute_spinsph_hi2lo(P->SRP, A, P->B, M);
if (P->s%2 == 0) {
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, (M+3)/4, &alpha, P->P1inv, N, A, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, (M+2)/4, &alpha, P->P2inv, N, A+N, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, (M+1)/4, &alpha, P->P2inv, N, A+2*N, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, M/4, &alpha, P->P1inv, N, A+3*N, 4*N);
}
else {
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, (M+3)/4, &alpha, P->P2inv, N, A, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, (M+2)/4, &alpha, P->P1inv, N, A+N, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, (M+1)/4, &alpha, P->P1inv, N, A+2*N, 4*N);
cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, M/4, &alpha, P->P2inv, N, A+3*N, 4*N);
}
}
ft_execute_spinsph_lo2hi(P->SRP, A, P->B, M);
}
4 changes: 2 additions & 2 deletions src/fasttransforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,9 @@ void ft_destroy_spin_harmonic_plan(ft_spin_harmonic_plan * P);
ft_spin_harmonic_plan * ft_plan_spinsph2fourier(const int n, const int s);

/// Transform a spin-weighted spherical harmonic expansion to a bivariate Fourier series.
void ft_execute_spinsph2fourier(const ft_spin_harmonic_plan * P, ft_complex * A, const int N, const int M);
void ft_execute_spinsph2fourier(const char TRANS, const ft_spin_harmonic_plan * P, ft_complex * A, const int N, const int M);
/// Transform a bivariate Fourier series to a spin-weighted spherical harmonic expansion.
void ft_execute_fourier2spinsph(const ft_spin_harmonic_plan * P, ft_complex * A, const int N, const int M);
void ft_execute_fourier2spinsph(const char TRANS, const ft_spin_harmonic_plan * P, ft_complex * A, const int N, const int M);


int ft_fftw_init_threads(void);
Expand Down
11 changes: 7 additions & 4 deletions test/test_drivers.c
Original file line number Diff line number Diff line change
Expand Up @@ -1453,8 +1453,11 @@ int main(int argc, const char * argv[]) {
B = (double *) BC;
SP = ft_plan_spinsph2fourier(N, S);

ft_execute_spinsph2fourier(SP, AC, N, M);
ft_execute_fourier2spinsph(SP, AC, N, M);
ft_execute_spinsph2fourier('N', SP, AC, N, M);
ft_execute_fourier2spinsph('N', SP, AC, N, M);

ft_execute_spinsph2fourier('T', SP, AC, N, M);
ft_execute_fourier2spinsph('T', SP, AC, N, M);

err = ft_norm_2arg(A, B, 2*N*M)/ft_norm_1arg(B, 2*N*M);
printf("ϵ_2 \t\t\t (N×M, S) = (%5ix%5i,%3i): \t |%20.2e ", N, M, S, err);
Expand All @@ -1481,10 +1484,10 @@ int main(int argc, const char * argv[]) {
ft_complex * AC = spinsphrand(N, M, S);
SP = ft_plan_spinsph2fourier(N, S);

FT_TIME(ft_execute_spinsph2fourier(SP, AC, N, M), start, end, NTIMES)
FT_TIME(ft_execute_spinsph2fourier('N', SP, AC, N, M), start, end, NTIMES)
printf("%d %.6f", S, elapsed(&start, &end, NTIMES));

FT_TIME(ft_execute_fourier2spinsph(SP, AC, N, M), start, end, NTIMES)
FT_TIME(ft_execute_fourier2spinsph('N', SP, AC, N, M), start, end, NTIMES)
printf(" %.6f\n", elapsed(&start, &end, NTIMES));

free(AC);
Expand Down
8 changes: 4 additions & 4 deletions test/test_fftw.c
Original file line number Diff line number Diff line change
Expand Up @@ -405,10 +405,10 @@ int main(int argc, const char * argv[]) {
US = ft_plan_spinsph_synthesis(N, M, S);
UA = ft_plan_spinsph_analysis(N, M, S);

ft_execute_spinsph2fourier(SP, AC, N, M);
ft_execute_spinsph2fourier('N', SP, AC, N, M);
ft_execute_spinsph_synthesis(US, AC, N, M);
ft_execute_spinsph_analysis(UA, AC, N, M);
ft_execute_fourier2spinsph(SP, AC, N, M);
ft_execute_fourier2spinsph('N', SP, AC, N, M);

err = ft_norm_2arg(A, B, 2*N*M)/ft_norm_1arg(B, 2*N*M);
printf("ϵ_2 \t\t\t (N×M, S) = (%5ix%5i,%3i): \t |%20.2e ", N, M, S, err);
Expand Down Expand Up @@ -438,10 +438,10 @@ int main(int argc, const char * argv[]) {
US = ft_plan_spinsph_synthesis(N, M, S);
UA = ft_plan_spinsph_analysis(N, M, S);

FT_TIME({ft_execute_spinsph2fourier(SP, AC, N, M); ft_execute_spinsph_synthesis(US, AC, N, M);}, start, end, NTIMES)
FT_TIME({ft_execute_spinsph2fourier('N', SP, AC, N, M); ft_execute_spinsph_synthesis(US, AC, N, M);}, start, end, NTIMES)
printf("%d %.6f", N, elapsed(&start, &end, NTIMES));

FT_TIME({ft_execute_spinsph_analysis(UA, AC, N, M); ft_execute_fourier2spinsph(SP, AC, N, M);}, start, end, NTIMES)
FT_TIME({ft_execute_spinsph_analysis(UA, AC, N, M); ft_execute_fourier2spinsph('N', SP, AC, N, M);}, start, end, NTIMES)
printf(" %.6f\n", elapsed(&start, &end, NTIMES));

free(AC);
Expand Down

0 comments on commit 4a30369

Please sign in to comment.