Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion prover-service/src/external_resources/jwk_fetcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub trait JWKIssuerInterface {
}

/// A simple JWK issuer struct
#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct JWKIssuer {
issuer_name: String,
issuer_jwk_url: String,
Expand Down
2 changes: 1 addition & 1 deletion prover-service/src/external_resources/prover_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const GENERATE_WITNESS_JS_FILE_NAME: &str = "generate_witness.js";
const MAIN_WASM_FILE_NAME: &str = "main.wasm";

/// The prover service configuration
#[derive(Debug, Serialize, Deserialize, Clone)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
#[serde(default, deny_unknown_fields)]
pub struct ProverServiceConfig {
pub setup_dir: String,
Expand Down
8 changes: 4 additions & 4 deletions prover-service/src/request_handler/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ use std::{convert::Infallible, sync::Arc};

// The list of endpoints/paths offered by the Prover Service.
// Note: if you update these paths, please also update the "ALL_PATHS" array below.
const ABOUT_PATH: &str = "/about";
const CONFIG_PATH: &str = "/config";
const HEALTH_CHECK_PATH: &str = "/healthcheck";
pub const ABOUT_PATH: &str = "/about";
pub const CONFIG_PATH: &str = "/config";
pub const HEALTH_CHECK_PATH: &str = "/healthcheck";
pub const JWK_PATH: &str = "/cached/jwk";
const PROVE_PATH: &str = "/v0/prove";
pub const PROVE_PATH: &str = "/v0/prove";

// An array of all known endpoints/paths
pub const ALL_PATHS: [&str; 5] = [
Expand Down
21 changes: 14 additions & 7 deletions prover-service/src/request_handler/prover_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,20 @@ async fn generate_groth16_proof(
// Generate the JSON proof
let full_prover = prover_service_state.full_prover();
let full_prover_locked = full_prover.lock().await;
let (proof_json, _internal_metrics) = match full_prover_locked.prove(witness_file_path) {
Ok((proof_json, metrics)) => (proof_json.to_string(), metrics),
Err(error) => {
return Err(ProverServiceError::UnexpectedError(format!(
"Failed to generate rapidsnark proof! Error: {:?}",
error
)));
let (proof_json, _internal_metrics) = match full_prover_locked.as_ref() {
Some(full_prover_locked) => match full_prover_locked.prove(witness_file_path) {
Ok((proof_json, metrics)) => (proof_json.to_string(), metrics),
Err(error) => {
return Err(ProverServiceError::UnexpectedError(format!(
"Failed to generate rapidsnark proof! Error: {:?}",
error
)));
}
},
None => {
return Err(ProverServiceError::UnexpectedError(
"The full prover was not initialized correctly!".into(),
));
}
};

Expand Down
43 changes: 39 additions & 4 deletions prover-service/src/request_handler/prover_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ use rust_rapidsnark::FullProver;
use std::sync::Arc;
use tokio::sync::Mutex;

#[cfg(test)]
use aptos_crypto::Uniform;

/// The shared state of the prover service (used across all requests)
pub struct ProverServiceState {
prover_service_config: Arc<ProverServiceConfig>,
circuit_config: CircuitConfig,
deployment_information: DeploymentInformation,
training_wheels_key_pair: TrainingWheelsKeyPair,
full_prover: Arc<Mutex<FullProver>>,
full_prover: Arc<Mutex<Option<FullProver>>>,
jwk_cache: JWKCache,
}

Expand All @@ -39,7 +42,32 @@ impl ProverServiceState {
circuit_config: circuit_configuration,
deployment_information,
training_wheels_key_pair,
full_prover: Arc::new(Mutex::new(full_prover)),
full_prover: Arc::new(Mutex::new(Some(full_prover))),
jwk_cache,
}
}

#[cfg(test)]
/// Creates a new prover service state for testing purposes
pub fn new_for_testing(
training_wheels_key_pair: TrainingWheelsKeyPair,
prover_service_config: Arc<ProverServiceConfig>,
deployment_information: DeploymentInformation,
jwk_cache: JWKCache,
) -> Self {
// Create a circuit configuration for testing
let circuit_configuration = CircuitConfig::new();

// Don't initialize any full prover for testing
let full_prover = Arc::new(Mutex::new(None));

// Create the prover service state
ProverServiceState {
prover_service_config,
circuit_config: circuit_configuration,
deployment_information,
training_wheels_key_pair,
full_prover,
jwk_cache,
}
}
Expand All @@ -59,8 +87,8 @@ impl ProverServiceState {
self.jwk_cache.clone()
}

/// Returns an Arc reference to the full prover instance
pub fn full_prover(&self) -> Arc<Mutex<FullProver>> {
/// Returns an Arc reference to the full prover instance (if one exists)
pub fn full_prover(&self) -> Arc<Mutex<Option<FullProver>>> {
self.full_prover.clone()
}

Expand Down Expand Up @@ -92,6 +120,13 @@ impl TrainingWheelsKeyPair {
}
}

#[cfg(test)]
/// Creates a new training wheels key pair for testing purposes
pub fn new_for_testing() -> Self {
let signing_key = Ed25519PrivateKey::generate_for_testing();
Self::from_sk(signing_key)
}

/// Returns a reference to the signing key
pub fn signing_key(&self) -> &Ed25519PrivateKey {
&self.signing_key
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use crate::tests::common::gen_test_jwk_keypair_with_kid_override;
use crate::tests::common::types::{ProofTestCase, TestJWTPayload};
use aptos_keyless_common::input_processing::encoding::DecodedJWT;

// TODO: avoid the external test dependencies!

// This test uses a demo auth0 tenant owned by oliver.he@aptoslabs.com
#[tokio::test]
async fn test_federated_jwk_fetch() {
Expand Down
234 changes: 234 additions & 0 deletions prover-service/src/tests/jwk_fetcher.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
// Copyright (c) Aptos Foundation

use crate::{
error::ProverServiceError,
external_resources::{
jwk_fetcher,
jwk_fetcher::{JWKCache, JWKIssuerInterface, KeyID},
},
};
use aptos_infallible::Mutex;
use aptos_time_service::TimeService;
use aptos_types::{jwks, jwks::rsa::RSA_JWK};
use std::{collections::HashMap, sync::Arc, time::Duration};
use tokio::{task::JoinHandle, time::timeout};

// Test issuer names for Apple and Google
const ISSUER_APPLE: &str = "issuer_apple";
const ISSUER_GOOGLE: &str = "issuer_google";

// JWK refresh interval used in tests
const JWK_REFRESH_INTERVAL_SECS: u64 = 10; // JWK refresh interval (secs)

// Maximum wait time (secs) for each test to complete
const MAX_TEST_WAIT_SECS: u64 = 10;

/// A mock JWK issuer (for testing JWK fetching and caching)
struct MockJWKIssuer {
issuer_name: String,
jwk_key_set: HashMap<KeyID, Arc<RSA_JWK>>,
num_fetch_failures: Arc<Mutex<u64>>, // The number of fetch failures to simulate
}

impl MockJWKIssuer {
pub fn new(
issuer_name: String,
jwk_key_set: HashMap<KeyID, Arc<RSA_JWK>>,
num_fetch_failures: u64,
) -> Self {
Self {
issuer_name,
jwk_key_set,
num_fetch_failures: Arc::new(Mutex::new(num_fetch_failures)),
}
}
}

#[async_trait::async_trait]
impl JWKIssuerInterface for MockJWKIssuer {
fn issuer_name(&self) -> String {
self.issuer_name.clone()
}

fn issuer_jwk_url(&self) -> String {
"".into() // The URL is not used in the mock implementation
}

async fn fetch_jwks(&self) -> anyhow::Result<HashMap<KeyID, Arc<RSA_JWK>>> {
// If there are failures to simulate, decrement the counter and
// return an error, otherwise, return the JWK key set.
let mut num_fetch_failures = self.num_fetch_failures.lock();
if *num_fetch_failures > 0 {
*num_fetch_failures -= 1;
Err(anyhow::anyhow!(
"Simulated fetch failure for issuer {}! Failures remaining: {}",
self.issuer_name,
*num_fetch_failures
))
} else {
Ok(self.jwk_key_set.clone())
}
}
}

#[tokio::test(flavor = "multi_thread")]
async fn jwk_cache_updates_after_interval() {
// Create test JWK key sets
let apple_jwk_key_set = create_test_jwk_key_set(ISSUER_APPLE.into());
let google_jwk_key_set = create_test_jwk_key_set(ISSUER_GOOGLE.into());

// Create mock JWK issuers for Apple and Google
let apple_jwk_issuer = Arc::new(MockJWKIssuer::new(
ISSUER_APPLE.into(),
apple_jwk_key_set.clone(),
0,
));
let google_jwk_issuer = Arc::new(MockJWKIssuer::new(
ISSUER_GOOGLE.into(),
google_jwk_key_set.clone(),
0,
));

// Verify that the JWK cache is updated correctly
verify_jwk_cache_updates(
apple_jwk_key_set,
google_jwk_key_set,
apple_jwk_issuer,
google_jwk_issuer,
)
.await;
}

#[tokio::test(flavor = "multi_thread")]
async fn jwk_cache_updates_after_failures() {
// Create test JWK key sets
let apple_jwk_key_set = create_test_jwk_key_set(ISSUER_APPLE.into());
let google_jwk_key_set = create_test_jwk_key_set(ISSUER_GOOGLE.into());

// Create mock JWK issuers that fail a few times before succeeding
let num_fetch_failures = 3;
let apple_jwk_issuer = Arc::new(MockJWKIssuer::new(
ISSUER_APPLE.into(),
apple_jwk_key_set.clone(),
num_fetch_failures,
));
let google_jwk_issuer = Arc::new(MockJWKIssuer::new(
ISSUER_GOOGLE.into(),
google_jwk_key_set.clone(),
num_fetch_failures,
));

// Verify that the JWK cache is updated correctly
verify_jwk_cache_updates(
apple_jwk_key_set,
google_jwk_key_set,
apple_jwk_issuer,
google_jwk_issuer,
)
.await;
}

/// Advances the mock time service by the given number of seconds
async fn advance_time_secs(time_service: TimeService, seconds: u64) {
let mock_time_service = time_service.into_mock();
mock_time_service
.advance_async(Duration::from_secs(seconds))
.await;
}

/// Creates a JWK cache and starts the refresh loops for known issuers
fn create_and_start_jwk_cache(
apple_jwk_issuer: Arc<dyn JWKIssuerInterface + Send + Sync>,
google_jwk_issuer: Arc<dyn JWKIssuerInterface + Send + Sync>,
time_service: TimeService,
) -> JWKCache {
// Create the JWK cache
let jwk_cache = Arc::new(Mutex::new(HashMap::new()));

// Start the JWK refresh loops for Apple and Google
let jwk_refresh_rate = Duration::from_secs(JWK_REFRESH_INTERVAL_SECS);
jwk_fetcher::start_jwk_refresh_loop(
apple_jwk_issuer,
jwk_cache.clone(),
jwk_refresh_rate,
time_service.clone(),
);
jwk_fetcher::start_jwk_refresh_loop(
google_jwk_issuer,
jwk_cache.clone(),
jwk_refresh_rate,
time_service,
);

jwk_cache
}

/// Creates a test JWK key set with a single RSA key
fn create_test_jwk_key_set(issuer_name: String) -> HashMap<KeyID, Arc<RSA_JWK>> {
// Create several keys with different IDs
let mut key_set = HashMap::new();
for i in 0..5 {
let key_id = format!("{}_key_{}", issuer_name, i);
key_set.insert(key_id, Arc::new(jwks::insecure_test_rsa_jwk()));
}

key_set
}

/// Verifies that the cached resources are eventually updated
/// correctly after the resource fetcher is started.
async fn verify_jwk_cache_updates(
apple_jwk_key_set: HashMap<KeyID, Arc<RSA_JWK>>,
google_jwk_key_set: HashMap<KeyID, Arc<RSA_JWK>>,
apple_jwk_issuer: Arc<dyn JWKIssuerInterface + Send + Sync>,
google_jwk_issuer: Arc<dyn JWKIssuerInterface + Send + Sync>,
) {
// Create the JWK cache and start the refresh loops
let time_service = TimeService::mock();
let jwk_cache =
create_and_start_jwk_cache(apple_jwk_issuer, google_jwk_issuer, time_service.clone());

// Verify that initially the cache is empty
assert!(jwk_cache.lock().get(ISSUER_APPLE).is_none());
assert!(jwk_cache.lock().get(ISSUER_GOOGLE).is_none());

// Spawn a task that advances time and verifies the cache is eventually updated
let cache_verification_task: JoinHandle<Result<(), ProverServiceError>> = tokio::spawn({
let time_service = time_service.clone();
let jwk_cache = jwk_cache.clone();

async move {
loop {
// Advance time by the refresh interval
advance_time_secs(time_service.clone(), JWK_REFRESH_INTERVAL_SECS + 1).await;

// Grab the cached key sets
let cached_apple_jwk_set = jwk_cache.lock().get(ISSUER_APPLE).cloned();
let cached_google_jwk_set = jwk_cache.lock().get(ISSUER_GOOGLE).cloned();

// Check if the cache has been updated correctly
match (cached_apple_jwk_set, cached_google_jwk_set) {
(Some(cached_apple_jwk_set), Some(cached_google_jwk_set)) => {
assert_eq!(cached_apple_jwk_set, apple_jwk_key_set);
assert_eq!(cached_google_jwk_set, google_jwk_key_set);
return Ok(());
}
_ => {
// Yield to allow other tasks to run
tokio::task::yield_now().await;
}
}
}
}
});

// Verify that the JWK cache is eventually updated
if let Err(error) = timeout(
Duration::from_secs(MAX_TEST_WAIT_SECS),
cache_verification_task,
)
.await
{
panic!("Failed waiting for JWK cache to be updated: {}", error);
}
}
4 changes: 3 additions & 1 deletion prover-service/src/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Copyright (c) Aptos Foundation

pub mod common;
pub mod jwk_fetching;
pub mod federated_jwk;
pub mod jwk_fetcher;
pub mod playground;
pub mod request_handler;
pub mod smoke;
pub mod training_wheels;
Loading
Loading