mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
feat: upgrade openai to 1.x
This commit is contained in:
parent
4845dafb94
commit
45aa451ec6
6 changed files with 174 additions and 144 deletions
|
|
@ -4,6 +4,7 @@
|
|||
@Time : 2023/5/11 14:45
|
||||
@Author : alexanderwu
|
||||
@File : llm.py
|
||||
@Modified By: mashenquan, 2023-12-4. Upgrade openai to 1.x
|
||||
"""
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
|
|
@ -11,7 +12,9 @@ from metagpt.provider.anthropic_api import Claude2 as Claude
|
|||
from metagpt.provider.human_provider import HumanProvider
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI
|
||||
from metagpt.provider.spark_api import SparkAPI
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
|
||||
# openai v1.x removed the 'api_requestor', making interfaces built on it no longer functional.
|
||||
# More: https://github.com/openai/openai-python/discussions/742
|
||||
# from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
|
||||
|
||||
_ = HumanProvider() # Avoid pre-commit error
|
||||
|
||||
|
|
@ -25,8 +28,8 @@ def LLM() -> "BaseGPTAPI":
|
|||
llm = Claude()
|
||||
elif CONFIG.spark_api_key:
|
||||
llm = SparkAPI()
|
||||
elif CONFIG.zhipuai_api_key:
|
||||
llm = ZhiPuAIGPTAPI()
|
||||
# elif CONFIG.zhipuai_api_key: # openai v1.x removed the 'api_requestor'
|
||||
# llm = ZhiPuAIGPTAPI()
|
||||
else:
|
||||
raise RuntimeError("You should config a LLM configuration first")
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
@Time : 2023/5/5 23:00
|
||||
@Author : alexanderwu
|
||||
@File : base_chatbot.py
|
||||
@Modified By: mashenquan, 2023/11/21. Add `timeout`.
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
|
@ -17,13 +18,13 @@ class BaseChatbot(ABC):
|
|||
use_system_prompt: bool = True
|
||||
|
||||
@abstractmethod
|
||||
def ask(self, msg: str) -> str:
|
||||
def ask(self, msg: str, timeout=3) -> str:
|
||||
"""Ask GPT a question and get an answer"""
|
||||
|
||||
@abstractmethod
|
||||
def ask_batch(self, msgs: list) -> str:
|
||||
def ask_batch(self, msgs: list, timeout=3) -> str:
|
||||
"""Ask GPT multiple questions and get a series of answers"""
|
||||
|
||||
@abstractmethod
|
||||
def ask_code(self, msgs: list) -> str:
|
||||
def ask_code(self, msgs: list, timeout=3) -> str:
|
||||
"""Ask GPT multiple questions and get a piece of code"""
|
||||
|
|
|
|||
|
|
@ -33,23 +33,27 @@ class BaseGPTAPI(BaseChatbot):
|
|||
def _default_system_msg(self):
|
||||
return self._system_msg(self.system_prompt)
|
||||
|
||||
def ask(self, msg: str) -> str:
|
||||
def ask(self, msg: str, timeout=3) -> str:
|
||||
message = [self._default_system_msg(), self._user_msg(msg)] if self.use_system_prompt else [self._user_msg(msg)]
|
||||
rsp = self.completion(message)
|
||||
rsp = self.completion(message, timeout=timeout)
|
||||
return self.get_choice_text(rsp)
|
||||
|
||||
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None) -> str:
|
||||
async def aask(
|
||||
self,
|
||||
msg: str,
|
||||
system_msgs: Optional[list[str]] = None,
|
||||
format_msgs: Optional[list[dict[str, str]]] = None,
|
||||
generator: bool = False,
|
||||
timeout=3,
|
||||
) -> str:
|
||||
if system_msgs:
|
||||
message = (
|
||||
self._system_msgs(system_msgs) + [self._user_msg(msg)]
|
||||
if self.use_system_prompt
|
||||
else [self._user_msg(msg)]
|
||||
)
|
||||
message = self._system_msgs(system_msgs)
|
||||
else:
|
||||
message = (
|
||||
[self._default_system_msg(), self._user_msg(msg)] if self.use_system_prompt else [self._user_msg(msg)]
|
||||
)
|
||||
rsp = await self.acompletion_text(message, stream=True)
|
||||
message = [self._default_system_msg()]
|
||||
if format_msgs:
|
||||
message.extend(format_msgs)
|
||||
message.append(self._user_msg(msg))
|
||||
rsp = await self.acompletion_text(message, stream=True, generator=generator, timeout=timeout)
|
||||
logger.debug(message)
|
||||
# logger.debug(rsp)
|
||||
return rsp
|
||||
|
|
@ -57,38 +61,38 @@ class BaseGPTAPI(BaseChatbot):
|
|||
def _extract_assistant_rsp(self, context):
|
||||
return "\n".join([i["content"] for i in context if i["role"] == "assistant"])
|
||||
|
||||
def ask_batch(self, msgs: list) -> str:
|
||||
def ask_batch(self, msgs: list, timeout=3) -> str:
|
||||
context = []
|
||||
for msg in msgs:
|
||||
umsg = self._user_msg(msg)
|
||||
context.append(umsg)
|
||||
rsp = self.completion(context)
|
||||
rsp = self.completion(context, timeout=timeout)
|
||||
rsp_text = self.get_choice_text(rsp)
|
||||
context.append(self._assistant_msg(rsp_text))
|
||||
return self._extract_assistant_rsp(context)
|
||||
|
||||
async def aask_batch(self, msgs: list) -> str:
|
||||
async def aask_batch(self, msgs: list, timeout=3) -> str:
|
||||
"""Sequential questioning"""
|
||||
context = []
|
||||
for msg in msgs:
|
||||
umsg = self._user_msg(msg)
|
||||
context.append(umsg)
|
||||
rsp_text = await self.acompletion_text(context)
|
||||
rsp_text = await self.acompletion_text(context, timeout=timeout)
|
||||
context.append(self._assistant_msg(rsp_text))
|
||||
return self._extract_assistant_rsp(context)
|
||||
|
||||
def ask_code(self, msgs: list[str]) -> str:
|
||||
def ask_code(self, msgs: list[str], timeout=3) -> str:
|
||||
"""FIXME: No code segment filtering has been done here, and all results are actually displayed"""
|
||||
rsp_text = self.ask_batch(msgs)
|
||||
rsp_text = self.ask_batch(msgs, timeout=timeout)
|
||||
return rsp_text
|
||||
|
||||
async def aask_code(self, msgs: list[str]) -> str:
|
||||
async def aask_code(self, msgs: list[str], timeout=3) -> str:
|
||||
"""FIXME: No code segment filtering has been done here, and all results are actually displayed"""
|
||||
rsp_text = await self.aask_batch(msgs)
|
||||
rsp_text = await self.aask_batch(msgs, timeout=timeout)
|
||||
return rsp_text
|
||||
|
||||
@abstractmethod
|
||||
def completion(self, messages: list[dict]):
|
||||
def completion(self, messages: list[dict], timeout=3):
|
||||
"""All GPTAPIs are required to provide the standard OpenAI completion interface
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
|
|
@ -98,7 +102,7 @@ class BaseGPTAPI(BaseChatbot):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def acompletion(self, messages: list[dict]):
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
"""Asynchronous version of completion
|
||||
All GPTAPIs are required to provide the standard OpenAI completion interface
|
||||
[
|
||||
|
|
@ -109,7 +113,7 @@ class BaseGPTAPI(BaseChatbot):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""Asynchronous version of completion. Return str. Support stream-print"""
|
||||
|
||||
def get_choice_text(self, rsp: dict) -> str:
|
||||
|
|
@ -145,7 +149,7 @@ class BaseGPTAPI(BaseChatbot):
|
|||
:return dict: return first function of choice, for exmaple,
|
||||
{'name': 'execute', 'arguments': '{\n "language": "python",\n "code": "print(\'Hello, World!\')"\n}'}
|
||||
"""
|
||||
return rsp.get("choices")[0]["message"]["tool_calls"][0]["function"].to_dict()
|
||||
return rsp.get("choices")[0]["message"]["tool_calls"][0]["function"]
|
||||
|
||||
def get_choice_function_arguments(self, rsp: dict) -> dict:
|
||||
"""Required to provide the first function arguments of choice.
|
||||
|
|
@ -163,3 +167,8 @@ class BaseGPTAPI(BaseChatbot):
|
|||
def messages_to_dict(self, messages):
|
||||
"""objects to [{"role": "user", "content": msg}] etc."""
|
||||
return [i.to_dict() for i in messages]
|
||||
|
||||
@abstractmethod
|
||||
async def close(self):
|
||||
"""Close connection"""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -14,24 +14,32 @@ class HumanProvider(BaseGPTAPI):
|
|||
This enables replacing LLM anywhere in the framework with a human, thus introducing human interaction
|
||||
"""
|
||||
|
||||
def ask(self, msg: str) -> str:
|
||||
def ask(self, msg: str, timeout=3) -> str:
|
||||
logger.info("It's your turn, please type in your response. You may also refer to the context below")
|
||||
rsp = input(msg)
|
||||
if rsp in ["exit", "quit"]:
|
||||
exit()
|
||||
return rsp
|
||||
|
||||
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None) -> str:
|
||||
return self.ask(msg)
|
||||
async def aask(self, msg: str,
|
||||
system_msgs: Optional[list[str]] = None,
|
||||
format_msgs: Optional[list[dict[str, str]]] = None,
|
||||
generator: bool = False,
|
||||
timeout=3,) -> str:
|
||||
return self.ask(msg, timeout=timeout)
|
||||
|
||||
def completion(self, messages: list[dict]):
|
||||
def completion(self, messages: list[dict], timeout=3):
|
||||
"""dummy implementation of abstract method in base"""
|
||||
return []
|
||||
|
||||
async def acompletion(self, messages: list[dict]):
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
"""dummy implementation of abstract method in base"""
|
||||
return []
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""dummy implementation of abstract method in base"""
|
||||
return []
|
||||
return ""
|
||||
|
||||
async def close(self):
|
||||
"""Close connection"""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -3,18 +3,23 @@
|
|||
@Time : 2023/5/5 23:08
|
||||
@Author : alexanderwu
|
||||
@File : openai.py
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation;
|
||||
Change cost control from global to company level.
|
||||
@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout.
|
||||
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
from typing import NamedTuple, Union
|
||||
|
||||
import openai
|
||||
from openai.error import APIConnectionError
|
||||
from openai import APIConnectionError, AsyncAzureOpenAI, AsyncOpenAI, RateLimitError
|
||||
from openai.types import CompletionUsage
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
wait_fixed,
|
||||
)
|
||||
|
||||
|
|
@ -143,47 +148,31 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.__init_openai(CONFIG)
|
||||
self.llm = openai
|
||||
self.model = CONFIG.openai_api_model
|
||||
self.auto_max_tokens = False
|
||||
self.rpm = int(CONFIG.get("RPM", 10))
|
||||
if CONFIG.openai_api_type == "azure":
|
||||
# https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix
|
||||
self._client = AsyncAzureOpenAI(
|
||||
api_key=CONFIG.openai_api_key,
|
||||
api_version=CONFIG.openai_api_version,
|
||||
azure_endpoint=CONFIG.openai_api_base,
|
||||
)
|
||||
else:
|
||||
# https://github.com/openai/openai-python#async-usage
|
||||
self._client = AsyncOpenAI(api_key=CONFIG.openai_api_key, base_url=CONFIG.openai_api_base)
|
||||
self._cost_manager = CostManager()
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
def __init_openai(self, config):
|
||||
openai.api_key = config.openai_api_key
|
||||
if config.openai_api_base:
|
||||
openai.api_base = config.openai_api_base
|
||||
if config.openai_api_type:
|
||||
openai.api_type = config.openai_api_type
|
||||
openai.api_version = config.openai_api_version
|
||||
if config.openai_proxy:
|
||||
openai.proxy = config.openai_proxy
|
||||
self.rpm = int(config.get("RPM", 10))
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
response = await openai.ChatCompletion.acreate(**self._cons_kwargs(messages), stream=True)
|
||||
|
||||
# create variables to collect the stream of chunks
|
||||
collected_chunks = []
|
||||
collected_messages = []
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
|
||||
kwargs = self._cons_kwargs(messages, timeout=timeout)
|
||||
response = await self._client.chat.completions.create(**kwargs, stream=True)
|
||||
# iterate through the stream of events
|
||||
async for chunk in response:
|
||||
collected_chunks.append(chunk) # save the event response
|
||||
choices = chunk["choices"]
|
||||
if len(choices) > 0:
|
||||
chunk_message = chunk["choices"][0].get("delta", {}) # extract the message
|
||||
collected_messages.append(chunk_message) # save the message
|
||||
if "content" in chunk_message:
|
||||
print(chunk_message["content"], end="")
|
||||
print()
|
||||
chunk_message = chunk.choices[0].delta.content or "" # extract the message
|
||||
yield chunk_message
|
||||
|
||||
full_reply_content = "".join([m.get("content", "") for m in collected_messages])
|
||||
usage = self._calc_usage(messages, full_reply_content)
|
||||
self._update_costs(usage)
|
||||
return full_reply_content
|
||||
|
||||
def _cons_kwargs(self, messages: list[dict], **configs) -> dict:
|
||||
def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict:
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"max_tokens": self.get_max_tokens(messages),
|
||||
|
|
@ -196,39 +185,27 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
kwargs.update(configs)
|
||||
|
||||
if CONFIG.openai_api_type == "azure":
|
||||
if CONFIG.deployment_name and CONFIG.deployment_id:
|
||||
raise ValueError("You can only use one of the `deployment_id` or `deployment_name` model")
|
||||
elif not CONFIG.deployment_name and not CONFIG.deployment_id:
|
||||
raise ValueError("You must specify `DEPLOYMENT_NAME` or `DEPLOYMENT_ID` parameter")
|
||||
kwargs_mode = (
|
||||
{"engine": CONFIG.deployment_name}
|
||||
if CONFIG.deployment_name
|
||||
else {"deployment_id": CONFIG.deployment_id}
|
||||
)
|
||||
kwargs["model"] = CONFIG.deployment_id
|
||||
else:
|
||||
kwargs_mode = {"model": self.model}
|
||||
kwargs.update(kwargs_mode)
|
||||
kwargs["model"] = self.model
|
||||
kwargs["timeout"] = max(CONFIG.TIMEOUT, timeout) if CONFIG.TIMEOUT is not None else timeout
|
||||
|
||||
return kwargs
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> dict:
|
||||
rsp = await self.llm.ChatCompletion.acreate(**self._cons_kwargs(messages))
|
||||
self._update_costs(rsp.get("usage"))
|
||||
return rsp
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3) -> dict:
|
||||
kwargs = self._cons_kwargs(messages, timeout=timeout)
|
||||
rsp = await self._client.chat.completions.create(**kwargs)
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp.dict()
|
||||
|
||||
def _chat_completion(self, messages: list[dict]) -> dict:
|
||||
rsp = self.llm.ChatCompletion.create(**self._cons_kwargs(messages))
|
||||
self._update_costs(rsp)
|
||||
return rsp
|
||||
def completion(self, messages: list[dict], timeout=3) -> dict:
|
||||
loop = self.get_event_loop()
|
||||
return loop.run_until_complete(self.acompletion(messages, timeout=timeout))
|
||||
|
||||
def completion(self, messages: list[dict]) -> dict:
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
|
||||
# if isinstance(messages[0], Message):
|
||||
# messages = self.messages_to_dict(messages)
|
||||
return self._chat_completion(messages)
|
||||
|
||||
async def acompletion(self, messages: list[dict]) -> dict:
|
||||
# if isinstance(messages[0], Message):
|
||||
# messages = self.messages_to_dict(messages)
|
||||
return await self._achat_completion(messages)
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
|
|
@ -237,14 +214,34 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
retry=retry_if_exception_type(APIConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
@retry(
|
||||
stop=stop_after_attempt(6),
|
||||
wait=wait_exponential(1),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(RateLimitError),
|
||||
reraise=True,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
|
||||
"""when streaming, print each token in place."""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
rsp = await self._achat_completion(messages)
|
||||
resp = self._achat_completion_stream(messages, timeout=timeout)
|
||||
if generator:
|
||||
return resp
|
||||
|
||||
collected_messages = []
|
||||
async for i in resp:
|
||||
print(i, end="")
|
||||
collected_messages.append(i)
|
||||
|
||||
full_reply_content = "".join(collected_messages)
|
||||
usage = self._calc_usage(messages, full_reply_content)
|
||||
self._update_costs(usage)
|
||||
return full_reply_content
|
||||
|
||||
rsp = await self._achat_completion(messages, timeout=timeout)
|
||||
return self.get_choice_text(rsp)
|
||||
|
||||
def _func_configs(self, messages: list[dict], **kwargs) -> dict:
|
||||
def _func_configs(self, messages: list[dict], timeout=3, **kwargs) -> dict:
|
||||
"""
|
||||
Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create
|
||||
"""
|
||||
|
|
@ -255,17 +252,17 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
}
|
||||
kwargs.update(configs)
|
||||
|
||||
return self._cons_kwargs(messages, **kwargs)
|
||||
return self._cons_kwargs(messages=messages, timeout=timeout, **kwargs)
|
||||
|
||||
def _chat_completion_function(self, messages: list[dict], **kwargs) -> dict:
|
||||
rsp = self.llm.ChatCompletion.create(**self._func_configs(messages, **kwargs))
|
||||
self._update_costs(rsp.get("usage"))
|
||||
return rsp
|
||||
def _chat_completion_function(self, messages: list[dict], timeout=3, **kwargs) -> dict:
|
||||
loop = self.get_event_loop()
|
||||
return loop.run_until_complete(self._achat_completion_function(messages=messages, timeout=timeout, **kwargs))
|
||||
|
||||
async def _achat_completion_function(self, messages: list[dict], **chat_configs) -> dict:
|
||||
rsp = await self.llm.ChatCompletion.acreate(**self._func_configs(messages, **chat_configs))
|
||||
self._update_costs(rsp.get("usage"))
|
||||
return rsp
|
||||
async def _achat_completion_function(self, messages: list[dict], timeout=3, **chat_configs) -> dict:
|
||||
kwargs = self._func_configs(messages=messages, timeout=timeout, **chat_configs)
|
||||
rsp = await self._client.chat.completions.create(**kwargs)
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp.dict()
|
||||
|
||||
def _process_message(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
|
||||
"""convert messages to list[dict]."""
|
||||
|
|
@ -319,21 +316,22 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
rsp = await self._achat_completion_function(messages, **kwargs)
|
||||
return self.get_choice_function_arguments(rsp)
|
||||
|
||||
def _calc_usage(self, messages: list[dict], rsp: str) -> dict:
|
||||
usage = {}
|
||||
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
|
||||
if CONFIG.calc_usage:
|
||||
try:
|
||||
prompt_tokens = count_message_tokens(messages, self.model)
|
||||
completion_tokens = count_string_tokens(rsp, self.model)
|
||||
usage["prompt_tokens"] = prompt_tokens
|
||||
usage["completion_tokens"] = completion_tokens
|
||||
usage = CompletionUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
return usage
|
||||
except Exception as e:
|
||||
logger.error("usage calculation failed!", e)
|
||||
else:
|
||||
return usage
|
||||
return CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
|
||||
async def acompletion_batch(self, batch: list[list[dict]]) -> list[dict]:
|
||||
async def acompletion_batch(self, batch: list[list[dict]], timeout=3) -> list[dict]:
|
||||
"""Return full JSON"""
|
||||
split_batches = self.split_batches(batch)
|
||||
all_results = []
|
||||
|
|
@ -342,16 +340,16 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
logger.info(small_batch)
|
||||
await self.wait_if_needed(len(small_batch))
|
||||
|
||||
future = [self.acompletion(prompt) for prompt in small_batch]
|
||||
future = [self.acompletion(prompt, timeout=timeout) for prompt in small_batch]
|
||||
results = await asyncio.gather(*future)
|
||||
logger.info(results)
|
||||
all_results.extend(results)
|
||||
|
||||
return all_results
|
||||
|
||||
async def acompletion_batch_text(self, batch: list[list[dict]]) -> list[str]:
|
||||
async def acompletion_batch_text(self, batch: list[list[dict]], timeout=3) -> list[str]:
|
||||
"""Only return plain text"""
|
||||
raw_results = await self.acompletion_batch(batch)
|
||||
raw_results = await self.acompletion_batch(batch, timeout=timeout)
|
||||
results = []
|
||||
for idx, raw_result in enumerate(raw_results, start=1):
|
||||
result = self.get_choice_text(raw_result)
|
||||
|
|
@ -359,14 +357,11 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
logger.info(f"Result of task {idx}: {result}")
|
||||
return results
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if CONFIG.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage["prompt_tokens"])
|
||||
completion_tokens = int(usage["completion_tokens"])
|
||||
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error("updating costs failed!", e)
|
||||
prompt_tokens = usage.prompt_tokens
|
||||
completion_tokens = usage.completion_tokens
|
||||
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return self._cost_manager.get_costs()
|
||||
|
|
@ -377,18 +372,8 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp)
|
||||
|
||||
def moderation(self, content: Union[str, list[str]]):
|
||||
try:
|
||||
if not content:
|
||||
logger.error("content cannot be empty!")
|
||||
else:
|
||||
rsp = self._moderation(content=content)
|
||||
return rsp
|
||||
except Exception as e:
|
||||
logger.error(f"moderating failed:{e}")
|
||||
|
||||
def _moderation(self, content: Union[str, list[str]]):
|
||||
rsp = self.llm.Moderation.create(input=content)
|
||||
return rsp
|
||||
loop = self.get_event_loop()
|
||||
loop.run_until_complete(self.amoderation(content=content))
|
||||
|
||||
async def amoderation(self, content: Union[str, list[str]]):
|
||||
try:
|
||||
|
|
@ -401,5 +386,25 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
logger.error(f"moderating failed:{e}")
|
||||
|
||||
async def _amoderation(self, content: Union[str, list[str]]):
|
||||
rsp = await self.llm.Moderation.acreate(input=content)
|
||||
rsp = await self._client.moderations.create(input=content)
|
||||
return rsp
|
||||
|
||||
async def close(self):
|
||||
"""Close connection"""
|
||||
if not self._client:
|
||||
return
|
||||
await self._client.close()
|
||||
self._client = None
|
||||
|
||||
@staticmethod
|
||||
def get_event_loop():
|
||||
try:
|
||||
return asyncio.get_event_loop()
|
||||
except RuntimeError as e:
|
||||
if "There is no current event loop in thread" in str(e):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ langchain==0.0.231
|
|||
loguru==0.6.0
|
||||
meilisearch==0.21.0
|
||||
numpy==1.24.3
|
||||
openai>=0.28.1
|
||||
openai>=1.3.6
|
||||
openpyxl
|
||||
beautifulsoup4==4.12.2
|
||||
pandas==2.0.3
|
||||
|
|
@ -42,9 +42,13 @@ qdrant-client==1.4.0
|
|||
pytest-mock==3.11.1
|
||||
open-interpreter==0.1.7; python_version>"3.9"
|
||||
ta==0.10.2
|
||||
semantic-kernel==0.3.13.dev0
|
||||
semantic-kernel
|
||||
wrapt==1.15.0
|
||||
websocket-client==0.58.0
|
||||
#aiohttp_jinja2
|
||||
#azure-cognitiveservices-speech~=1.31.0
|
||||
#aioboto3~=11.3.0
|
||||
#redis==4.3.5
|
||||
websocket-client==1.6.2
|
||||
aiofiles==23.2.1
|
||||
gitpython==3.1.40
|
||||
zhipuai==1.0.7
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue