mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-17 15:35:21 +02:00
add openai.Completion support
This commit is contained in:
parent
6c947945d7
commit
81f716b57e
3 changed files with 36 additions and 2 deletions
|
|
@ -36,6 +36,10 @@ class BaseGPTAPI(BaseChatbot):
|
|||
rsp = self.completion(message)
|
||||
return self.get_choice_text(rsp)
|
||||
|
||||
def ask_nonchat(self, prompt: str) -> str:
|
||||
rsp = self.nonchat_completion(prompt)
|
||||
return self.get_nonchat_choice_text(rsp)
|
||||
|
||||
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None) -> str:
|
||||
if system_msgs:
|
||||
message = self._system_msgs(system_msgs) + [self._user_msg(msg)]
|
||||
|
|
@ -89,6 +93,10 @@ class BaseGPTAPI(BaseChatbot):
|
|||
]
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def nonchat_completion(self, prompt: str) -> dict:
|
||||
""" for openai.Completion request """
|
||||
|
||||
@abstractmethod
|
||||
async def acompletion(self, messages: list[dict]):
|
||||
"""Asynchronous version of completion
|
||||
|
|
@ -108,6 +116,10 @@ class BaseGPTAPI(BaseChatbot):
|
|||
"""Required to provide the first text of choice"""
|
||||
return rsp.get("choices")[0]["message"]["content"]
|
||||
|
||||
def get_nonchat_choice_text(self, rsp: dict) -> str:
|
||||
""" for openai.Completion """
|
||||
return rsp.get("choices")[0]["text"]
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]):
|
||||
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
|
||||
return '\n'.join([f"{i['role']}: {i['content']}" for i in messages])
|
||||
|
|
|
|||
|
|
@ -205,6 +205,17 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
kwargs.update(kwargs_mode)
|
||||
return kwargs
|
||||
|
||||
def _nonchat_cons_kwargs(self, prompt: str) -> dict:
|
||||
kwargs = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
"max_tokens": self.get_max_tokens(prompt), # TODO adapt if auto_max_tokens is True
|
||||
"n": 1,
|
||||
"stop": None,
|
||||
"temperature": 0.3
|
||||
}
|
||||
return kwargs
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> dict:
|
||||
rsp = await self.llm.ChatCompletion.acreate(**self._cons_kwargs(messages))
|
||||
self._update_costs(rsp.get("usage"))
|
||||
|
|
@ -220,6 +231,14 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
# messages = self.messages_to_dict(messages)
|
||||
return self._chat_completion(messages)
|
||||
|
||||
def _nonchat_completion(self, prompt: str) -> dict:
|
||||
rsp = self.llm.Completion.create(**self._nonchat_cons_kwargs(prompt))
|
||||
self._update_costs(rsp.get("usage"))
|
||||
return rsp
|
||||
|
||||
def nonchat_completion(self, prompt: str) -> dict:
|
||||
return self._nonchat_completion(prompt)
|
||||
|
||||
async def acompletion(self, messages: list[dict]) -> dict:
|
||||
# if isinstance(messages[0], Message):
|
||||
# messages = self.messages_to_dict(messages)
|
||||
|
|
@ -249,7 +268,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
usage["completion_tokens"] = completion_tokens
|
||||
return usage
|
||||
except Exception as e:
|
||||
logger.error("usage calculation failed! {e}")
|
||||
logger.error(f"usage calculation failed! {e}")
|
||||
else:
|
||||
return usage
|
||||
|
||||
|
|
@ -286,7 +305,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
completion_tokens = int(usage["completion_tokens"])
|
||||
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error("updating costs failed! {e}")
|
||||
logger.error(f"updating costs failed! {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return self._cost_manager.get_costs()
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ ref3: https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/o
|
|||
import tiktoken
|
||||
|
||||
TOKEN_COSTS = {
|
||||
"gpt-3.5-turbo-instruct": {"prompt": 0.0015, "completion": 0.002},
|
||||
"gpt-3.5-turbo": {"prompt": 0.0015, "completion": 0.002},
|
||||
"gpt-3.5-turbo-0301": {"prompt": 0.0015, "completion": 0.002},
|
||||
"gpt-3.5-turbo-0613": {"prompt": 0.0015, "completion": 0.002},
|
||||
|
|
@ -26,6 +27,7 @@ TOKEN_COSTS = {
|
|||
|
||||
|
||||
TOKEN_MAX = {
|
||||
"gpt-3.5-turbo-instruct": 4096,
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"gpt-3.5-turbo-0301": 4096,
|
||||
"gpt-3.5-turbo-0613": 4096,
|
||||
|
|
@ -48,6 +50,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
|
|||
print("Warning: model not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
if model in {
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-4-0314",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue