Skip to content

Commit

Permalink
add file path setting to RadbeltClass constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobwilliams committed Feb 17, 2024
1 parent 076e03c commit 234298f
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 15 deletions.
34 changes: 28 additions & 6 deletions radbeltpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,46 @@ class RadbeltClass:
Class for using the radbelt model.
"""

def __init__(self) -> None:
"""constructor for the fortran class"""
def __init__(self, aep8_dir : str = None, igrf_dir: str = None) -> None:
"""
Constructor for the fortran class
Parameters
----------
aep8_dir : str
The directory containing the aep8 files. If None, then use the default directory ('data/aep8/')
igrf_dir : str
The directory containing the igrf files. If None, then use the default directory ('data/igrf/')
"""

#`ip` is an integer that represents a c pointer
# to a `radbelt_type` in the Fortran library.
self.ip = radbelt_fortran.radbelt_c_module.initialize_c()

#TODO should allow for passing in the data directories here
# note that None means use defaults,
# but '' means current directory
if aep8_dir is not None:
self.set_trm_file_path(aep8_dir)
if igrf_dir is not None:
self.set_igrf_file_path(igrf_dir)

def __del__(self) -> None:
"""destructor for the fortran class"""

radbelt_fortran.radbelt_c_module.destroy_c(self.ip)

def set_data_files_paths(self, aep8_dir : str, igrf_dir : str) -> None:
"""Set the file paths"""
def set_trm_file_path(self, aep8_dir : str) -> None:
"""Set just the aep8 file path"""

radbelt_fortran.radbelt_c_module.set_trm_file_path_c(self.ip, aep8_dir, len(aep8_dir))

def set_igrf_file_path(self, igrf_dir : str) -> None:
"""Set just the igrf file path"""

#TODO split these up so they can be called separately
radbelt_fortran.radbelt_c_module.set_igrf_file_path_c(self.ip, igrf_dir, len(igrf_dir))

def set_data_files_paths(self, aep8_dir : str, igrf_dir : str) -> None:
"""Set both the file paths"""

radbelt_fortran.radbelt_c_module.set_data_files_paths_c(self.ip, aep8_dir, igrf_dir,
len(aep8_dir), len(igrf_dir))
Expand Down
12 changes: 12 additions & 0 deletions radbeltpy/radbeltpy.pyf
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@ python module radbelt_fortran ! in
real(c_double),intent(out) :: flux
end subroutine get_flux_g_c

subroutine set_trm_file_path_c(ipointer, aep8_dir, n) ! in :radbelt_fortran:radbelt_c_module.f90:radbelt_c_module
integer(c_intptr_t),intent(in) :: ipointer
character(kind=c_char,len=n),intent(in),depend(n) :: aep8_dir
integer(c_int),intent(in) :: n
end subroutine set_trm_file_path_c

subroutine set_igrf_file_path_c(ipointer, igrf_dir, n) ! in :radbelt_fortran:radbelt_c_module.f90:radbelt_c_module
integer(c_intptr_t),intent(in) :: ipointer
character(kind=c_char,len=n),intent(in),depend(n) :: igrf_dir
integer(c_int),intent(in) :: n
end subroutine set_igrf_file_path_c

subroutine set_data_files_paths_c(ipointer, aep8_dir, igrf_dir, n, m) ! in :radbelt_fortran:radbelt_c_module.f90:radbelt_c_module
integer(c_intptr_t),intent(in) :: ipointer
character(kind=c_char,len=n),intent(in),depend(n) :: aep8_dir
Expand Down
52 changes: 51 additions & 1 deletion src/radbelt_c_module.f90
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,61 @@ subroutine destroy_c(ipointer) bind(C, name="destroy_c")

end subroutine destroy_c

!*****************************************************************************************
!>
! C interface for setting the `trm` data file path

subroutine set_trm_file_path_c(ipointer, aep8_dir, n) bind(C, name="set_trm_file_path_c")

integer(c_intptr_t),intent(in) :: ipointer
integer(c_int),intent(in) :: n !! size of `aep8_dir`
character(kind=c_char,len=1),dimension(n),intent(in) :: aep8_dir

character(len=:),allocatable :: aep8_dir_
type(radbelt_type),pointer :: p

call int_pointer_to_f_pointer(ipointer, p)

if (associated(p)) then
aep8_dir_ = c2f_str(aep8_dir)
call p%set_trm_file_path(aep8_dir_)
else
error stop 'error in set_trm_file_path_c: class is not associated'
end if

end subroutine set_trm_file_path_c
!*****************************************************************************************

!*****************************************************************************************
!>
! C interface for setting the `igrf` data file path

subroutine set_igrf_file_path_c(ipointer, igrf_dir, n) bind(C, name="set_igrf_file_path")

integer(c_intptr_t),intent(in) :: ipointer
integer(c_int),intent(in) :: n !! size of `igrf_dir`
character(kind=c_char,len=1),dimension(n),intent(in) :: igrf_dir

character(len=:),allocatable :: igrf_dir_
type(radbelt_type),pointer :: p

call int_pointer_to_f_pointer(ipointer, p)

if (associated(p)) then
igrf_dir_ = c2f_str(igrf_dir)
call p%set_igrf_file_path(igrf_dir_)
else
error stop 'error in set_igrf_file_path: class is not associated'
end if

end subroutine set_igrf_file_path_c
!*****************************************************************************************

!*****************************************************************************************
!>
! C interface for setting the data file paths

subroutine set_data_files_paths_c(ipointer, aep8_dir, igrf_dir, n, m) bind(C, name="set_data_files_paths_c")
subroutine set_data_files_paths_c(ipointer, aep8_dir, igrf_dir, n, m) bind(C, name="set_data_files_paths_c")

integer(c_intptr_t),intent(in) :: ipointer
integer(c_int),intent(in) :: n !! size of `aep8_dir`
Expand Down
36 changes: 33 additions & 3 deletions src/radbelt_module.f90
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ module radbelt_module
private
generic,public :: get_flux => get_flux_g_, get_flux_c_
procedure :: get_flux_g_, get_flux_c_
procedure,public :: set_data_files_paths
procedure,public :: set_trm_file_path, &
set_igrf_file_path, &
set_data_files_paths
end type radbelt_type

interface get_flux
Expand All @@ -35,6 +37,34 @@ module radbelt_module

contains

!*****************************************************************************************
!>
! Set the `trm` path.

subroutine set_trm_file_path(me, dir)

class(radbelt_type),intent(inout) :: me
character(len=*),intent(in) :: dir

call me%trm%set_data_file_dir(trim(dir))

end subroutine set_trm_file_path
!*****************************************************************************************

!*****************************************************************************************
!>
! Set the `igrf` path.

subroutine set_igrf_file_path(me, dir)

class(radbelt_type),intent(inout) :: me
character(len=*),intent(in) :: dir

call me%igrf%set_data_file_dir(trim(dir))

end subroutine set_igrf_file_path
!*****************************************************************************************

!*****************************************************************************************
!>
! Set the paths to the data files.
Expand All @@ -47,8 +77,8 @@ subroutine set_data_files_paths(me, aep8_dir, igrf_dir)
character(len=*),intent(in) :: aep8_dir
character(len=*),intent(in) :: igrf_dir

call me%trm%set_data_file_dir(trim(aep8_dir))
call me%igrf%set_data_file_dir(trim(igrf_dir))
call me%set_trm_file_path(trim(aep8_dir))
call me%set_igrf_file_path(trim(igrf_dir))

end subroutine set_data_files_paths
!*****************************************************************************************
Expand Down
10 changes: 5 additions & 5 deletions test/radbeltpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
sys.path.insert(0,str(dir)) # assuming the radbelt lib is in python directory
from radbeltpy import RadbeltClass # this is the module being tested

model = RadbeltClass()

# set location of the data files:
# location of the data files:
aep8_dir = str(dir / 'data' / 'aep8')
igrf_dir = str(dir / 'data' / 'igrf')
model.set_data_files_paths(aep8_dir, igrf_dir)

# create the class:
model = RadbeltClass(aep8_dir = aep8_dir, igrf_dir = igrf_dir)

EPS = sys.float_info.epsilon # machine precision for error checking
lon = -45.0
Expand All @@ -36,7 +36,7 @@
print(f'Flux = {flux}')
print(f'Error = {error}')
print(f'Rel Error = {relerror}')
if relerror>10*EPS:
if relerror > 10.0*EPS:
raise Exception('error')

print('')
Expand Down

0 comments on commit 234298f

Please sign in to comment.