Skip to content

Commit

Permalink
add any rank support for softmax and logsoftmax
Browse files Browse the repository at this point in the history
  • Loading branch information
jalvesz committed Sep 29, 2024
1 parent bc2bf5a commit 5c47bf0
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 92 deletions.
23 changes: 23 additions & 0 deletions include/common.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,29 @@ ${prefix + joinstr.join([line.strip() for line in txt.split("\n")]) + suffix}$
#:endif
#:enddef

#! Brace enclosed, comma separated Fortran expressions for a shape.
#!
#! It defines an output variable with the same shape as the input variable.
#!
#! Args:
#! varname (str): Name of the variable to be used as origin
#! origrank (int): Rank of the original variable
#!
#! Returns:
#! Shape expression enclosed in braces, so that it can be used as suffix to
#! define array shapes in declarations.
#!
#:def shape(varname, origrank)
#:assert origrank > 0
#:if origrank > 1
#:call join_lines(joinstr=", ", prefix="(", suffix=")")
#:for i in range(1, origrank+1)
size(${varname}$, ${i}$)
#:endfor
#:endcall
#:endif
#:enddef


#! Generates a routine name from a generic name, rank, type and kind
#!
Expand Down
53 changes: 30 additions & 23 deletions src/stdlib_specialfunctions.fypp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#:include "common.fypp"
#:set RANKS = range(2, MAXRANK + 1)
module stdlib_specialfunctions
use stdlib_kinds, only: int8, int16, int32, int64, sp, dp, xdp, qp

Expand Down Expand Up @@ -271,26 +272,19 @@ module stdlib_specialfunctions
!!
!! Softmax function. Available for ranks 1 to 4
#:for rk, rt in REAL_KINDS_TYPES
pure module function Softmax_r1_${rk}$( x ) result( y )
pure module function Softmax_r1_${rk}$( x , dim ) result( y )
${rt}$, intent(in) :: x(:)
${rt}$ :: y(size(x))
end function
pure module function Softmax_r2_${rk}$( x , dim ) result( y )
${rt}$, intent(in) :: x(:,:)
${rt}$ :: y(size(x,dim=1),size(x,dim=2))
integer, intent(in), optional :: dim
end function
pure module function Softmax_r3_${rk}$( x , dim ) result( y )
${rt}$, intent(in) :: x(:,:,:)
${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3))
integer, intent(in), optional :: dim
end function
pure module function Softmax_r4_${rk}$( x , dim ) result( y )
${rt}$, intent(in) :: x(:,:,:,:)
${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3),size(x,dim=4))
#:for rank in RANKS
pure module function Softmax_r${rank}$_${rk}$( x , dim ) result( y )
${rt}$, intent(in) :: x${ranksuffix(rank)}$
${rt}$ :: y${shape('x', rank)}$
integer, intent(in), optional :: dim
end function
#:endfor
#:endfor
end interface
public :: softmax

Expand All @@ -303,24 +297,37 @@ module stdlib_specialfunctions
${rt}$, intent(in) :: x(:)
${rt}$ :: y(size(x))
end function
pure module function Softmax_grad_r2_${rk}$( x , dim ) result( y )
${rt}$, intent(in) :: x(:,:)
${rt}$ :: y(size(x,dim=1),size(x,dim=2))
#:for rank in RANKS
pure module function Softmax_grad_r${rank}$_${rk}$( x , dim ) result( y )
${rt}$, intent(in) :: x${ranksuffix(rank)}$
${rt}$ :: y${shape('x', rank)}$
integer, intent(in), optional :: dim
end function
pure module function Softmax_grad_r3_${rk}$( x , dim ) result( y )
${rt}$, intent(in) :: x(:,:,:)
${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3))
#:endfor
#:endfor
end interface
public :: Softmax_grad

interface LogSoftmax
!! Version: experimental
!!
!! Softmax function. Available for ranks 1 to 4
#:for rk, rt in REAL_KINDS_TYPES
pure module function LogSoftmax_r1_${rk}$( x, dim ) result( y )
${rt}$, intent(in) :: x(:)
${rt}$ :: y(size(x))
integer, intent(in), optional :: dim
end function
pure module function Softmax_grad_r4_${rk}$( x , dim ) result( y )
${rt}$, intent(in) :: x(:,:,:,:)
${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3),size(x,dim=4))
#:for rank in RANKS
pure module function LogSoftmax_r${rank}$_${rk}$( x , dim ) result( y )
${rt}$, intent(in) :: x${ranksuffix(rank)}$
${rt}$ :: y${shape('x', rank)}$
integer, intent(in), optional :: dim
end function
#:endfor
#:endfor
end interface
public :: Softmax_grad
public :: LogSoftmax

interface Softplus
!! Version: experimental
Expand Down
130 changes: 61 additions & 69 deletions src/stdlib_specialfunctions_activations.fypp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#:include "common.fypp"
#:set RANKS = range(2, MAXRANK + 1)
submodule(stdlib_specialfunctions) stdlib_specialfunctions_activations
implicit none

