Skip to content

Issue with Updating viewpoint_cam.cam_trans_delta and viewpoint_cam.cam_rot_delta in Optimizer #23

@WillLan

Description

@WillLan

I am experiencing an issue where the viewpoint_cam.cam_trans_delta and viewpoint_cam.cam_rot_delta parameters are not being updated as expected during optimization. After reviewing the code, I believe the issue is caused by the update_pose(viewpoint_cam) being executed inside the with torch.no_grad() block. This prevents the gradients from being computed and updated for these parameters.

Here is the relevant code snippet:

pose_optimizer = torch.optim.Adam([{"params": [viewpoint_cam.cam_trans_delta], "lr": opt.translation_lr_init}, 
                                   {"params": [viewpoint_cam.cam_rot_delta], "lr": opt.rotation_lr_init}])

gt_image = viewpoint_cam.original_image.cuda()

progress_bar = tqdm(range(0, pose_iteration), desc="Pose estimation progress")
for iteration in range(pose_iteration):
    voxel_visible_mask = prefilter_voxel(viewpoint_cam, gaussians, pipe, background)
    render_pkg = render(viewpoint_cam, gaussians, pipe, background, visible_mask=voxel_visible_mask, retain_grad=True)
    image = render_pkg["render"]
    rendered_depth = render_pkg["depth"][0]
    occ_mask = get_occlusion_mask(viewpoint_cam=pre_viewpoint_cam1, viewpoint_cam2=viewpoint_cam, 
                                  depth=pre_rendered_depth, device=pre_rendered_depth.device, thresh=0.001).detach()

    Ll1 = l1_loss(image[:, occ_mask], gt_image[:, occ_mask])
    loss = Ll1

    # 2D correspondence loss
    if opt.loss_2d_correspondence_weight > 0 and viewpoint_cam.uid > 0:
        view1 = scene.getTrainCameras()[viewpoint_cam.uid - 1]
        view2 = viewpoint_cam
        kp0, kp1, conf = view2.kp0.cuda(), view2.kp1.cuda(), view2.conf.cuda()
        loss_2d = correspondence_2d_loss(kp0, kp1, conf, rendered_depth, 
                                         view2.view_world_transform, view1.world_view_transform, view2.intrinsic)
        loss += loss_2d * opt.loss_2d_correspondence_weight

    loss.backward()

    with torch.no_grad():
        pose_optimizer.step()
        pose_optimizer.zero_grad(set_to_none=True)
        gaussians.optimizer.zero_grad(set_to_none=True)
        gaussians.pose_optimizer.zero_grad(set_to_none=True)
        update_pose(viewpoint_cam)

        if iteration % 10 == 0:
            progress_bar.set_postfix({"Loss": f"{loss:.{7}f}"})
            progress_bar.update(10)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions