Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions torchvision/models/detection/roi_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,18 @@ def _onnx_heatmaps_to_keypoints_loop(


def heatmaps_to_keypoints(maps, rois):
"""Extract predicted keypoint locations from heatmaps. Output has shape
(#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
for each keypoint.
"""Extract predicted keypoint locations from heatmaps.

Args:
maps (Tensor[K, N, H, W]): The predicted heatmaps, where K is the number of RoIs,
N is the number of keypoints, and H, W are the heatmap spatial dimensions.
rois (Tensor[K, 4]): The RoI boxes in ``(x1, y1, x2, y2)`` format.

Returns:
tuple:
- **xy_preds** (Tensor[K, N, 3]): The predicted keypoint locations, where the last
dimension contains ``(x, y, v)`` with x, y being coordinates and v being visibility (always 1).
- **scores** (Tensor[K, N]): The heatmap scores at the predicted keypoint locations.
"""
# This function converts a discrete image coordinate in a HEATMAP_SIZE x
# HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
Expand Down
Loading