Skip to content

Commit f67e4b5

Browse files
committed
[Prover Service] Clean up to test types.rs
1 parent 37ea2c6 commit f67e4b5

File tree

4 files changed

+77
-107
lines changed

4 files changed

+77
-107
lines changed

prover-service/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,6 @@ tokio = { workspace = true }
5353

5454
[dev-dependencies]
5555
aptos-time-service = { workspace = true, features = ["testing"] }
56+
57+
[package.metadata.cargo-machete]
58+
ignored = ["hex"]

prover-service/src/tests/prover_handler.rs

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use crate::request_handler::prover_state::{ProverServiceState, TrainingWheelsKey
66
use crate::request_handler::types::ProverServiceResponse;
77
use crate::request_handler::{handler, prover_handler, training_wheels, types};
88
use crate::tests::types::TestJWKKeyPair;
9-
use crate::tests::types::{ProofTestCase, TestJWTPayload, WithNonce};
9+
use crate::tests::types::{ProofTestCase, TestJWTPayload};
1010
use crate::tests::utils;
1111
use anyhow::anyhow;
1212
use aptos_crypto::ed25519::{Ed25519PrivateKey, Ed25519PublicKey};
@@ -18,7 +18,6 @@ use hyper::{body, Body};
1818
use rand::prelude::ThreadRng;
1919
use rand::thread_rng;
2020
use rust_rapidsnark::FullProver;
21-
use serde::Serialize;
2221
use serial_test::serial;
2322
use std::collections::HashMap;
2423
use std::sync::Arc;
@@ -177,10 +176,7 @@ async fn request_jwt_exp_field_does_not_matter() {
177176
exp: 234342342428348284,
178177
..TestJWTPayload::default()
179178
};
180-
let testcase = ProofTestCase {
181-
..ProofTestCase::default_with_payload(jwt_payload)
182-
}
183-
.compute_nonce();
179+
let testcase = ProofTestCase::default_with_payload(jwt_payload).compute_nonce();
184180

185181
// Handle the prove request, and verify the proof
186182
convert_prove_and_verify(&testcase).await.unwrap();
@@ -206,8 +202,9 @@ async fn request_with_incorrect_nonce() {
206202

207203
#[tokio::test]
208204
#[serial]
205+
#[ignore] // Currently ignored because it takes a long time to run
209206
async fn prove_request_all_sub_lengths() {
210-
for i in [0, 64] {
207+
for i in 0..65 {
211208
// Create a JWT payload with varying lengths of the sub field
212209
let jwt_payload = TestJWTPayload {
213210
sub: Some("a".repeat(i)),
@@ -235,9 +232,7 @@ fn dummy_circuit_load_test() {
235232

236233
/// Helper function that converts a test case to a prover request,
237234
/// sends it to the prover handler, and verifies the returned proof.
238-
async fn convert_prove_and_verify(
239-
testcase: &ProofTestCase<impl Serialize + WithNonce + Clone>,
240-
) -> Result<(), anyhow::Error> {
235+
async fn convert_prove_and_verify(testcase: &ProofTestCase) -> Result<(), anyhow::Error> {
241236
// Start the aptos logger (so test failures print logs)
242237
aptos_logger::Logger::init_for_testing();
243238

@@ -248,7 +243,7 @@ async fn convert_prove_and_verify(
248243

249244
// Create the JWK cache with the test JWK
250245
let test_jwk: HashMap<KeyID, Arc<RSA_JWK>> =
251-
HashMap::from_iter([("test-rsa".to_owned(), Arc::new(jwk_keypair.into_rsa_jwk()))]);
246+
HashMap::from_iter([("test-rsa".to_owned(), Arc::new(jwk_keypair.get_rsa_jwk()))]);
252247
let jwk_cache: HashMap<Issuer, HashMap<KeyID, Arc<RSA_JWK>>> =
253248
HashMap::from_iter([("test.oidc.provider".into(), test_jwk)]);
254249

prover-service/src/tests/training_wheels.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ fn test_validate_jwt_invalid_signature() {
3232
// Verify the JWT signature using a different keypair to simulate an invalid signature
3333
let another_jwk_keypair = utils::generate_test_jwk_keypair();
3434
let result = training_wheels::validate_jwt_signature(
35-
&another_jwk_keypair.into_rsa_jwk(),
35+
&another_jwk_keypair.get_rsa_jwk(),
3636
&prover_request_input.jwt_b64,
3737
);
3838

@@ -64,7 +64,7 @@ fn test_jwt_signature_validation(jwt_payload: TestJWTPayload, expect_success: bo
6464

6565
// Verify the JWT signature
6666
let result = training_wheels::validate_jwt_signature(
67-
&jwk_keypair.into_rsa_jwk(),
67+
&jwk_keypair.get_rsa_jwk(),
6868
&prover_request_input.jwt_b64,
6969
);
7070
if expect_success {

prover-service/src/tests/types.rs

Lines changed: 66 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@ use crate::request_handler::training_wheels;
55
use crate::request_handler::types::{EphemeralPublicKeyBlinder, RequestInput};
66
use crate::tests::utils::{RsaPrivateKey, RsaPublicKey};
77
use crate::utils;
8-
use aptos_crypto::ed25519::{Ed25519PrivateKey, Ed25519PublicKey};
8+
use aptos_crypto::ed25519::Ed25519PrivateKey;
99
use aptos_crypto::encoding_type::EncodingType;
10+
use aptos_crypto::PrivateKey;
1011
use aptos_keyless_common::input_processing::encoding::FromFr;
1112
use aptos_logger::info;
1213
use aptos_types::{
1314
jwks::rsa::RSA_JWK, keyless::Pepper, transaction::authenticator::EphemeralPublicKey,
1415
};
15-
use ark_ff::{BigInteger, PrimeField};
1616
use jsonwebtoken::{Algorithm, Header};
1717
use once_cell::sync::Lazy;
1818
use serde::{Deserialize, Serialize};
@@ -24,16 +24,34 @@ use std::time::{SystemTime, UNIX_EPOCH};
2424
// The name of the local testing config file
2525
const LOCAL_TESTING_CONFIG_FILE_NAME: &str = "config_local_testing.yml";
2626

27+
// Ensures that the local testing setup has been procured
28+
static LOCAL_SETUP_PROCURED: Lazy<bool> = Lazy::new(|| {
29+
// Determine the repository root directory
30+
let mut repo_root = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
31+
let repo_root_found = repo_root.pop();
32+
33+
// Run the setup script to procure the local testing setup
34+
if repo_root_found {
35+
Command::new("bash")
36+
.arg("scripts/task.sh")
37+
.arg("setup")
38+
.arg("procure-testing-setup")
39+
.current_dir(repo_root)
40+
.status()
41+
.is_ok()
42+
} else {
43+
false
44+
}
45+
});
46+
47+
/// JWT payload struct for testing
2748
#[derive(Serialize, Deserialize, Clone)]
2849
pub struct TestJWTPayload {
2950
pub azp: String,
3051
pub aud: String,
31-
#[serde(skip_serializing_if = "Option::is_none")]
3252
pub sub: Option<String>,
33-
#[serde(skip_serializing_if = "Option::is_none")]
3453
pub email: Option<String>,
3554
pub hd: String,
36-
#[serde(skip_serializing_if = "Option::is_none")]
3755
pub email_verified: Option<bool>,
3856
pub at_hash: String,
3957
pub name: String,
@@ -47,12 +65,9 @@ pub struct TestJWTPayload {
4765
pub nonce: String,
4866
}
4967

50-
pub trait WithNonce {
51-
fn with_nonce(&self, nonce: &str) -> Self;
52-
}
53-
54-
impl WithNonce for TestJWTPayload {
55-
fn with_nonce(&self, nonce: &str) -> Self {
68+
impl TestJWTPayload {
69+
/// Creates a new TestJWTPayload with the given nonce
70+
fn new_with_nonce(&self, nonce: &str) -> Self {
5671
Self {
5772
nonce: String::from(nonce),
5873
..self.clone()
@@ -88,16 +103,15 @@ impl Default for TestJWTPayload {
88103
}
89104
}
90105

91-
// JWK keypair trait/struct
92-
106+
/// Trait for test JWK key pairs
93107
pub trait TestJWKKeyPair {
94108
fn pubkey_mod_b64(&self) -> String;
95109
fn kid(&self) -> &str;
96110
fn sign(&self, payload: &impl Serialize) -> String;
97-
#[allow(clippy::all)]
98-
fn into_rsa_jwk(&self) -> RSA_JWK;
111+
fn get_rsa_jwk(&self) -> RSA_JWK;
99112
}
100113

114+
/// Default implementation of TestJWKKeyPair using RSA keys
101115
pub struct DefaultTestJWKKeyPair {
102116
kid: String,
103117
private_key: RsaPrivateKey,
@@ -128,30 +142,34 @@ impl TestJWKKeyPair for DefaultTestJWKKeyPair {
128142
&self.kid
129143
}
130144

131-
#[allow(clippy::all)]
145+
#[allow(clippy::field_reassign_with_default)]
132146
fn sign(&self, payload: &impl Serialize) -> String {
147+
// Create the JWT header
133148
let mut header = Header::default();
134149
header.alg = Algorithm::RS256;
135150
header.kid = Some(self.kid.clone());
136151

152+
// Create the JWT
137153
let jwt =
138154
jsonwebtoken::encode(&header, &payload, &self.private_key.as_encoding_key()).unwrap();
139155

156+
// Verify the signature before returning (to ensure correctness)
140157
let jwk = RSA_JWK::new_256_aqab(self.kid.as_str(), &self.pubkey_mod_b64());
141158
assert!(jwk.verify_signature_without_exp_check(&jwt).is_ok());
142159

143160
jwt
144161
}
145162

146-
fn into_rsa_jwk(&self) -> RSA_JWK {
163+
fn get_rsa_jwk(&self) -> RSA_JWK {
147164
RSA_JWK::new_256_aqab(&self.kid, &self.pubkey_mod_b64())
148165
}
149166
}
150167

168+
/// Struct representing a proof test case
151169
#[derive(Clone)]
152-
pub struct ProofTestCase<T: Serialize + WithNonce + Clone> {
170+
pub struct ProofTestCase {
153171
pub prover_service_config: ProverServiceConfig,
154-
pub jwt_payload: T,
172+
pub jwt_payload: TestJWTPayload,
155173
pub epk: EphemeralPublicKey,
156174
pub epk_blinder_fr: ark_bn254::Fr,
157175
pub pepper: Pepper,
@@ -163,65 +181,38 @@ pub struct ProofTestCase<T: Serialize + WithNonce + Clone> {
163181
pub skip_aud_checks: bool,
164182
}
165183

166-
impl<T: Serialize + WithNonce + Clone> ProofTestCase<T> {
167-
#[allow(clippy::all)]
168-
#[allow(dead_code)]
169-
pub fn new_with_test_epk_and_blinder(
170-
jwt_payload: T,
171-
pepper: Pepper,
172-
exp_date: u64,
173-
exp_horizon: u64,
174-
extra_field: Option<String>,
175-
uid_key: String,
176-
idc_aud: Option<String>,
177-
) -> Self {
184+
impl ProofTestCase {
185+
/// Creates a default test case with the given JWT payload
186+
pub fn default_with_payload(jwt_payload: TestJWTPayload) -> Self {
187+
// Ensure that the local setup has been procured
178188
assert!(*LOCAL_SETUP_PROCURED);
179-
let prover_service_config = get_config();
180-
let circuit_metadata = prover_service_config.load_circuit_params();
181-
let epk = gen_test_ephemeral_pk();
182-
let epk_blinder = gen_test_ephemeral_pk_blinder();
183-
let nonce =
184-
training_wheels::compute_nonce(exp_date, &epk, epk_blinder, &circuit_metadata).unwrap();
185-
let payload_with_nonce = jwt_payload.with_nonce(&nonce.to_string());
186-
187-
Self {
188-
prover_service_config,
189-
jwt_payload: payload_with_nonce as T,
190-
epk,
191-
epk_blinder_fr: epk_blinder,
192-
pepper,
193-
epk_expiry_time_secs: exp_date,
194-
epk_expiry_horizon_secs: exp_horizon,
195-
extra_field,
196-
uid_key,
197-
idc_aud,
198-
skip_aud_checks: false,
199-
}
200-
}
201189

202-
pub fn default_with_payload(jwt_payload: T) -> Self {
203-
assert!(*LOCAL_SETUP_PROCURED);
204-
let epk = gen_test_ephemeral_pk();
205-
let epk_blinder = gen_test_ephemeral_pk_blinder();
206-
let pepper = get_test_pepper();
190+
// Generate test ephemeral public key and blinder
191+
let epk = generate_test_ephemeral_pk();
192+
let epk_blinder = ark_bn254::Fr::from_str("42").unwrap();
193+
let pepper = Pepper::from_number(42);
207194

208195
Self {
209-
prover_service_config: get_config(),
196+
prover_service_config: get_prover_service_config(),
210197
jwt_payload,
211198
epk,
212199
epk_blinder_fr: epk_blinder,
213200
pepper,
214201
epk_expiry_time_secs: 0,
215202
epk_expiry_horizon_secs: 100,
216-
extra_field: Some(String::from("name")),
217-
uid_key: String::from("email"),
203+
extra_field: Some("name".into()),
204+
uid_key: "email".into(),
218205
idc_aud: None,
219206
skip_aud_checks: false,
220207
}
221208
}
222209

210+
/// Computes the nonce and returns a new test case with the updated JWT payload
223211
pub fn compute_nonce(self) -> Self {
212+
// Ensure that the local setup has been procured
224213
assert!(*LOCAL_SETUP_PROCURED);
214+
215+
// Compute the nonce
225216
let circuit_metadata = self.prover_service_config.load_circuit_params();
226217
let nonce = training_wheels::compute_nonce(
227218
self.epk_expiry_time_secs,
@@ -230,17 +221,18 @@ impl<T: Serialize + WithNonce + Clone> ProofTestCase<T> {
230221
&circuit_metadata,
231222
)
232223
.unwrap();
233-
let payload_with_nonce = self.jwt_payload.with_nonce(&nonce.to_string());
224+
225+
// Create a new payload with the nonce
226+
let jwt_payload = self.jwt_payload.new_with_nonce(&nonce.to_string());
234227

235228
Self {
236-
jwt_payload: payload_with_nonce,
229+
jwt_payload,
237230
..self
238231
}
239232
}
240233

234+
/// Converts the test case to a prover request input
241235
pub fn convert_to_prover_request(&self, jwk_keypair: &impl TestJWKKeyPair) -> RequestInput {
242-
let _epk_blinder_hex_string = hex::encode(self.epk_blinder_fr.into_bigint().to_bytes_le());
243-
244236
RequestInput {
245237
jwt_b64: jwk_keypair.sign(&self.jwt_payload),
246238
epk: self.epk.clone(),
@@ -257,45 +249,25 @@ impl<T: Serialize + WithNonce + Clone> ProofTestCase<T> {
257249
}
258250
}
259251

260-
static LOCAL_SETUP_PROCURED: Lazy<bool> = Lazy::new(|| {
261-
let mut repo_root = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
262-
let repo_root_found = repo_root.pop();
263-
if repo_root_found {
264-
Command::new("bash")
265-
.arg("scripts/task.sh")
266-
.arg("setup")
267-
.arg("procure-testing-setup")
268-
.current_dir(repo_root)
269-
.status()
270-
.is_ok()
271-
} else {
272-
false
273-
}
274-
});
275-
276-
pub fn gen_test_ephemeral_pk() -> EphemeralPublicKey {
277-
let ephemeral_private_key: Ed25519PrivateKey = EncodingType::Hex
252+
/// Generates a test ephemeral public key
253+
fn generate_test_ephemeral_pk() -> EphemeralPublicKey {
254+
// Generate a test Ed25519 ephemeral keypair
255+
let ed25519_private_key: Ed25519PrivateKey = EncodingType::Hex
278256
.decode_key(
279257
"zkid test ephemeral private key",
280258
"0x76b8e0ada0f13d90405d6ae55386bd28bdd219b8a08ded1aa836efcc8b770dc7"
281259
.as_bytes()
282260
.to_vec(),
283261
)
284262
.unwrap();
285-
let ephemeral_public_key_unwrapped: Ed25519PublicKey =
286-
Ed25519PublicKey::from(&ephemeral_private_key);
287-
EphemeralPublicKey::ed25519(ephemeral_public_key_unwrapped)
288-
}
289-
290-
pub fn get_test_pepper() -> Pepper {
291-
Pepper::from_number(42)
292-
}
263+
let ed25519_public_key = ed25519_private_key.public_key();
293264

294-
pub fn gen_test_ephemeral_pk_blinder() -> ark_bn254::Fr {
295-
ark_bn254::Fr::from_str("42").unwrap()
265+
// Return the ephemeral public key
266+
EphemeralPublicKey::ed25519(ed25519_public_key)
296267
}
297268

298-
pub fn get_config() -> ProverServiceConfig {
269+
/// Loads and returns the prover service config for local testing
270+
fn get_prover_service_config() -> ProverServiceConfig {
299271
// Read the config file contents
300272
let config_file_contents = utils::read_string_from_file_path(LOCAL_TESTING_CONFIG_FILE_NAME);
301273

0 commit comments

Comments
 (0)