Skip to content

Commit

Permalink
include test for hyperbolic functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jalvesz committed Sep 29, 2023
1 parent c51ce49 commit c5140a0
Showing 1 changed file with 137 additions and 40 deletions.
177 changes: 137 additions & 40 deletions test/test_fast_math.f90
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ module test_fast_math
use fast_math
implicit none

logical :: verbose = .false. ! change me to .true. if you want to see the results
character (len=*), parameter :: fmt_cr = "(a10,*(f22.12))", fmt_er = "(a10,*(es22.4))"
logical :: verbose = .true. ! change me to .true. if you want to see the results
contains

!> Collect all exported unit tests
Expand All @@ -16,7 +15,8 @@ subroutine collect_suite(testsuite)
testsuite = [ &
new_unittest('fast_sum', test_fast_sum) , &
new_unittest('fast_dotp', test_fast_dotproduct) , &
new_unittest('fast_trig', test_fast_trigonometry) &
new_unittest('fast_trig', test_fast_trigonometry) , &
new_unittest('fast_hyper', test_fast_hyperbolic ) &
]
end subroutine

Expand All @@ -28,6 +28,7 @@ subroutine test_fast_sum(error)
integer, parameter :: n = 1e6, ncalc = 4, niter = 20
integer :: iter, i
real(dp) :: times(0:ncalc), times_tot(ncalc)
1 format(a10,': <time> = ',f9.4,' ns/eval, speed-up=',f5.2,'X, rel. error=',es16.4)
!====================================================================================
call random_seed()
block
Expand Down Expand Up @@ -58,10 +59,10 @@ subroutine test_fast_sum(error)
print *,""
print *,"================================================================"
print *," SUM on single precision values "
print "(/,a10,*(a22))", "", "sum_quad", "intrinsic", "fsum_pair", "fsum_chunk"
print fmt_cr,"Value: ", meanval(:)
print fmt_er,"Error: ", abs(err(:)) / abs(meanval(:))
print fmt_cr,"time : ", times_tot(1:ncalc) / niter
print 1, "sum_quad" , times_tot(1), times_tot(2)/times_tot(1), abs(err(1)) / abs(meanval(1))
print 1, "intrinsic" , times_tot(2), times_tot(2)/times_tot(2), abs(err(2)) / abs(meanval(2))
print 1, "fsum_pair" , times_tot(3), times_tot(2)/times_tot(3), abs(err(3)) / abs(meanval(3))
print 1, "fsum_chunk", times_tot(4), times_tot(2)/times_tot(4), abs(err(4)) / abs(meanval(4))
end if

call check(error, abs(err(ncalc)) / abs(meanval(1)) < tolerance &
Expand Down Expand Up @@ -101,10 +102,10 @@ subroutine test_fast_sum(error)
print *,""
print *,"================================================================"
print *," SUM on single precision values with a mask"
print "(/,a10,*(a22))", "", "sum_quad", "intrinsic", "fsum_pair", "fsum_chunk"
print fmt_cr,"Value: ", meanval(:)
print fmt_er,"Error: ", abs(err(:)) / abs(meanval(:))
print fmt_cr,"time : ", times_tot(1:ncalc) / niter
print 1, "sum_quad" , times_tot(1), times_tot(2)/times_tot(1), abs(err(1)) / abs(meanval(1))
print 1, "intrinsic" , times_tot(2), times_tot(2)/times_tot(2), abs(err(2)) / abs(meanval(2))
print 1, "fsum_pair" , times_tot(3), times_tot(2)/times_tot(3), abs(err(3)) / abs(meanval(3))
print 1, "fsum_chunk", times_tot(4), times_tot(2)/times_tot(4), abs(err(4)) / abs(meanval(4))
end if

call check(error, abs(err(ncalc)) / abs(meanval(1)) < tolerance &
Expand Down Expand Up @@ -140,10 +141,10 @@ subroutine test_fast_sum(error)
print *,""
print *,"================================================================"
print *," SUM on double precision values "
print "(/,a10,*(a22))", "", "sum_quad", "intrinsic", "fsum_pair", "fsum_chunk"
print fmt_cr,"Value: ", meanval(:)
print fmt_er,"Error: ", abs(err(:)) / abs(meanval(:))
print fmt_cr,"time : ", times_tot(1:ncalc) / niter
print 1, "sum_quad" , times_tot(1), times_tot(2)/times_tot(1), abs(err(1)) / abs(meanval(1))
print 1, "intrinsic" , times_tot(2), times_tot(2)/times_tot(2), abs(err(2)) / abs(meanval(2))
print 1, "fsum_pair" , times_tot(3), times_tot(2)/times_tot(3), abs(err(3)) / abs(meanval(3))
print 1, "fsum_chunk", times_tot(4), times_tot(2)/times_tot(4), abs(err(4)) / abs(meanval(4))
end if

call check(error, abs(err(ncalc)) / abs(meanval(1)) < tolerance &
Expand Down Expand Up @@ -183,10 +184,10 @@ subroutine test_fast_sum(error)
print *,""
print *,"================================================================"
print *," SUM on double precision values with a mask"
print "(/,a10,*(a22))", "", "sum_quad", "intrinsic", "fsum_pair", "fsum_chunk"
print fmt_cr,"Value: ", meanval(:)
print fmt_er,"Error: ", abs(err(:)) / abs(meanval(:))
print fmt_cr,"time : ", times_tot(1:ncalc) / niter
print 1, "sum_quad" , times_tot(1), times_tot(2)/times_tot(1), abs(err(1)) / abs(meanval(1))
print 1, "intrinsic" , times_tot(2), times_tot(2)/times_tot(2), abs(err(2)) / abs(meanval(2))
print 1, "fsum_pair" , times_tot(3), times_tot(2)/times_tot(3), abs(err(3)) / abs(meanval(3))
print 1, "fsum_chunk", times_tot(4), times_tot(2)/times_tot(4), abs(err(4)) / abs(meanval(4))
end if

call check(error, abs(err(ncalc)) / abs(meanval(1)) < tolerance &
Expand All @@ -204,6 +205,7 @@ subroutine test_fast_dotproduct(error)
integer, parameter :: n = 1e6, ncalc = 3, niter = 50
integer :: iter, i
real(dp) :: times(0:ncalc), times_tot(ncalc)
1 format(a10,': <time> = ',f9.4,' ns/eval, speed-up=',f5.2,'X, rel. error=',es16.4)
!====================================================================================
call random_seed()
block
Expand Down Expand Up @@ -233,10 +235,9 @@ subroutine test_fast_dotproduct(error)
print *,""
print *,"================================================================"
print *," dot product on single precision values "
print "(/,a10,*(a22))", "", "quad" , "intrinsic", "fprod"
print fmt_cr,"Value: ", meanval(:)
print fmt_er,"Error: ", abs(err(:)) / abs(meanval(:))
print fmt_cr,"time : ", times_tot(1:ncalc) / niter
print 1, "dot_quad" , times_tot(1), times_tot(2)/times_tot(1), abs(err(1)) / abs(meanval(1))
print 1, "intrinsic" , times_tot(2), times_tot(2)/times_tot(2), abs(err(2)) / abs(meanval(2))
print 1, "fprod" , times_tot(3), times_tot(2)/times_tot(3), abs(err(3)) / abs(meanval(3))
end if

call check(error, abs(err(ncalc)) / abs(meanval(1)) < tolerance &
Expand Down Expand Up @@ -271,10 +272,9 @@ subroutine test_fast_dotproduct(error)
print *,""
print *,"================================================================"
print *," dot product on double precision values "
print "(/,a10,*(a22))", "", "quad" , "intrinsic", "fprod"
print fmt_cr,"Value: ", meanval(:)
print fmt_er,"Error: ", abs(err(:)) / abs(meanval(:))
print fmt_cr,"time : ", times_tot(1:ncalc) / niter
print 1, "dot_quad" , times_tot(1), times_tot(2)/times_tot(1), abs(err(1)) / abs(meanval(1))
print 1, "intrinsic" , times_tot(2), times_tot(2)/times_tot(2), abs(err(2)) / abs(meanval(2))
print 1, "fprod" , times_tot(3), times_tot(2)/times_tot(3), abs(err(3)) / abs(meanval(3))
end if

call check(error, abs(err(ncalc)) / abs(meanval(1)) < tolerance &
Expand Down Expand Up @@ -309,10 +309,9 @@ subroutine test_fast_dotproduct(error)
print *,""
print *,"================================================================"
print *," weigthed dot product on single precision values "
print "(/,a10,*(a22))", "", "quad" , "intrinsic", "fprod"
print fmt_cr,"Value: ", meanval(:)
print fmt_er,"Error: ", abs(err(:)) / abs(meanval(:))
print fmt_cr,"time : ", times_tot(1:ncalc) / niter
print 1, "dot_quad" , times_tot(1), times_tot(2)/times_tot(1), abs(err(1)) / abs(meanval(1))
print 1, "intrinsic" , times_tot(2), times_tot(2)/times_tot(2), abs(err(2)) / abs(meanval(2))
print 1, "fprod" , times_tot(3), times_tot(2)/times_tot(3), abs(err(3)) / abs(meanval(3))
end if

call check(error, abs(err(ncalc)) / abs(meanval(1)) < tolerance &
Expand Down Expand Up @@ -347,10 +346,9 @@ subroutine test_fast_dotproduct(error)
print *,""
print *,"================================================================"
print *," weigthed dot product on double precision values "
print "(/,a10,*(a22))", "", "quad" , "intrinsic", "fprod"
print fmt_cr,"Value: ", meanval(:)
print fmt_er,"Error: ", abs(err(:)) / abs(meanval(:))
print fmt_cr,"time : ", times_tot(1:ncalc) / niter
print 1, "dot_quad" , times_tot(1), times_tot(2)/times_tot(1), abs(err(1)) / abs(meanval(1))
print 1, "intrinsic" , times_tot(2), times_tot(2)/times_tot(2), abs(err(2)) / abs(meanval(2))
print 1, "fprod" , times_tot(3), times_tot(2)/times_tot(3), abs(err(3)) / abs(meanval(3))
end if

call check(error, abs(err(ncalc)) / abs(meanval(1)) < tolerance &
Expand All @@ -368,12 +366,12 @@ subroutine test_fast_trigonometry(error)
integer, parameter :: n = 1e6, ncalc = 2
integer :: i
real(dp) :: time(0:ncalc), err
1 format(a10,': <time> = ',f9.4,' ns/eval, speed-up=',f5.2,'X, rel. error=',es16.4)
!====================================================================================
if(verbose)then
print *,""
print *,"================================================================"
print *," Fast trigonometric"
print "(/,a10,*(a22))", "", "Error", "time intrinsic", "time fast", "Speed up"
end if
block
integer, parameter :: wp=sp
Expand All @@ -391,7 +389,7 @@ subroutine test_fast_trigonometry(error)
err = sqrt( sum( y - yref )**2 / n )

if(verbose)then
print fmt_er, "sin r32", err, time(1), time(2), time(1)/time(2)
print 1, "fsin r32" , time(2), time(1)/time(2), err
end if

call check(error, err < tolerance .and. time(2) < time(1) )
Expand All @@ -413,7 +411,7 @@ subroutine test_fast_trigonometry(error)
err = sqrt( sum( y - yref )**2 / n )

if(verbose)then
print fmt_er, "sin r64", err, time(1), time(2), time(1)/time(2)
print 1, "fsin r64" , time(2), time(1)/time(2), err
end if

call check(error, err < tolerance .and. time(2) < time(1) )
Expand All @@ -436,7 +434,7 @@ subroutine test_fast_trigonometry(error)
err = sqrt( sum( y - yref )**2 / n )

if(verbose)then
print fmt_er, "acos r32", err, time(1), time(2), time(1)/time(2)
print 1, "facos r32" , time(2), time(1)/time(2), err
end if

call check(error, err < tolerance .and. time(2) < time(1) )
Expand All @@ -458,14 +456,113 @@ subroutine test_fast_trigonometry(error)
err = sqrt( sum( y - yref )**2 / n )

if(verbose)then
print fmt_er, "acos r64", err, time(1), time(2), time(1)/time(2)
print 1, "facos r64" , time(2), time(1)/time(2), err
end if

call check(error, err < tolerance .and. time(2) < time(1) )
if (allocated(error)) return
end block

end subroutine

subroutine test_fast_hyperbolic(error)
!> Error handling
type(error_type), allocatable, intent(out) :: error

!> Internal parameters and variables
integer, parameter :: n = 1e6, ncalc = 2
integer :: i
real(dp) :: time(0:ncalc), err
1 format(a10,': <time> = ',f9.4,' ns/eval, speed-up=',f5.2,'X, rel. error=',es16.4)
!====================================================================================
if(verbose)then
print *,""
print *,"================================================================"
print *," Fast hyperbolic"
end if
block
integer, parameter :: wp=sp
real(wp), allocatable :: x(:) , y(:), yref(:)
real(kind=wp) :: tolerance = 1e-5_wp
!> define a linspace between [-3,3]
allocate( x(n) , y(n), yref(n) )
x(:) = [ (2*(real(i,kind=wp) / n - 0.5_wp)*3._wp , i = 1, n) ]

call cpu_time(time(0))
yref = tanh(x); call cpu_time(time(1))
y = ftanh(x) ; call cpu_time(time(2))

time(2:1:-1) = time(2:1:-1) - time(1:0:-1)
err = sqrt( fsum(( y - yref )**2) ) / sqrt( fsum(( yref )**2) )

if(verbose) print 1, "ftanh r64" , time(2), time(1)/time(2), err

call check(error, err < tolerance .and. time(2) < time(1) )
if (allocated(error)) return
end block
block
integer, parameter :: wp=dp
real(wp), allocatable :: x(:) , y(:), yref(:)
real(kind=wp) :: tolerance = 1e-5_wp
!> define a linspace between [-3,3]
allocate( x(n) , y(n), yref(n) )
x(:) = [ (2*(real(i,kind=wp) / n - 0.5_wp)*3._wp , i = 1, n) ]

call cpu_time(time(0))
yref = tanh(x); call cpu_time(time(1))
y = ftanh(x) ; call cpu_time(time(2))

time(2:1:-1) = time(2:1:-1) - time(1:0:-1)
err = sqrt( fsum(( y - yref )**2) ) / sqrt( fsum(( yref )**2) )

if(verbose) print 1, "ftanh r64" , time(2), time(1)/time(2), err

call check(error, err < tolerance .and. time(2) < time(1) )
if (allocated(error)) return
end block

block
integer, parameter :: wp=sp
real(wp), allocatable :: x(:) , y(:), yref(:)
real(kind=wp) :: tolerance = 1e-2_wp
!> define a linspace between [-3,3]
allocate( x(n) , y(n), yref(n) )
x(:) = [ (2*(real(i,kind=wp) / n - 0.5_wp)*3._wp , i = 1, n) ]

call cpu_time(time(0))
yref = erf(x); call cpu_time(time(1))
y = ferf(x) ; call cpu_time(time(2))

time(2:1:-1) = time(2:1:-1) - time(1:0:-1)
err = sqrt( fsum(( y - yref )**2) ) / sqrt( fsum(( yref )**2) )

if(verbose)print 1, "ferf r32" , time(2), time(1)/time(2), err

call check(error, err < tolerance .and. time(2) < time(1) )
if (allocated(error)) return
end block
block
integer, parameter :: wp=dp
real(wp), allocatable :: x(:) , y(:), yref(:)
real(kind=wp) :: tolerance = 1e-2_wp
!> define a linspace between [-3,3]
allocate( x(n) , y(n), yref(n) )
x(:) = [ (2*(real(i,kind=wp) / n - 0.5_wp)*3._wp , i = 1, n) ]

call cpu_time(time(0))
yref = erf(x); call cpu_time(time(1))
y = ferf(x) ; call cpu_time(time(2))

time(2:1:-1) = time(2:1:-1) - time(1:0:-1)
err = sqrt( fsum(( y - yref )**2) ) / sqrt( fsum(( yref )**2) )

if(verbose) print 1, "ferf r64" , time(2), time(1)/time(2), err

call check(error, err < tolerance .and. time(2) < time(1) )
if (allocated(error)) return
end block

end subroutine

end module test_fast_math

Expand Down

0 comments on commit c5140a0

Please sign in to comment.