mirror of
https://github.com/VectifyAI/PageIndex.git
synced 2026-05-09 14:52:36 +02:00
feat: add PageIndex SDK with local/cloud dual-mode support (#207)
This commit is contained in:
parent
f2dcffc0b7
commit
c7fe93bb56
45 changed files with 4225 additions and 274 deletions
0
pageindex/index/__init__.py
Normal file
0
pageindex/index/__init__.py
Normal file
2
pageindex/index/legacy_utils.py
Normal file
2
pageindex/index/legacy_utils.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
# Re-export from the original utils.py for backward compatibility
|
||||
from ..utils import *
|
||||
1155
pageindex/index/page_index.py
Normal file
1155
pageindex/index/page_index.py
Normal file
File diff suppressed because it is too large
Load diff
341
pageindex/index/page_index_md.py
Normal file
341
pageindex/index/page_index_md.py
Normal file
|
|
@ -0,0 +1,341 @@
|
|||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
try:
|
||||
from .legacy_utils import *
|
||||
except:
|
||||
from legacy_utils import *
|
||||
|
||||
async def get_node_summary(node, summary_token_threshold=200, model=None):
|
||||
node_text = node.get('text')
|
||||
num_tokens = count_tokens(node_text, model=model)
|
||||
if num_tokens < summary_token_threshold:
|
||||
return node_text
|
||||
else:
|
||||
return await generate_node_summary(node, model=model)
|
||||
|
||||
|
||||
async def generate_summaries_for_structure_md(structure, summary_token_threshold, model=None):
|
||||
nodes = structure_to_list(structure)
|
||||
tasks = [get_node_summary(node, summary_token_threshold=summary_token_threshold, model=model) for node in nodes]
|
||||
summaries = await asyncio.gather(*tasks)
|
||||
|
||||
for node, summary in zip(nodes, summaries):
|
||||
if not node.get('nodes'):
|
||||
node['summary'] = summary
|
||||
else:
|
||||
node['prefix_summary'] = summary
|
||||
return structure
|
||||
|
||||
|
||||
def extract_nodes_from_markdown(markdown_content):
|
||||
header_pattern = r'^(#{1,6})\s+(.+)$'
|
||||
code_block_pattern = r'^```'
|
||||
node_list = []
|
||||
|
||||
lines = markdown_content.split('\n')
|
||||
in_code_block = False
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
stripped_line = line.strip()
|
||||
|
||||
# Check for code block delimiters (triple backticks)
|
||||
if re.match(code_block_pattern, stripped_line):
|
||||
in_code_block = not in_code_block
|
||||
continue
|
||||
|
||||
# Skip empty lines
|
||||
if not stripped_line:
|
||||
continue
|
||||
|
||||
# Only look for headers when not inside a code block
|
||||
if not in_code_block:
|
||||
match = re.match(header_pattern, stripped_line)
|
||||
if match:
|
||||
title = match.group(2).strip()
|
||||
node_list.append({'node_title': title, 'line_num': line_num})
|
||||
|
||||
return node_list, lines
|
||||
|
||||
|
||||
def extract_node_text_content(node_list, markdown_lines):
|
||||
all_nodes = []
|
||||
for node in node_list:
|
||||
line_content = markdown_lines[node['line_num'] - 1]
|
||||
header_match = re.match(r'^(#{1,6})', line_content)
|
||||
|
||||
if header_match is None:
|
||||
print(f"Warning: Line {node['line_num']} does not contain a valid header: '{line_content}'")
|
||||
continue
|
||||
|
||||
processed_node = {
|
||||
'title': node['node_title'],
|
||||
'line_num': node['line_num'],
|
||||
'level': len(header_match.group(1))
|
||||
}
|
||||
all_nodes.append(processed_node)
|
||||
|
||||
for i, node in enumerate(all_nodes):
|
||||
start_line = node['line_num'] - 1
|
||||
if i + 1 < len(all_nodes):
|
||||
end_line = all_nodes[i + 1]['line_num'] - 1
|
||||
else:
|
||||
end_line = len(markdown_lines)
|
||||
|
||||
node['text'] = '\n'.join(markdown_lines[start_line:end_line]).strip()
|
||||
return all_nodes
|
||||
|
||||
def update_node_list_with_text_token_count(node_list, model=None):
|
||||
|
||||
def find_all_children(parent_index, parent_level, node_list):
|
||||
"""Find all direct and indirect children of a parent node"""
|
||||
children_indices = []
|
||||
|
||||
# Look for children after the parent
|
||||
for i in range(parent_index + 1, len(node_list)):
|
||||
current_level = node_list[i]['level']
|
||||
|
||||
# If we hit a node at same or higher level than parent, stop
|
||||
if current_level <= parent_level:
|
||||
break
|
||||
|
||||
# This is a descendant
|
||||
children_indices.append(i)
|
||||
|
||||
return children_indices
|
||||
|
||||
# Make a copy to avoid modifying the original
|
||||
result_list = node_list.copy()
|
||||
|
||||
# Process nodes from end to beginning to ensure children are processed before parents
|
||||
for i in range(len(result_list) - 1, -1, -1):
|
||||
current_node = result_list[i]
|
||||
current_level = current_node['level']
|
||||
|
||||
# Get all children of this node
|
||||
children_indices = find_all_children(i, current_level, result_list)
|
||||
|
||||
# Start with the node's own text
|
||||
node_text = current_node.get('text', '')
|
||||
total_text = node_text
|
||||
|
||||
# Add all children's text
|
||||
for child_index in children_indices:
|
||||
child_text = result_list[child_index].get('text', '')
|
||||
if child_text:
|
||||
total_text += '\n' + child_text
|
||||
|
||||
# Calculate token count for combined text
|
||||
result_list[i]['text_token_count'] = count_tokens(total_text, model=model)
|
||||
|
||||
return result_list
|
||||
|
||||
|
||||
def tree_thinning_for_index(node_list, min_node_token=None, model=None):
|
||||
def find_all_children(parent_index, parent_level, node_list):
|
||||
children_indices = []
|
||||
|
||||
for i in range(parent_index + 1, len(node_list)):
|
||||
current_level = node_list[i]['level']
|
||||
|
||||
if current_level <= parent_level:
|
||||
break
|
||||
|
||||
children_indices.append(i)
|
||||
|
||||
return children_indices
|
||||
|
||||
result_list = node_list.copy()
|
||||
nodes_to_remove = set()
|
||||
|
||||
for i in range(len(result_list) - 1, -1, -1):
|
||||
if i in nodes_to_remove:
|
||||
continue
|
||||
|
||||
current_node = result_list[i]
|
||||
current_level = current_node['level']
|
||||
|
||||
total_tokens = current_node.get('text_token_count', 0)
|
||||
|
||||
if total_tokens < min_node_token:
|
||||
children_indices = find_all_children(i, current_level, result_list)
|
||||
|
||||
children_texts = []
|
||||
for child_index in sorted(children_indices):
|
||||
if child_index not in nodes_to_remove:
|
||||
child_text = result_list[child_index].get('text', '')
|
||||
if child_text.strip():
|
||||
children_texts.append(child_text)
|
||||
nodes_to_remove.add(child_index)
|
||||
|
||||
if children_texts:
|
||||
parent_text = current_node.get('text', '')
|
||||
merged_text = parent_text
|
||||
for child_text in children_texts:
|
||||
if merged_text and not merged_text.endswith('\n'):
|
||||
merged_text += '\n\n'
|
||||
merged_text += child_text
|
||||
|
||||
result_list[i]['text'] = merged_text
|
||||
|
||||
result_list[i]['text_token_count'] = count_tokens(merged_text, model=model)
|
||||
|
||||
for index in sorted(nodes_to_remove, reverse=True):
|
||||
result_list.pop(index)
|
||||
|
||||
return result_list
|
||||
|
||||
|
||||
def build_tree_from_nodes(node_list):
|
||||
if not node_list:
|
||||
return []
|
||||
|
||||
stack = []
|
||||
root_nodes = []
|
||||
node_counter = 1
|
||||
|
||||
for node in node_list:
|
||||
current_level = node['level']
|
||||
|
||||
tree_node = {
|
||||
'title': node['title'],
|
||||
'node_id': str(node_counter).zfill(4),
|
||||
'text': node['text'],
|
||||
'line_num': node['line_num'],
|
||||
'nodes': []
|
||||
}
|
||||
node_counter += 1
|
||||
|
||||
while stack and stack[-1][1] >= current_level:
|
||||
stack.pop()
|
||||
|
||||
if not stack:
|
||||
root_nodes.append(tree_node)
|
||||
else:
|
||||
parent_node, parent_level = stack[-1]
|
||||
parent_node['nodes'].append(tree_node)
|
||||
|
||||
stack.append((tree_node, current_level))
|
||||
|
||||
return root_nodes
|
||||
|
||||
|
||||
def clean_tree_for_output(tree_nodes):
|
||||
cleaned_nodes = []
|
||||
|
||||
for node in tree_nodes:
|
||||
cleaned_node = {
|
||||
'title': node['title'],
|
||||
'node_id': node['node_id'],
|
||||
'text': node['text'],
|
||||
'line_num': node['line_num']
|
||||
}
|
||||
|
||||
if node['nodes']:
|
||||
cleaned_node['nodes'] = clean_tree_for_output(node['nodes'])
|
||||
|
||||
cleaned_nodes.append(cleaned_node)
|
||||
|
||||
return cleaned_nodes
|
||||
|
||||
|
||||
async def md_to_tree(md_path, if_thinning=False, min_token_threshold=None, if_add_node_summary=False, summary_token_threshold=None, model=None, if_add_doc_description=False, if_add_node_text=False, if_add_node_id=True):
|
||||
with open(md_path, 'r', encoding='utf-8') as f:
|
||||
markdown_content = f.read()
|
||||
line_count = markdown_content.count('\n') + 1
|
||||
|
||||
print(f"Extracting nodes from markdown...")
|
||||
node_list, markdown_lines = extract_nodes_from_markdown(markdown_content)
|
||||
|
||||
print(f"Extracting text content from nodes...")
|
||||
nodes_with_content = extract_node_text_content(node_list, markdown_lines)
|
||||
|
||||
if if_thinning:
|
||||
nodes_with_content = update_node_list_with_text_token_count(nodes_with_content, model=model)
|
||||
print(f"Thinning nodes...")
|
||||
nodes_with_content = tree_thinning_for_index(nodes_with_content, min_token_threshold, model=model)
|
||||
|
||||
print(f"Building tree from nodes...")
|
||||
tree_structure = build_tree_from_nodes(nodes_with_content)
|
||||
|
||||
if if_add_node_id:
|
||||
write_node_id(tree_structure)
|
||||
|
||||
print(f"Formatting tree structure...")
|
||||
|
||||
if if_add_node_summary:
|
||||
# Always include text for summary generation
|
||||
tree_structure = format_structure(tree_structure, order = ['title', 'node_id', 'line_num', 'summary', 'prefix_summary', 'text', 'nodes'])
|
||||
|
||||
print(f"Generating summaries for each node...")
|
||||
tree_structure = await generate_summaries_for_structure_md(tree_structure, summary_token_threshold=summary_token_threshold, model=model)
|
||||
|
||||
if not if_add_node_text:
|
||||
# Remove text after summary generation if not requested
|
||||
tree_structure = format_structure(tree_structure, order = ['title', 'node_id', 'line_num', 'summary', 'prefix_summary', 'nodes'])
|
||||
|
||||
if if_add_doc_description:
|
||||
print(f"Generating document description...")
|
||||
clean_structure = create_clean_structure_for_description(tree_structure)
|
||||
doc_description = generate_doc_description(clean_structure, model=model)
|
||||
return {
|
||||
'doc_name': os.path.splitext(os.path.basename(md_path))[0],
|
||||
'doc_description': doc_description,
|
||||
'line_count': line_count,
|
||||
'structure': tree_structure,
|
||||
}
|
||||
else:
|
||||
# No summaries needed, format based on text preference
|
||||
if if_add_node_text:
|
||||
tree_structure = format_structure(tree_structure, order = ['title', 'node_id', 'line_num', 'summary', 'prefix_summary', 'text', 'nodes'])
|
||||
else:
|
||||
tree_structure = format_structure(tree_structure, order = ['title', 'node_id', 'line_num', 'summary', 'prefix_summary', 'nodes'])
|
||||
|
||||
return {
|
||||
'doc_name': os.path.splitext(os.path.basename(md_path))[0],
|
||||
'line_count': line_count,
|
||||
'structure': tree_structure,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import json
|
||||
|
||||
# MD_NAME = 'Detect-Order-Construct'
|
||||
MD_NAME = 'cognitive-load'
|
||||
MD_PATH = os.path.join(os.path.dirname(__file__), '..', 'examples/documents/', f'{MD_NAME}.md')
|
||||
|
||||
|
||||
MODEL="gpt-4.1"
|
||||
IF_THINNING=False
|
||||
THINNING_THRESHOLD=5000
|
||||
SUMMARY_TOKEN_THRESHOLD=200
|
||||
IF_SUMMARY=True
|
||||
|
||||
tree_structure = asyncio.run(md_to_tree(
|
||||
md_path=MD_PATH,
|
||||
if_thinning=IF_THINNING,
|
||||
min_token_threshold=THINNING_THRESHOLD,
|
||||
if_add_node_summary='yes' if IF_SUMMARY else 'no',
|
||||
summary_token_threshold=SUMMARY_TOKEN_THRESHOLD,
|
||||
model=MODEL))
|
||||
|
||||
print('\n' + '='*60)
|
||||
print('TREE STRUCTURE')
|
||||
print('='*60)
|
||||
print_json(tree_structure)
|
||||
|
||||
print('\n' + '='*60)
|
||||
print('TABLE OF CONTENTS')
|
||||
print('='*60)
|
||||
print_toc(tree_structure['structure'])
|
||||
|
||||
output_path = os.path.join(os.path.dirname(__file__), '..', 'results', f'{MD_NAME}_structure.json')
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(tree_structure, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\nTree structure saved to: {output_path}")
|
||||
122
pageindex/index/pipeline.py
Normal file
122
pageindex/index/pipeline.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
# pageindex/index/pipeline.py
|
||||
from __future__ import annotations
|
||||
from ..parser.protocol import ContentNode, ParsedDocument
|
||||
|
||||
|
||||
def detect_strategy(nodes: list[ContentNode]) -> str:
|
||||
"""Determine which indexing strategy to use based on node data."""
|
||||
if any(n.level is not None for n in nodes):
|
||||
return "level_based"
|
||||
return "content_based"
|
||||
|
||||
|
||||
def build_tree_from_levels(nodes: list[ContentNode]) -> list[dict]:
|
||||
"""Strategy 0: Build tree from explicit level information.
|
||||
Adapted from pageindex/page_index_md.py:build_tree_from_nodes."""
|
||||
stack = []
|
||||
root_nodes = []
|
||||
|
||||
for node in nodes:
|
||||
tree_node = {
|
||||
"title": node.title or "",
|
||||
"text": node.content,
|
||||
"line_num": node.index,
|
||||
"nodes": [],
|
||||
}
|
||||
current_level = node.level or 1
|
||||
|
||||
while stack and stack[-1][1] >= current_level:
|
||||
stack.pop()
|
||||
|
||||
if not stack:
|
||||
root_nodes.append(tree_node)
|
||||
else:
|
||||
parent_node, _ = stack[-1]
|
||||
parent_node["nodes"].append(tree_node)
|
||||
|
||||
stack.append((tree_node, current_level))
|
||||
|
||||
return root_nodes
|
||||
|
||||
|
||||
def _run_async(coro):
|
||||
"""Run an async coroutine, handling the case where an event loop is already running."""
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
# Already inside an event loop -- run in a separate thread
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
return pool.submit(asyncio.run, coro).result()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def build_index(parsed: ParsedDocument, model: str = None, opt=None) -> dict:
|
||||
"""Main entry point: ParsedDocument -> tree structure dict.
|
||||
Routes to the appropriate strategy and runs enhancement."""
|
||||
from .utils import (write_node_id, add_node_text, remove_structure_text,
|
||||
generate_summaries_for_structure, generate_doc_description,
|
||||
create_clean_structure_for_description)
|
||||
from ..config import IndexConfig
|
||||
|
||||
if opt is None:
|
||||
opt = IndexConfig(model=model) if model else IndexConfig()
|
||||
|
||||
nodes = parsed.nodes
|
||||
strategy = detect_strategy(nodes)
|
||||
|
||||
if strategy == "level_based":
|
||||
structure = build_tree_from_levels(nodes)
|
||||
# For level-based, text is already in the tree nodes
|
||||
else:
|
||||
# Strategies 1-3: convert ContentNode list to page_list format for existing pipeline
|
||||
page_list = [(n.content, n.tokens) for n in nodes]
|
||||
structure = _run_async(_content_based_pipeline(page_list, opt))
|
||||
|
||||
# Unified enhancement
|
||||
if opt.if_add_node_id:
|
||||
write_node_id(structure)
|
||||
|
||||
if strategy != "level_based":
|
||||
if opt.if_add_node_text or opt.if_add_node_summary:
|
||||
add_node_text(structure, page_list)
|
||||
|
||||
if opt.if_add_node_summary:
|
||||
_run_async(generate_summaries_for_structure(structure, model=opt.model))
|
||||
|
||||
if not opt.if_add_node_text and strategy != "level_based":
|
||||
remove_structure_text(structure)
|
||||
|
||||
result = {
|
||||
"doc_name": parsed.doc_name,
|
||||
"structure": structure,
|
||||
}
|
||||
|
||||
if opt.if_add_doc_description:
|
||||
clean_structure = create_clean_structure_for_description(structure)
|
||||
result["doc_description"] = generate_doc_description(
|
||||
clean_structure, model=opt.model
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class _NullLogger:
|
||||
"""Minimal logger that satisfies the tree_parser interface without writing files."""
|
||||
def info(self, message, **kwargs): pass
|
||||
def error(self, message, **kwargs): pass
|
||||
def debug(self, message, **kwargs): pass
|
||||
|
||||
|
||||
async def _content_based_pipeline(page_list, opt):
|
||||
"""Strategies 1-3: delegates to the existing PDF pipeline from pageindex/page_index.py.
|
||||
|
||||
The page_list is already in the format expected by tree_parser:
|
||||
[(page_text, token_count), ...]
|
||||
"""
|
||||
from .page_index import tree_parser
|
||||
|
||||
logger = _NullLogger()
|
||||
structure = await tree_parser(page_list, opt, doc=None, logger=logger)
|
||||
return structure
|
||||
431
pageindex/index/utils.py
Normal file
431
pageindex/index/utils.py
Normal file
|
|
@ -0,0 +1,431 @@
|
|||
import litellm
|
||||
import logging
|
||||
import time
|
||||
import json
|
||||
import copy
|
||||
import re
|
||||
import asyncio
|
||||
import PyPDF2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def count_tokens(text, model=None):
|
||||
if not text:
|
||||
return 0
|
||||
return litellm.token_counter(model=model, text=text)
|
||||
|
||||
|
||||
def llm_completion(model, prompt, chat_history=None, return_finish_reason=False):
|
||||
if model:
|
||||
model = model.removeprefix("litellm/")
|
||||
max_retries = 10
|
||||
messages = list(chat_history) + [{"role": "user", "content": prompt}] if chat_history else [{"role": "user", "content": prompt}]
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
litellm.drop_params = True
|
||||
response = litellm.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
if return_finish_reason:
|
||||
finish_reason = "max_output_reached" if response.choices[0].finish_reason == "length" else "finished"
|
||||
return content, finish_reason
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.warning("Retrying LLM completion (%d/%d)", i + 1, max_retries)
|
||||
logger.error(f"Error: {e}")
|
||||
if i < max_retries - 1:
|
||||
time.sleep(1)
|
||||
else:
|
||||
logger.error('Max retries reached for prompt: ' + prompt)
|
||||
raise RuntimeError(f"LLM call failed after {max_retries} retries") from e
|
||||
|
||||
|
||||
|
||||
async def llm_acompletion(model, prompt):
|
||||
if model:
|
||||
model = model.removeprefix("litellm/")
|
||||
max_retries = 10
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
litellm.drop_params = True
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
logger.warning("Retrying async LLM completion (%d/%d)", i + 1, max_retries)
|
||||
logger.error(f"Error: {e}")
|
||||
if i < max_retries - 1:
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
logger.error('Max retries reached for prompt: ' + prompt)
|
||||
raise RuntimeError(f"LLM call failed after {max_retries} retries") from e
|
||||
|
||||
|
||||
def extract_json(content):
|
||||
try:
|
||||
# First, try to extract JSON enclosed within ```json and ```
|
||||
start_idx = content.find("```json")
|
||||
if start_idx != -1:
|
||||
start_idx += 7 # Adjust index to start after the delimiter
|
||||
end_idx = content.rfind("```")
|
||||
json_content = content[start_idx:end_idx].strip()
|
||||
else:
|
||||
# If no delimiters, assume entire content could be JSON
|
||||
json_content = content.strip()
|
||||
|
||||
# Clean up common issues that might cause parsing errors
|
||||
json_content = json_content.replace('None', 'null') # Replace Python None with JSON null
|
||||
json_content = json_content.replace('\n', ' ').replace('\r', ' ') # Remove newlines
|
||||
json_content = ' '.join(json_content.split()) # Normalize whitespace
|
||||
|
||||
# Attempt to parse and return the JSON object
|
||||
return json.loads(json_content)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to extract JSON: {e}")
|
||||
# Try to clean up the content further if initial parsing fails
|
||||
try:
|
||||
# Remove any trailing commas before closing brackets/braces
|
||||
json_content = json_content.replace(',]', ']').replace(',}', '}')
|
||||
return json.loads(json_content)
|
||||
except Exception:
|
||||
logging.error("Failed to parse JSON even after cleanup")
|
||||
return {}
|
||||
except Exception as e:
|
||||
logging.error(f"Unexpected error while extracting JSON: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def get_json_content(response):
|
||||
start_idx = response.find("```json")
|
||||
if start_idx != -1:
|
||||
start_idx += 7
|
||||
response = response[start_idx:]
|
||||
|
||||
end_idx = response.rfind("```")
|
||||
if end_idx != -1:
|
||||
response = response[:end_idx]
|
||||
|
||||
json_content = response.strip()
|
||||
return json_content
|
||||
|
||||
|
||||
def write_node_id(data, node_id=0):
|
||||
if isinstance(data, dict):
|
||||
data['node_id'] = str(node_id).zfill(4)
|
||||
node_id += 1
|
||||
for key in list(data.keys()):
|
||||
if 'nodes' in key:
|
||||
node_id = write_node_id(data[key], node_id)
|
||||
elif isinstance(data, list):
|
||||
for index in range(len(data)):
|
||||
node_id = write_node_id(data[index], node_id)
|
||||
return node_id
|
||||
|
||||
|
||||
def remove_fields(data, fields=None):
|
||||
fields = fields or ["text"]
|
||||
if isinstance(data, dict):
|
||||
return {k: remove_fields(v, fields)
|
||||
for k, v in data.items() if k not in fields}
|
||||
elif isinstance(data, list):
|
||||
return [remove_fields(item, fields) for item in data]
|
||||
return data
|
||||
|
||||
|
||||
def structure_to_list(structure):
|
||||
if isinstance(structure, dict):
|
||||
nodes = []
|
||||
nodes.append(structure)
|
||||
if 'nodes' in structure:
|
||||
nodes.extend(structure_to_list(structure['nodes']))
|
||||
return nodes
|
||||
elif isinstance(structure, list):
|
||||
nodes = []
|
||||
for item in structure:
|
||||
nodes.extend(structure_to_list(item))
|
||||
return nodes
|
||||
|
||||
|
||||
def get_nodes(structure):
|
||||
if isinstance(structure, dict):
|
||||
structure_node = copy.deepcopy(structure)
|
||||
structure_node.pop('nodes', None)
|
||||
nodes = [structure_node]
|
||||
for key in list(structure.keys()):
|
||||
if 'nodes' in key:
|
||||
nodes.extend(get_nodes(structure[key]))
|
||||
return nodes
|
||||
elif isinstance(structure, list):
|
||||
nodes = []
|
||||
for item in structure:
|
||||
nodes.extend(get_nodes(item))
|
||||
return nodes
|
||||
|
||||
|
||||
def get_leaf_nodes(structure):
|
||||
if isinstance(structure, dict):
|
||||
if not structure['nodes']:
|
||||
structure_node = copy.deepcopy(structure)
|
||||
structure_node.pop('nodes', None)
|
||||
return [structure_node]
|
||||
else:
|
||||
leaf_nodes = []
|
||||
for key in list(structure.keys()):
|
||||
if 'nodes' in key:
|
||||
leaf_nodes.extend(get_leaf_nodes(structure[key]))
|
||||
return leaf_nodes
|
||||
elif isinstance(structure, list):
|
||||
leaf_nodes = []
|
||||
for item in structure:
|
||||
leaf_nodes.extend(get_leaf_nodes(item))
|
||||
return leaf_nodes
|
||||
|
||||
|
||||
async def generate_node_summary(node, model=None):
|
||||
prompt = f"""You are given a part of a document, your task is to generate a description of the partial document about what are main points covered in the partial document.
|
||||
|
||||
Partial Document Text: {node['text']}
|
||||
|
||||
Directly return the description, do not include any other text.
|
||||
"""
|
||||
response = await llm_acompletion(model, prompt)
|
||||
return response
|
||||
|
||||
|
||||
async def generate_summaries_for_structure(structure, model=None):
|
||||
nodes = structure_to_list(structure)
|
||||
tasks = [generate_node_summary(node, model=model) for node in nodes]
|
||||
summaries = await asyncio.gather(*tasks)
|
||||
|
||||
for node, summary in zip(nodes, summaries):
|
||||
node['summary'] = summary
|
||||
return structure
|
||||
|
||||
|
||||
def generate_doc_description(structure, model=None):
|
||||
prompt = f"""Your are an expert in generating descriptions for a document.
|
||||
You are given a structure of a document. Your task is to generate a one-sentence description for the document, which makes it easy to distinguish the document from other documents.
|
||||
|
||||
Document Structure: {structure}
|
||||
|
||||
Directly return the description, do not include any other text.
|
||||
"""
|
||||
response = llm_completion(model, prompt)
|
||||
return response
|
||||
|
||||
|
||||
def list_to_tree(data):
|
||||
def get_parent_structure(structure):
|
||||
"""Helper function to get the parent structure code"""
|
||||
if not structure:
|
||||
return None
|
||||
parts = str(structure).split('.')
|
||||
return '.'.join(parts[:-1]) if len(parts) > 1 else None
|
||||
|
||||
# First pass: Create nodes and track parent-child relationships
|
||||
nodes = {}
|
||||
root_nodes = []
|
||||
|
||||
for item in data:
|
||||
structure = item.get('structure')
|
||||
node = {
|
||||
'title': item.get('title'),
|
||||
'start_index': item.get('start_index'),
|
||||
'end_index': item.get('end_index'),
|
||||
'nodes': []
|
||||
}
|
||||
|
||||
nodes[structure] = node
|
||||
|
||||
# Find parent
|
||||
parent_structure = get_parent_structure(structure)
|
||||
|
||||
if parent_structure:
|
||||
# Add as child to parent if parent exists
|
||||
if parent_structure in nodes:
|
||||
nodes[parent_structure]['nodes'].append(node)
|
||||
else:
|
||||
root_nodes.append(node)
|
||||
else:
|
||||
# No parent, this is a root node
|
||||
root_nodes.append(node)
|
||||
|
||||
# Helper function to clean empty children arrays
|
||||
def clean_node(node):
|
||||
if not node['nodes']:
|
||||
del node['nodes']
|
||||
else:
|
||||
for child in node['nodes']:
|
||||
clean_node(child)
|
||||
return node
|
||||
|
||||
# Clean and return the tree
|
||||
return [clean_node(node) for node in root_nodes]
|
||||
|
||||
|
||||
def post_processing(structure, end_physical_index):
|
||||
# First convert page_number to start_index in flat list
|
||||
for i, item in enumerate(structure):
|
||||
item['start_index'] = item.get('physical_index')
|
||||
if i < len(structure) - 1:
|
||||
if structure[i + 1].get('appear_start') == 'yes':
|
||||
item['end_index'] = structure[i + 1]['physical_index']-1
|
||||
else:
|
||||
item['end_index'] = structure[i + 1]['physical_index']
|
||||
else:
|
||||
item['end_index'] = end_physical_index
|
||||
tree = list_to_tree(structure)
|
||||
if len(tree)!=0:
|
||||
return tree
|
||||
else:
|
||||
### remove appear_start
|
||||
for node in structure:
|
||||
node.pop('appear_start', None)
|
||||
node.pop('physical_index', None)
|
||||
return structure
|
||||
|
||||
|
||||
def reorder_dict(data, key_order):
|
||||
if not key_order:
|
||||
return data
|
||||
return {key: data[key] for key in key_order if key in data}
|
||||
|
||||
|
||||
def format_structure(structure, order=None):
|
||||
if not order:
|
||||
return structure
|
||||
if isinstance(structure, dict):
|
||||
if 'nodes' in structure:
|
||||
structure['nodes'] = format_structure(structure['nodes'], order)
|
||||
if not structure.get('nodes'):
|
||||
structure.pop('nodes', None)
|
||||
structure = reorder_dict(structure, order)
|
||||
elif isinstance(structure, list):
|
||||
structure = [format_structure(item, order) for item in structure]
|
||||
return structure
|
||||
|
||||
|
||||
def create_clean_structure_for_description(structure):
|
||||
"""
|
||||
Create a clean structure for document description generation,
|
||||
excluding unnecessary fields like 'text'.
|
||||
"""
|
||||
if isinstance(structure, dict):
|
||||
clean_node = {}
|
||||
# Only include essential fields for description
|
||||
for key in ['title', 'node_id', 'summary', 'prefix_summary']:
|
||||
if key in structure:
|
||||
clean_node[key] = structure[key]
|
||||
|
||||
# Recursively process child nodes
|
||||
if 'nodes' in structure and structure['nodes']:
|
||||
clean_node['nodes'] = create_clean_structure_for_description(structure['nodes'])
|
||||
|
||||
return clean_node
|
||||
elif isinstance(structure, list):
|
||||
return [create_clean_structure_for_description(item) for item in structure]
|
||||
else:
|
||||
return structure
|
||||
|
||||
|
||||
def _get_text_of_pages(page_list, start_page, end_page):
|
||||
"""Concatenate text from page_list for pages [start_page, end_page] (1-indexed)."""
|
||||
text = ""
|
||||
for page_num in range(start_page - 1, end_page):
|
||||
text += page_list[page_num][0]
|
||||
return text
|
||||
|
||||
|
||||
def add_node_text(node, page_list):
|
||||
"""Recursively add 'text' field to each node from page_list content.
|
||||
|
||||
Each node must have 'start_index' and 'end_index' (1-indexed page numbers).
|
||||
page_list is [(page_text, token_count), ...].
|
||||
"""
|
||||
if isinstance(node, dict):
|
||||
start_page = node.get('start_index')
|
||||
end_page = node.get('end_index')
|
||||
if start_page is not None and end_page is not None:
|
||||
node['text'] = _get_text_of_pages(page_list, start_page, end_page)
|
||||
if 'nodes' in node:
|
||||
add_node_text(node['nodes'], page_list)
|
||||
elif isinstance(node, list):
|
||||
for item in node:
|
||||
add_node_text(item, page_list)
|
||||
|
||||
|
||||
def remove_structure_text(data):
|
||||
if isinstance(data, dict):
|
||||
data.pop('text', None)
|
||||
if 'nodes' in data:
|
||||
remove_structure_text(data['nodes'])
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
remove_structure_text(item)
|
||||
return data
|
||||
|
||||
|
||||
# ── Functions migrated from retrieve.py ──────────────────────────────────────
|
||||
|
||||
def parse_pages(pages: str) -> list[int]:
|
||||
"""Parse a pages string like '5-7', '3,8', or '12' into a sorted list of ints."""
|
||||
result = []
|
||||
for part in pages.split(','):
|
||||
part = part.strip()
|
||||
if '-' in part:
|
||||
start, end = int(part.split('-', 1)[0].strip()), int(part.split('-', 1)[1].strip())
|
||||
if start > end:
|
||||
raise ValueError(f"Invalid range '{part}': start must be <= end")
|
||||
result.extend(range(start, end + 1))
|
||||
else:
|
||||
result.append(int(part))
|
||||
result = [p for p in result if p >= 1]
|
||||
result = sorted(set(result))
|
||||
if len(result) > 1000:
|
||||
raise ValueError(f"Page range too large: {len(result)} pages (max 1000)")
|
||||
return result
|
||||
|
||||
|
||||
def get_pdf_page_content(file_path: str, page_nums: list[int]) -> list[dict]:
|
||||
"""Extract text for specific PDF pages (1-indexed), opening the PDF once."""
|
||||
with open(file_path, 'rb') as f:
|
||||
pdf_reader = PyPDF2.PdfReader(f)
|
||||
total = len(pdf_reader.pages)
|
||||
valid_pages = [p for p in page_nums if 1 <= p <= total]
|
||||
return [
|
||||
{'page': p, 'content': pdf_reader.pages[p - 1].extract_text() or ''}
|
||||
for p in valid_pages
|
||||
]
|
||||
|
||||
|
||||
def get_md_page_content(structure: list, page_nums: list[int]) -> list[dict]:
|
||||
"""
|
||||
For Markdown documents, 'pages' are line numbers.
|
||||
Find nodes whose line_num falls within [min(page_nums), max(page_nums)] and return their text.
|
||||
"""
|
||||
if not page_nums:
|
||||
return []
|
||||
min_line, max_line = min(page_nums), max(page_nums)
|
||||
results = []
|
||||
seen = set()
|
||||
|
||||
def _traverse(nodes):
|
||||
for node in nodes:
|
||||
ln = node.get('line_num')
|
||||
if ln and min_line <= ln <= max_line and ln not in seen:
|
||||
seen.add(ln)
|
||||
results.append({'page': ln, 'content': node.get('text', '')})
|
||||
if node.get('nodes'):
|
||||
_traverse(node['nodes'])
|
||||
|
||||
_traverse(structure)
|
||||
results.sort(key=lambda x: x['page'])
|
||||
return results
|
||||
Loading…
Add table
Add a link
Reference in a new issue