Skip to content

Commit 025de69

Browse files
committed
Skip sections that should be filtered from quizzes
1 parent 6ebd654 commit 025de69

File tree

5 files changed

+37
-12
lines changed

5 files changed

+37
-12
lines changed

adt_press/llm/section_quiz.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# mypy: ignore-errors
2-
import instructor
32
from banks import Prompt
4-
from litellm import acompletion
53
from pydantic import BaseModel, ValidationInfo, field_validator
4+
import structlog
5+
6+
from adt_press.llm import get_instructor_client
67

78
from adt_press.models.config import QuizPromptConfig
89
from adt_press.models.quiz import SectionQuiz
@@ -11,6 +12,7 @@
1112
from adt_press.utils.file import cached_read_text_file
1213
from adt_press.utils.languages import Language
1314

15+
log = structlog.get_logger(__name__)
1416

1517
class Quiz(BaseModel):
1618
question: str
@@ -60,7 +62,7 @@ class QuizResponse(BaseModel):
6062

6163
async def generate_quiz(
6264
config: QuizPromptConfig, language: Language, sections: list[PageSection], text_groups_by_id: dict[str, TextGroup]
63-
) -> SectionQuiz:
65+
) -> SectionQuiz | None:
6466
context = dict(
6567
sections=sections,
6668
text_groups=text_groups_by_id,
@@ -69,14 +71,19 @@ async def generate_quiz(
6971
)
7072

7173
prompt = Prompt(cached_read_text_file(config.template_path))
72-
client = instructor.from_litellm(acompletion)
73-
response: QuizResponse = await client.chat.completions.create(
74-
model=config.model,
75-
response_model=QuizResponse,
76-
messages=[m.model_dump(exclude_none=True) for m in prompt.chat_messages(context)],
77-
max_retries=config.max_retries,
78-
context={},
79-
)
74+
client = get_instructor_client()
75+
try:
76+
response: QuizResponse = await client.chat.completions.create(
77+
model=config.model,
78+
response_model=QuizResponse,
79+
messages=[m.model_dump(exclude_none=True) for m in prompt.chat_messages(context)],
80+
max_retries=config.max_retries,
81+
timeout=config.timeout,
82+
context={},
83+
)
84+
except Exception as exc:
85+
log.warning("quiz_generation_failed", error=str(exc))
86+
return None
8087

8188
after_section = sections[-1]
8289

adt_press/nodes/config_nodes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,10 @@ def pruned_section_types_config(config: DictConfig) -> list[SectionTypeName]:
383383
return list[SectionTypeName](config.get("section_filters", {}).get("pruned_section_types", []))
384384

385385

386+
def quiz_count_section_types_config(config: DictConfig) -> list[SectionTypeName]:
387+
return list[SectionTypeName](config.get("section_filters", {}).get("quiz_count_section_types", []))
388+
389+
386390
@cache(behavior="recompute")
387391
def activity_strategy_config(config: DictConfig) -> str:
388392
"""Get the activity generation strategy from config."""

adt_press/nodes/section_nodes.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,13 @@ def quizzes_by_section_id__llm(
9393
filtered_sections_by_page_id: dict[PageID, PageSections],
9494
pdf_text_groups_by_id: dict[TextGroupID, TextGroup],
9595
quiz_prompt_config: QuizPromptConfig,
96+
quiz_count_section_types_config: list[SectionTypeName],
9697
) -> dict[SectionID, SectionQuiz]:
9798
if quiz_prompt_config.sections_per_quiz < 1:
9899
raise ValueError("sections_per_quiz must be at least 1, use quiz_strategy 'none' to disable quizzes")
99100

101+
quiz_count_section_types = set(quiz_count_section_types_config)
102+
100103
async def get_quizzes():
101104
tasks = []
102105
sections = []
@@ -107,6 +110,9 @@ async def get_quizzes():
107110
if section.is_pruned:
108111
continue
109112

113+
if quiz_count_section_types and section.section_type.name not in quiz_count_section_types:
114+
continue
115+
110116
count += 1
111117
sections.append(section)
112118

@@ -118,7 +124,7 @@ async def get_quizzes():
118124
return await gather_with_limit(tasks, quiz_prompt_config.rate_limit)
119125

120126
results = run_async_task(get_quizzes)
121-
return {quiz.section_id: quiz for quiz in results}
127+
return {quiz.section_id: quiz for quiz in results if quiz}
122128

123129

124130
@config.when(explanation_strategy="llm")

config/config.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,13 @@ section_filters:
297297
- back_cover
298298
- credits
299299
- inside_cover
300+
# Only these section types count toward quiz grouping. Empty list means all non-pruned sections count.
301+
quiz_count_section_types:
302+
- boxed_text
303+
- text_only
304+
- text_and_single_image
305+
- text_and_images
306+
- images_only
300307

301308
prompts:
302309
metadata_extraction:

prompts/section_quiz.jinja2

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ For the quiz format:
1616
- The explanation inside each option must describe that specific option
1717
- The option at `answer_index` must have the congratulatory/correct explanation, and the others must provide gentle corrections
1818
- Each explanation option should start with an emoji: ❌ for incorrect answers and ✅ for the correct answer
19+
- IMPORTANT: The `quiz` object MUST include the full `options` array with 3 objects. Do not omit any fields.
1920

2021
Return your response as JSON with this exact structure:
2122
{

0 commit comments

Comments
 (0)