-
Notifications
You must be signed in to change notification settings - Fork 87
Description
When I use the example code for training hover net, I find that errors always occur during the training steps,
RuntimeError Traceback (most recent call last)
Cell In[12], line 39
36 outputs = hovernet(images)
38 # compute loss
---> 39 loss = loss_hovernet(outputs=outputs, ground_truth=[masks, hv], n_classes=6)
41 # track loss
42 minibatch_train_losses.append(loss.item())
File ~/user/anaconda3/envs/pathml/lib/python3.9/site-packages/pathml/ml/models/hovernet.py:603, in loss_hovernet(outputs, ground_truth, n_classes)
600 nucleus_mask = true_mask[:, -1, :, :] == 0
602 # from Eq. 1 in HoVer-Net paper, loss function is composed of two terms for each branch.
--> 603 np_loss_dice = _dice_loss_np_head(np_out, true_mask)
604 np_loss_ce = _ce_loss_np_head(np_out, true_mask)
606 hv_loss_grad = _loss_hv_grad(hv, true_hv, nucleus_mask)
File ~/user/anaconda3/envs/pathml/lib/python3.9/site-packages/pathml/ml/models/hovernet.py:355, in _dice_loss_np_head(np_out, true_mask, epsilon)
353 true_mask = _convert_multiclass_mask_to_binary(true_mask)
354 true_mask = true_mask.type(torch.long)
--> 355 loss = dice_loss(logits=preds, true=true_mask, eps=epsilon)
356 return loss
File ~/user/anaconda3/envs/pathml/lib/python3.9/site-packages/pathml/ml/utils.py:200, in dice_loss(true, logits, eps)
198 num_classes = logits.shape[1]
199 if num_classes == 1:
--> 200 true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)]
201 true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
202 true_1_hot_f = true_1_hot[:, 0:1, :, :]
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)
When I try to modify the code at position 200 of pathml/ml/utilis.ty based on the error message to: torch. eye (num_classes+1) ->torch. eye (num_classes+1, device=true. device), it will cause a new error as follows:
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.