@@ -222,7 +222,6 @@ auto apply_within_interface(
222222 TORCH_CHECK (coef.size (1 ) == 2 , " coef must contain 2 elements for each term." );
223223
224224 auto result_sort_index = torch::arange (result_batch_size, torch::TensorOptions ().dtype (torch::kInt64 ).device (device, device_id));
225-
226225 std::sort (
227226 reinterpret_cast <std::int64_t *>(result_sort_index.data_ptr ()),
228227 reinterpret_cast <std::int64_t *>(result_sort_index.data_ptr ()) + result_batch_size,
@@ -385,7 +384,6 @@ void find_relative_kernel(
385384 std::int8_t sign = parity ? -1 : +1 ;
386385 double real = sign * (coef[term_index][0 ] * psi[batch_index][0 ] - coef[term_index][1 ] * psi[batch_index][1 ]);
387386 double imag = sign * (coef[term_index][0 ] * psi[batch_index][1 ] + coef[term_index][1 ] * psi[batch_index][0 ]);
388- // Currently, the weight is calculated as the probability of the state, but it can be changed to other values in the future.
389387 double weight = real * real + imag * imag;
390388 std::array<std::uint8_t , n_qubytes + sizeof (double ) / sizeof (std::uint8_t )> value;
391389 for (std::int64_t i = 0 ; i < sizeof (double ) / sizeof (uint8_t ); ++i) {
@@ -542,81 +540,87 @@ auto find_relative_interface(
542540 return unique_nonzero_result_config;
543541}
544542
545- template <std::int64_t n_qubytes_local>
543+ constexpr std::int64_t max_uint8_t = 256 ;
544+ using largest_atomic_int = unsigned int ; // The largest int type that can be atomicAdd/atomicSub
545+ using smallest_atomic_int = unsigned short int ; // The smallest int type that can be atomicCAS
546+
547+ template <std::int64_t n_qubytes>
546548struct dictionary_tree {
547- using child_t = dictionary_tree<n_qubytes_local - 1 >;
548- child_t * children[256 ];
549- int exist[256 ];
550- long long nonzero_count;
549+ using child_t = dictionary_tree<n_qubytes - 1 >;
550+ child_t * children[max_uint8_t ];
551+ smallest_atomic_int exist[max_uint8_t ];
552+ largest_atomic_int nonzero_count;
551553
552554 bool add (const std::uint8_t * begin, double real, double imag) {
553555 std::uint8_t index = *begin;
554556 if (children[index] == nullptr ) {
555557 auto new_child = (child_t *)malloc (sizeof (child_t ));
556- if (new_child != nullptr ) {
557- memset (new_child, 0 , sizeof (child_t ));
558- children[index] = new_child;
559- exist[index] = 1 ;
560- }
558+ assert (new_child != nullptr );
559+ memset (new_child, 0 , sizeof (child_t ));
560+ children[index] = new_child;
561+ exist[index] = 1 ;
561562 }
562563
563564 if (children[index]->add (begin + 1 , real, imag)) {
564565 nonzero_count++;
565566 return true ;
567+ } else {
568+ return false ;
566569 }
567- return false ;
568570 }
569571
570572 template <std::int64_t n_total_qubytes>
571573 void collect (std::uint64_t index, std::array<std::uint8_t , n_total_qubytes>* configs, std::array<double , 2 >* psi) {
572574 std::uint64_t size_counter = 0 ;
573- for (int i = 0 ; i < 256 ; ++i) {
575+ for (std:: int64_t i = 0 ; i < max_uint8_t ; ++i) {
574576 if (exist[i]) {
575- std::uint64_t sub_count = children[i]->nonzero_count ;
576- if (size_counter + sub_count > index) {
577- configs[index][n_total_qubytes - n_qubytes_local] = i;
578- children[i]->template collect <n_total_qubytes>(index - size_counter, configs, psi);
579- if (--nonzero_count == 0 ) {
577+ std::uint64_t new_size_counter = size_counter + children[i]->nonzero_count ;
578+ if (new_size_counter > index) {
579+ std::uint64_t new_index = index - size_counter;
580+ configs[index][n_total_qubytes - n_qubytes] = i;
581+ children[i]->collect <n_total_qubytes>(new_index, &configs[size_counter], &psi[size_counter]);
582+ if (--children[i]->nonzero_count == 0 ) {
580583 free (children[i]);
581- }
584+ };
582585 return ;
583586 }
584- size_counter += sub_count ;
587+ size_counter = new_size_counter ;
585588 }
586589 }
587590 }
588591};
589592
590593template <>
591594struct dictionary_tree <1 > {
592- double values[256 ][2 ];
593- int exist[256 ];
594- long long nonzero_count;
595+ double values[max_uint8_t ][2 ];
596+ smallest_atomic_int exist[max_uint8_t ];
597+ largest_atomic_int nonzero_count;
595598
596599 bool add (const std::uint8_t * begin, double real, double imag) {
597600 std::uint8_t index = *begin;
598601 values[index][0 ] += real;
599602 values[index][1 ] += imag;
600603 if (exist[index] == 0 ) {
601604 exist[index] = 1 ;
602- nonzero_count++ ;
605+ ++nonzero_count ;
603606 return true ;
607+ } else {
608+ return false ;
604609 }
605- return false ;
606610 }
607611
608612 template <std::int64_t n_total_qubytes>
609613 void collect (std::uint64_t index, std::array<std::uint8_t , n_total_qubytes>* configs, std::array<double , 2 >* psi) {
610614 std::uint64_t size_counter = 0 ;
611- for (int i = 0 ; i < 256 ; ++i) {
615+ for (std:: int64_t i = 0 ; i < max_uint8_t ; ++i) {
612616 if (exist[i]) {
613617 if (size_counter == index) {
614618 configs[index][n_total_qubytes - 1 ] = i;
615619 psi[index][0 ] = values[i][0 ];
616620 psi[index][1 ] = values[i][1 ];
617621 return ;
618622 }
619- size_counter++ ;
623+ ++size_counter ;
620624 }
621625 }
622626 }
@@ -774,9 +778,8 @@ auto list_relative_interface(
774778 );
775779
776780 auto result_tree = (dictionary_tree<n_qubytes>*)malloc (sizeof (dictionary_tree<n_qubytes>));
777- if (result_tree != nullptr ) {
778- memset (result_tree, 0 , sizeof (dictionary_tree<n_qubytes>));
779- }
781+ assert (result_tree != nullptr );
782+ memset (result_tree, 0 , sizeof (dictionary_tree<n_qubytes>));
780783
781784 list_relative_kernel_interface<max_op_number, n_qubytes, particle_cut>(
782785 term_number,
@@ -792,6 +795,7 @@ auto list_relative_interface(
792795 );
793796
794797 long long result_size = result_tree->nonzero_count ;
798+
795799 auto result_configs = torch::zeros ({result_size, n_qubytes}, torch::TensorOptions ().dtype (torch::kUInt8 ).device (torch::kCPU ));
796800 auto result_psi = torch::zeros ({result_size, 2 }, torch::TensorOptions ().dtype (torch::kFloat64 ).device (torch::kCPU ));
797801
0 commit comments