Expand Down Expand Up @@ -192,73 +193,44 @@ end function
! Softmax
!==================================================
#:for rk, rt in REAL_KINDS_TYPES
pure module function Softmax_r1_${rk}$( x ) result( y )
pure module function Softmax_r1_${rk}$( x , dim ) result( y )
${rt}$, intent(in) :: x(:)
${rt}$ :: y(size(x))
integer, intent(in), optional :: dim

y = exp(x - maxval(x))
y = y / sum(y)
end function

pure module function Softmax_r2_${rk}$( x , dim ) result( y )
${rt}$, intent(in) :: x(:,:)
${rt}$ :: y(size(x,dim=1),size(x,dim=2))
#:for rank in RANKS
pure module function Softmax_r${rank}$_${rk}$( x , dim ) result( y )
${rt}$, intent(in) :: x${ranksuffix(rank)}$
${rt}$ :: y${shape('x', rank)}$

integer, intent(in), optional :: dim
integer :: dim_, j

dim_ = 1; if(present(dim)) dim_ = dim

if(dim_==1)then
do j = 1, size(x,dim=2)
y(:,j) = Softmax( x(:,j) )
if(dim_<${rank}$)then
do j = 1, size(x,dim=${rank}$)
#:if rank == 2
y${select_subarray(rank, [(rank, 'j')])}$ = Softmax( x${select_subarray(rank, [(rank, 'j')])}$ )
#:else
y${select_subarray(rank, [(rank, 'j')])}$ = Softmax( x${select_subarray(rank, [(rank, 'j')])}$, dim=dim_ )
#:endif
end do
else
do j = 1, size(x,dim=1)
y(j,:) = Softmax( x(j,:) )
end do
end if
end function

pure module function Softmax_r3_${rk}$( x , dim ) result( y )
${rt}$, intent(in) :: x(:,:,:)
${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3))

integer, intent(in), optional :: dim
integer :: dim_, j

dim_ = 1; if(present(dim)) dim_ = dim

if(dim_<=2)then
do j = 1, size(x,dim=3)
y(:,:,j) = Softmax( x(:,:,j) , dim = dim_ )
end do
else
do j = 1, size(x,dim=1)
y(j,:,:) = Softmax( x(j,:,:) , dim = 2 )
end do
end if
end function

pure module function Softmax_r4_${rk}$( x , dim ) result( y )
${rt}$, intent(in) :: x(:,:,:,:)
${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3),size(x,dim=4))

integer, intent(in), optional :: dim
integer :: dim_, j

dim_ = 1; if(present(dim)) dim_ = dim

if(dim_<=3)then
do j = 1, size(x,dim=4)
y(:,:,:,j) = Softmax( x(:,:,:,j) , dim = dim_ )
end do
else
do j = 1, size(x,dim=1)
y(j,:,:,:) = Softmax( x(j,:,:,:) , dim = 3 )
#:if rank == 2
y${select_subarray(rank, [(1, 'j')])}$ = Softmax( x${select_subarray(rank, [(1, 'j')])}$ )
#:else
y${select_subarray(rank, [(1, 'j')])}$ = Softmax( x${select_subarray(rank, [(1, 'j')])}$, dim=${rank-1}$ )
#:endif
end do
end if
end function
#:endfor

pure module function Softmax_grad_r1_${rk}$( x ) result( y )
${rt}$, intent(in) :: x(:)
Expand All @@ -268,9 +240,10 @@ pure module function Softmax_grad_r1_${rk}$( x ) result( y )
y = y * (1._${rk}$ - y)
end function

pure module function Softmax_grad_r2_${rk}$( x , dim ) result( y )
${rt}$, intent(in) :: x(:,:)
${rt}$ :: y(size(x,dim=1),size(x,dim=2))
#:for rank in RANKS
pure module function Softmax_grad_r${rank}$_${rk}$( x , dim ) result( y )
${rt}$, intent(in) :: x${ranksuffix(rank)}$
${rt}$ :: y${shape('x', rank)}$

integer, intent(in), optional :: dim
integer :: dim_
Expand All @@ -280,32 +253,51 @@ pure module function Softmax_grad_r2_${rk}$( x , dim ) result( y )
y = Softmax(x,dim_)
y = y * (1._${rk}$ - y)
end function
#:endfor

pure module function Softmax_grad_r3_${rk}$( x , dim ) result( y )
${rt}$, intent(in) :: x(:,:,:)
${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3))

integer, intent(in), optional :: dim
integer :: dim_
#:endfor

dim_ = 1; if(present(dim)) dim_ = dim

y = Softmax(x,dim_)
y = y * (1._${rk}$ - y)
!==================================================
! LogSoftmax
!==================================================
#:for rk, rt in REAL_KINDS_TYPES
pure module function LogSoftmax_r1_${rk}$( x, dim ) result( y )
${rt}$, intent(in) :: x(:)
${rt}$ :: y(size(x))
integer, intent(in), optional :: dim
y = x - maxval(x)
y = y - log( sum(exp(y)) )
end function

pure module function Softmax_grad_r4_${rk}$( x , dim ) result( y )
${rt}$, intent(in) :: x(:,:,:,:)
${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3),size(x,dim=4))

