-
Notifications
You must be signed in to change notification settings - Fork 12
Open
Description
reproduce by code below
fn scatter_add() -> candle_core::Result<()> {
// let device = Device::new_cuda(0)?;
let device = Device::new_cuda(0)?;
let logits_idx_end = 32000_usize;
let logits_idx = Tensor::arange(0_u32, logits_idx_end as u32, &device)?.reshape((1, 32000))?;
let logits_idx_inv = Tensor::zeros_like(&logits_idx)?;
let src = Tensor::arange(0_u32, logits_idx_end as u32, logits_idx.device())?
.expand(logits_idx.shape())?
.contiguous()?;
let start = std::time::Instant::now();
let logits_idx_inv = candle_ext::F::scatter(&logits_idx_inv, &logits_idx, &src, D::Minus1)?;
match device {
Device::Cuda(cuda_dev) => {
cuda_dev.synchronize();
}
_ => {}
}
println!("scatter cost {:?}/{}", start.elapsed(), logits_idx_end);
Ok(())
}rust result(run 2times in the same process)
scatter cost 3.288861ms/32000
scatter cost 3.271358ms/32000
logits_idx = torch.arange(0,32000, dtype=torch.int64, device = 'cuda').reshape(1,32000)
logits_idx_inv = torch.zeros_like(logits_idx)
src = torch.arange(0,32000, device = 'cuda').expand(logits_idx.shape)
torch.cuda.synchronize()
start_time = time.time_ns()
logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,index=logits_idx,src=src)
torch.cuda.synchronize()
print("first cuda scatter cost ", time.time_ns() - start_time, "ns", logits_idx.shape,logits_idx_inv.shape)
logits_idx = torch.arange(0,32000, dtype=torch.int64, device = 'cuda').reshape(1,32000)
logits_idx_inv = torch.zeros_like(logits_idx)
src = torch.arange(0,32000, device = 'cuda').expand(logits_idx.shape)
torch.cuda.synchronize()
start_time = time.time_ns()
logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,index=logits_idx,src=src)
torch.cuda.synchronize()
print("cuda scatter cost ", time.time_ns() - start_time, "ns", logits_idx.shape,logits_idx_inv.shape)python result(run 2times in the same process)
first cuda scatter cost 3191597 ns torch.Size([1, 32000]) torch.Size([1, 32000])
cuda scatter cost 38734 ns torch.Size([1, 32000]) torch.Size([1, 32000])
it seems pytorch run much faster after warmup.
Metadata
Metadata
Assignees
Labels
No labels