Skip to content

Commit 0b89a58

Browse files
committed
rebel...
1 parent 37e2229 commit 0b89a58

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

qmb/vmc.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,17 +108,21 @@ def closure() -> torch.Tensor:
108108
with torch.no_grad():
109109
psi_dst = network(configs_dst)
110110
hamiltonian_psi_dst = model.apply_within(configs_dst, psi_dst, configs_src)
111+
weight = psi_src / hamiltonian_psi_dst
112+
loss = weight.var()
111113
num = psi_src.conj() @ hamiltonian_psi_dst
112114
den = psi_src.conj() @ psi_src.detach()
113115
energy = num / den
114116
energy = energy.real
115-
energy.backward() # type: ignore[no-untyped-call]
116-
return energy
117+
loss.backward() # type: ignore[no-untyped-call]
118+
loss.energy = energy.item()
119+
return loss
117120

118121
logging.info("Starting local optimization process")
119122

120123
for i in range(self.local_step):
121-
energy: torch.Tensor = optimizer.step(closure) # type: ignore[assignment,arg-type]
124+
loss: torch.Tensor = optimizer.step(closure) # type: ignore[assignment,arg-type]
125+
energy = loss.energy
122126
logging.info("Local optimization in progress, step: %d, energy: %.10f, ref energy: %.10f", i, energy.item(), model.ref_energy)
123127
writer.add_scalar("vmc/energy", energy, data["vmc"]["local"]) # type: ignore[no-untyped-call]
124128
writer.add_scalar("vmc/error", energy - model.ref_energy, data["vmc"]["local"]) # type: ignore[no-untyped-call]

0 commit comments

Comments
 (0)