Skip to content

Commit 0a7a280

Browse files
author
Simba
committed
Move _run_coro to base class, prevent LiteLLM logging warnings by introducing the AsyncLoopRunner
1 parent 9ac1a9e commit 0a7a280

File tree

2 files changed

+79
-63
lines changed

2 files changed

+79
-63
lines changed

adt_eval/mlflow_base.py

Lines changed: 79 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,66 @@
22

33
from __future__ import annotations
44

5+
import asyncio
56
import inspect
7+
import threading
68
from abc import abstractmethod
79
from datetime import datetime
8-
from typing import Any, Dict, List
10+
from typing import Any, Dict, List, Optional
911

1012
import mlflow
1113

1214
from adt_eval.base import BaseEvaluator
1315

1416

17+
class AsyncLoopRunner:
18+
"""Run coroutines on a long-lived event loop in a background thread."""
19+
20+
def __init__(self) -> None:
21+
self._loop: Optional[asyncio.AbstractEventLoop] = None
22+
self._thread: Optional[threading.Thread] = None
23+
self._ready = threading.Event()
24+
25+
def start(self) -> None:
26+
if self._thread and self._thread.is_alive():
27+
return
28+
29+
def _runner() -> None:
30+
loop = asyncio.new_event_loop()
31+
asyncio.set_event_loop(loop)
32+
self._loop = loop
33+
self._ready.set()
34+
try:
35+
loop.run_forever()
36+
finally:
37+
try:
38+
pending = asyncio.all_tasks(loop)
39+
for task in pending:
40+
task.cancel()
41+
if pending:
42+
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
43+
finally:
44+
loop.close()
45+
46+
self._thread = threading.Thread(target=_runner, daemon=True)
47+
self._thread.start()
48+
self._ready.wait()
49+
50+
def submit(self, coro):
51+
if not self._loop:
52+
raise RuntimeError("AsyncLoopRunner not started")
53+
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
54+
return future.result()
55+
56+
def close(self) -> None:
57+
if not self._loop or not self._thread:
58+
return
59+
self._loop.call_soon_threadsafe(self._loop.stop)
60+
self._thread.join()
61+
self._loop = None
62+
self._thread = None
63+
64+
1565
class MLflowEvaluatorBase(BaseEvaluator):
1666
"""Base evaluator that wraps the core run in an MLflow run."""
1767

@@ -77,6 +127,11 @@ def log_run_metrics(self, metrics: Dict[str, Any]) -> None:
77127
if score is not None:
78128
mlflow.log_metric("score", score)
79129

130+
def _run_coro(self, coro):
131+
if not hasattr(self, "_loop_runner") or self._loop_runner is None:
132+
raise RuntimeError("Async loop runner not initialized")
133+
return self._loop_runner.submit(coro)
134+
80135
def get_report_results_and_metrics(self, eval_results) -> tuple[List[Dict[str, Any]], Dict[str, Any]]:
81136
"""Return report-ready results and metrics from mlflow.genai.evaluate output."""
82137
metrics = {}
@@ -96,21 +151,26 @@ async def run(self):
96151
self.configure_mlflow()
97152
run_name = self.get_run_name()
98153
nested = mlflow.active_run() is not None
99-
100-
with mlflow.start_run(run_name=run_name, nested=nested):
101-
self.log_run_params()
102-
cases = self.filter_cases(self.load_data())
103-
eval_dataset = self.build_eval_dataset(cases)
104-
eval_results = mlflow.genai.evaluate(
105-
data=eval_dataset,
106-
predict_fn=self.predict_fn,
107-
scorers=self.get_scorers(),
108-
**self._get_evaluate_kwargs(),
109-
)
110-
111-
results, metrics = self.get_report_results_and_metrics(eval_results)
112-
113-
self.log_run_metrics(metrics)
114-
if results and metrics:
115-
self.generate_report(results, metrics)
116-
return results, metrics
154+
self._loop_runner = AsyncLoopRunner()
155+
self._loop_runner.start()
156+
try:
157+
with mlflow.start_run(run_name=run_name, nested=nested):
158+
self.log_run_params()
159+
cases = self.filter_cases(self.load_data())
160+
eval_dataset = self.build_eval_dataset(cases)
161+
eval_results = mlflow.genai.evaluate(
162+
data=eval_dataset,
163+
predict_fn=self.predict_fn,
164+
scorers=self.get_scorers(),
165+
**self._get_evaluate_kwargs(),
166+
)
167+
168+
results, metrics = self.get_report_results_and_metrics(eval_results)
169+
170+
self.log_run_metrics(metrics)
171+
if results and metrics:
172+
self.generate_report(results, metrics)
173+
return results, metrics
174+
finally:
175+
self._loop_runner.close()
176+
self._loop_runner = None

adt_eval/text_type.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ def text_type_per_page_scorer(inputs: Dict[str, Any], outputs: Dict[str, Any]) -
8484
rationale=json.dumps(matches),
8585
)
8686

87-
8887
class TextTypeEvaluator(MLflowEvaluatorBase):
8988
"""Evaluator for text type accuracy."""
9089

@@ -155,49 +154,6 @@ def build_eval_dataset(self, cases: List[Dict[str, Any]]) -> List[Dict[str, Any]
155154
)
156155
return records
157156

158-
def _run_coro(self, coro):
159-
try:
160-
loop = asyncio.get_running_loop()
161-
except RuntimeError:
162-
loop = None
163-
164-
if loop and loop.is_running():
165-
# If an event loop is already running (e.g., in notebooks), offload to a thread
166-
# and wait for the result to avoid nested loop issues.
167-
result_container: Dict[str, Any] = {}
168-
error_container: Dict[str, BaseException] = {}
169-
170-
def _runner():
171-
try:
172-
result_container["value"] = asyncio.run(coro)
173-
except BaseException as exc: # pragma: no cover - re-raise below
174-
error_container["error"] = exc
175-
176-
thread = threading.Thread(target=_runner, daemon=True)
177-
thread.start()
178-
thread.join()
179-
180-
if "error" in error_container:
181-
raise error_container["error"]
182-
return result_container.get("value")
183-
184-
result_container: Dict[str, Any] = {}
185-
error_container: Dict[str, BaseException] = {}
186-
187-
def _runner():
188-
try:
189-
result_container["value"] = asyncio.run(coro)
190-
except BaseException as exc: # pragma: no cover - re-raise below
191-
error_container["error"] = exc
192-
193-
thread = threading.Thread(target=_runner, daemon=True)
194-
thread.start()
195-
thread.join()
196-
197-
if "error" in error_container:
198-
raise error_container["error"]
199-
return result_container.get("value")
200-
201157
def predict_fn(self, **inputs: Any) -> Dict[str, Any]:
202158
page_text = inputs["page_text"]
203159
page_image_path = inputs["page_image_path"]

0 commit comments

Comments
 (0)