Skip to content

Commit a85b31b

Browse files
Share aiohttp.ClientSessions per worker (#282)
Slightly refactor `openAIModelServerClient` to add a new method, `process_request_with_session`, that accepts a custom `ReusableHTTPClientSession` per request, which allows the caller to reuse an HTTP client session per worker. The previous method, `process_request`, is made to create a fresh HTTP client session then call `process_request_with_session`, preserving the previous behavior. Prior to this commit, a new `aiohttp.ClientSession` is created for each request. Not only is this inefficient and lowers throughput, on certain environments, it also leads to inotify watch issues: aiodns - WARNING - Failed to create DNS resolver channel with automatic monitoring of resolver configuration changes. This usually means the system ran out of inotify watches. Falling back to socket state callback. Consider increasing the system inotify watch limit: Failed to initialize c-ares channel Indeed, because each DNS resolver is created for a new `ClientSession`, creating tons of new `ClientSession`s causes eventual inotify watch exhaustion. Sharing `ClientSession`s solves this issue. Relevant links: - https://docs.aiohttp.org/en/stable/http_request_lifecycle.html - https://stackoverflow.com/questions/62707369/one-aiohttp-clientsession-per-thread - home-assistant/core#144457 (comment) Relevant PR: #247 (doesn't address the issue of worker sharing).
1 parent c85e5a4 commit a85b31b

File tree

3 files changed

+130
-79
lines changed

3 files changed

+130
-79
lines changed

inference_perf/client/modelserver/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,16 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from .base import ModelServerClient
14+
from .base import ModelServerClient, ModelServerClientSession
1515
from .mock_client import MockModelServerClient
1616
from .vllm_client import vLLMModelServerClient
1717
from .sglang_client import SGlangModelServerClient
1818

1919

20-
__all__ = ["ModelServerClient", "MockModelServerClient", "vLLMModelServerClient", "SGlangModelServerClient"]
20+
__all__ = [
21+
"ModelServerClient",
22+
"ModelServerClientSession",
23+
"MockModelServerClient",
24+
"vLLMModelServerClient",
25+
"SGlangModelServerClient",
26+
]

inference_perf/client/modelserver/base.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from typing import List, Optional, Tuple
1616
from inference_perf.client.metricsclient.base import MetricsMetadata
1717
from inference_perf.config import APIConfig, APIType
18-
1918
from inference_perf.apis import InferenceAPIData
2019

2120

@@ -82,6 +81,9 @@ def __init__(self, api_config: APIConfig, timeout: Optional[float] = None, *args
8281
self.api_config = api_config
8382
self.timeout = timeout
8483

84+
def new_session(self) -> "ModelServerClientSession":
85+
return ModelServerClientSession(self)
86+
8587
@abstractmethod
8688
def get_supported_apis(self) -> List[APIType]:
8789
raise NotImplementedError
@@ -94,3 +96,14 @@ async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled
9496
def get_prometheus_metric_metadata(self) -> PrometheusMetricMetadata:
9597
# assumption: all metrics clients have metrics exported in Prometheus format
9698
raise NotImplementedError
99+
100+
101+
class ModelServerClientSession:
102+
def __init__(self, client: ModelServerClient):
103+
self.client = client
104+
105+
async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled_time: float) -> None:
106+
await self.client.process_request(data, stage_id, scheduled_time)
107+
108+
async def close(self) -> None: # noqa - subclasses optionally override this
109+
pass

inference_perf/client/modelserver/openai_client.py

Lines changed: 108 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from inference_perf.config import APIConfig, APIType, CustomTokenizerConfig
1818
from inference_perf.apis import InferenceAPIData, InferenceInfo, RequestLifecycleMetric, ErrorResponseInfo
1919
from inference_perf.utils import CustomTokenizer
20-
from .base import ModelServerClient, PrometheusMetricMetadata
20+
from .base import ModelServerClient, ModelServerClientSession, PrometheusMetricMetadata
2121
from typing import List, Optional
2222
import aiohttp
2323
import asyncio
@@ -30,6 +30,9 @@
3030

3131

3232
class openAIModelServerClient(ModelServerClient):
33+
_session: "openAIModelServerClientSession | None" = None
34+
_session_lock = asyncio.Lock()
35+
3336
def __init__(
3437
self,
3538
metrics_collector: RequestDataCollector,
@@ -70,82 +73,27 @@ def __init__(
7073
tokenizer_config = CustomTokenizerConfig(pretrained_model_name_or_path=self.model_name)
7174
self.tokenizer = CustomTokenizer(tokenizer_config)
7275

73-
async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled_time: float) -> None:
74-
payload = await data.to_payload(
75-
model_name=self.model_name,
76-
max_tokens=self.max_completion_tokens,
77-
ignore_eos=self.ignore_eos,
78-
streaming=self.api_config.streaming,
79-
)
80-
headers = {"Content-Type": "application/json"}
81-
82-
if self.api_key:
83-
headers["Authorization"] = f"Bearer {self.api_key}"
84-
85-
if self.api_config.headers:
86-
headers.update(self.api_config.headers)
76+
def new_session(self) -> "ModelServerClientSession":
77+
return openAIModelServerClientSession(self)
8778

88-
request_data = json.dumps(payload)
89-
90-
timeout = aiohttp.ClientTimeout(total=self.timeout) if self.timeout else aiohttp.helpers.sentinel
91-
92-
async with aiohttp.ClientSession(
93-
connector=aiohttp.TCPConnector(limit=self.max_tcp_connections), timeout=timeout
94-
) as session:
95-
start = time.perf_counter()
96-
try:
97-
async with session.post(self.uri + data.get_route(), headers=headers, data=request_data) as response:
98-
response_info = await data.process_response(
99-
response=response, config=self.api_config, tokenizer=self.tokenizer
100-
)
101-
response_content = await response.text()
102-
103-
end_time = time.perf_counter()
104-
error = None
105-
if response.status != 200:
106-
error = ErrorResponseInfo(
107-
error_msg=response_content,
108-
error_type=f"{response.status} {response.reason}",
109-
)
110-
111-
self.metrics_collector.record_metric(
112-
RequestLifecycleMetric(
113-
stage_id=stage_id,
114-
request_data=request_data,
115-
response_data=response_content,
116-
info=response_info,
117-
error=error,
118-
start_time=start,
119-
end_time=end_time,
120-
scheduled_time=scheduled_time,
121-
)
122-
)
123-
except Exception as e:
124-
if isinstance(e, asyncio.exceptions.TimeoutError):
125-
logger.error("request timed out:", exc_info=True)
126-
else:
127-
logger.error("error occured during request processing:", exc_info=True)
128-
failure_info = await data.process_failure(
129-
response=response if "response" in locals() else None,
130-
config=self.api_config,
131-
tokenizer=self.tokenizer,
132-
exception=e,
133-
)
134-
self.metrics_collector.record_metric(
135-
RequestLifecycleMetric(
136-
stage_id=stage_id,
137-
request_data=request_data,
138-
response_data=response_content if "response_content" in locals() else "",
139-
info=failure_info if failure_info else InferenceInfo(),
140-
error=ErrorResponseInfo(
141-
error_msg=str(e),
142-
error_type=type(e).__name__,
143-
),
144-
start_time=start,
145-
end_time=time.perf_counter(),
146-
scheduled_time=scheduled_time,
147-
)
148-
)
79+
async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled_time: float) -> None:
80+
"""
81+
Create an internal client session if not already, then use that to
82+
process the request.
83+
"""
84+
session: openAIModelServerClientSession
85+
# ensure session is only created once.
86+
async with self._session_lock:
87+
if self._session is None:
88+
self._session = openAIModelServerClientSession(self)
89+
session = self._session
90+
await session.process_request(data, stage_id, scheduled_time)
91+
92+
async def close(self) -> None:
93+
"""Close the internal session created by process_request, if any."""
94+
if self._session is not None:
95+
await self._session.close()
96+
self._session = None
14997

15098
def get_supported_apis(self) -> List[APIType]:
15199
return []
@@ -166,3 +114,87 @@ def get_supported_models(self) -> List[str]:
166114
except Exception as e:
167115
logger.error(f"Got exception retrieving supported models {e}")
168116
return []
117+
118+
119+
class openAIModelServerClientSession(ModelServerClientSession):
120+
def __init__(self, client: openAIModelServerClient):
121+
self.client = client
122+
self.session = aiohttp.ClientSession(
123+
timeout=aiohttp.ClientTimeout(total=client.timeout) if client.timeout else aiohttp.helpers.sentinel,
124+
connector=aiohttp.TCPConnector(limit=client.max_tcp_connections),
125+
)
126+
127+
async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled_time: float) -> None:
128+
payload = await data.to_payload(
129+
model_name=self.client.model_name,
130+
max_tokens=self.client.max_completion_tokens,
131+
ignore_eos=self.client.ignore_eos,
132+
streaming=self.client.api_config.streaming,
133+
)
134+
headers = {"Content-Type": "application/json"}
135+
136+
if self.client.api_key:
137+
headers["Authorization"] = f"Bearer {self.client.api_key}"
138+
139+
if self.client.api_config.headers:
140+
headers.update(self.client.api_config.headers)
141+
142+
request_data = json.dumps(payload)
143+
144+
start = time.perf_counter()
145+
try:
146+
async with self.session.post(self.client.uri + data.get_route(), headers=headers, data=request_data) as response:
147+
response_info = await data.process_response(
148+
response=response, config=self.client.api_config, tokenizer=self.client.tokenizer
149+
)
150+
response_content = await response.text()
151+
152+
end_time = time.perf_counter()
153+
error = None
154+
if response.status != 200:
155+
error = ErrorResponseInfo(
156+
error_msg=response_content,
157+
error_type=f"{response.status} {response.reason}",
158+
)
159+
160+
self.client.metrics_collector.record_metric(
161+
RequestLifecycleMetric(
162+
stage_id=stage_id,
163+
request_data=request_data,
164+
response_data=response_content,
165+
info=response_info,
166+
error=error,
167+
start_time=start,
168+
end_time=end_time,
169+
scheduled_time=scheduled_time,
170+
)
171+
)
172+
except Exception as e:
173+
if isinstance(e, asyncio.exceptions.TimeoutError):
174+
logger.error("request timed out:", exc_info=True)
175+
else:
176+
logger.error("error occured during request processing:", exc_info=True)
177+
failure_info = await data.process_failure(
178+
response=response if "response" in locals() else None,
179+
config=self.client.api_config,
180+
tokenizer=self.client.tokenizer,
181+
exception=e,
182+
)
183+
self.client.metrics_collector.record_metric(
184+
RequestLifecycleMetric(
185+
stage_id=stage_id,
186+
request_data=request_data,
187+
response_data=response_content if "response_content" in locals() else "",
188+
info=failure_info if failure_info else InferenceInfo(),
189+
error=ErrorResponseInfo(
190+
error_msg=str(e),
191+
error_type=type(e).__name__,
192+
),
193+
start_time=start,
194+
end_time=time.perf_counter(),
195+
scheduled_time=scheduled_time,
196+
)
197+
)
198+
199+
async def close(self) -> None:
200+
await self.session.close()

0 commit comments

Comments
 (0)