Skip to content

Commit 124fee2

Browse files
committed
refactor(models): rename attributes to plural forms and cleanup logging
- Rename n_qubit, n_electron, and n_spin to n_qubits, n_electrons, and n_spins in fcidump.py for consistency. - Add type: ignore to yaml import in fcidump.py. - Remove redundant "Input arguments successfully parsed" log in ising.py. - Remove outdated comments in context.py.
1 parent 5778a91 commit 124fee2

File tree

3 files changed

+19
-21
lines changed

3 files changed

+19
-21
lines changed

qmp/models/fcidump.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pathlib
1010
import hashlib
1111
import torch
12-
import yaml
12+
import yaml # type: ignore[import-untyped]
1313
import openfermion
1414
import platformdirs
1515
from ..networks.mlp import WaveFunctionElectronUpDown as MlpWaveFunction
@@ -151,14 +151,14 @@ def __init__(self, args: ModelConfig) -> None:
151151
torch.save((self.hamiltonian.site, self.hamiltonian.kind, self.hamiltonian.coef), cache_file)
152152
logging.info("OpenFermion Hamiltonian successfully cached")
153153

154-
self.n_qubit: int = n_orbit * 2
155-
self.n_electron: int = n_electron
156-
self.n_spin: int = n_spin
154+
self.n_qubits: int = n_orbit * 2
155+
self.n_electrons: int = n_electron
156+
self.n_spins: int = n_spin
157157
logging.info(
158158
"Identified %d qubits, %d electrons and %d spin",
159-
self.n_qubit,
160-
self.n_electron,
161-
self.n_spin,
159+
self.n_qubits,
160+
self.n_electrons,
161+
self.n_spins,
162162
)
163163

164164
self.ref_energy: float
@@ -201,7 +201,7 @@ def show_config(self, config: torch.Tensor) -> str:
201201
string = "".join(f"{i:08b}"[::-1] for i in config.cpu().numpy())
202202
return (
203203
"["
204-
+ "".join(self._show_config_site(string[index : index + 2]) for index in range(0, self.n_qubit, 2))
204+
+ "".join(self._show_config_site(string[index : index + 2]) for index in range(0, self.n_qubits, 2))
205205
+ "]"
206206
)
207207

@@ -238,11 +238,11 @@ def create(self, model: Model) -> NetworkProto:
238238
logging.info("Hidden layer widths: %a", self.hidden)
239239

240240
network = MlpWaveFunction(
241-
double_sites=model.n_qubit,
241+
double_sites=model.n_qubits,
242242
physical_dim=2,
243243
is_complex=True,
244-
spin_up=(model.n_electron + model.n_spin) // 2,
245-
spin_down=(model.n_electron - model.n_spin) // 2,
244+
spin_up=(model.n_electrons + model.n_spins) // 2,
245+
spin_down=(model.n_electrons - model.n_spins) // 2,
246246
hidden_size=self.hidden,
247247
ordering=+1,
248248
)
@@ -298,11 +298,11 @@ def create(self, model: Model) -> NetworkProto:
298298
)
299299

300300
network = TransformersWaveFunction(
301-
double_sites=model.n_qubit,
301+
double_sites=model.n_qubits,
302302
physical_dim=2,
303303
is_complex=True,
304-
spin_up=(model.n_electron + model.n_spin) // 2,
305-
spin_down=(model.n_electron - model.n_spin) // 2,
304+
spin_up=(model.n_electrons + model.n_spins) // 2,
305+
spin_down=(model.n_electrons - model.n_spins) // 2,
306306
embedding_dim=self.embedding_dim,
307307
heads_num=self.heads_num,
308308
feed_forward_dim=self.feed_forward_dim,
@@ -336,10 +336,10 @@ def create(self, model: Model) -> NetworkProto:
336336
logging.info("Hidden layer widths: %a", self.hidden)
337337

338338
network = MlpWaveFunctionElectron(
339-
sites=model.n_qubit,
339+
sites=model.n_qubits,
340340
physical_dim=2,
341341
is_complex=True,
342-
electrons=model.n_electron,
342+
electrons=model.n_electrons,
343343
hidden_size=self.hidden,
344344
ordering=+1,
345345
)
@@ -394,10 +394,10 @@ def create(self, model: Model) -> NetworkProto:
394394
)
395395

396396
network = TransformersWaveFunctionElectron(
397-
sites=model.n_qubit,
397+
sites=model.n_qubits,
398398
physical_dim=2,
399399
is_complex=True,
400-
electrons=model.n_electron,
400+
electrons=model.n_electrons,
401401
embedding_dim=self.embedding_dim,
402402
heads_num=self.heads_num,
403403
feed_forward_dim=self.feed_forward_dim,

qmp/models/ising.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,6 @@ def _z(i: int, j: int) -> tuple[tuple[tuple[tuple[int, int], ...], complex], ...
167167
return hamiltonian
168168

169169
def __init__(self, args: ModelConfig) -> None:
170-
logging.info("Input arguments successfully parsed")
171-
172170
self.m: int = args.m
173171
self.n: int = args.n
174172
logging.info(

qmp/utility/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
@dataclasses.dataclass
2020
class RuntimeContext:
2121
"""
22-
This class defines the common runtime environment (logging, device, random seed, checkpoints).
22+
This class defines the common runtime environment.
2323
"""
2424

2525
# The log path for parent job job name, it is only used for loading the checkpoint from the parent job, leave empty to use the current job name

0 commit comments

Comments
 (0)