|
| 1 | +--- |
| 2 | +title: Core |
| 3 | +description: Core |
| 4 | +--- |
| 5 | + |
| 6 | +## |
| 7 | + |
| 8 | +HierarchicalForecast contains pure Python implementations of |
| 9 | +hierarchical reconciliation methods as well as a |
| 10 | +`core.HierarchicalReconciliation` wrapper class that enables easy |
| 11 | +interaction with these methods through pandas DataFrames containing the |
| 12 | +hierarchical time series and the base predictions. |
| 13 | + |
| 14 | +The `core.HierarchicalReconciliation` reconciliation class operates with |
| 15 | +the hierarchical time series pd.DataFrame `Y_df`, the base predictions |
| 16 | +pd.DataFrame `Y_hat_df`, the aggregation constraints matrix `S_df`. For |
| 17 | +more information on the creation of aggregation constraints matrix see |
| 18 | +the utils [aggregation |
| 19 | +method](https://nixtlaverse.nixtla.io/hierarchicalforecast/src/utils.html#aggregate) |
| 20 | + |
| 21 | +::: hierarchicalforecast.core.HierarchicalReconciliation |
| 22 | + options: |
| 23 | + members: |
| 24 | + - reconcile |
| 25 | + - bootstrap_reconcile |
| 26 | + |
| 27 | +### Example |
| 28 | + |
| 29 | +```python |
| 30 | +import pandas as pd |
| 31 | + |
| 32 | +from hierarchicalforecast.core import HierarchicalReconciliation |
| 33 | +from hierarchicalforecast.methods import BottomUp, MinTrace |
| 34 | +from hierarchicalforecast.utils import aggregate |
| 35 | +from hierarchicalforecast.evaluation import evaluate |
| 36 | +from statsforecast.core import StatsForecast |
| 37 | +from statsforecast.models import AutoETS |
| 38 | +from utilsforecast.losses import mase, rmse |
| 39 | +from functools import partial |
| 40 | + |
| 41 | +# Load TourismSmall dataset |
| 42 | +df = pd.read_csv('https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/tourism.csv') |
| 43 | +df = df.rename({'Trips': 'y', 'Quarter': 'ds'}, axis=1) |
| 44 | +df.insert(0, 'Country', 'Australia') |
| 45 | +qs = df['ds'].str.replace(r'(\d+) (Q\d)', r'\1-\2', regex=True) |
| 46 | +df['ds'] = pd.PeriodIndex(qs, freq='Q').to_timestamp() |
| 47 | + |
| 48 | +# Create hierarchical seires based on geographic levels and purpose |
| 49 | +# And Convert quarterly ds string to pd.datetime format |
| 50 | +hierarchy_levels = [['Country'], |
| 51 | + ['Country', 'State'], |
| 52 | + ['Country', 'Purpose'], |
| 53 | + ['Country', 'State', 'Region'], |
| 54 | + ['Country', 'State', 'Purpose'], |
| 55 | + ['Country', 'State', 'Region', 'Purpose']] |
| 56 | + |
| 57 | +Y_df, S_df, tags = aggregate(df=df, spec=hierarchy_levels) |
| 58 | + |
| 59 | +# Split train/test sets |
| 60 | +Y_test_df = Y_df.groupby('unique_id').tail(8) |
| 61 | +Y_train_df = Y_df.drop(Y_test_df.index) |
| 62 | + |
| 63 | +# Compute base auto-ETS predictions |
| 64 | +# Careful identifying correct data freq, this data quarterly 'Q' |
| 65 | +fcst = StatsForecast(models=[AutoETS(season_length=4, model='ZZA')], freq='QS', n_jobs=-1) |
| 66 | +Y_hat_df = fcst.forecast(df=Y_train_df, h=8, fitted=True) |
| 67 | +Y_fitted_df = fcst.forecast_fitted_values() |
| 68 | + |
| 69 | +reconcilers = [ |
| 70 | + BottomUp(), |
| 71 | + MinTrace(method='ols'), |
| 72 | + MinTrace(method='mint_shrink'), |
| 73 | + ] |
| 74 | +hrec = HierarchicalReconciliation(reconcilers=reconcilers) |
| 75 | +Y_rec_df = hrec.reconcile(Y_hat_df=Y_hat_df, |
| 76 | + Y_df=Y_fitted_df, |
| 77 | + S_df=S_df, tags=tags) |
| 78 | + |
| 79 | +# Evaluate |
| 80 | +eval_tags = {} |
| 81 | +eval_tags['Total'] = tags['Country'] |
| 82 | +eval_tags['Purpose'] = tags['Country/Purpose'] |
| 83 | +eval_tags['State'] = tags['Country/State'] |
| 84 | +eval_tags['Regions'] = tags['Country/State/Region'] |
| 85 | +eval_tags['Bottom'] = tags['Country/State/Region/Purpose'] |
| 86 | + |
| 87 | +Y_rec_df_with_y = Y_rec_df.merge(Y_test_df, on=['unique_id', 'ds'], how='left') |
| 88 | +mase_p = partial(mase, seasonality=4) |
| 89 | + |
| 90 | +evaluation = evaluate(Y_rec_df_with_y, |
| 91 | + metrics=[mase_p, rmse], |
| 92 | + tags=eval_tags, |
| 93 | + train_df=Y_train_df) |
| 94 | + |
| 95 | +numeric_cols = evaluation.select_dtypes(include="number").columns |
| 96 | +evaluation[numeric_cols] = evaluation[numeric_cols].map('{:.2f}'.format) |
| 97 | +``` |
0 commit comments