diff --git a/cookbook/pageindex_RAG_simple.ipynb b/cookbook/pageindex_RAG_simple.ipynb index 36dd687..a56d15f 100644 --- a/cookbook/pageindex_RAG_simple.ipynb +++ b/cookbook/pageindex_RAG_simple.ipynb @@ -126,9 +126,9 @@ }, "outputs": [], "source": [ - "import os, json, openai, requests, textwrap\n", + "import json, os, requests\n", "from pageindex import PageIndexClient\n", - "from pprint import pprint\n", + "import pageindex.utils as utils\n", "\n", "# Get your PageIndex API key from https://dash.pageindex.ai/api-keys\n", "PAGEINDEX_API_KEY = \"YOUR_PAGEINDEX_API_KEY\"\n", @@ -137,60 +137,6 @@ "pi_client = PageIndexClient(api_key=PAGEINDEX_API_KEY)" ] }, - { - "cell_type": "markdown", - "metadata": { - "id": "AR7PLeVbcG1N" - }, - "source": [ - "#### 0.3 Define utility functions" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "metadata": { - "id": "hmj3POkDcG1N" - }, - "outputs": [], - "source": [ - "async def call_llm(prompt, model=\"gpt-4.1\", temperature=0):\n", - " client = openai.AsyncOpenAI(api_key=OPENAI_API_KEY)\n", - " response = await client.chat.completions.create(\n", - " model=model,\n", - " messages=[{\"role\": \"user\", \"content\": prompt}],\n", - " temperature=temperature\n", - " )\n", - " return response.choices[0].message.content.strip()\n", - "\n", - "def remove_fields(data, fields=['text'], max_len=None):\n", - " if isinstance(data, dict):\n", - " return {k: remove_fields(v, fields, max_len) for k, v in data.items() if k not in fields}\n", - " elif isinstance(data, list):\n", - " return [remove_fields(item, fields, max_len) for item in data]\n", - " elif isinstance(data, str):\n", - " return data[:max_len] + '...' if max_len is not None and len(data) > max_len else data\n", - " return data\n", - "\n", - "def print_tree(tree, exclude_fields=['text', 'page_index']):\n", - " cleaned_tree = remove_fields(tree.copy(), exclude_fields, max_len=40)\n", - " pprint(cleaned_tree, sort_dicts=False, width=100)\n", - "\n", - "def show(text, width=100):\n", - " for line in text.splitlines():\n", - " print(textwrap.fill(line, width=width))\n", - "\n", - "def create_node_mapping(tree):\n", - " \"\"\"Create a mapping of node_id to node for quick lookup\"\"\"\n", - " def get_all_nodes(tree):\n", - " if isinstance(tree, dict):\n", - " return [tree] + [node for child in tree.get('nodes', []) for node in get_all_nodes(child)]\n", - " elif isinstance(tree, list):\n", - " return [node for item in tree for node in get_all_nodes(item)]\n", - " return []\n", - " return {node[\"node_id\"]: node for node in get_all_nodes(tree) if node.get(\"node_id\")}" - ] - }, { "cell_type": "markdown", "metadata": { @@ -346,7 +292,7 @@ "if pi_client.is_retrieval_ready(doc_id):\n", " tree = pi_client.get_tree(doc_id, node_summary=True)['result']\n", " print('Simplified Tree Structure of the Document:')\n", - " print_tree(tree)\n", + " utils.print_tree(tree)\n", "else:\n", " print(\"Processing document, please try again later...\")" ] @@ -377,7 +323,7 @@ "source": [ "query = \"What are the conclusions in this document?\"\n", "\n", - "tree_without_text = remove_fields(tree.copy(), fields=['text'])\n", + "tree_without_text = utils.remove_fields(tree.copy(), fields=['text'])\n", "\n", "search_prompt = f\"\"\"\n", "You are given a question and a tree structure of a document.\n", @@ -397,7 +343,7 @@ "Directly return the final JSON structure. Do not output anything else.\n", "\"\"\"\n", "\n", - "tree_search_result = await call_llm(search_prompt)" + "tree_search_result = await utils.call_llm(search_prompt, api_key=OPENAI_API_KEY)" ] }, { @@ -438,11 +384,11 @@ } ], "source": [ - "node_map = create_node_mapping(tree)\n", + "node_map = utils.create_node_mapping(tree)\n", "tree_search_result_json = json.loads(tree_search_result)\n", "\n", "print('Reasoning Process:')\n", - "show(tree_search_result_json['thinking'])\n", + "utils.print_wrapped(tree_search_result_json['thinking'])\n", "\n", "print('\\nRetrieved Nodes:')\n", "for node_id in tree_search_result_json[\"node_list\"]:\n", @@ -508,7 +454,7 @@ "relevant_content = \"\\n\\n\".join(node_map[node_id][\"text\"] for node_id in node_list)\n", "\n", "print('Retrieved Context:\\n')\n", - "show(relevant_content[:1000] + '...')" + "utils.print_wrapped(relevant_content[:1000] + '...')" ] }, { @@ -562,8 +508,8 @@ "\"\"\"\n", "\n", "print('Generated Answer:\\n')\n", - "answer = await call_llm(answer_prompt)\n", - "show(answer)" + "answer = await utils.call_llm(answer_prompt, api_key=OPENAI_API_KEY)\n", + "utils.print_wrapped(answer)" ] }, {