add async. various fixes.

This commit is contained in:
Ray 2025-04-20 07:57:07 +08:00
parent 3fbf2d9139
commit b588cd62a1
3 changed files with 134 additions and 111 deletions

View file

@ -7,11 +7,10 @@ import re
from .utils import *
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
import argparse
################### check title in page #########################################################
def check_title_appearance(item, page_list, start_index=1, model=None):
async def check_title_appearance(item, page_list, start_index=1, model=None):
title=item['title']
if 'physical_index' not in item or item['physical_index'] is None:
return {'list_index': item.get('list_index'), 'answer': 'no', 'title':title, 'page_number': None}
@ -37,7 +36,7 @@ def check_title_appearance(item, page_list, start_index=1, model=None):
}}
Directly return the final JSON structure. Do not output anything else."""
response = ChatGPT_API(model=model, prompt=prompt)
response = await ChatGPT_API_async(model=model, prompt=prompt)
response = extract_json(response)
if 'answer' in response:
answer = response['answer']
@ -46,9 +45,9 @@ def check_title_appearance(item, page_list, start_index=1, model=None):
return {'list_index': item['list_index'], 'answer': answer, 'title': title, 'page_number': page_number}
def check_title_appearance_in_start(title, page_text, model=None, logger=None):
async def check_title_appearance_in_start(title, page_text, model=None, logger=None):
prompt = f"""
You will be given given the current section title and the current page_text.
You will be given the current section title and the current page_text.
Your job is to check if the current section starts in the beginning of the given page_text.
If there are other contents before the current section title, then the current section does not start in the beginning of the given page_text.
If the current section title is the first content in the given page_text, then the current section starts in the beginning of the given page_text.
@ -65,36 +64,40 @@ def check_title_appearance_in_start(title, page_text, model=None, logger=None):
}}
Directly return the final JSON structure. Do not output anything else."""
response = ChatGPT_API(model=model, prompt=prompt)
response = await ChatGPT_API_async(model=model, prompt=prompt)
response = extract_json(response)
if logger:
logger.info(f"Response: {response}")
if 'start_begin' in response:
return response['start_begin']
else:
return 'no'
return response.get("start_begin", "no")
def check_title_appearance_in_start_parallel(structure, page_list, model=None, logger=None):
async def check_title_appearance_in_start_concurrent(structure, page_list, model=None, logger=None):
if logger:
logger.info(f"Checking title appearance in start parallel")
with ThreadPoolExecutor(max_workers=10) as executor:
future_to_item = {
executor.submit(check_title_appearance_in_start, item['title'], page_list[item['physical_index']-1][0], model=model, logger=logger): item
for item in structure
}
# Process completed futures and attach results to items
for future in as_completed(future_to_item):
item = future_to_item[future]
try:
result = future.result()
item['appear_start'] = result
except Exception as e:
if logger:
logger.error(f"Error processing item {item['title']}: {str(e)}")
item['appear_start'] = 'no'
logger.info("Checking title appearance in start concurrently")
# skip items without physical_index
for item in structure:
if item.get('physical_index') is None:
item['appear_start'] = 'no'
# only for items with valid physical_index
tasks = []
valid_items = []
for item in structure:
if item.get('physical_index') is not None:
page_text = page_list[item['physical_index'] - 1][0]
tasks.append(check_title_appearance_in_start(item['title'], page_text, model=model, logger=logger))
valid_items.append(item)
results = await asyncio.gather(*tasks, return_exceptions=True)
for item, result in zip(valid_items, results):
if isinstance(result, Exception):
if logger:
logger.error(f"Error checking start for {item['title']}: {result}")
item['appear_start'] = 'no'
else:
item['appear_start'] = result
return structure
@ -505,14 +508,15 @@ def generate_toc_continue(toc_content, part, model="gpt-4o-2024-11-20"):
For the title, you need to extract the original title from the text, only fix the space inconsistency.
The provided text contains tags like <physical_index_X> and <physical_index_X> to indicate the start and end of page X. \
For the physical_index, you need to extract the physical index of the start of the section from the text. Keep the <physical_index_X> format.
The response should be in the following format.
[
{
"structure": <structure index, "x.x.x" or None> (string),
"structure": <structure index, "x.x.x"> (string),
"title": <title of the section, keep the original title>,
"physical_index": "<physical_index_X> (keep the format)" or None
"physical_index": "<physical_index_X> (keep the format)"
},
...
]
@ -538,13 +542,15 @@ def generate_toc_init(part, model=None):
The provided text contains tags like <physical_index_X> and <physical_index_X> to indicate the start and end of page X.
For the physical_index, you need to extract the physical index of the start of the section from the text. Keep the <physical_index_X> format.
The response should be in the following format.
[
{
"structure": <structure index, "x.x.x" or None> (string),
{{
"structure": <structure index, "x.x.x"> (string),
"title": <title of the section, keep the original title>,
"physical_index": "<physical_index_X> (keep the format)" or None
},
"physical_index": "<physical_index_X> (keep the format)"
}},
],
@ -738,7 +744,7 @@ def single_toc_item_index_fixer(section_title, content, model="gpt-4o-2024-11-20
def fix_incorrect_toc(toc_with_page_number, page_list, incorrect_results, start_index=1, model=None, logger=None):
async def fix_incorrect_toc(toc_with_page_number, page_list, incorrect_results, start_index=1, model=None, logger=None):
print(f'start fix_incorrect_toc with {len(incorrect_results)} incorrect results')
incorrect_indices = {result['list_index'] for result in incorrect_results}
@ -746,7 +752,7 @@ def fix_incorrect_toc(toc_with_page_number, page_list, incorrect_results, start_
incorrect_results_and_range_logs = []
# Helper function to process and check a single incorrect item
def process_and_check_item(incorrect_item):
async def process_and_check_item(incorrect_item):
list_index = incorrect_item['list_index']
# Find the previous correct item
prev_correct = None
@ -786,7 +792,7 @@ def fix_incorrect_toc(toc_with_page_number, page_list, incorrect_results, start_
# Check if the result is correct
check_item = incorrect_item.copy()
check_item['physical_index'] = physical_index_int
check_result = check_title_appearance(check_item, page_list, start_index, model)
check_result = await check_title_appearance(check_item, page_list, start_index, model)
return {
'list_index': list_index,
@ -794,20 +800,19 @@ def fix_incorrect_toc(toc_with_page_number, page_list, incorrect_results, start_
'physical_index': physical_index_int,
'is_valid': check_result['answer'] == 'yes'
}
results = []
with ThreadPoolExecutor() as executor:
future_to_item = {executor.submit(process_and_check_item, item): item for item in incorrect_results}
for future in as_completed(future_to_item):
item = future_to_item[future]
try:
result = future.result()
results.append(result)
except Exception as exc:
print(f"Processing item {item} generated an exception: {exc}")
# Process incorrect items concurrently
tasks = [
process_and_check_item(item)
for item in incorrect_results
]
results = await asyncio.gather(*tasks, return_exceptions=True)
for item, result in zip(incorrect_results, results):
if isinstance(result, Exception):
print(f"Processing item {item} generated an exception: {result}")
continue
results = [result for result in results if not isinstance(result, Exception)]
# Update the toc_with_page_number with the fixed indices and check for any invalid results
invalid_results = []
for result in results:
@ -827,7 +832,7 @@ def fix_incorrect_toc(toc_with_page_number, page_list, incorrect_results, start_
def fix_incorrect_toc_with_retries(toc_with_page_number, page_list, incorrect_results, start_index=1, max_attempts=3, model=None, logger=None):
async def fix_incorrect_toc_with_retries(toc_with_page_number, page_list, incorrect_results, start_index=1, max_attempts=3, model=None, logger=None):
print('start fix_incorrect_toc')
fix_attempt = 0
current_toc = toc_with_page_number
@ -836,7 +841,7 @@ def fix_incorrect_toc_with_retries(toc_with_page_number, page_list, incorrect_re
while current_incorrect:
print(f"Fixing {len(current_incorrect)} incorrect results")
current_toc, current_incorrect = fix_incorrect_toc(current_toc, page_list, current_incorrect, start_index, model, logger)
current_toc, current_incorrect = await fix_incorrect_toc(current_toc, page_list, current_incorrect, start_index, model, logger)
fix_attempt += 1
if fix_attempt >= max_attempts:
@ -849,7 +854,7 @@ def fix_incorrect_toc_with_retries(toc_with_page_number, page_list, incorrect_re
################### verify toc #########################################################
def verify_toc(page_list, list_result, start_index=1, N=None, model=None):
async def verify_toc(page_list, list_result, start_index=1, N=None, model=None):
print('start verify_toc')
# Find the last non-None physical_index
last_physical_index = None
@ -879,16 +884,12 @@ def verify_toc(page_list, list_result, start_index=1, N=None, model=None):
item_with_index['list_index'] = idx # Add the original index in list_result
indexed_sample_list.append(item_with_index)
# Run checks in parallel
results = []
with ThreadPoolExecutor(max_workers=10) as executor:
future_to_item = {
executor.submit(check_title_appearance, item, page_list, start_index, model): item
for item in indexed_sample_list
}
for future in as_completed(future_to_item):
results.append(future.result())
# Run checks concurrently
tasks = [
check_title_appearance(item, page_list, start_index, model)
for item in indexed_sample_list
]
results = await asyncio.gather(*tasks)
# Process results
correct_count = 0
@ -910,7 +911,7 @@ def verify_toc(page_list, list_result, start_index=1, N=None, model=None):
################### main process #########################################################
def meta_processor(page_list, mode=None, toc_content=None, toc_page_list=None, start_index=1, opt=None, logger=None):
async def meta_processor(page_list, mode=None, toc_content=None, toc_page_list=None, start_index=1, opt=None, logger=None):
print(mode)
print(f'start_index: {start_index}')
@ -922,7 +923,7 @@ def meta_processor(page_list, mode=None, toc_content=None, toc_page_list=None, s
toc_with_page_number = process_no_toc(page_list, start_index=start_index, model=opt.model, logger=logger)
toc_with_page_number = [item for item in toc_with_page_number if item.get('physical_index') is not None]
accuracy, incorrect_results = verify_toc(page_list, toc_with_page_number, start_index=start_index, model=opt.model)
accuracy, incorrect_results = await verify_toc(page_list, toc_with_page_number, start_index=start_index, model=opt.model)
logger.info({
'mode': 'process_toc_with_page_numbers',
@ -932,26 +933,26 @@ def meta_processor(page_list, mode=None, toc_content=None, toc_page_list=None, s
if accuracy == 1.0 and len(incorrect_results) == 0:
return toc_with_page_number
if accuracy > 0.6 and len(incorrect_results) > 0:
toc_with_page_number, incorrect_results = fix_incorrect_toc_with_retries(toc_with_page_number, page_list, incorrect_results,start_index=start_index, max_attempts=3, model=opt.model, logger=logger)
toc_with_page_number, incorrect_results = await fix_incorrect_toc_with_retries(toc_with_page_number, page_list, incorrect_results,start_index=start_index, max_attempts=3, model=opt.model, logger=logger)
return toc_with_page_number
else:
if mode == 'process_toc_with_page_numbers':
return meta_processor(page_list, mode='process_toc_no_page_numbers', toc_content=toc_content, toc_page_list=toc_page_list, start_index=start_index, opt=opt, logger=logger)
return await meta_processor(page_list, mode='process_toc_no_page_numbers', toc_content=toc_content, toc_page_list=toc_page_list, start_index=start_index, opt=opt, logger=logger)
elif mode == 'process_toc_no_page_numbers':
return meta_processor(page_list, mode='process_no_toc', start_index=start_index, opt=opt, logger=logger)
return await meta_processor(page_list, mode='process_no_toc', start_index=start_index, opt=opt, logger=logger)
else:
raise Exception('Processing failed')
def process_large_node_recursively(node, page_list, opt=None, logger=None):
async def process_large_node_recursively(node, page_list, opt=None, logger=None):
node_page_list = page_list[node['start_index']-1:node['end_index']-1]
token_num = sum([page[1] for page in node_page_list])
if node['end_index'] - node['start_index'] > opt.max_page_num_each_node and token_num >= opt.max_token_num_each_node:
print('large node:', node['title'], 'start_index:', node['start_index'], 'end_index:', node['end_index'], 'token_num:', token_num)
node_toc_tree = meta_processor(node_page_list, mode='process_no_toc', start_index=node['start_index'], opt=opt, logger=logger)
node_toc_tree = check_title_appearance_in_start_parallel(node_toc_tree, page_list, model=opt.model, logger=logger)
node_toc_tree = await meta_processor(node_page_list, mode='process_no_toc', start_index=node['start_index'], opt=opt, logger=logger)
node_toc_tree = await check_title_appearance_in_start_concurrent(node_toc_tree, page_list, model=opt.model, logger=logger)
if node['title'].strip() == node_toc_tree[0]['title'].strip():
node['nodes'] = post_processing(node_toc_tree[1:], node['end_index'])
@ -961,17 +962,20 @@ def process_large_node_recursively(node, page_list, opt=None, logger=None):
node['end_index'] = node_toc_tree[0]['start_index']
if 'nodes' in node and node['nodes']:
for child_node in node['nodes']:
tasks = [
process_large_node_recursively(child_node, page_list, opt, logger=logger)
for child_node in node['nodes']
]
await asyncio.gather(*tasks)
return node
def tree_parser(page_list, opt, logger=None):
check_toc_result = check_toc(page_list, opt)
async def tree_parser(page_list, opt, doc=None, logger=None):
check_toc_result = check_toc(page_list, opt)
logger.info(check_toc_result)
if check_toc_result['toc_content'] is not None and check_toc_result['page_index_given_in_toc'] == 'yes':
toc_with_page_number = meta_processor(
if check_toc_result.get("toc_content") and check_toc_result["toc_content"].strip() and check_toc_result["page_index_given_in_toc"] == "yes":
toc_with_page_number = await meta_processor(
page_list,
mode='process_toc_with_page_numbers',
start_index=1,
@ -980,7 +984,7 @@ def tree_parser(page_list, opt, logger=None):
opt=opt,
logger=logger)
else:
toc_with_page_number = meta_processor(
toc_with_page_number = await meta_processor(
page_list,
mode='process_no_toc',
start_index=1,
@ -988,10 +992,13 @@ def tree_parser(page_list, opt, logger=None):
logger=logger)
toc_with_page_number = add_preface_if_needed(toc_with_page_number)
toc_with_page_number = check_title_appearance_in_start_parallel(toc_with_page_number, page_list, model=opt.model, logger=logger)
toc_with_page_number = await check_title_appearance_in_start_concurrent(toc_with_page_number, page_list, model=opt.model, logger=logger)
toc_tree = post_processing(toc_with_page_number, len(page_list))
for node in toc_tree:
tasks = [
process_large_node_recursively(node, page_list, opt, logger=logger)
for node in toc_tree
]
await asyncio.gather(*tasks)
return toc_tree
@ -1012,13 +1019,15 @@ def page_index_main(doc, opt=None):
logger.info({'total_page_number': len(page_list)})
logger.info({'total_token': sum([page[1] for page in page_list])})
structure = tree_parser(page_list, opt, logger=logger)
structure = asyncio.run(tree_parser(page_list, opt, doc=doc, logger=logger))
if opt.if_add_node_id == 'yes':
write_node_id(structure)
if opt.if_add_node_summary == 'yes':
add_node_text(structure, page_list)
asyncio.run(generate_summaries_for_structure(structure, model=opt.model))
remove_structure_text(structure)
remove_structure_text(structure)
if opt.if_add_node_text == 'yes':
add_node_text_with_labels(structure, page_list)
if opt.if_add_doc_description == 'yes':
doc_description = generate_doc_description(structure, model=opt.model)
return {
@ -1033,7 +1042,7 @@ def page_index_main(doc, opt=None):
def page_index(doc, model=None, toc_check_page_num=None, max_page_num_each_node=None, max_token_num_each_node=None,
if_add_node_id=None, if_add_node_summary=None, if_add_doc_description=None):
if_add_node_id=None, if_add_node_summary=None, if_add_doc_description=None, if_add_node_text=None):
user_opt = {
arg: value for arg, value in locals().items()