Skip to content

Commit 5db1bdf

Browse files
authored
add Dinov2 backbone (#43)
* add dinoV2 with linear head * format * improve dino decoder and freeze encoder * improve naming in model docs * fix deprecated mamba setup in pytest workflow * fix typo * try to fix dependency hell * dependency merde * invalidate cached env upon changes
1 parent 3cbca83 commit 5db1bdf

File tree

8 files changed

+183
-17
lines changed

8 files changed

+183
-17
lines changed

.github/workflows/pytest.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@ jobs:
1414
run: |
1515
sudo apt-get install fonts-freefont-ttf
1616
- name: install conda env with micromamba
17-
uses: mamba-org/provision-with-micromamba@main
17+
uses: mamba-org/setup-micromamba@v2
1818
with:
1919
channel-priority: strict
2020
environment-file: environment.yaml
2121
cache-env: true
22+
# add hash of environment.yaml and setup.py
23+
cache-environment-key: environment-${{ steps.date.outputs.date }} -${{ hashFiles('environment.yaml') }} -${{ hashFiles('setup.py') }}
24+
cache-downloads-key: downloads-${{ steps.date.outputs.date }} - ${{ hashFiles('environment.yaml') }} -${{ hashFiles('setup.py') }}
2225
- name: Conda list
2326
shell: bash -l {0}
2427
run: conda list

environment.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ dependencies:
88
- pytorch=1.13
99
- pytorch-cuda=11.7
1010
- torchvision
11+
- mkl==2024.0 # bug, https://github.com/pytorch/pytorch/issues/123097
1112
- pip
1213
- pip:
1314
- wandb>=0.13.7 # quick fix, gh actions failed to install wandb https://github.com/tlpss/keypoint-detection/actions/runs/3204224778/jobs/5235259475
14-
- setuptools==59.5.0
15+
- setuptools==70.0
1516
- -e .

keypoint_detection/data/coco_dataset.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def __getitem__(self, index) -> Tuple[torch.Tensor, IMG_KEYPOINTS_TYPE]:
117117
image = self.image_to_tensor_transform(image)
118118
return image, keypoints
119119

120-
def prepare_dataset(self):
120+
def prepare_dataset(self): # noqa: C901
121121
"""Prepares the dataset to map from COCO to (img, [keypoints for each channel])
122122
123123
Returns:
@@ -161,7 +161,11 @@ def prepare_dataset(self):
161161
for semantic_type, keypoints in keypoint_dict.items():
162162
for keypoint in keypoints:
163163

164-
if min(keypoint[:2]) < 0 or keypoint[0] > img_dict[img_id].width or keypoint[1] > img_dict[img_id].height:
164+
if (
165+
min(keypoint[:2]) < 0
166+
or keypoint[0] > img_dict[img_id].width
167+
or keypoint[1] > img_dict[img_id].height
168+
):
165169
print("keypoint outside of image, ignoring.")
166170
continue
167171
if self.is_keypoint_visible(keypoint):

keypoint_detection/models/backbones/backbone_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from keypoint_detection.models.backbones.base_backbone import Backbone
55
from keypoint_detection.models.backbones.convnext_unet import ConvNeXtUnet
66
from keypoint_detection.models.backbones.dilated_cnn import DilatedCnn
7+
from keypoint_detection.models.backbones.dinov2 import DinoV2Up
78
from keypoint_detection.models.backbones.maxvit_unet import MaxVitPicoUnet, MaxVitUnet
89
from keypoint_detection.models.backbones.mobilenetv3 import MobileNetV3
910
from keypoint_detection.models.backbones.s3k import S3K
@@ -20,6 +21,7 @@ class BackboneFactory:
2021
S3K,
2122
DilatedCnn,
2223
MobileNetV3,
24+
DinoV2Up,
2325
]
2426

2527
@staticmethod
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import timm
2+
import torch
3+
import torch.nn as nn
4+
from torchvision.models.feature_extraction import create_feature_extractor
5+
from torchvision.transforms import Resize
6+
7+
from keypoint_detection.models.backbones.base_backbone import Backbone
8+
9+
10+
class UpSamplingBlock(nn.Module):
11+
"""
12+
A very basic Upsampling block (these params have to be learnt from scratch so keep them small)
13+
14+
x --> up ---> conv1 --> norm -> relu
15+
16+
"""
17+
18+
def __init__(self, n_channels_in, n_channels_out, kernel_size):
19+
super().__init__()
20+
21+
self.conv1 = nn.Conv2d(
22+
in_channels=n_channels_in,
23+
out_channels=n_channels_out,
24+
kernel_size=kernel_size,
25+
bias=False,
26+
padding="same",
27+
)
28+
29+
self.norm1 = nn.BatchNorm2d(n_channels_out)
30+
self.relu1 = nn.ReLU()
31+
32+
def forward(self, x):
33+
# bilinear is not deterministic, use nearest neighbor instead
34+
x = nn.functional.interpolate(x, scale_factor=2.0)
35+
x = self.conv1(x)
36+
x = self.norm1(x)
37+
x = self.relu1(x)
38+
39+
# second conv as in original UNet upsampling block decreases performance
40+
# probably because I was using a small dataset that did not have enough data to learn the extra parameters
41+
return x
42+
43+
44+
class DinoV2Up(Backbone):
45+
"""
46+
backbone based on a frozen Dino-v2 ViT-S model and a number of conv-based upsampling blocks to go from patch-level to pixel-level.
47+
Images are resized to 518x518 before being fed to the ViT.
48+
49+
The Dino v2 paper considers adding both a linear layer and a full-blown DPT head to the intermediate output of the last 4 blocks of the ViT.
50+
51+
This model can be considered as a simpler alternative to the DPT head that also aims to increase resolution of the features.
52+
53+
The upsample blocks add about 6M params, bringing the total to 28 params.
54+
only these blocks are trained, the dino model is frozen.
55+
56+
Dinov2 paper: https://arxiv.org/pdf/2304.07193#page=13.87
57+
DPT paper: https://arxiv.org/abs/2103.13413
58+
59+
60+
THe head is most likely not the optimal architecture. reducing the #params in the decoder does not work for sure.
61+
Unfreezing the dino model doesn't work either (for small datasets).
62+
"""
63+
64+
def __init__(self, **kwargs):
65+
super().__init__()
66+
self.encoder = timm.create_model(
67+
"vit_small_patch14_dinov2.lvd142m",
68+
pretrained=True,
69+
num_classes=0, # remove classifier nn.Linear
70+
)
71+
72+
# get model specific transforms (normalization, resize)
73+
self.img_resizer = Resize((518, 518)) # specific to DinoV2 ViT
74+
75+
self.feature_extractor = create_feature_extractor(
76+
self.encoder, ["blocks.8", "blocks.9", "blocks.10", "blocks.11"]
77+
)
78+
79+
# freeze the feature extractor
80+
for param in self.feature_extractor.parameters():
81+
param.requires_grad = False
82+
83+
self.upsamplingblocks = nn.ModuleList(
84+
[
85+
UpSamplingBlock(4 * 384, 384, 3),
86+
UpSamplingBlock(384, 192, 3),
87+
UpSamplingBlock(192, 96, 3),
88+
UpSamplingBlock(96, 96, 3),
89+
]
90+
)
91+
92+
def forward(self, x):
93+
orig_image_shape = x.shape[-2:]
94+
x = self.img_resizer(x)
95+
features = self.feature_extractor(x) # [(B,1370,384)]
96+
features = list(features.values())
97+
# concatenate the features
98+
features = torch.cat(features, dim=2)
99+
# drop class token patch
100+
features = features[:, 1:] # (B, 1369, 384)
101+
102+
# reshape to (B,B, 37,37,4*384)
103+
features = features.view(features.shape[0], 37, 37, -1)
104+
105+
# permute to (B, 4*384, 37, 37)
106+
features = features.permute(0, 3, 1, 2)
107+
108+
# upsample 3 times 2x to 37*8 = 296
109+
for i in range(3):
110+
features = self.upsamplingblocks[i](features)
111+
112+
# resize to 518/2 = 259
113+
features = nn.functional.interpolate(features, size=(259, 259))
114+
# upsample final time to 518
115+
features = self.upsamplingblocks[-1](features)
116+
117+
# now resize to original image shape
118+
features = nn.functional.interpolate(features, size=orig_image_shape)
119+
return features
120+
121+
def get_n_channels_out(self):
122+
return 96
123+
124+
125+
if __name__ == "__main__":
126+
model = DinoV2Up()
127+
128+
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
129+
print(f"num trainable params = {num_params/10**6:.2f} M")
130+
131+
x = torch.zeros((1, 3, 512, 512))
132+
y = model(x)
133+
print(y.shape)

keypoint_detection/models/detector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def __init__(
159159
# this is for later reference (e.g. checkpoint loading) and consistency.
160160
self.save_hyperparameters(ignore=["**kwargs", "backbone"])
161161

162-
self._most_recent_val_mean_ap = 0.0 # used to store the most recent validation mean AP and log it in each epoch, so that checkpoint can be chosen based on this one.
162+
self._most_recent_val_mean_ap = 0.0 # used to store the most recent validation mean AP and log it in each epoch, so that checkpoint can be chosen based on this one.
163163

164164
def forward(self, x: torch.Tensor):
165165
"""
@@ -306,7 +306,7 @@ def log_channel_predictions_grids(self, image_grids, mode: str):
306306
for channel_configuration, grid in zip(self.keypoint_channel_configuration, image_grids):
307307
label = get_logging_label_from_channel_configuration(channel_configuration, mode)
308308
image_caption = "top: predicted heatmaps, bottom: gt heatmaps"
309-
self.logger.experiment.log({label: wandb.Image(grid, caption=image_caption,file_type="jpg")})
309+
self.logger.experiment.log({label: wandb.Image(grid, caption=image_caption, file_type="jpg")})
310310

311311
def visualize_predicted_keypoints(self, result_dict):
312312
images = result_dict["input_images"]
@@ -388,7 +388,7 @@ def log_and_reset_mean_ap(self, mode: str):
388388
self.log(f"{mode}/meanAP", mean_ap)
389389
self.log(f"{mode}/meanAP/meanAP", mean_ap)
390390

391-
if mode== "validation":
391+
if mode == "validation":
392392
self._most_recent_val_mean_ap = mean_ap
393393

394394
def training_epoch_end(self, outputs):

keypoint_detection/models/metrics.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
import torch
1313
from torchmetrics import Metric
14-
from torchmetrics.utilities import check_forward_full_state_property
1514

1615

1716
@dataclass
@@ -239,11 +238,35 @@ def _zero_aware_division(num: float, denom: float) -> float:
239238
return num / denom
240239

241240

242-
if __name__ == "__main__":
243-
print(
244-
check_forward_full_state_property(
245-
KeypointAPMetric,
246-
init_args={"keypoint_threshold_distance": 2.0},
247-
input_args={"detected_keypoints": [DetectedKeypoint(10, 20, 0.02)], "gt_keypoints": [Keypoint(10, 23)]},
248-
)
249-
)
241+
# if __name__ == "__main__":
242+
# print(
243+
# check_forward_full_state_property(
244+
# KeypointAPMetric,
245+
# init_args={"keypoint_threshold_distance": 2.0},
246+
# input_args={"detected_keypoints": [DetectedKeypoint(10, 20, 0.02)], "gt_keypoints": [Keypoint(10, 23)]},
247+
# )
248+
# )
249+
250+
251+
# if __name__ == "__main__":
252+
# import numpy as np
253+
# from sklearn.metrics import average_precision_score, precision_recall_curve
254+
# import matplotlib.pyplot as plt
255+
256+
# y_true = np.array([1, 1, 0, 1,0,0,0,0])
257+
# y_scores = np.array([0.1, 0.4, 0.35, 0.8,0.01,0.01,0.01,0.01])
258+
259+
# y_true = np.random.randint(0,2,100)
260+
# y_scores = np.random.rand(100)
261+
# sklearn_precisions, sklearn_recalls, _ = precision_recall_curve(y_true, y_scores)
262+
# sklearnAP = average_precision_score(y_true, y_scores)
263+
264+
# print(f"sklearn AP: {sklearnAP}")
265+
# my_precisions, my_recalls = calculate_precision_recall([ClassifiedKeypoint(None,None,y_scores[i],None,y_true[i]) for i in range(len(y_true))], sum(y_true))
266+
# myAP = calculate_ap_from_pr(my_precisions, my_recalls)
267+
# print(f"my AP: {myAP}")
268+
269+
# plt.plot(sklearn_recalls, sklearn_precisions, label=f"sklearn AP: {sklearnAP}")
270+
# plt.plot(my_recalls, my_precisions, label=f"my AP: {myAP}")
271+
# plt.legend()
272+
# plt.savefig("test.png")

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"pytest",
2020
"pre-commit",
2121
"scikit-image",
22-
"albumentations",
22+
"albumentations<2.0", # >=2.0 requires higher version of pydantic than wandb currently allows
2323
"matplotlib",
2424
"pydantic>=2.0.0", # 2.0 has breaking changes
2525
"fiftyone",

0 commit comments

Comments
 (0)