diff --git a/dsgp4/__init__.py b/dsgp4/__init__.py index c049a10..762a1e5 100644 --- a/dsgp4/__init__.py +++ b/dsgp4/__init__.py @@ -7,3 +7,5 @@ from .sgp4_batched import sgp4_batched from . import tle from .tle import TLE +import torch +torch.set_default_dtype(torch.float64) diff --git a/dsgp4/initl.py b/dsgp4/initl.py index 3371980..ec3d147 100644 --- a/dsgp4/initl.py +++ b/dsgp4/initl.py @@ -3,8 +3,6 @@ from . import util -torch.set_default_dtype(torch.float64) - def initl( xke, j2, ecco, epoch, inclo, no, diff --git a/dsgp4/newton_method.py b/dsgp4/newton_method.py index 0314673..60c99e4 100644 --- a/dsgp4/newton_method.py +++ b/dsgp4/newton_method.py @@ -5,7 +5,6 @@ from .sgp4init import sgp4init from . import util from .tle import TLE -#torch.set_default_dtype(torch.float64) def initial_guess(tle_0, time_mjd, target_state=None): """ diff --git a/dsgp4/sgp4.py b/dsgp4/sgp4.py index be6a6a4..4be2cbc 100644 --- a/dsgp4/sgp4.py +++ b/dsgp4/sgp4.py @@ -1,6 +1,6 @@ import numpy import torch -#torch.set_default_dtype(torch.float64) +torch.set_default_dtype(torch.float64) #@torch.jit.script def sgp4(satellite, tsince): diff --git a/dsgp4/sgp4init.py b/dsgp4/sgp4init.py index e94af80..3f0087b 100644 --- a/dsgp4/sgp4init.py +++ b/dsgp4/sgp4init.py @@ -2,7 +2,6 @@ import torch from .initl import initl from .sgp4 import sgp4 -#torch.set_default_dtype(torch.float64) def sgp4init( whichconst, opsmode, satn, epoch, diff --git a/dsgp4/tle.py b/dsgp4/tle.py index 8c78dbb..0965ab6 100644 --- a/dsgp4/tle.py +++ b/dsgp4/tle.py @@ -4,11 +4,8 @@ import torch from . import util - -from sgp4.earth_gravity import wgs84 -#torch.set_default_dtype(torch.float64) - -MU_EARTH = wgs84.mu*1e9 +_, MU_EARTH, _, _, _, _, _, _=util.get_gravity_constants('wgs-84') +MU_EARTH = MU_EARTH*1e9 # This function is from python-sgp4 released under MIT License, (c) 2012–2016 Brandon Rhodes def compute_checksum(line): diff --git a/dsgp4/util.py b/dsgp4/util.py index 609e8fa..1bec0c4 100644 --- a/dsgp4/util.py +++ b/dsgp4/util.py @@ -2,8 +2,6 @@ import numpy as np import torch -#torch.set_default_dtype(torch.float64) - def get_gravity_constants(gravity_constant_name): if gravity_constant_name == 'wgs-72old': mu = 398600.79964 # in km3 / s2 diff --git a/tests/test_differentiability.py b/tests/test_differentiability.py index d5c7b6a..a768a4f 100644 --- a/tests/test_differentiability.py +++ b/tests/test_differentiability.py @@ -25,7 +25,7 @@ def test_velocity(self): data.append(lines[i+1]) data.append(lines[i+2]) tles.append(dsgp4.tle.TLE(data)) - whichconst=dsgp4.util.get_gravity_constants("wgs-72") + whichconst=dsgp4.util.get_gravity_constants("wgs-84") #I filter out deep space and error cases: tles_filtered=[] for idx, tle_satellite in enumerate(tles): @@ -98,7 +98,7 @@ def test_input_gradients(self): data.append(lines[i+1]) data.append(lines[i+2]) tles.append(dsgp4.tle.TLE(data)) - whichconst=dsgp4.util.get_gravity_constants("wgs-72") + whichconst=dsgp4.util.get_gravity_constants("wgs-84") #I filter out deep space and error cases: tles_filtered=[] for idx, tle_satellite in enumerate(tles):