mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-04-26 17:06:23 +02:00
Add support for other providers - litellm, openrouter
This commit is contained in:
parent
8c2c21a239
commit
14eee3e0c3
24 changed files with 398 additions and 95 deletions
32
apps/rowboat_agents/src/utils/client.py
Normal file
32
apps/rowboat_agents/src/utils/client.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
import os
|
||||
import logging
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
import dotenv
|
||||
dotenv.load_dotenv()
|
||||
|
||||
PROVIDER_BASE_URL = os.getenv('PROVIDER_BASE_URL', '')
|
||||
PROVIDER_API_KEY = os.getenv('PROVIDER_API_KEY', os.getenv('OPENAI_API_KEY', ''))
|
||||
PROVIDER_DEFAULT_MODEL = os.getenv('PROVIDER_DEFAULT_MODEL', 'gpt-4.1')
|
||||
|
||||
client = None
|
||||
if not PROVIDER_API_KEY:
|
||||
raise ValueError("No LLM Provider API key found")
|
||||
|
||||
if PROVIDER_BASE_URL:
|
||||
print(f"Using provider {PROVIDER_BASE_URL} with API key {PROVIDER_API_KEY}")
|
||||
client = AsyncOpenAI(base_url=PROVIDER_BASE_URL, api_key=PROVIDER_API_KEY)
|
||||
else:
|
||||
print("No provider base URL configured, using OpenAI directly")
|
||||
|
||||
completions_client = None
|
||||
if PROVIDER_BASE_URL:
|
||||
print(f"Using provider {PROVIDER_BASE_URL} for completions")
|
||||
completions_client = OpenAI(
|
||||
base_url=PROVIDER_BASE_URL,
|
||||
api_key=PROVIDER_API_KEY
|
||||
)
|
||||
else:
|
||||
print(f"Using OpenAI directly for completions")
|
||||
completions_client = OpenAI(
|
||||
api_key=PROVIDER_API_KEY
|
||||
)
|
||||
|
|
@ -7,6 +7,7 @@ import time
|
|||
from dotenv import load_dotenv
|
||||
from openai import OpenAI
|
||||
|
||||
from src.utils.client import completions_client
|
||||
load_dotenv()
|
||||
|
||||
def setup_logger(name, log_file='./run.log', level=logging.INFO, log_to_file=False):
|
||||
|
|
@ -53,31 +54,28 @@ def get_api_key(key_name):
|
|||
raise ValueError(f"{key_name} not found. Did you set it in the .env file?")
|
||||
return api_key
|
||||
|
||||
openai_client = OpenAI(
|
||||
api_key=get_api_key("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
def generate_gpt4o_output_from_multi_turn_conv(messages, output_type='json', model="gpt-4o"):
|
||||
return generate_openai_output(messages, output_type, model)
|
||||
|
||||
def generate_openai_output(messages, output_type='not_json', model="gpt-4o", return_completion=False):
|
||||
print(f"In generate_openai_output, using client: {completions_client} and model: {model}")
|
||||
try:
|
||||
if output_type == 'json':
|
||||
chat_completion = openai_client.chat.completions.create(
|
||||
messages=messages,
|
||||
chat_completion = completions_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
else:
|
||||
chat_completion = openai_client.chat.completions.create(
|
||||
messages=messages,
|
||||
chat_completion = completions_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
if return_completion:
|
||||
return chat_completion
|
||||
return chat_completion.choices[0].message.content
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue