-
Notifications
You must be signed in to change notification settings - Fork 37
Open
Description
I am experiencing an issue where the viewpoint_cam.cam_trans_delta and viewpoint_cam.cam_rot_delta parameters are not being updated as expected during optimization. After reviewing the code, I believe the issue is caused by the update_pose(viewpoint_cam) being executed inside the with torch.no_grad() block. This prevents the gradients from being computed and updated for these parameters.
Here is the relevant code snippet:
pose_optimizer = torch.optim.Adam([{"params": [viewpoint_cam.cam_trans_delta], "lr": opt.translation_lr_init},
{"params": [viewpoint_cam.cam_rot_delta], "lr": opt.rotation_lr_init}])
gt_image = viewpoint_cam.original_image.cuda()
progress_bar = tqdm(range(0, pose_iteration), desc="Pose estimation progress")
for iteration in range(pose_iteration):
voxel_visible_mask = prefilter_voxel(viewpoint_cam, gaussians, pipe, background)
render_pkg = render(viewpoint_cam, gaussians, pipe, background, visible_mask=voxel_visible_mask, retain_grad=True)
image = render_pkg["render"]
rendered_depth = render_pkg["depth"][0]
occ_mask = get_occlusion_mask(viewpoint_cam=pre_viewpoint_cam1, viewpoint_cam2=viewpoint_cam,
depth=pre_rendered_depth, device=pre_rendered_depth.device, thresh=0.001).detach()
Ll1 = l1_loss(image[:, occ_mask], gt_image[:, occ_mask])
loss = Ll1
# 2D correspondence loss
if opt.loss_2d_correspondence_weight > 0 and viewpoint_cam.uid > 0:
view1 = scene.getTrainCameras()[viewpoint_cam.uid - 1]
view2 = viewpoint_cam
kp0, kp1, conf = view2.kp0.cuda(), view2.kp1.cuda(), view2.conf.cuda()
loss_2d = correspondence_2d_loss(kp0, kp1, conf, rendered_depth,
view2.view_world_transform, view1.world_view_transform, view2.intrinsic)
loss += loss_2d * opt.loss_2d_correspondence_weight
loss.backward()
with torch.no_grad():
pose_optimizer.step()
pose_optimizer.zero_grad(set_to_none=True)
gaussians.optimizer.zero_grad(set_to_none=True)
gaussians.pose_optimizer.zero_grad(set_to_none=True)
update_pose(viewpoint_cam)
if iteration % 10 == 0:
progress_bar.set_postfix({"Loss": f"{loss:.{7}f}"})
progress_bar.update(10)
Metadata
Metadata
Assignees
Labels
No labels