#:for rank in RANKS
pure module function LogSoftmax_r${rank}$_${rk}$( x , dim ) result( y )
${rt}$, intent(in) :: x${ranksuffix(rank)}$
${rt}$ :: y${shape('x', rank)}$

integer, intent(in), optional :: dim
integer :: dim_
integer :: dim_, j

dim_ = 1; if(present(dim)) dim_ = dim

y = Softmax(x,dim_)
y = y * (1._${rk}$ - y)

if(dim_<${rank}$)then
do j = 1, size(x,dim=${rank}$)
#:if rank == 2
y${select_subarray(rank, [(rank, 'j')])}$ = LogSoftmax( x${select_subarray(rank, [(rank, 'j')])}$ )
#:else
y${select_subarray(rank, [(rank, 'j')])}$ = LogSoftmax( x${select_subarray(rank, [(rank, 'j')])}$, dim=dim_ )
#:endif
end do
else
do j = 1, size(x,dim=1)
#:if rank == 2
y${select_subarray(rank, [(1, 'j')])}$ = LogSoftmax( x${select_subarray(rank, [(1, 'j')])}$ )
#:else
y${select_subarray(rank, [(1, 'j')])}$ = LogSoftmax( x${select_subarray(rank, [(1, 'j')])}$, dim=${rank-1}$ )
#:endif
end do
end if
end function
#:endfor

#:endfor

Expand Down
71 changes: 71 additions & 0 deletions test/specialfunctions/test_specialfunctions_activations.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ contains

testsuite = [ &
new_unittest("sigmoid", test_sigmoid), &
new_unittest("logsoftmax", test_logsoftmax), &
new_unittest("gelu" , test_gelu ), &
new_unittest("softmax", test_softmax) &
]
Expand Down Expand Up @@ -134,6 +135,76 @@ contains

end subroutine test_softmax

subroutine test_logsoftmax(error)
type(error_type), allocatable, intent(out) :: error

real(sp) :: x(3,3,3), y(3,3,3), y_ref(3,3,3)

x = reshape( [ 0.755308866500854,-0.789980888366699, 0.88806813955307 ,&
-1.210636496543884, 0.746919095516205, 0.177668794989586,&
0.540819883346558, 0.291532933712006,-0.324642956256866,&

1.94184136390686 , 0.951070547103882,-2.303410291671753,&
0.59752631187439 , 1.189722180366516, 1.401878595352173,&
-0.262732744216919, 0.421907186508179,-0.200457707047462,&

-0.702468574047089, 0.153426378965378, 0.330110251903534,&
-1.16956090927124 ,-0.845042765140533,-1.364316940307617,&
-1.679381608963013,-1.497506022453308,-1.194215059280396 ] ,[3,3,3] )

!> LogSoftmax on dim = 1
y = LogSoftmax(x,dim=1)

y_ref = reshape( [ -0.856636286,-2.40192604,-0.723877013,&
-2.49238253,-0.534826934,-1.10407722 ,&
-0.788554132,-1.03784108,-1.65401697 ,&

-0.326149583,-1.31692040,-4.57140112 ,&
-1.61804128,-1.02584541,-0.813688993 ,&
-1.39805317,-0.713413179,-1.33577800 ,&

-1.81836534,-0.962470412,-0.785786569,&
-1.16514850,-0.840630412,-1.35990453 ,&
-1.34127355,-1.15939808,-0.856107056 ],[3,3,3] )

!> LogSoftmax on dim = 2
y = LogSoftmax(x,dim=2)

y_ref = reshape( [ -0.666278005,-2.15167999, -0.581566215,&
-2.63222337 ,-0.614779949,-1.29196548 ,&
-0.880766988,-1.07016611,-1.79427731 ,&

-0.315551817,-1.05034387,-3.90906072 ,&
-1.65986681 ,-0.811692238,-0.203771874,&
-2.52012587 ,-1.57950723 ,-1.80610812 ,&

-0.694792688,-0.444887042,-0.337523341,&
-1.16188502 ,-1.44335616 ,-2.03195047 ,&
-1.67170572 ,-2.09581947 ,-1.86184871 ],[3,3,3] )

call check(error, norm2(y-y_ref) < tol_sp )
if (allocated(error)) return

!> LogSoftmax on dim = 3
y = LogSoftmax(x,dim=3)

y_ref = reshape( [ -1.50595474 , -2.22700500 ,-0.478398114,&
-2.09693313 , -1.01544499 ,-1.52940571 ,&
-0.442325860, -0.835677147,-0.936625183,&

-0.319422185, -0.485953659,-3.66987658 ,&
-0.288770229, -0.572641909,-0.305195898,&
-1.24587846 , -0.705302894,-0.812439919,&

-2.96373224 , -1.28359783 ,-1.03635597 ,&
-2.05585742 , -2.60740685 ,-3.07139134 ,&
-2.66252732 , -2.62471604 ,-1.80619729 ],[3,3,3] )

call check(error, norm2(y-y_ref) < tol_sp )
if (allocated(error)) return

end subroutine test_logsoftmax


end module test_specialfunctions_activation

Expand Down

0 comments on commit 5c47bf0

Please sign in to comment.