mirror of
https://github.com/VectifyAI/PageIndex.git
synced 2026-04-24 23:56:21 +02:00
Integrate LiteLLM for multi-provider LLM support (#168)
* Integrate litellm for multi-provider LLM support * recover the default config yaml * Use litellm.acompletion for native async support * fix tob * Rename llm_complete/allm_complete to llm_completion/llm_acompletion, remove unused llm_complete_stream * Pin litellm to version 1.82.0 * resolve comments * args from cli is used to overrides config.yaml * Fix get_page_tokens hardcoded model default Pass opt.model to get_page_tokens so tokenization respects the configured model instead of always using gpt-4o-2024-11-20. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Remove explicit openai dependency from requirements.txt openai is no longer directly imported; it comes in as a transitive dependency of litellm. Pinning it explicitly risks version conflicts. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Restore openai==1.101.0 pin in requirements.txt litellm==1.82.0 and openai-agents have conflicting openai version requirements, but openai==1.101.0 works at runtime for both. The pin is necessary to prevent litellm from pulling in openai>=2.x which would break openai-agents. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Remove explicit openai dependency from requirements.txt openai is not directly used; it comes in as a transitive dependency of litellm. No openai-agents in this branch so no pin needed. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix an litellm error log * resolve comments --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
4b4b20f9c4
commit
2403be8f27
5 changed files with 78 additions and 104 deletions
|
|
@ -1,5 +1,4 @@
|
|||
import tiktoken
|
||||
import openai
|
||||
import litellm
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
|
@ -17,95 +16,65 @@ import yaml
|
|||
from pathlib import Path
|
||||
from types import SimpleNamespace as config
|
||||
|
||||
CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY")
|
||||
# Backward compatibility: support CHATGPT_API_KEY as alias for OPENAI_API_KEY
|
||||
if not os.getenv("OPENAI_API_KEY") and os.getenv("CHATGPT_API_KEY"):
|
||||
os.environ["OPENAI_API_KEY"] = os.getenv("CHATGPT_API_KEY")
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
def count_tokens(text, model=None):
|
||||
if not text:
|
||||
return 0
|
||||
enc = tiktoken.encoding_for_model(model)
|
||||
tokens = enc.encode(text)
|
||||
return len(tokens)
|
||||
return litellm.token_counter(model=model, text=text)
|
||||
|
||||
def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
|
||||
|
||||
def llm_completion(model, prompt, chat_history=None, return_finish_reason=False):
|
||||
max_retries = 10
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
messages = list(chat_history) + [{"role": "user", "content": prompt}] if chat_history else [{"role": "user", "content": prompt}]
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
if chat_history:
|
||||
messages = chat_history
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
else:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
response = client.chat.completions.create(
|
||||
response = litellm.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
)
|
||||
if response.choices[0].finish_reason == "length":
|
||||
return response.choices[0].message.content, "max_output_reached"
|
||||
else:
|
||||
return response.choices[0].message.content, "finished"
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if return_finish_reason:
|
||||
finish_reason = "max_output_reached" if response.choices[0].finish_reason == "length" else "finished"
|
||||
return content, finish_reason
|
||||
return content
|
||||
except Exception as e:
|
||||
print('************* Retrying *************')
|
||||
logging.error(f"Error: {e}")
|
||||
if i < max_retries - 1:
|
||||
time.sleep(1) # Wait for 1秒 before retrying
|
||||
time.sleep(1)
|
||||
else:
|
||||
logging.error('Max retries reached for prompt: ' + prompt)
|
||||
return "", "error"
|
||||
if return_finish_reason:
|
||||
return "", "error"
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
|
||||
async def llm_acompletion(model, prompt):
|
||||
max_retries = 10
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
if chat_history:
|
||||
messages = chat_history
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
else:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
response = client.chat.completions.create(
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
print('************* Retrying *************')
|
||||
logging.error(f"Error: {e}")
|
||||
if i < max_retries - 1:
|
||||
time.sleep(1) # Wait for 1秒 before retrying
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
logging.error('Max retries reached for prompt: ' + prompt)
|
||||
return "Error"
|
||||
|
||||
|
||||
async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY):
|
||||
max_retries = 10
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
async with openai.AsyncOpenAI(api_key=api_key) as client:
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
print('************* Retrying *************')
|
||||
logging.error(f"Error: {e}")
|
||||
if i < max_retries - 1:
|
||||
await asyncio.sleep(1) # Wait for 1s before retrying
|
||||
else:
|
||||
logging.error('Max retries reached for prompt: ' + prompt)
|
||||
return "Error"
|
||||
return ""
|
||||
|
||||
|
||||
def get_json_content(response):
|
||||
|
|
@ -410,15 +379,14 @@ def add_preface_if_needed(data):
|
|||
|
||||
|
||||
|
||||
def get_page_tokens(pdf_path, model="gpt-4o-2024-11-20", pdf_parser="PyPDF2"):
|
||||
enc = tiktoken.encoding_for_model(model)
|
||||
def get_page_tokens(pdf_path, model=None, pdf_parser="PyPDF2"):
|
||||
if pdf_parser == "PyPDF2":
|
||||
pdf_reader = PyPDF2.PdfReader(pdf_path)
|
||||
page_list = []
|
||||
for page_num in range(len(pdf_reader.pages)):
|
||||
page = pdf_reader.pages[page_num]
|
||||
page_text = page.extract_text()
|
||||
token_length = len(enc.encode(page_text))
|
||||
token_length = litellm.token_counter(model=model, text=page_text)
|
||||
page_list.append((page_text, token_length))
|
||||
return page_list
|
||||
elif pdf_parser == "PyMuPDF":
|
||||
|
|
@ -430,7 +398,7 @@ def get_page_tokens(pdf_path, model="gpt-4o-2024-11-20", pdf_parser="PyPDF2"):
|
|||
page_list = []
|
||||
for page in doc:
|
||||
page_text = page.get_text()
|
||||
token_length = len(enc.encode(page_text))
|
||||
token_length = litellm.token_counter(model=model, text=page_text)
|
||||
page_list.append((page_text, token_length))
|
||||
return page_list
|
||||
else:
|
||||
|
|
@ -533,7 +501,7 @@ def remove_structure_text(data):
|
|||
def check_token_limit(structure, limit=110000):
|
||||
list = structure_to_list(structure)
|
||||
for node in list:
|
||||
num_tokens = count_tokens(node['text'], model='gpt-4o')
|
||||
num_tokens = count_tokens(node['text'], model=None)
|
||||
if num_tokens > limit:
|
||||
print(f"Node ID: {node['node_id']} has {num_tokens} tokens")
|
||||
print("Start Index:", node['start_index'])
|
||||
|
|
@ -609,7 +577,7 @@ async def generate_node_summary(node, model=None):
|
|||
|
||||
Directly return the description, do not include any other text.
|
||||
"""
|
||||
response = await ChatGPT_API_async(model, prompt)
|
||||
response = await llm_acompletion(model, prompt)
|
||||
return response
|
||||
|
||||
|
||||
|
|
@ -654,7 +622,7 @@ def generate_doc_description(structure, model=None):
|
|||
|
||||
Directly return the description, do not include any other text.
|
||||
"""
|
||||
response = ChatGPT_API(model, prompt)
|
||||
response = llm_completion(model, prompt)
|
||||
return response
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue