-
Notifications
You must be signed in to change notification settings - Fork 0
/
daily_arima.py
91 lines (72 loc) · 2.87 KB
/
daily_arima.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import pandas as pd
from statsmodels.tsa.statespace.sarimax import SARIMAX
import matplotlib.pyplot as plt
from data_preprocessing import load_nyt_data, load_ccc_data
nyt_filepath = './datasets/us-counties-2020.csv'
ccc_filepath = './datasets/ccc_filtered.csv'
covid_data = load_nyt_data(nyt_filepath)
event_data = load_ccc_data(ccc_filepath)
# merge COVID data with event data
merged_data = covid_data.merge(
event_data,
left_on=['date', 'fips'],
right_on=['date', 'fips_code'],
how='left'
)
# if missing event-related columns, fill with default values
merged_data['has_event'] = merged_data['fips_code'].notna().astype(int)
merged_data['valence'] = merged_data['valence'].fillna(0)
merged_data['size_mean'] = merged_data['size_mean'].fillna(0)
# county = '53061'
county = '06037' # LA county, want to see effect of this
county_data = merged_data[merged_data['fips'] == county]
if county_data.empty:
print("no data available for the selected county or FIPS code.")
exit()
# drop rows with NaN in relevant columns
county_data = county_data[['new_cases', 'has_event']].dropna().reset_index(drop=True)
county_data['adjusted_cases'] = county_data['new_cases'].replace(0, pd.NA).interpolate(method='linear').fillna(0)
# forward fill to handle zeros in the data
county_data['adjusted_cases'] = county_data['new_cases'].replace(0, pd.NA).fillna(method='ffill').fillna(0)
if county_data.empty:
print("no valid data after cleaning")
exit()
# fit SARIMAX model
model = SARIMAX(
## county_data['new_cases'], # uncomment this line to use original data, not adjusted cases (smoothed data)
county_data['adjusted_cases'], # testing if using adjusted_cases results in better results
exog=county_data[['has_event']],
order=(1, 1, 1), # can adjust order params
seasonal_order=(1, 0, 1, 7) # can adjust seasonal params
)
# results = model.fit() # can adjust number of iterations
results = model.fit(maxiter=1000)
forecast = results.get_forecast(steps=30, exog=[[1]] * 30) # can adjust these
forecast_ci = forecast.conf_int()
plt.figure(figsize=(15, 9))
# observed cases
plt.plot(county_data['new_cases'], label='Observed New Cases', color='blue')
# adjusted cases (smoothed data)
plt.plot(county_data['adjusted_cases'], label='Adjusted Cases (Smoothed)', color='green', linestyle='--')
# forecasted cases
plt.plot(
range(len(county_data), len(county_data) + 30),
forecast.predicted_mean,
label='Forecast',
color='orange'
)
# show confidence intervals
plt.fill_between(
range(len(county_data), len(county_data) + 30),
forecast_ci.iloc[:, 0],
forecast_ci.iloc[:, 1],
color='orange',
alpha=0.3
)
plt.title('SARIMAX Model: Observed, Adjusted, and Forecasted New Cases for county: ' + county)
plt.xlabel('Time (Days)')
plt.ylabel('New Cases')
plt.legend()
plt.grid()
plt.savefig('./output/SARIMAX_forecast_county_' + county + '.png')
plt.show()