mirror of
https://github.com/VectifyAI/PageIndex.git
synced 2026-07-03 20:41:02 +02:00
fix: prevent KeyError crash and context exhaustion in TOC processing (#188)
* fix: prevent KeyError crash and context exhaustion in TOC processing - Use .get() with safe defaults for all LLM response dict accesses - Optimize extract_toc_content retry loop to grow chat_history incrementally instead of rebuilding with full accumulated response - Optimize toc_transformer retry loop to use chat_history instead of re-embedding the entire raw TOC and incomplete JSON in each prompt - Return best-effort results on max retries instead of raising - Add 14 mock-based tests covering all fix scenarios Closes #163 * fix: address review feedback on retry behavior and None guard - Restore explicit Exception on max retries instead of silent warning - Move truncation logic before the retry loop so it only runs once on the initial incomplete response, not on every iteration - Add explicit None guard for physical_index before passing to convert_physical_index_to_int to prevent potential TypeError - Update test to expect Exception on max retries --------- Co-authored-by: Your Name <you@example.com>
This commit is contained in:
parent
076dd07bd7
commit
f413c66fee
2 changed files with 175 additions and 48 deletions
|
|
@ -117,9 +117,8 @@ def toc_detector_single_page(content, model=None):
|
|||
Please note: abstract,summary, notation list, figure list, table list, etc. are not table of contents."""
|
||||
|
||||
response = llm_completion(model=model, prompt=prompt)
|
||||
# print('response', response)
|
||||
json_content = extract_json(response)
|
||||
return json_content['toc_detected']
|
||||
return json_content.get('toc_detected', 'no')
|
||||
|
||||
|
||||
def check_if_toc_extraction_is_complete(content, toc, model=None):
|
||||
|
|
@ -137,7 +136,7 @@ def check_if_toc_extraction_is_complete(content, toc, model=None):
|
|||
prompt = prompt + '\n Document:\n' + content + '\n Table of contents:\n' + toc
|
||||
response = llm_completion(model=model, prompt=prompt)
|
||||
json_content = extract_json(response)
|
||||
return json_content['completed']
|
||||
return json_content.get('completed', 'no')
|
||||
|
||||
|
||||
def check_if_toc_transformation_is_complete(content, toc, model=None):
|
||||
|
|
@ -155,7 +154,7 @@ def check_if_toc_transformation_is_complete(content, toc, model=None):
|
|||
prompt = prompt + '\n Raw Table of contents:\n' + content + '\n Cleaned Table of contents:\n' + toc
|
||||
response = llm_completion(model=model, prompt=prompt)
|
||||
json_content = extract_json(response)
|
||||
return json_content['completed']
|
||||
return json_content.get('completed', 'no')
|
||||
|
||||
def extract_toc_content(content, model=None):
|
||||
prompt = f"""
|
||||
|
|
@ -175,27 +174,19 @@ def extract_toc_content(content, model=None):
|
|||
{"role": "user", "content": prompt},
|
||||
{"role": "assistant", "content": response},
|
||||
]
|
||||
prompt = f"""please continue the generation of table of contents , directly output the remaining part of the structure"""
|
||||
new_response, finish_reason = llm_completion(model=model, prompt=prompt, chat_history=chat_history, return_finish_reason=True)
|
||||
response = response + new_response
|
||||
if_complete = check_if_toc_transformation_is_complete(content, response, model)
|
||||
continue_prompt = "please continue the generation of table of contents, directly output the remaining part of the structure"
|
||||
|
||||
attempt = 0
|
||||
max_attempts = 5
|
||||
|
||||
while not (if_complete == "yes" and finish_reason == "finished"):
|
||||
attempt += 1
|
||||
if attempt > max_attempts:
|
||||
raise Exception('Failed to complete table of contents after maximum retries')
|
||||
|
||||
chat_history = [
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "assistant", "content": response},
|
||||
]
|
||||
prompt = f"""please continue the generation of table of contents , directly output the remaining part of the structure"""
|
||||
new_response, finish_reason = llm_completion(model=model, prompt=prompt, chat_history=chat_history, return_finish_reason=True)
|
||||
for attempt in range(max_attempts):
|
||||
new_response, finish_reason = llm_completion(model=model, prompt=continue_prompt, chat_history=chat_history, return_finish_reason=True)
|
||||
response = response + new_response
|
||||
chat_history.append({"role": "user", "content": continue_prompt})
|
||||
chat_history.append({"role": "assistant", "content": new_response})
|
||||
if_complete = check_if_toc_transformation_is_complete(content, response, model)
|
||||
if if_complete == "yes" and finish_reason == "finished":
|
||||
break
|
||||
else:
|
||||
raise Exception('Failed to complete table of contents extraction after maximum retries')
|
||||
|
||||
return response
|
||||
|
||||
|
|
@ -217,7 +208,7 @@ def detect_page_index(toc_content, model=None):
|
|||
|
||||
response = llm_completion(model=model, prompt=prompt)
|
||||
json_content = extract_json(response)
|
||||
return json_content['page_index_given_in_toc']
|
||||
return json_content.get('page_index_given_in_toc', 'no')
|
||||
|
||||
def toc_extractor(page_list, toc_page_list, model):
|
||||
def transform_dots_to_colon(text):
|
||||
|
|
@ -296,43 +287,41 @@ def toc_transformer(toc_content, model=None):
|
|||
if_complete = check_if_toc_transformation_is_complete(toc_content, last_complete, model)
|
||||
if if_complete == "yes" and finish_reason == "finished":
|
||||
last_complete = extract_json(last_complete)
|
||||
cleaned_response=convert_page_to_int(last_complete['table_of_contents'])
|
||||
cleaned_response = convert_page_to_int(last_complete.get('table_of_contents', []))
|
||||
return cleaned_response
|
||||
|
||||
last_complete = get_json_content(last_complete)
|
||||
attempt = 0
|
||||
chat_history = [
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "assistant", "content": last_complete},
|
||||
]
|
||||
continue_prompt = "Please continue the table of contents JSON structure from where you left off. Directly output only the remaining part."
|
||||
|
||||
position = last_complete.rfind('}')
|
||||
if position != -1:
|
||||
last_complete = last_complete[:position+2]
|
||||
|
||||
max_attempts = 5
|
||||
while not (if_complete == "yes" and finish_reason == "finished"):
|
||||
attempt += 1
|
||||
if attempt > max_attempts:
|
||||
raise Exception('Failed to complete toc transformation after maximum retries')
|
||||
position = last_complete.rfind('}')
|
||||
if position != -1:
|
||||
last_complete = last_complete[:position+2]
|
||||
prompt = f"""
|
||||
Your task is to continue the table of contents json structure, directly output the remaining part of the json structure.
|
||||
The response should be in the following JSON format:
|
||||
for attempt in range(max_attempts):
|
||||
|
||||
The raw table of contents json structure is:
|
||||
{toc_content}
|
||||
|
||||
The incomplete transformed table of contents json structure is:
|
||||
{last_complete}
|
||||
|
||||
Please continue the json structure, directly output the remaining part of the json structure."""
|
||||
|
||||
new_complete, finish_reason = llm_completion(model=model, prompt=prompt, return_finish_reason=True)
|
||||
new_complete, finish_reason = llm_completion(model=model, prompt=continue_prompt, chat_history=chat_history, return_finish_reason=True)
|
||||
|
||||
if new_complete.startswith('```json'):
|
||||
new_complete = get_json_content(new_complete)
|
||||
last_complete = last_complete+new_complete
|
||||
new_complete = get_json_content(new_complete)
|
||||
last_complete = last_complete + new_complete
|
||||
|
||||
chat_history.append({"role": "user", "content": continue_prompt})
|
||||
chat_history.append({"role": "assistant", "content": new_complete})
|
||||
|
||||
if_complete = check_if_toc_transformation_is_complete(toc_content, last_complete, model)
|
||||
|
||||
if if_complete == "yes" and finish_reason == "finished":
|
||||
break
|
||||
else:
|
||||
raise Exception('Failed to complete TOC transformation after maximum retries')
|
||||
|
||||
last_complete = extract_json(last_complete)
|
||||
|
||||
cleaned_response=convert_page_to_int(last_complete['table_of_contents'])
|
||||
cleaned_response = convert_page_to_int(last_complete.get('table_of_contents', []))
|
||||
return cleaned_response
|
||||
|
||||
|
||||
|
|
@ -753,7 +742,10 @@ async def single_toc_item_index_fixer(section_title, content, model=None):
|
|||
prompt = toc_extractor_prompt + '\nSection Title:\n' + str(section_title) + '\nDocument pages:\n' + content
|
||||
response = await llm_acompletion(model=model, prompt=prompt)
|
||||
json_content = extract_json(response)
|
||||
return convert_physical_index_to_int(json_content['physical_index'])
|
||||
physical_index = json_content.get('physical_index')
|
||||
if physical_index is None:
|
||||
return None
|
||||
return convert_physical_index_to_int(physical_index)
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
135
tests/test_issue_163.py
Normal file
135
tests/test_issue_163.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
import pytest
|
||||
import sys
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
from pageindex.page_index import (
|
||||
check_if_toc_extraction_is_complete,
|
||||
check_if_toc_transformation_is_complete,
|
||||
toc_detector_single_page,
|
||||
detect_page_index,
|
||||
extract_toc_content,
|
||||
toc_transformer,
|
||||
)
|
||||
|
||||
|
||||
class TestRobustKeyAccess:
|
||||
@patch("pageindex.page_index.llm_completion", return_value="")
|
||||
def test_toc_detector_empty_response(self, mock_llm):
|
||||
result = toc_detector_single_page("some content", model="test")
|
||||
assert result == "no"
|
||||
|
||||
@patch("pageindex.page_index.llm_completion", return_value='{"toc_detected": "yes"}')
|
||||
def test_toc_detector_valid_response(self, mock_llm):
|
||||
result = toc_detector_single_page("some content", model="test")
|
||||
assert result == "yes"
|
||||
|
||||
@patch("pageindex.page_index.llm_completion", return_value="not json at all")
|
||||
def test_toc_detector_malformed_response(self, mock_llm):
|
||||
result = toc_detector_single_page("some content", model="test")
|
||||
assert result == "no"
|
||||
|
||||
@patch("pageindex.page_index.llm_completion", return_value="")
|
||||
def test_extraction_complete_empty_response(self, mock_llm):
|
||||
result = check_if_toc_extraction_is_complete("doc", "toc", model="test")
|
||||
assert result == "no"
|
||||
|
||||
@patch("pageindex.page_index.llm_completion", return_value='{"completed": "yes"}')
|
||||
def test_extraction_complete_valid_response(self, mock_llm):
|
||||
result = check_if_toc_extraction_is_complete("doc", "toc", model="test")
|
||||
assert result == "yes"
|
||||
|
||||
@patch("pageindex.page_index.llm_completion", return_value="")
|
||||
def test_transformation_complete_empty_response(self, mock_llm):
|
||||
result = check_if_toc_transformation_is_complete("raw", "cleaned", model="test")
|
||||
assert result == "no"
|
||||
|
||||
@patch("pageindex.page_index.llm_completion", return_value='{"thinking": "looks fine", "completed": "yes"}')
|
||||
def test_transformation_complete_valid_response(self, mock_llm):
|
||||
result = check_if_toc_transformation_is_complete("raw", "cleaned", model="test")
|
||||
assert result == "yes"
|
||||
|
||||
@patch("pageindex.page_index.llm_completion", return_value="")
|
||||
def test_detect_page_index_empty_response(self, mock_llm):
|
||||
result = detect_page_index("toc text", model="test")
|
||||
assert result == "no"
|
||||
|
||||
|
||||
class TestExtractTocContentRetryLoop:
|
||||
@patch("pageindex.page_index.check_if_toc_transformation_is_complete")
|
||||
@patch("pageindex.page_index.llm_completion")
|
||||
def test_completes_on_first_try(self, mock_llm, mock_check):
|
||||
mock_llm.return_value = ("full toc content", "finished")
|
||||
mock_check.return_value = "yes"
|
||||
result = extract_toc_content("raw content", model="test")
|
||||
assert result == "full toc content"
|
||||
assert mock_llm.call_count == 1
|
||||
|
||||
@patch("pageindex.page_index.check_if_toc_transformation_is_complete")
|
||||
@patch("pageindex.page_index.llm_completion")
|
||||
def test_continues_on_incomplete(self, mock_llm, mock_check):
|
||||
mock_llm.side_effect = [
|
||||
("partial toc", "max_output_reached"),
|
||||
(" continued toc", "finished"),
|
||||
]
|
||||
mock_check.side_effect = ["no", "yes"]
|
||||
result = extract_toc_content("raw content", model="test")
|
||||
assert result == "partial toc continued toc"
|
||||
assert mock_llm.call_count == 2
|
||||
|
||||
@patch("pageindex.page_index.check_if_toc_transformation_is_complete")
|
||||
@patch("pageindex.page_index.llm_completion")
|
||||
def test_max_retries_raises_exception(self, mock_llm, mock_check):
|
||||
mock_llm.return_value = ("chunk", "max_output_reached")
|
||||
mock_check.return_value = "no"
|
||||
with pytest.raises(Exception, match="Failed to complete table of contents extraction"):
|
||||
extract_toc_content("raw content", model="test")
|
||||
assert mock_llm.call_count == 6
|
||||
|
||||
@patch("pageindex.page_index.check_if_toc_transformation_is_complete")
|
||||
@patch("pageindex.page_index.llm_completion")
|
||||
def test_chat_history_grows_incrementally(self, mock_llm, mock_check):
|
||||
call_count = [0]
|
||||
|
||||
def side_effect(*args, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return ("initial", "max_output_reached")
|
||||
if call_count[0] == 2:
|
||||
history = kwargs.get("chat_history", [])
|
||||
assert len(history) == 2
|
||||
return (" part2", "max_output_reached")
|
||||
if call_count[0] == 3:
|
||||
history = kwargs.get("chat_history", [])
|
||||
assert len(history) == 4
|
||||
return (" part3", "finished")
|
||||
return ("", "finished")
|
||||
|
||||
mock_llm.side_effect = side_effect
|
||||
mock_check.side_effect = ["no", "no", "yes"]
|
||||
result = extract_toc_content("raw content", model="test")
|
||||
assert result == "initial part2 part3"
|
||||
|
||||
|
||||
class TestTocTransformerRetryLoop:
|
||||
@patch("pageindex.page_index.check_if_toc_transformation_is_complete")
|
||||
@patch("pageindex.page_index.llm_completion")
|
||||
def test_completes_on_first_try(self, mock_llm, mock_check):
|
||||
mock_llm.return_value = (
|
||||
'{"table_of_contents": [{"structure": "1", "title": "Intro", "page": 1}]}',
|
||||
"finished",
|
||||
)
|
||||
mock_check.return_value = "yes"
|
||||
result = toc_transformer("raw toc", model="test")
|
||||
assert len(result) == 1
|
||||
assert result[0]["title"] == "Intro"
|
||||
|
||||
@patch("pageindex.page_index.check_if_toc_transformation_is_complete")
|
||||
@patch("pageindex.page_index.llm_completion")
|
||||
def test_handles_missing_table_of_contents_key(self, mock_llm, mock_check):
|
||||
mock_llm.return_value = ('{"other_key": "value"}', "finished")
|
||||
mock_check.return_value = "yes"
|
||||
result = toc_transformer("raw toc", model="test")
|
||||
assert result == []
|
||||
Loading…
Add table
Add a link
Reference in a new issue