Skip to content

Commit f823bbb

Browse files
committed
Fix a typo.
1 parent 6836cdd commit f823bbb

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

qmp/algorithms/pretrain.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,20 @@
1111
from ..utility.optimizer import initialize_optimizer
1212
from ..utility.subcommand_dict import subcommand_dict
1313

14-
@dataclassses.dataclass
14+
@dataclasses.dataclass
1515
class PretrainConfig:
1616
"""
1717
Configuration for pretraining quantum many-body models.
1818
"""
1919

2020
common: CommonConfig
2121

22+
# Dataset path for pretraining
23+
dataset_path: str
2224
# The learning rate for the local optimizer
2325
learning_rate: float = 1e-3
2426
# The name of the loss function to use
2527
loss_name: str = "sum_filtered_angle_scaled_log"
26-
# Dataset path for pretraining
27-
dataset_path: str
2828

2929
def main(self, *, model_param: typing.Any = None, network_param: typing.Any = None) -> None:
3030
"""
@@ -33,9 +33,9 @@ def main(self, *, model_param: typing.Any = None, network_param: typing.Any = No
3333

3434
model, network, data = self.common.main(model_param=model_param, network_param=network_param)
3535

36-
dataset = torch.load(self.dataset_path, map_location="cpu", weight_only=True)
37-
config = dataset["config"].to(device=self.common.device)
38-
psi = dataset["psi"].to(device=self.common.device)
36+
dataset = torch.load(self.dataset_path, map_location="cpu", weights_only=True)
37+
config = dataset[0].to(device=self.common.device)
38+
psi = dataset[1].to(device=self.common.device)
3939

4040
optimizer = initialize_optimizer(
4141
network.parameters(),

0 commit comments

Comments
 (0)