Skip to content

Commit 6c004cf

Browse files
committed
Add option to use t0 embedding features
1 parent da5c9e9 commit 6c004cf

File tree

4 files changed

+15
-1
lines changed

4 files changed

+15
-1
lines changed

pvnet/models/late_fusion/late_fusion.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646
include_generation_history: bool = False,
4747
include_sun: bool = True,
4848
include_time: bool = False,
49+
t0_embedding_dim: int = 0,
4950
location_id_mapping: dict[Any, int] | None = None,
5051
embedding_dim: int = 16,
5152
forecast_minutes: int = 30,
@@ -85,6 +86,8 @@ def __init__(
8586
include_generation_history: Include generation yield data.
8687
include_sun: Include sun azimuth and altitude data.
8788
include_time: Include sine and cosine of dates and times.
89+
t0_embedding_dim: Shape of the embedding of the init-time (t0) of the forecast. Not used
90+
if set to 0.
8891
location_id_mapping: A dictionary mapping the location ID to an integer. ID embedding is
8992
not used if this is not provided.
9093
embedding_dim: Number of embedding dimensions to use for location ID.
@@ -119,6 +122,7 @@ def __init__(
119122
self.include_pv = pv_encoder is not None
120123
self.include_sun = include_sun
121124
self.include_time = include_time
125+
self.t0_embedding_dim = t0_embedding_dim
122126
self.location_id_mapping = location_id_mapping
123127
self.embedding_dim = embedding_dim
124128
self.add_image_embedding_channel = add_image_embedding_channel
@@ -246,6 +250,8 @@ def __init__(
246250
# Update num features
247251
fusion_input_features += 32
248252

253+
fusion_input_features += self.t0_embedding_dim
254+
249255
if include_generation_history:
250256
# Update num features
251257
fusion_input_features += self.history_len + 1
@@ -321,6 +327,9 @@ def forward(self, x: TensorBatch) -> torch.Tensor:
321327
time = self.time_fc1(time)
322328
modes["time"] = time
323329

330+
if self.t0_embedding_dim>0:
331+
modes["t0_embed"] = x["t0_embedding"]
332+
324333
out = self.output_network(modes)
325334

326335
if self.use_quantile_regression:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ readme = {file="README.md", content-type="text/markdown"}
1212
requires-python = ">=3.11,<3.14"
1313

1414
dependencies = [
15-
"ocf-data-sampler>=0.6.0",
15+
"ocf-data-sampler>=1.0.8",
1616
"numpy",
1717
"pandas",
1818
"matplotlib",

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ def raw_late_fusion_model_kwargs_generation_history(model_minutes_kwargs) -> dic
327327
embedding_dim=None,
328328
include_sun=False,
329329
include_time=True,
330+
t0_embedding_dim=3,
330331
include_generation_history=True,
331332
forecast_minutes=480,
332333
history_minutes=60,

tests/test_data/data_config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,7 @@ input_data:
125125
interval_start_minutes: -60
126126
interval_end_minutes: 480
127127
time_resolution_minutes: 30
128+
129+
t0_embedding:
130+
periods: ["1h", "24h"]
131+
embeddings: ["linear", "cyclic"]

0 commit comments

Comments
 (0)