Skip to content

Commit

Permalink
Merge pull request #594 from NVJY/dev
Browse files Browse the repository at this point in the history
Adds Translators to xarray_wrappers.py
  • Loading branch information
avcopan authored Dec 13, 2024
2 parents b2e7b12 + 90b9089 commit 0667cdc
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 31 deletions.
90 changes: 60 additions & 30 deletions autoreact/ktp_xarray/xarray_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
"""

import xarray
import numpy

# Constructors
def from_data(temps, press, rates):
"""
Construct a KTP DataArray from data
"""
"""Construct a KTP DataArray from data"""

ktp = xarray.DataArray(rates, (("pres", press), ("temp", temps)))

Expand All @@ -18,75 +17,106 @@ def from_data(temps, press, rates):

# Getters
def get_pressures(ktp):
"""
Gets the pressure values
"""
"""Gets the pressure values"""

return ktp.pres.data


def get_temperatures(ktp):
"""
Gets the temperature values
"""
"""Gets the temperature values"""

return ktp.temp.data


def get_values(ktp):
"""
Gets the KTP values
"""
"""Gets the KTP values"""

return ktp.values


def get_pslice(ktp, ip):
"""
Get a slice at a selected pressure value
"""
"""Get a slice at a selected pressure value"""

return ktp.sel(pres=ip)


def get_tslice(ktp, it):
"""
Get a slice at a selected temperature value
"""
"""Get a slice at a selected temperature value"""

return ktp.sel(temp=it)


def get_spec_vals(ktp, it, ip):
"""
Get a specific value at a selected temperature and pressure value
"""
"""Get a specific value at a selected temperature and pressure value"""

return ktp.sel(temp=it, pres=ip)


def get_ipslice(ktp, ip):
"""
Get a slice at a selected pressure index
"""
"""Get a slice at a selected pressure index"""

return ktp.isel(pres=ip)


def get_itslice(ktp, it):
"""
Get a slice at a selected temperature index
"""
"""Get a slice at a selected temperature index"""

return ktp.isel(temp=it)



# Setters
def set_rates(ktp, rates, pres, temp):
"""
Sets the KTP values
"""
"""Sets the KTP values"""

ktp.loc[{"pres": pres, "temp": temp}] = rates
return ktp


# Translators

def dict_from_xarray(xarray_in):
"""Turns an xarray into a ktp_dct"""

ktp_dct = {}
#dict_temps = get_temperatures(xarray)
for pres in get_pressures(xarray_in):
dict_temps = get_temperatures(xarray_in)
dict_kts = []
curr_temps = numpy.copy(dict_temps)
for temp_idx, temp in enumerate(dict_temps):
kt = get_spec_vals(xarray_in, temp, pres)
if numpy.isnan(kt):
curr_temps = numpy.delete(curr_temps, temp_idx)
else:
dict_kts += (float(kt),)
dict_temps = curr_temps
dict_kts = numpy.array(dict_kts, dtype=numpy.float64)
if pres == numpy.inf:
pres = 'high'
ktp_dct[pres] = (dict_temps, dict_kts)
return ktp_dct

#Stopped working on this one because it was less critical!

#def xarray_from_dict(ktp_dct):
# """DOES NOT WORK YET!
# Turns a ktp_dct into an xarray"""
#
# xarray_press = []
# xarray_temps = []
# for pressures, (temps,x) in ktp_dct.items():
# xarray_press.append(pressures)
# for temp in temps:
# if temp not in xarray_temps:
# xarray_temps.append(temp)
# xarray_temps = xarray_temps.sort()
# temporary_kts = numpy.ndarray((len(xarray_press),len(xarray_temps)))
# for pres_idx, pres in enumerate(xarray_press):
# ktp = list(ktp_dct[pres])[1]
# for kt in ktp:
# temporary_kts[pres_idx] = kt
# breakpoint()
# xarray = from_data(xarray_temps, xarray_press, temp_kts)
# breakpoint()
# return xarray
20 changes: 19 additions & 1 deletion autoreact/ktp_xarray/xarray_wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
Rates = [[1e1, 1e2, 1e3, 1e4], [1e5, 1e6, 1e7, 1e8], [1e9, 1e10, 1e11, 1e12]]

Ktp = xarray_wrappers.from_data(Temps, Press, Rates)
Ktp_dct = {1.0: (([1000., 1500., 2000., 2500.]),
([ 10., 100., 1000., 10000.])),
10: (([1000., 1500., 2000., 2500.]),
([1.e+05, 1.e+06, 1.e+07, 1.e+08])),
numpy.inf: (([1000., 1500., 2000., 2500.]),
([1.e+09, 1.e+10, 1.e+11, 1.e+12]))}
print(Ktp)

def test_get_temperatures():
Expand Down Expand Up @@ -60,9 +66,19 @@ def test_get_itslice():

def test_set_rates():
"""Tests the set_rates function"""
new_rates = xarray_wrappers.set_rates(Ktp, 1e11, 10, 2000)
new_rates = xarray_wrappers.set_rates(Ktp, numpy.nan, 10, 2000)
print(new_rates)

def test_dict_from_xarray():
"""Tests the ktp_to_xarray function"""
ktp_dct = xarray_wrappers.dict_from_xarray(Ktp)
print(ktp_dct)

#def test_xarray_from_dict():
# """Tests the set_ktp_dct function"""
# xarray = xarray_wrappers.xarray_from_dict(Ktp_dct)
# print(xarray)


test_get_pressures()
test_get_temperatures()
Expand All @@ -73,3 +89,5 @@ def test_set_rates():
test_get_ipslice()
test_get_itslice()
test_set_rates()
test_dict_from_xarray()
#test_xarray_from_dict()

0 comments on commit 0667cdc

Please sign in to comment.