Skip to content

Commit

Permalink
changed pd version, fixed benchmark model
Browse files Browse the repository at this point in the history
  • Loading branch information
nepslor committed Nov 4, 2024
1 parent 8e4d0f7 commit 7a24a8d
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions pyforecaster/forecasting_models/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(self, period='1d', q_vect=None, val_ratio=None, nodes_at_step=None,
def fit(self, x:pd.DataFrame, y:pd.DataFrame):
self.n_sa = y.shape[1]
# infer sampling time
sampling_time = pd.infer_freq(x.index)
sampling_time = x.index.drop_duplicates().diff().median()

# retrieve support of target distribution
support = np.unique(y.values)
Expand Down Expand Up @@ -109,11 +109,11 @@ def fit(self, x:pd.DataFrame, y:pd.DataFrame):
return self

def predict(self, x, **kwargs):
return (self.predict_probabilities(x) * np.tile(self.support.reshape(1, -1), self.n_sa)).groupby(level=0, axis=1).sum()
return (self.predict_probabilities(x) * np.tile(self.support.reshape(1, -1), self.n_sa)).T.groupby(level=0).sum()

def predict_probabilities(self, x, **kwargs):
# infer sampling time
sampling_time = pd.infer_freq(x.index)
sampling_time = x.index.drop_duplicates().diff().median()
# Create a new column for the time within a day (hours and minutes)
time_of_day = pd.Series(x.index.floor(sampling_time).time)

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
numpy>=1.20.2
optuna>=2.10.0
networkx>=2.6.3
pandas>=1.2.3
pandas>=2.2.3
seaborn>=0.11.1
matplotlib>=3.4.1
scipy>=1.7.1
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
install_requires=['numpy>=1.20.2',
'optuna>=2.10.0',
'networkx>=2.6.3',
'pandas>=1.2.3',
'pandas>=2.2.3',
'seaborn>=0.11.1',
'matplotlib>=3.4.1',
'scipy>=1.7.1',
Expand Down

0 comments on commit 7a24a8d

Please sign in to comment.