Skip to content

Commit 3cbca83

Browse files
committed
determine best checkpoint using mAP instead of val_loss
1 parent 21bf1e7 commit 3cbca83

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

keypoint_detection/models/detector.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ def __init__(
159159
# this is for later reference (e.g. checkpoint loading) and consistency.
160160
self.save_hyperparameters(ignore=["**kwargs", "backbone"])
161161

162+
self._most_recent_val_mean_ap = 0.0 # used to store the most recent validation mean AP and log it in each epoch, so that checkpoint can be chosen based on this one.
163+
162164
def forward(self, x: torch.Tensor):
163165
"""
164166
x shape must be of shape (N,3,H,W)
@@ -386,6 +388,9 @@ def log_and_reset_mean_ap(self, mode: str):
386388
self.log(f"{mode}/meanAP", mean_ap)
387389
self.log(f"{mode}/meanAP/meanAP", mean_ap)
388390

391+
if mode== "validation":
392+
self._most_recent_val_mean_ap = mean_ap
393+
389394
def training_epoch_end(self, outputs):
390395
"""
391396
Called on the end of a training epoch.
@@ -401,6 +406,7 @@ def validation_epoch_end(self, outputs):
401406
"""
402407
if self.is_ap_epoch():
403408
self.log_and_reset_mean_ap("validation")
409+
self.log("checkpointing_metrics/valmeanAP", self._most_recent_val_mean_ap)
404410

405411
def test_epoch_end(self, outputs):
406412
"""

keypoint_detection/tasks/train_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,13 @@ def create_pl_trainer(hparams: dict, wandb_logger: WandbLogger) -> Trainer:
8282
)
8383
# cf https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.loggers.wandb.html
8484

85-
# would be better to use mAP metric for checkpointing, but this is not calculated every epoch because it is rather expensive
86-
# epoch_loss still correlates rather well though
85+
# would be better to use mAP metric for checkpointing, but this is not calculated every epoch
86+
# so I manually log the last known value to make the callback happy.
8787
# only store the best checkpoint and only the weights
8888
# so cannot be used to resume training but only for inference
8989
# saves storage though and training the detector is usually cheap enough to retrain it from scratch if you need specific weights etc.
9090
checkpoint_callback = ModelCheckpoint(
91-
monitor="validation/epoch_loss", mode="min", save_weights_only=True, save_top_k=1
91+
monitor="checkpointing_metrics/valmeanAP", mode="max", save_weights_only=True, save_top_k=1
9292
)
9393

9494
trainer = pl.Trainer(**trainer_kwargs, callbacks=[early_stopping, checkpoint_callback])

0 commit comments

Comments
 (0)