@@ -46,6 +46,7 @@ def __init__(
4646 include_generation_history : bool = False ,
4747 include_sun : bool = True ,
4848 include_time : bool = False ,
49+ t0_embedding_dim : int = 0 ,
4950 location_id_mapping : dict [Any , int ] | None = None ,
5051 embedding_dim : int = 16 ,
5152 forecast_minutes : int = 30 ,
@@ -85,6 +86,8 @@ def __init__(
8586 include_generation_history: Include generation yield data.
8687 include_sun: Include sun azimuth and altitude data.
8788 include_time: Include sine and cosine of dates and times.
89+ t0_embedding_dim: Shape of the embedding of the init-time (t0) of the forecast. Not used
90+ if set to 0.
8891 location_id_mapping: A dictionary mapping the location ID to an integer. ID embedding is
8992 not used if this is not provided.
9093 embedding_dim: Number of embedding dimensions to use for location ID.
@@ -119,6 +122,7 @@ def __init__(
119122 self .include_pv = pv_encoder is not None
120123 self .include_sun = include_sun
121124 self .include_time = include_time
125+ self .t0_embedding_dim = t0_embedding_dim
122126 self .location_id_mapping = location_id_mapping
123127 self .embedding_dim = embedding_dim
124128 self .add_image_embedding_channel = add_image_embedding_channel
@@ -246,6 +250,8 @@ def __init__(
246250 # Update num features
247251 fusion_input_features += 32
248252
253+ fusion_input_features += self .t0_embedding_dim
254+
249255 if include_generation_history :
250256 # Update num features
251257 fusion_input_features += self .history_len + 1
@@ -321,6 +327,9 @@ def forward(self, x: TensorBatch) -> torch.Tensor:
321327 time = self .time_fc1 (time )
322328 modes ["time" ] = time
323329
330+ if self .t0_embedding_dim > 0 :
331+ modes ["t0_embed" ] = x ["t0_embedding" ]
332+
324333 out = self .output_network (modes )
325334
326335 if self .use_quantile_regression :
0 commit comments