diff --git a/metagpt/provider/base_gpt_api.py b/metagpt/provider/base_gpt_api.py index de61167b9..d6e827fd1 100644 --- a/metagpt/provider/base_gpt_api.py +++ b/metagpt/provider/base_gpt_api.py @@ -36,6 +36,10 @@ class BaseGPTAPI(BaseChatbot): rsp = self.completion(message) return self.get_choice_text(rsp) + def ask_nonchat(self, prompt: str) -> str: + rsp = self.nonchat_completion(prompt) + return self.get_nonchat_choice_text(rsp) + async def aask(self, msg: str, system_msgs: Optional[list[str]] = None) -> str: if system_msgs: message = self._system_msgs(system_msgs) + [self._user_msg(msg)] @@ -89,6 +93,10 @@ class BaseGPTAPI(BaseChatbot): ] """ + @abstractmethod + def nonchat_completion(self, prompt: str) -> dict: + """ for openai.Completion request """ + @abstractmethod async def acompletion(self, messages: list[dict]): """Asynchronous version of completion @@ -108,6 +116,10 @@ class BaseGPTAPI(BaseChatbot): """Required to provide the first text of choice""" return rsp.get("choices")[0]["message"]["content"] + def get_nonchat_choice_text(self, rsp: dict) -> str: + """ for openai.Completion """ + return rsp.get("choices")[0]["text"] + def messages_to_prompt(self, messages: list[dict]): """[{"role": "user", "content": msg}] to user: etc.""" return '\n'.join([f"{i['role']}: {i['content']}" for i in messages]) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 461986f16..77403dac8 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -205,6 +205,17 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): kwargs.update(kwargs_mode) return kwargs + def _nonchat_cons_kwargs(self, prompt: str) -> dict: + kwargs = { + "model": self.model, + "prompt": prompt, + "max_tokens": self.get_max_tokens(prompt), # TODO adapt if auto_max_tokens is True + "n": 1, + "stop": None, + "temperature": 0.3 + } + return kwargs + async def _achat_completion(self, messages: list[dict]) -> dict: rsp = await self.llm.ChatCompletion.acreate(**self._cons_kwargs(messages)) self._update_costs(rsp.get("usage")) @@ -220,6 +231,14 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): # messages = self.messages_to_dict(messages) return self._chat_completion(messages) + def _nonchat_completion(self, prompt: str) -> dict: + rsp = self.llm.Completion.create(**self._nonchat_cons_kwargs(prompt)) + self._update_costs(rsp.get("usage")) + return rsp + + def nonchat_completion(self, prompt: str) -> dict: + return self._nonchat_completion(prompt) + async def acompletion(self, messages: list[dict]) -> dict: # if isinstance(messages[0], Message): # messages = self.messages_to_dict(messages) @@ -249,7 +268,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): usage["completion_tokens"] = completion_tokens return usage except Exception as e: - logger.error("usage calculation failed! {e}") + logger.error(f"usage calculation failed! {e}") else: return usage @@ -286,7 +305,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): completion_tokens = int(usage["completion_tokens"]) self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) except Exception as e: - logger.error("updating costs failed! {e}") + logger.error(f"updating costs failed! {e}") def get_costs(self) -> Costs: return self._cost_manager.get_costs() diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index a5a65803a..fa451014e 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -11,6 +11,7 @@ ref3: https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/o import tiktoken TOKEN_COSTS = { + "gpt-3.5-turbo-instruct": {"prompt": 0.0015, "completion": 0.002}, "gpt-3.5-turbo": {"prompt": 0.0015, "completion": 0.002}, "gpt-3.5-turbo-0301": {"prompt": 0.0015, "completion": 0.002}, "gpt-3.5-turbo-0613": {"prompt": 0.0015, "completion": 0.002}, @@ -26,6 +27,7 @@ TOKEN_COSTS = { TOKEN_MAX = { + "gpt-3.5-turbo-instruct": 4096, "gpt-3.5-turbo": 4096, "gpt-3.5-turbo-0301": 4096, "gpt-3.5-turbo-0613": 4096, @@ -48,6 +50,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"): print("Warning: model not found. Using cl100k_base encoding.") encoding = tiktoken.get_encoding("cl100k_base") if model in { + "gpt-3.5-turbo-instruct", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-4-0314",