Skip to content

Commit 02b1d12

Browse files
committed
...
1 parent 2946bea commit 02b1d12

File tree

1 file changed

+16
-35
lines changed

1 file changed

+16
-35
lines changed

qmb/imag.py

Lines changed: 16 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@ class _DynamicLanczos:
3535
batch_count_apply_within: int
3636
single_extend: bool
3737
first_extend: bool
38-
b: bool = False
3938

40-
def _extend(self, psi: torch.Tensor, basic_configs: torch.Tensor | None = None) -> None:
39+
def _extend(self, psi: torch.Tensor = None, basic_configs: torch.Tensor | None = None) -> None:
40+
if psi is None:
41+
psi = self.psi
4142
if basic_configs is None:
4243
basic_configs = self.configs
4344
logging.info("Extending basis...")
@@ -47,7 +48,7 @@ def _extend(self, psi: torch.Tensor, basic_configs: torch.Tensor | None = None)
4748

4849
import time
4950
begin0 = time.time()
50-
extra = self.model.find_relative(basic_configs, psi, self.count_extend, self.configs)
51+
extra = self.model.find_relative(basic_configs, psi, 10000000, self.configs)
5152
self.configs = torch.cat([self.configs, extra])
5253
begin = time.time()
5354
a = self.model.apply_within(basic_configs, psi, extra)
@@ -59,6 +60,8 @@ def _extend(self, psi: torch.Tensor, basic_configs: torch.Tensor | None = None)
5960
self.psi = torch.nn.functional.pad(self.psi, (0, count_selected - count_core))
6061
logging.info("Basis extended from %d to %d", count_core, count_selected)
6162

63+
return self
64+
6265
def run(self) -> typing.Iterable[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
6366
"""
6467
Run the Lanczos algorithm.
@@ -106,18 +109,13 @@ def run(self) -> typing.Iterable[tuple[torch.Tensor, torch.Tensor, torch.Tensor]
106109
yield energy, self.configs, psi
107110
else:
108111
# Extend the configuration, during processing the dynamic lanczos.
109-
first = True
110-
count = 0
111112
for step in range(1 + self.step):
112113
for _, [alpha, beta, v] in zip(range(1 + step), self._run()):
113114
pass
114115
energy, psi = self._eigh_tridiagonal(alpha, beta, v)
115116
yield energy, self.configs, psi
116-
if not first and self.b:
117-
break
118117
if step != self.step:
119118
self._extend(v[-1])
120-
first = False
121119

122120
def _run(self) -> typing.Iterable[tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]]:
123121
"""
@@ -382,9 +380,11 @@ def main(self, *, model_param: typing.Any = None, network_param: typing.Any = No
382380

383381
logging.info("Sampling configurations from last iteration")
384382
configs, original_psi = data["imag"]["pool"]
385-
top20436 = original_psi.abs().argsort(descending=True)[:20436]
386-
configs = configs[top20436]
387-
original_psi = original_psi[top20436]
383+
import os
384+
if "C" in os.environ:
385+
top20436 = original_psi.abs().argsort(descending=True)[:int(os.environ["C"])]
386+
configs = configs[top20436]
387+
original_psi = original_psi[top20436]
388388
logging.info("Sampling completed, unique configurations count: %d", len(configs))
389389

390390
for target_energy, configs, original_psi in _DynamicLanczos(
@@ -395,27 +395,8 @@ def main(self, *, model_param: typing.Any = None, network_param: typing.Any = No
395395
threshold=self.krylov_threshold,
396396
count_extend=0,
397397
batch_count_apply_within=self.local_batch_count_apply_within,
398-
single_extend=self.krylov_single_extend,
399-
first_extend=self.krylov_extend_first,
400-
).run():
401-
logging.info("The current energy is %.10f where the sampling count is %d", target_energy.item(), len(configs))
402-
writer.add_scalar("imag/lanczos/energy", target_energy, data["imag"]["lanczos"]) # type: ignore[no-untyped-call]
403-
writer.add_scalar("imag/lanczos/error", target_energy - model.ref_energy, data["imag"]["lanczos"]) # type: ignore[no-untyped-call]
404-
data["imag"]["lanczos"] += 1
405-
406-
logging.info("Computing the target for local optimization")
407-
target_energy: torch.Tensor
408-
for target_energy, configs, original_psi in _DynamicLanczos(
409-
model=model,
410-
configs=configs,
411-
psi=original_psi,
412-
step=self.krylov_iteration,
413-
threshold=self.krylov_threshold,
414-
count_extend=self.krylov_extend_count,
415-
batch_count_apply_within=self.local_batch_count_apply_within,
416-
single_extend=self.krylov_single_extend,
417-
first_extend=self.krylov_extend_first,
418-
b=True,
398+
single_extend=True,
399+
first_extend=False,
419400
).run():
420401
logging.info("The current energy is %.10f where the sampling count is %d", target_energy.item(), len(configs))
421402
writer.add_scalar("imag/lanczos/energy", target_energy, data["imag"]["lanczos"]) # type: ignore[no-untyped-call]
@@ -430,9 +411,9 @@ def main(self, *, model_param: typing.Any = None, network_param: typing.Any = No
430411
threshold=self.krylov_threshold,
431412
count_extend=0,
432413
batch_count_apply_within=self.local_batch_count_apply_within,
433-
single_extend=self.krylov_single_extend,
434-
first_extend=self.krylov_extend_first,
435-
).run():
414+
single_extend=True,
415+
first_extend=False,
416+
)._extend().run():
436417
logging.info("The current energy is %.10f where the sampling count is %d", target_energy.item(), len(configs))
437418
writer.add_scalar("imag/lanczos/energy", target_energy, data["imag"]["lanczos"]) # type: ignore[no-untyped-call]
438419
writer.add_scalar("imag/lanczos/error", target_energy - model.ref_energy, data["imag"]["lanczos"]) # type: ignore[no-untyped-call]

0 commit comments

Comments
 (0)