mirror of
https://github.com/VectifyAI/PageIndex.git
synced 2026-04-24 23:56:21 +02:00
fix node summary
This commit is contained in:
parent
19faaad74f
commit
c22778f85d
1 changed files with 17 additions and 1 deletions
|
|
@ -5,14 +5,30 @@ import tiktoken
|
||||||
from utils import *
|
from utils import *
|
||||||
|
|
||||||
def count_tokens(text, model='gpt-4o'):
|
def count_tokens(text, model='gpt-4o'):
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
enc = tiktoken.encoding_for_model(model)
|
enc = tiktoken.encoding_for_model(model)
|
||||||
tokens = enc.encode(text)
|
tokens = enc.encode(text)
|
||||||
return len(tokens)
|
return len(tokens)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_node_summary(node, summary_token_threshold=200, model=None):
|
||||||
|
"""
|
||||||
|
This function gets the summary of a node.
|
||||||
|
If the node's text is less than summary_token_threshold, return the node's text.
|
||||||
|
Otherwise, return the node's summary generated by LLM.
|
||||||
|
"""
|
||||||
|
node_text = node.get('text')
|
||||||
|
num_tokens = count_tokens(node_text)
|
||||||
|
if num_tokens < summary_token_threshold:
|
||||||
|
return node_text
|
||||||
|
else:
|
||||||
|
return await generate_node_summary(node, model=model)
|
||||||
|
|
||||||
|
|
||||||
async def generate_summaries_for_structure_md(structure, model="gpt-4.1"):
|
async def generate_summaries_for_structure_md(structure, model="gpt-4.1"):
|
||||||
nodes = structure_to_list(structure)
|
nodes = structure_to_list(structure)
|
||||||
tasks = [generate_node_summary(node, model=model) for node in nodes]
|
tasks = [get_node_summary(node, model=model) for node in nodes]
|
||||||
summaries = await asyncio.gather(*tasks)
|
summaries = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
for node, summary in zip(nodes, summaries):
|
for node, summary in zip(nodes, summaries):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue