@@ -35,9 +35,10 @@ class _DynamicLanczos:
3535 batch_count_apply_within : int
3636 single_extend : bool
3737 first_extend : bool
38- b : bool = False
3938
40- def _extend (self , psi : torch .Tensor , basic_configs : torch .Tensor | None = None ) -> None :
39+ def _extend (self , psi : torch .Tensor = None , basic_configs : torch .Tensor | None = None ) -> None :
40+ if psi is None :
41+ psi = self .psi
4142 if basic_configs is None :
4243 basic_configs = self .configs
4344 logging .info ("Extending basis..." )
@@ -47,7 +48,7 @@ def _extend(self, psi: torch.Tensor, basic_configs: torch.Tensor | None = None)
4748
4849 import time
4950 begin0 = time .time ()
50- extra = self .model .find_relative (basic_configs , psi , self . count_extend , self .configs )
51+ extra = self .model .find_relative (basic_configs , psi , 10000000 , self .configs )
5152 self .configs = torch .cat ([self .configs , extra ])
5253 begin = time .time ()
5354 a = self .model .apply_within (basic_configs , psi , extra )
@@ -59,6 +60,8 @@ def _extend(self, psi: torch.Tensor, basic_configs: torch.Tensor | None = None)
5960 self .psi = torch .nn .functional .pad (self .psi , (0 , count_selected - count_core ))
6061 logging .info ("Basis extended from %d to %d" , count_core , count_selected )
6162
63+ return self
64+
6265 def run (self ) -> typing .Iterable [tuple [torch .Tensor , torch .Tensor , torch .Tensor ]]:
6366 """
6467 Run the Lanczos algorithm.
@@ -106,18 +109,13 @@ def run(self) -> typing.Iterable[tuple[torch.Tensor, torch.Tensor, torch.Tensor]
106109 yield energy , self .configs , psi
107110 else :
108111 # Extend the configuration, during processing the dynamic lanczos.
109- first = True
110- count = 0
111112 for step in range (1 + self .step ):
112113 for _ , [alpha , beta , v ] in zip (range (1 + step ), self ._run ()):
113114 pass
114115 energy , psi = self ._eigh_tridiagonal (alpha , beta , v )
115116 yield energy , self .configs , psi
116- if not first and self .b :
117- break
118117 if step != self .step :
119118 self ._extend (v [- 1 ])
120- first = False
121119
122120 def _run (self ) -> typing .Iterable [tuple [list [torch .Tensor ], list [torch .Tensor ], list [torch .Tensor ]]]:
123121 """
@@ -382,9 +380,11 @@ def main(self, *, model_param: typing.Any = None, network_param: typing.Any = No
382380
383381 logging .info ("Sampling configurations from last iteration" )
384382 configs , original_psi = data ["imag" ]["pool" ]
385- top20436 = original_psi .abs ().argsort (descending = True )[:20436 ]
386- configs = configs [top20436 ]
387- original_psi = original_psi [top20436 ]
383+ import os
384+ if "C" in os .environ :
385+ top20436 = original_psi .abs ().argsort (descending = True )[:int (os .environ ["C" ])]
386+ configs = configs [top20436 ]
387+ original_psi = original_psi [top20436 ]
388388 logging .info ("Sampling completed, unique configurations count: %d" , len (configs ))
389389
390390 for target_energy , configs , original_psi in _DynamicLanczos (
@@ -395,27 +395,8 @@ def main(self, *, model_param: typing.Any = None, network_param: typing.Any = No
395395 threshold = self .krylov_threshold ,
396396 count_extend = 0 ,
397397 batch_count_apply_within = self .local_batch_count_apply_within ,
398- single_extend = self .krylov_single_extend ,
399- first_extend = self .krylov_extend_first ,
400- ).run ():
401- logging .info ("The current energy is %.10f where the sampling count is %d" , target_energy .item (), len (configs ))
402- writer .add_scalar ("imag/lanczos/energy" , target_energy , data ["imag" ]["lanczos" ]) # type: ignore[no-untyped-call]
403- writer .add_scalar ("imag/lanczos/error" , target_energy - model .ref_energy , data ["imag" ]["lanczos" ]) # type: ignore[no-untyped-call]
404- data ["imag" ]["lanczos" ] += 1
405-
406- logging .info ("Computing the target for local optimization" )
407- target_energy : torch .Tensor
408- for target_energy , configs , original_psi in _DynamicLanczos (
409- model = model ,
410- configs = configs ,
411- psi = original_psi ,
412- step = self .krylov_iteration ,
413- threshold = self .krylov_threshold ,
414- count_extend = self .krylov_extend_count ,
415- batch_count_apply_within = self .local_batch_count_apply_within ,
416- single_extend = self .krylov_single_extend ,
417- first_extend = self .krylov_extend_first ,
418- b = True ,
398+ single_extend = True ,
399+ first_extend = False ,
419400 ).run ():
420401 logging .info ("The current energy is %.10f where the sampling count is %d" , target_energy .item (), len (configs ))
421402 writer .add_scalar ("imag/lanczos/energy" , target_energy , data ["imag" ]["lanczos" ]) # type: ignore[no-untyped-call]
@@ -430,9 +411,9 @@ def main(self, *, model_param: typing.Any = None, network_param: typing.Any = No
430411 threshold = self .krylov_threshold ,
431412 count_extend = 0 ,
432413 batch_count_apply_within = self .local_batch_count_apply_within ,
433- single_extend = self . krylov_single_extend ,
434- first_extend = self . krylov_extend_first ,
435- ).run ():
414+ single_extend = True ,
415+ first_extend = False ,
416+ )._extend (). run ():
436417 logging .info ("The current energy is %.10f where the sampling count is %d" , target_energy .item (), len (configs ))
437418 writer .add_scalar ("imag/lanczos/energy" , target_energy , data ["imag" ]["lanczos" ]) # type: ignore[no-untyped-call]
438419 writer .add_scalar ("imag/lanczos/error" , target_energy - model .ref_energy , data ["imag" ]["lanczos" ]) # type: ignore[no-untyped-call]
0 commit comments