Skip to content

Commit ff9ef83

Browse files
authored
Improve the hamiltonian subpackage. (#161)
2 parents 1db7b87 + 3f5d3f4 commit ff9ef83

File tree

3 files changed

+113
-113
lines changed

3 files changed

+113
-113
lines changed

qmp/hamiltonian/_hamiltonian_cpu.cpp

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,6 @@ auto apply_within_interface(
222222
TORCH_CHECK(coef.size(1) == 2, "coef must contain 2 elements for each term.");
223223

224224
auto result_sort_index = torch::arange(result_batch_size, torch::TensorOptions().dtype(torch::kInt64).device(device, device_id));
225-
226225
std::sort(
227226
reinterpret_cast<std::int64_t*>(result_sort_index.data_ptr()),
228227
reinterpret_cast<std::int64_t*>(result_sort_index.data_ptr()) + result_batch_size,
@@ -385,7 +384,6 @@ void find_relative_kernel(
385384
std::int8_t sign = parity ? -1 : +1;
386385
double real = sign * (coef[term_index][0] * psi[batch_index][0] - coef[term_index][1] * psi[batch_index][1]);
387386
double imag = sign * (coef[term_index][0] * psi[batch_index][1] + coef[term_index][1] * psi[batch_index][0]);
388-
// Currently, the weight is calculated as the probability of the state, but it can be changed to other values in the future.
389387
double weight = real * real + imag * imag;
390388
std::array<std::uint8_t, n_qubytes + sizeof(double) / sizeof(std::uint8_t)> value;
391389
for (std::int64_t i = 0; i < sizeof(double) / sizeof(uint8_t); ++i) {
@@ -542,81 +540,87 @@ auto find_relative_interface(
542540
return unique_nonzero_result_config;
543541
}
544542

545-
template<std::int64_t n_qubytes_local>
543+
constexpr std::int64_t max_uint8_t = 256;
544+
using largest_atomic_int = unsigned int; // The largest int type that can be atomicAdd/atomicSub
545+
using smallest_atomic_int = unsigned short int; // The smallest int type that can be atomicCAS
546+
547+
template<std::int64_t n_qubytes>
546548
struct dictionary_tree {
547-
using child_t = dictionary_tree<n_qubytes_local - 1>;
548-
child_t* children[256];
549-
int exist[256];
550-
long long nonzero_count;
549+
using child_t = dictionary_tree<n_qubytes - 1>;
550+
child_t* children[max_uint8_t];
551+
smallest_atomic_int exist[max_uint8_t];
552+
largest_atomic_int nonzero_count;
551553

552554
bool add(const std::uint8_t* begin, double real, double imag) {
553555
std::uint8_t index = *begin;
554556
if (children[index] == nullptr) {
555557
auto new_child = (child_t*)malloc(sizeof(child_t));
556-
if (new_child != nullptr) {
557-
memset(new_child, 0, sizeof(child_t));
558-
children[index] = new_child;
559-
exist[index] = 1;
560-
}
558+
assert(new_child != nullptr);
559+
memset(new_child, 0, sizeof(child_t));
560+
children[index] = new_child;
561+
exist[index] = 1;
561562
}
562563

563564
if (children[index]->add(begin + 1, real, imag)) {
564565
nonzero_count++;
565566
return true;
567+
} else {
568+
return false;
566569
}
567-
return false;
568570
}
569571

570572
template<std::int64_t n_total_qubytes>
571573
void collect(std::uint64_t index, std::array<std::uint8_t, n_total_qubytes>* configs, std::array<double, 2>* psi) {
572574
std::uint64_t size_counter = 0;
573-
for (int i = 0; i < 256; ++i) {
575+
for (std::int64_t i = 0; i < max_uint8_t; ++i) {
574576
if (exist[i]) {
575-
std::uint64_t sub_count = children[i]->nonzero_count;
576-
if (size_counter + sub_count > index) {
577-
configs[index][n_total_qubytes - n_qubytes_local] = i;
578-
children[i]->template collect<n_total_qubytes>(index - size_counter, configs, psi);
579-
if (--nonzero_count == 0) {
577+
std::uint64_t new_size_counter = size_counter + children[i]->nonzero_count;
578+
if (new_size_counter > index) {
579+
std::uint64_t new_index = index - size_counter;
580+
configs[index][n_total_qubytes - n_qubytes] = i;
581+
children[i]->collect<n_total_qubytes>(new_index, &configs[size_counter], &psi[size_counter]);
582+
if (--children[i]->nonzero_count == 0) {
580583
free(children[i]);
581-
}
584+
};
582585
return;
583586
}
584-
size_counter += sub_count;
587+
size_counter = new_size_counter;
585588
}
586589
}
587590
}
588591
};
589592

590593
template<>
591594
struct dictionary_tree<1> {
592-
double values[256][2];
593-
int exist[256];
594-
long long nonzero_count;
595+
double values[max_uint8_t][2];
596+
smallest_atomic_int exist[max_uint8_t];
597+
largest_atomic_int nonzero_count;
595598

596599
bool add(const std::uint8_t* begin, double real, double imag) {
597600
std::uint8_t index = *begin;
598601
values[index][0] += real;
599602
values[index][1] += imag;
600603
if (exist[index] == 0) {
601604
exist[index] = 1;
602-
nonzero_count++;
605+
++nonzero_count;
603606
return true;
607+
} else {
608+
return false;
604609
}
605-
return false;
606610
}
607611

608612
template<std::int64_t n_total_qubytes>
609613
void collect(std::uint64_t index, std::array<std::uint8_t, n_total_qubytes>* configs, std::array<double, 2>* psi) {
610614
std::uint64_t size_counter = 0;
611-
for (int i = 0; i < 256; ++i) {
615+
for (std::int64_t i = 0; i < max_uint8_t; ++i) {
612616
if (exist[i]) {
613617
if (size_counter == index) {
614618
configs[index][n_total_qubytes - 1] = i;
615619
psi[index][0] = values[i][0];
616620
psi[index][1] = values[i][1];
617621
return;
618622
}
619-
size_counter++;
623+
++size_counter;
620624
}
621625
}
622626
}
@@ -774,9 +778,8 @@ auto list_relative_interface(
774778
);
775779

776780
auto result_tree = (dictionary_tree<n_qubytes>*)malloc(sizeof(dictionary_tree<n_qubytes>));
777-
if (result_tree != nullptr) {
778-
memset(result_tree, 0, sizeof(dictionary_tree<n_qubytes>));
779-
}
781+
assert(result_tree != nullptr);
782+
memset(result_tree, 0, sizeof(dictionary_tree<n_qubytes>));
780783

781784
list_relative_kernel_interface<max_op_number, n_qubytes, particle_cut>(
782785
term_number,
@@ -792,6 +795,7 @@ auto list_relative_interface(
792795
);
793796

794797
long long result_size = result_tree->nonzero_count;
798+
795799
auto result_configs = torch::zeros({result_size, n_qubytes}, torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU));
796800
auto result_psi = torch::zeros({result_size, 2}, torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU));
797801

0 commit comments

Comments
 (0)