Skip to content

Commit

Permalink
better treatement of 0-values in evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
gugerlir committed Nov 1, 2023
1 parent a6fdb47 commit f5d7506
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
8 changes: 6 additions & 2 deletions rainforest/performance/eval_calculate.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,11 @@ def calcScoresStations(precip_ref : pd.DataFrame,

th_ref, th_est = threshold if isinstance(threshold, list) else [threshold] * 2

perf_scores = pd.DataFrame(columns=['YESref_YESpred','NOref_NOpred','NOref_YESpred','YESref_NOpred',
score_names = ['YESref_YESpred','NOref_NOpred','NOref_YESpred','YESref_NOpred',
'RMSE','corr_p','scatter','logBias', 'n_values',
'n_events_db','sum_ref_db','sum_pred_db'],
'n_events_db','sum_ref_db','sum_pred_db']
score_db_names = ['n_events_db','sum_ref_db','sum_pred_db', 'RMSE','corr_p','scatter','logBias']
perf_scores = pd.DataFrame(columns=score_names,
index=precip_ref.columns.unique())

stations = METSTATIONS.copy()
Expand Down Expand Up @@ -94,6 +96,7 @@ def calcScoresStations(precip_ref : pd.DataFrame,
np.round(10*np.log10(sum_pred_db/sum_ref_db),decimals=4)
else:
logging.info('No measurements for station {}'.format(station_id))
perf_scores.loc[station_id, ['logBias','RMSE','corr_p','scatter']] = np.nan
continue
try:
scores = det_cont_fct(precip_est[station_id][double_cond.index].to_numpy(dtype=float),
Expand All @@ -104,6 +107,7 @@ def calcScoresStations(precip_ref : pd.DataFrame,
perf_scores.at[station_id, 'scatter']= np.round(scores['scatter'],decimals=4)
except:
logging.info('Could not calculate scores for station {}'.format(station_id))
perf_scores.loc[station_id, ['RMSE','corr_p','scatter']] = np.nan

return perf_scores

Expand Down
12 changes: 9 additions & 3 deletions rainforest/performance/eval_get_estimates.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
from ..common.io_data import read_cart


def getGaugeObservations(gaugefolder, t0=None, t1=None, slf_stations=False):
def getGaugeObservations(gaugefolder, t0=None, t1=None, slf_stations=False,
missing2nan=False):
"""_summary_
Args:
Expand All @@ -50,7 +51,9 @@ def getGaugeObservations(gaugefolder, t0=None, t1=None, slf_stations=False):
dtype = {'TIMESTAMP':int, 'STATION': str})

gauge_all = gauge_all.compute().drop_duplicates()
gauge_all = gauge_all.replace(-9999,np.nan)

if missing2nan:
gauge_all = gauge_all.replace(-9999,np.nan)

# Assure that datetime object is in UTC
if t0 != None :
Expand Down Expand Up @@ -291,7 +294,8 @@ def extractEstimatesFromMaps(self, slf_stations=False, tagg_hourly=True, save_ou
#-----------------------------------
logging.info('Get gauge observations')
gauge_all = getGaugeObservations(gaugefolder=self.gaugefolder,
t0 = self.tstart, t1= self.tend, slf_stations = False)
t0 = self.tstart, t1= self.tend, slf_stations = False,
missing2nan=False)
stations = np.unique(gauge_all['STATION'])

# Get all model files
Expand Down Expand Up @@ -360,6 +364,8 @@ def extractEstimatesFromMaps(self, slf_stations=False, tagg_hourly=True, save_ou
col = gauge_all.loc[gauge_all['STATION'] == ss, 'RRE150Z0'].to_frame(name=ss)
col.set_index(gauge_all.loc[(gauge_all['STATION'] == ss),'TIME'], inplace=True)
precip_qpe['GAUGE'][ss] = col
precip_qpe['GAUGE'][ss].fillna(0, inplace=True)
precip_qpe['GAUGE'][ss].loc[precip_qpe['GAUGE'][ss] < 0] = np.nan

for m in list_models:
precip_qpe[m] = pd.DataFrame(precip_qpe[m], index=tstamps_10min, columns=precip_qpe['GAUGE'].columns)
Expand Down

0 comments on commit f5d7506

Please sign in to comment.