mirror of
https://github.com/VectifyAI/PageIndex.git
synced 2026-04-24 23:56:21 +02:00
fix utility functions
This commit is contained in:
parent
3277f16ae1
commit
74ec78af36
1 changed files with 10 additions and 64 deletions
|
|
@ -126,9 +126,9 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os, json, openai, requests, textwrap\n",
|
||||
"import json, os, requests\n",
|
||||
"from pageindex import PageIndexClient\n",
|
||||
"from pprint import pprint\n",
|
||||
"import pageindex.utils as utils\n",
|
||||
"\n",
|
||||
"# Get your PageIndex API key from https://dash.pageindex.ai/api-keys\n",
|
||||
"PAGEINDEX_API_KEY = \"YOUR_PAGEINDEX_API_KEY\"\n",
|
||||
|
|
@ -137,60 +137,6 @@
|
|||
"pi_client = PageIndexClient(api_key=PAGEINDEX_API_KEY)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "AR7PLeVbcG1N"
|
||||
},
|
||||
"source": [
|
||||
"#### 0.3 Define utility functions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"metadata": {
|
||||
"id": "hmj3POkDcG1N"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"async def call_llm(prompt, model=\"gpt-4.1\", temperature=0):\n",
|
||||
" client = openai.AsyncOpenAI(api_key=OPENAI_API_KEY)\n",
|
||||
" response = await client.chat.completions.create(\n",
|
||||
" model=model,\n",
|
||||
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
|
||||
" temperature=temperature\n",
|
||||
" )\n",
|
||||
" return response.choices[0].message.content.strip()\n",
|
||||
"\n",
|
||||
"def remove_fields(data, fields=['text'], max_len=None):\n",
|
||||
" if isinstance(data, dict):\n",
|
||||
" return {k: remove_fields(v, fields, max_len) for k, v in data.items() if k not in fields}\n",
|
||||
" elif isinstance(data, list):\n",
|
||||
" return [remove_fields(item, fields, max_len) for item in data]\n",
|
||||
" elif isinstance(data, str):\n",
|
||||
" return data[:max_len] + '...' if max_len is not None and len(data) > max_len else data\n",
|
||||
" return data\n",
|
||||
"\n",
|
||||
"def print_tree(tree, exclude_fields=['text', 'page_index']):\n",
|
||||
" cleaned_tree = remove_fields(tree.copy(), exclude_fields, max_len=40)\n",
|
||||
" pprint(cleaned_tree, sort_dicts=False, width=100)\n",
|
||||
"\n",
|
||||
"def show(text, width=100):\n",
|
||||
" for line in text.splitlines():\n",
|
||||
" print(textwrap.fill(line, width=width))\n",
|
||||
"\n",
|
||||
"def create_node_mapping(tree):\n",
|
||||
" \"\"\"Create a mapping of node_id to node for quick lookup\"\"\"\n",
|
||||
" def get_all_nodes(tree):\n",
|
||||
" if isinstance(tree, dict):\n",
|
||||
" return [tree] + [node for child in tree.get('nodes', []) for node in get_all_nodes(child)]\n",
|
||||
" elif isinstance(tree, list):\n",
|
||||
" return [node for item in tree for node in get_all_nodes(item)]\n",
|
||||
" return []\n",
|
||||
" return {node[\"node_id\"]: node for node in get_all_nodes(tree) if node.get(\"node_id\")}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
|
|
@ -346,7 +292,7 @@
|
|||
"if pi_client.is_retrieval_ready(doc_id):\n",
|
||||
" tree = pi_client.get_tree(doc_id, node_summary=True)['result']\n",
|
||||
" print('Simplified Tree Structure of the Document:')\n",
|
||||
" print_tree(tree)\n",
|
||||
" utils.print_tree(tree)\n",
|
||||
"else:\n",
|
||||
" print(\"Processing document, please try again later...\")"
|
||||
]
|
||||
|
|
@ -377,7 +323,7 @@
|
|||
"source": [
|
||||
"query = \"What are the conclusions in this document?\"\n",
|
||||
"\n",
|
||||
"tree_without_text = remove_fields(tree.copy(), fields=['text'])\n",
|
||||
"tree_without_text = utils.remove_fields(tree.copy(), fields=['text'])\n",
|
||||
"\n",
|
||||
"search_prompt = f\"\"\"\n",
|
||||
"You are given a question and a tree structure of a document.\n",
|
||||
|
|
@ -397,7 +343,7 @@
|
|||
"Directly return the final JSON structure. Do not output anything else.\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"tree_search_result = await call_llm(search_prompt)"
|
||||
"tree_search_result = await utils.call_llm(search_prompt, api_key=OPENAI_API_KEY)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -438,11 +384,11 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"node_map = create_node_mapping(tree)\n",
|
||||
"node_map = utils.create_node_mapping(tree)\n",
|
||||
"tree_search_result_json = json.loads(tree_search_result)\n",
|
||||
"\n",
|
||||
"print('Reasoning Process:')\n",
|
||||
"show(tree_search_result_json['thinking'])\n",
|
||||
"utils.print_wrapped(tree_search_result_json['thinking'])\n",
|
||||
"\n",
|
||||
"print('\\nRetrieved Nodes:')\n",
|
||||
"for node_id in tree_search_result_json[\"node_list\"]:\n",
|
||||
|
|
@ -508,7 +454,7 @@
|
|||
"relevant_content = \"\\n\\n\".join(node_map[node_id][\"text\"] for node_id in node_list)\n",
|
||||
"\n",
|
||||
"print('Retrieved Context:\\n')\n",
|
||||
"show(relevant_content[:1000] + '...')"
|
||||
"utils.print_wrapped(relevant_content[:1000] + '...')"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -562,8 +508,8 @@
|
|||
"\"\"\"\n",
|
||||
"\n",
|
||||
"print('Generated Answer:\\n')\n",
|
||||
"answer = await call_llm(answer_prompt)\n",
|
||||
"show(answer)"
|
||||
"answer = await utils.call_llm(answer_prompt, api_key=OPENAI_API_KEY)\n",
|
||||
"utils.print_wrapped(answer)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue