mirror of
https://github.com/VectifyAI/PageIndex.git
synced 2026-04-24 23:56:21 +02:00
431 lines
14 KiB
Python
431 lines
14 KiB
Python
import litellm
|
|
import logging
|
|
import time
|
|
import json
|
|
import copy
|
|
import re
|
|
import asyncio
|
|
import PyPDF2
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def count_tokens(text, model=None):
|
|
if not text:
|
|
return 0
|
|
return litellm.token_counter(model=model, text=text)
|
|
|
|
|
|
def llm_completion(model, prompt, chat_history=None, return_finish_reason=False):
|
|
if model:
|
|
model = model.removeprefix("litellm/")
|
|
max_retries = 10
|
|
messages = list(chat_history) + [{"role": "user", "content": prompt}] if chat_history else [{"role": "user", "content": prompt}]
|
|
for i in range(max_retries):
|
|
try:
|
|
litellm.drop_params = True
|
|
response = litellm.completion(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=0,
|
|
)
|
|
content = response.choices[0].message.content
|
|
if return_finish_reason:
|
|
finish_reason = "max_output_reached" if response.choices[0].finish_reason == "length" else "finished"
|
|
return content, finish_reason
|
|
return content
|
|
except Exception as e:
|
|
logger.warning("Retrying LLM completion (%d/%d)", i + 1, max_retries)
|
|
logger.error(f"Error: {e}")
|
|
if i < max_retries - 1:
|
|
time.sleep(1)
|
|
else:
|
|
logger.error('Max retries reached for prompt: ' + prompt)
|
|
raise RuntimeError(f"LLM call failed after {max_retries} retries") from e
|
|
|
|
|
|
|
|
async def llm_acompletion(model, prompt):
|
|
if model:
|
|
model = model.removeprefix("litellm/")
|
|
max_retries = 10
|
|
messages = [{"role": "user", "content": prompt}]
|
|
for i in range(max_retries):
|
|
try:
|
|
litellm.drop_params = True
|
|
response = await litellm.acompletion(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=0,
|
|
)
|
|
return response.choices[0].message.content
|
|
except Exception as e:
|
|
logger.warning("Retrying async LLM completion (%d/%d)", i + 1, max_retries)
|
|
logger.error(f"Error: {e}")
|
|
if i < max_retries - 1:
|
|
await asyncio.sleep(1)
|
|
else:
|
|
logger.error('Max retries reached for prompt: ' + prompt)
|
|
raise RuntimeError(f"LLM call failed after {max_retries} retries") from e
|
|
|
|
|
|
def extract_json(content):
|
|
try:
|
|
# First, try to extract JSON enclosed within ```json and ```
|
|
start_idx = content.find("```json")
|
|
if start_idx != -1:
|
|
start_idx += 7 # Adjust index to start after the delimiter
|
|
end_idx = content.rfind("```")
|
|
json_content = content[start_idx:end_idx].strip()
|
|
else:
|
|
# If no delimiters, assume entire content could be JSON
|
|
json_content = content.strip()
|
|
|
|
# Clean up common issues that might cause parsing errors
|
|
json_content = json_content.replace('None', 'null') # Replace Python None with JSON null
|
|
json_content = json_content.replace('\n', ' ').replace('\r', ' ') # Remove newlines
|
|
json_content = ' '.join(json_content.split()) # Normalize whitespace
|
|
|
|
# Attempt to parse and return the JSON object
|
|
return json.loads(json_content)
|
|
except json.JSONDecodeError as e:
|
|
logging.error(f"Failed to extract JSON: {e}")
|
|
# Try to clean up the content further if initial parsing fails
|
|
try:
|
|
# Remove any trailing commas before closing brackets/braces
|
|
json_content = json_content.replace(',]', ']').replace(',}', '}')
|
|
return json.loads(json_content)
|
|
except Exception:
|
|
logging.error("Failed to parse JSON even after cleanup")
|
|
return {}
|
|
except Exception as e:
|
|
logging.error(f"Unexpected error while extracting JSON: {e}")
|
|
return {}
|
|
|
|
|
|
def get_json_content(response):
|
|
start_idx = response.find("```json")
|
|
if start_idx != -1:
|
|
start_idx += 7
|
|
response = response[start_idx:]
|
|
|
|
end_idx = response.rfind("```")
|
|
if end_idx != -1:
|
|
response = response[:end_idx]
|
|
|
|
json_content = response.strip()
|
|
return json_content
|
|
|
|
|
|
def write_node_id(data, node_id=0):
|
|
if isinstance(data, dict):
|
|
data['node_id'] = str(node_id).zfill(4)
|
|
node_id += 1
|
|
for key in list(data.keys()):
|
|
if 'nodes' in key:
|
|
node_id = write_node_id(data[key], node_id)
|
|
elif isinstance(data, list):
|
|
for index in range(len(data)):
|
|
node_id = write_node_id(data[index], node_id)
|
|
return node_id
|
|
|
|
|
|
def remove_fields(data, fields=None):
|
|
fields = fields or ["text"]
|
|
if isinstance(data, dict):
|
|
return {k: remove_fields(v, fields)
|
|
for k, v in data.items() if k not in fields}
|
|
elif isinstance(data, list):
|
|
return [remove_fields(item, fields) for item in data]
|
|
return data
|
|
|
|
|
|
def structure_to_list(structure):
|
|
if isinstance(structure, dict):
|
|
nodes = []
|
|
nodes.append(structure)
|
|
if 'nodes' in structure:
|
|
nodes.extend(structure_to_list(structure['nodes']))
|
|
return nodes
|
|
elif isinstance(structure, list):
|
|
nodes = []
|
|
for item in structure:
|
|
nodes.extend(structure_to_list(item))
|
|
return nodes
|
|
|
|
|
|
def get_nodes(structure):
|
|
if isinstance(structure, dict):
|
|
structure_node = copy.deepcopy(structure)
|
|
structure_node.pop('nodes', None)
|
|
nodes = [structure_node]
|
|
for key in list(structure.keys()):
|
|
if 'nodes' in key:
|
|
nodes.extend(get_nodes(structure[key]))
|
|
return nodes
|
|
elif isinstance(structure, list):
|
|
nodes = []
|
|
for item in structure:
|
|
nodes.extend(get_nodes(item))
|
|
return nodes
|
|
|
|
|
|
def get_leaf_nodes(structure):
|
|
if isinstance(structure, dict):
|
|
if not structure['nodes']:
|
|
structure_node = copy.deepcopy(structure)
|
|
structure_node.pop('nodes', None)
|
|
return [structure_node]
|
|
else:
|
|
leaf_nodes = []
|
|
for key in list(structure.keys()):
|
|
if 'nodes' in key:
|
|
leaf_nodes.extend(get_leaf_nodes(structure[key]))
|
|
return leaf_nodes
|
|
elif isinstance(structure, list):
|
|
leaf_nodes = []
|
|
for item in structure:
|
|
leaf_nodes.extend(get_leaf_nodes(item))
|
|
return leaf_nodes
|
|
|
|
|
|
async def generate_node_summary(node, model=None):
|
|
prompt = f"""You are given a part of a document, your task is to generate a description of the partial document about what are main points covered in the partial document.
|
|
|
|
Partial Document Text: {node['text']}
|
|
|
|
Directly return the description, do not include any other text.
|
|
"""
|
|
response = await llm_acompletion(model, prompt)
|
|
return response
|
|
|
|
|
|
async def generate_summaries_for_structure(structure, model=None):
|
|
nodes = structure_to_list(structure)
|
|
tasks = [generate_node_summary(node, model=model) for node in nodes]
|
|
summaries = await asyncio.gather(*tasks)
|
|
|
|
for node, summary in zip(nodes, summaries):
|
|
node['summary'] = summary
|
|
return structure
|
|
|
|
|
|
def generate_doc_description(structure, model=None):
|
|
prompt = f"""Your are an expert in generating descriptions for a document.
|
|
You are given a structure of a document. Your task is to generate a one-sentence description for the document, which makes it easy to distinguish the document from other documents.
|
|
|
|
Document Structure: {structure}
|
|
|
|
Directly return the description, do not include any other text.
|
|
"""
|
|
response = llm_completion(model, prompt)
|
|
return response
|
|
|
|
|
|
def list_to_tree(data):
|
|
def get_parent_structure(structure):
|
|
"""Helper function to get the parent structure code"""
|
|
if not structure:
|
|
return None
|
|
parts = str(structure).split('.')
|
|
return '.'.join(parts[:-1]) if len(parts) > 1 else None
|
|
|
|
# First pass: Create nodes and track parent-child relationships
|
|
nodes = {}
|
|
root_nodes = []
|
|
|
|
for item in data:
|
|
structure = item.get('structure')
|
|
node = {
|
|
'title': item.get('title'),
|
|
'start_index': item.get('start_index'),
|
|
'end_index': item.get('end_index'),
|
|
'nodes': []
|
|
}
|
|
|
|
nodes[structure] = node
|
|
|
|
# Find parent
|
|
parent_structure = get_parent_structure(structure)
|
|
|
|
if parent_structure:
|
|
# Add as child to parent if parent exists
|
|
if parent_structure in nodes:
|
|
nodes[parent_structure]['nodes'].append(node)
|
|
else:
|
|
root_nodes.append(node)
|
|
else:
|
|
# No parent, this is a root node
|
|
root_nodes.append(node)
|
|
|
|
# Helper function to clean empty children arrays
|
|
def clean_node(node):
|
|
if not node['nodes']:
|
|
del node['nodes']
|
|
else:
|
|
for child in node['nodes']:
|
|
clean_node(child)
|
|
return node
|
|
|
|
# Clean and return the tree
|
|
return [clean_node(node) for node in root_nodes]
|
|
|
|
|
|
def post_processing(structure, end_physical_index):
|
|
# First convert page_number to start_index in flat list
|
|
for i, item in enumerate(structure):
|
|
item['start_index'] = item.get('physical_index')
|
|
if i < len(structure) - 1:
|
|
if structure[i + 1].get('appear_start') == 'yes':
|
|
item['end_index'] = structure[i + 1]['physical_index']-1
|
|
else:
|
|
item['end_index'] = structure[i + 1]['physical_index']
|
|
else:
|
|
item['end_index'] = end_physical_index
|
|
tree = list_to_tree(structure)
|
|
if len(tree)!=0:
|
|
return tree
|
|
else:
|
|
### remove appear_start
|
|
for node in structure:
|
|
node.pop('appear_start', None)
|
|
node.pop('physical_index', None)
|
|
return structure
|
|
|
|
|
|
def reorder_dict(data, key_order):
|
|
if not key_order:
|
|
return data
|
|
return {key: data[key] for key in key_order if key in data}
|
|
|
|
|
|
def format_structure(structure, order=None):
|
|
if not order:
|
|
return structure
|
|
if isinstance(structure, dict):
|
|
if 'nodes' in structure:
|
|
structure['nodes'] = format_structure(structure['nodes'], order)
|
|
if not structure.get('nodes'):
|
|
structure.pop('nodes', None)
|
|
structure = reorder_dict(structure, order)
|
|
elif isinstance(structure, list):
|
|
structure = [format_structure(item, order) for item in structure]
|
|
return structure
|
|
|
|
|
|
def create_clean_structure_for_description(structure):
|
|
"""
|
|
Create a clean structure for document description generation,
|
|
excluding unnecessary fields like 'text'.
|
|
"""
|
|
if isinstance(structure, dict):
|
|
clean_node = {}
|
|
# Only include essential fields for description
|
|
for key in ['title', 'node_id', 'summary', 'prefix_summary']:
|
|
if key in structure:
|
|
clean_node[key] = structure[key]
|
|
|
|
# Recursively process child nodes
|
|
if 'nodes' in structure and structure['nodes']:
|
|
clean_node['nodes'] = create_clean_structure_for_description(structure['nodes'])
|
|
|
|
return clean_node
|
|
elif isinstance(structure, list):
|
|
return [create_clean_structure_for_description(item) for item in structure]
|
|
else:
|
|
return structure
|
|
|
|
|
|
def _get_text_of_pages(page_list, start_page, end_page):
|
|
"""Concatenate text from page_list for pages [start_page, end_page] (1-indexed)."""
|
|
text = ""
|
|
for page_num in range(start_page - 1, end_page):
|
|
text += page_list[page_num][0]
|
|
return text
|
|
|
|
|
|
def add_node_text(node, page_list):
|
|
"""Recursively add 'text' field to each node from page_list content.
|
|
|
|
Each node must have 'start_index' and 'end_index' (1-indexed page numbers).
|
|
page_list is [(page_text, token_count), ...].
|
|
"""
|
|
if isinstance(node, dict):
|
|
start_page = node.get('start_index')
|
|
end_page = node.get('end_index')
|
|
if start_page is not None and end_page is not None:
|
|
node['text'] = _get_text_of_pages(page_list, start_page, end_page)
|
|
if 'nodes' in node:
|
|
add_node_text(node['nodes'], page_list)
|
|
elif isinstance(node, list):
|
|
for item in node:
|
|
add_node_text(item, page_list)
|
|
|
|
|
|
def remove_structure_text(data):
|
|
if isinstance(data, dict):
|
|
data.pop('text', None)
|
|
if 'nodes' in data:
|
|
remove_structure_text(data['nodes'])
|
|
elif isinstance(data, list):
|
|
for item in data:
|
|
remove_structure_text(item)
|
|
return data
|
|
|
|
|
|
# ── Functions migrated from retrieve.py ──────────────────────────────────────
|
|
|
|
def parse_pages(pages: str) -> list[int]:
|
|
"""Parse a pages string like '5-7', '3,8', or '12' into a sorted list of ints."""
|
|
result = []
|
|
for part in pages.split(','):
|
|
part = part.strip()
|
|
if '-' in part:
|
|
start, end = int(part.split('-', 1)[0].strip()), int(part.split('-', 1)[1].strip())
|
|
if start > end:
|
|
raise ValueError(f"Invalid range '{part}': start must be <= end")
|
|
result.extend(range(start, end + 1))
|
|
else:
|
|
result.append(int(part))
|
|
result = [p for p in result if p >= 1]
|
|
result = sorted(set(result))
|
|
if len(result) > 1000:
|
|
raise ValueError(f"Page range too large: {len(result)} pages (max 1000)")
|
|
return result
|
|
|
|
|
|
def get_pdf_page_content(file_path: str, page_nums: list[int]) -> list[dict]:
|
|
"""Extract text for specific PDF pages (1-indexed), opening the PDF once."""
|
|
with open(file_path, 'rb') as f:
|
|
pdf_reader = PyPDF2.PdfReader(f)
|
|
total = len(pdf_reader.pages)
|
|
valid_pages = [p for p in page_nums if 1 <= p <= total]
|
|
return [
|
|
{'page': p, 'content': pdf_reader.pages[p - 1].extract_text() or ''}
|
|
for p in valid_pages
|
|
]
|
|
|
|
|
|
def get_md_page_content(structure: list, page_nums: list[int]) -> list[dict]:
|
|
"""
|
|
For Markdown documents, 'pages' are line numbers.
|
|
Find nodes whose line_num falls within [min(page_nums), max(page_nums)] and return their text.
|
|
"""
|
|
if not page_nums:
|
|
return []
|
|
min_line, max_line = min(page_nums), max(page_nums)
|
|
results = []
|
|
seen = set()
|
|
|
|
def _traverse(nodes):
|
|
for node in nodes:
|
|
ln = node.get('line_num')
|
|
if ln and min_line <= ln <= max_line and ln not in seen:
|
|
seen.add(ln)
|
|
results.append({'page': ln, 'content': node.get('text', '')})
|
|
if node.get('nodes'):
|
|
_traverse(node['nodes'])
|
|
|
|
_traverse(structure)
|
|
results.sort(key=lambda x: x['page'])
|
|
return results
|