1717from inference_perf .config import APIConfig , APIType , CustomTokenizerConfig
1818from inference_perf .apis import InferenceAPIData , InferenceInfo , RequestLifecycleMetric , ErrorResponseInfo
1919from inference_perf .utils import CustomTokenizer
20- from .base import ModelServerClient , PrometheusMetricMetadata
20+ from .base import ModelServerClient , ModelServerClientSession , PrometheusMetricMetadata
2121from typing import List , Optional
2222import aiohttp
2323import asyncio
3030
3131
3232class openAIModelServerClient (ModelServerClient ):
33+ _session : "openAIModelServerClientSession | None" = None
34+ _session_lock = asyncio .Lock ()
35+
3336 def __init__ (
3437 self ,
3538 metrics_collector : RequestDataCollector ,
@@ -70,82 +73,27 @@ def __init__(
7073 tokenizer_config = CustomTokenizerConfig (pretrained_model_name_or_path = self .model_name )
7174 self .tokenizer = CustomTokenizer (tokenizer_config )
7275
73- async def process_request (self , data : InferenceAPIData , stage_id : int , scheduled_time : float ) -> None :
74- payload = await data .to_payload (
75- model_name = self .model_name ,
76- max_tokens = self .max_completion_tokens ,
77- ignore_eos = self .ignore_eos ,
78- streaming = self .api_config .streaming ,
79- )
80- headers = {"Content-Type" : "application/json" }
81-
82- if self .api_key :
83- headers ["Authorization" ] = f"Bearer { self .api_key } "
84-
85- if self .api_config .headers :
86- headers .update (self .api_config .headers )
76+ def new_session (self ) -> "ModelServerClientSession" :
77+ return openAIModelServerClientSession (self )
8778
88- request_data = json .dumps (payload )
89-
90- timeout = aiohttp .ClientTimeout (total = self .timeout ) if self .timeout else aiohttp .helpers .sentinel
91-
92- async with aiohttp .ClientSession (
93- connector = aiohttp .TCPConnector (limit = self .max_tcp_connections ), timeout = timeout
94- ) as session :
95- start = time .perf_counter ()
96- try :
97- async with session .post (self .uri + data .get_route (), headers = headers , data = request_data ) as response :
98- response_info = await data .process_response (
99- response = response , config = self .api_config , tokenizer = self .tokenizer
100- )
101- response_content = await response .text ()
102-
103- end_time = time .perf_counter ()
104- error = None
105- if response .status != 200 :
106- error = ErrorResponseInfo (
107- error_msg = response_content ,
108- error_type = f"{ response .status } { response .reason } " ,
109- )
110-
111- self .metrics_collector .record_metric (
112- RequestLifecycleMetric (
113- stage_id = stage_id ,
114- request_data = request_data ,
115- response_data = response_content ,
116- info = response_info ,
117- error = error ,
118- start_time = start ,
119- end_time = end_time ,
120- scheduled_time = scheduled_time ,
121- )
122- )
123- except Exception as e :
124- if isinstance (e , asyncio .exceptions .TimeoutError ):
125- logger .error ("request timed out:" , exc_info = True )
126- else :
127- logger .error ("error occured during request processing:" , exc_info = True )
128- failure_info = await data .process_failure (
129- response = response if "response" in locals () else None ,
130- config = self .api_config ,
131- tokenizer = self .tokenizer ,
132- exception = e ,
133- )
134- self .metrics_collector .record_metric (
135- RequestLifecycleMetric (
136- stage_id = stage_id ,
137- request_data = request_data ,
138- response_data = response_content if "response_content" in locals () else "" ,
139- info = failure_info if failure_info else InferenceInfo (),
140- error = ErrorResponseInfo (
141- error_msg = str (e ),
142- error_type = type (e ).__name__ ,
143- ),
144- start_time = start ,
145- end_time = time .perf_counter (),
146- scheduled_time = scheduled_time ,
147- )
148- )
79+ async def process_request (self , data : InferenceAPIData , stage_id : int , scheduled_time : float ) -> None :
80+ """
81+ Create an internal client session if not already, then use that to
82+ process the request.
83+ """
84+ session : openAIModelServerClientSession
85+ # ensure session is only created once.
86+ async with self ._session_lock :
87+ if self ._session is None :
88+ self ._session = openAIModelServerClientSession (self )
89+ session = self ._session
90+ await session .process_request (data , stage_id , scheduled_time )
91+
92+ async def close (self ) -> None :
93+ """Close the internal session created by process_request, if any."""
94+ if self ._session is not None :
95+ await self ._session .close ()
96+ self ._session = None
14997
15098 def get_supported_apis (self ) -> List [APIType ]:
15199 return []
@@ -166,3 +114,87 @@ def get_supported_models(self) -> List[str]:
166114 except Exception as e :
167115 logger .error (f"Got exception retrieving supported models { e } " )
168116 return []
117+
118+
119+ class openAIModelServerClientSession (ModelServerClientSession ):
120+ def __init__ (self , client : openAIModelServerClient ):
121+ self .client = client
122+ self .session = aiohttp .ClientSession (
123+ timeout = aiohttp .ClientTimeout (total = client .timeout ) if client .timeout else aiohttp .helpers .sentinel ,
124+ connector = aiohttp .TCPConnector (limit = client .max_tcp_connections ),
125+ )
126+
127+ async def process_request (self , data : InferenceAPIData , stage_id : int , scheduled_time : float ) -> None :
128+ payload = await data .to_payload (
129+ model_name = self .client .model_name ,
130+ max_tokens = self .client .max_completion_tokens ,
131+ ignore_eos = self .client .ignore_eos ,
132+ streaming = self .client .api_config .streaming ,
133+ )
134+ headers = {"Content-Type" : "application/json" }
135+
136+ if self .client .api_key :
137+ headers ["Authorization" ] = f"Bearer { self .client .api_key } "
138+
139+ if self .client .api_config .headers :
140+ headers .update (self .client .api_config .headers )
141+
142+ request_data = json .dumps (payload )
143+
144+ start = time .perf_counter ()
145+ try :
146+ async with self .session .post (self .client .uri + data .get_route (), headers = headers , data = request_data ) as response :
147+ response_info = await data .process_response (
148+ response = response , config = self .client .api_config , tokenizer = self .client .tokenizer
149+ )
150+ response_content = await response .text ()
151+
152+ end_time = time .perf_counter ()
153+ error = None
154+ if response .status != 200 :
155+ error = ErrorResponseInfo (
156+ error_msg = response_content ,
157+ error_type = f"{ response .status } { response .reason } " ,
158+ )
159+
160+ self .client .metrics_collector .record_metric (
161+ RequestLifecycleMetric (
162+ stage_id = stage_id ,
163+ request_data = request_data ,
164+ response_data = response_content ,
165+ info = response_info ,
166+ error = error ,
167+ start_time = start ,
168+ end_time = end_time ,
169+ scheduled_time = scheduled_time ,
170+ )
171+ )
172+ except Exception as e :
173+ if isinstance (e , asyncio .exceptions .TimeoutError ):
174+ logger .error ("request timed out:" , exc_info = True )
175+ else :
176+ logger .error ("error occured during request processing:" , exc_info = True )
177+ failure_info = await data .process_failure (
178+ response = response if "response" in locals () else None ,
179+ config = self .client .api_config ,
180+ tokenizer = self .client .tokenizer ,
181+ exception = e ,
182+ )
183+ self .client .metrics_collector .record_metric (
184+ RequestLifecycleMetric (
185+ stage_id = stage_id ,
186+ request_data = request_data ,
187+ response_data = response_content if "response_content" in locals () else "" ,
188+ info = failure_info if failure_info else InferenceInfo (),
189+ error = ErrorResponseInfo (
190+ error_msg = str (e ),
191+ error_type = type (e ).__name__ ,
192+ ),
193+ start_time = start ,
194+ end_time = time .perf_counter (),
195+ scheduled_time = scheduled_time ,
196+ )
197+ )
198+
199+ async def close (self ) -> None :
200+ await self .session .close ()
0 commit comments