1+ use std:: sync:: Arc ;
2+
13use color_eyre:: SectionExt ;
24use opentelemetry:: {
35 global,
@@ -7,34 +9,58 @@ use opentelemetry::{
79use opentelemetry_http:: HeaderExtractor ;
810use opentelemetry_semantic_conventions:: { attribute, resource} ;
911use poem:: {
10- http:: HeaderValue , middleware:: Middleware , web:: {
12+ http:: HeaderValue ,
13+ middleware:: Middleware ,
14+ web:: {
1115 headers:: { self , HeaderMapExt } ,
1216 RealIp ,
13- } , Endpoint , FromRequest , IntoResponse , PathPattern , Request , Response , Result
17+ } ,
18+ Endpoint , FromRequest , IntoResponse , PathPattern , Request , Response , Result ,
1419} ;
1520
1621/// Middleware that injects the OpenTelemetry trace ID into the response headers.
1722#[ derive( Default ) ]
18- pub struct TraceId ;
23+ pub struct TraceId < T > {
24+ tracer : Arc < T > ,
25+ }
1926
20- impl < E : Endpoint > Middleware < E > for TraceId {
21- type Output = TraceIdEndpoint < E > ;
27+ impl < T > TraceId < T > {
28+ pub fn new ( tracer : Arc < T > ) -> Self {
29+ Self { tracer }
30+ }
31+ }
32+
33+ impl < T , E > Middleware < E > for TraceId < T >
34+ where
35+ E : Endpoint ,
36+ T : Tracer + Send + Sync ,
37+ T :: Span : Send + Sync + ' static ,
38+ {
39+ type Output = TraceIdEndpoint < T , E > ;
2240
2341 fn transform ( & self , ep : E ) -> Self :: Output {
24- TraceIdEndpoint { inner : ep }
42+ TraceIdEndpoint {
43+ inner : ep,
44+ tracer : self . tracer . clone ( ) ,
45+ }
2546 }
2647}
2748
2849/// The endpoint wrapper produced by the TraceId middleware.
29- pub struct TraceIdEndpoint < E > {
50+ pub struct TraceIdEndpoint < T , E > {
3051 inner : E ,
52+ tracer : Arc < T > ,
3153}
3254
33- impl < E : Endpoint > Endpoint for TraceIdEndpoint < E > {
55+ impl < T , E > Endpoint for TraceIdEndpoint < T , E >
56+ where
57+ E : Endpoint ,
58+ T : Tracer + Send + Sync ,
59+ T :: Span : Send + Sync + ' static ,
60+ {
3461 type Output = Response ;
3562
3663 async fn call ( & self , req : Request ) -> Result < Self :: Output > {
37- let tracer = global:: tracer ( "edgeserver" ) ;
3864 // // Execute the inner endpoint.
3965 // let response = self.inner.call(req).await?;
4066
@@ -56,6 +82,8 @@ impl<E: Endpoint> Endpoint for TraceIdEndpoint<E> {
5682
5783 // Ok(response)
5884
85+ let tracer = self . tracer . clone ( ) ;
86+
5987 let remote_addr = RealIp :: from_request_without_body ( & req)
6088 . await
6189 . ok ( )
@@ -96,7 +124,7 @@ impl<E: Endpoint> Endpoint for TraceIdEndpoint<E> {
96124 . span_builder ( format ! ( "{} {}" , method, req. uri( ) ) )
97125 . with_kind ( SpanKind :: Server )
98126 . with_attributes ( attributes)
99- . start_with_context ( & tracer, & parent_cx) ;
127+ . start_with_context ( & * tracer, & parent_cx) ;
100128
101129 span. add_event ( "request.started" . to_string ( ) , vec ! [ ] ) ;
102130
0 commit comments