diff --git a/code/07_05_random_forest.py b/code/07_05_random_forest.py index 7f5b8a3..3aa5d52 100644 --- a/code/07_05_random_forest.py +++ b/code/07_05_random_forest.py @@ -15,6 +15,11 @@ cont_vars = ['dist', 'st_x', 'st_y', 'period_time_remaining', 'empty'] cat_vars = ['pos', 'hand', 'period'] +# replace periods 1, 2, 3 ... with P1, P2, P3 ... +# this is so that when we turn them into dummy variables the column names are +# P1, P2, ... and not just 1, 2, which can cause issues +df['period'] = 'P' + df['period'].astype(str) + df_cat = pd.concat([pd.get_dummies(df[x]) for x in cat_vars], axis=1) df_all = pd.concat([df[cont_vars], df_cat], axis=1)