add openai.Completion support

This commit is contained in:
better629 2023-10-09 13:52:30 +08:00
parent 6c947945d7
commit 81f716b57e
3 changed files with 36 additions and 2 deletions

View file

@ -36,6 +36,10 @@ class BaseGPTAPI(BaseChatbot):
rsp = self.completion(message)
return self.get_choice_text(rsp)
def ask_nonchat(self, prompt: str) -> str:
rsp = self.nonchat_completion(prompt)
return self.get_nonchat_choice_text(rsp)
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None) -> str:
if system_msgs:
message = self._system_msgs(system_msgs) + [self._user_msg(msg)]
@ -89,6 +93,10 @@ class BaseGPTAPI(BaseChatbot):
]
"""
@abstractmethod
def nonchat_completion(self, prompt: str) -> dict:
""" for openai.Completion request """
@abstractmethod
async def acompletion(self, messages: list[dict]):
"""Asynchronous version of completion
@ -108,6 +116,10 @@ class BaseGPTAPI(BaseChatbot):
"""Required to provide the first text of choice"""
return rsp.get("choices")[0]["message"]["content"]
def get_nonchat_choice_text(self, rsp: dict) -> str:
""" for openai.Completion """
return rsp.get("choices")[0]["text"]
def messages_to_prompt(self, messages: list[dict]):
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
return '\n'.join([f"{i['role']}: {i['content']}" for i in messages])

View file

@ -205,6 +205,17 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
kwargs.update(kwargs_mode)
return kwargs
def _nonchat_cons_kwargs(self, prompt: str) -> dict:
kwargs = {
"model": self.model,
"prompt": prompt,
"max_tokens": self.get_max_tokens(prompt), # TODO adapt if auto_max_tokens is True
"n": 1,
"stop": None,
"temperature": 0.3
}
return kwargs
async def _achat_completion(self, messages: list[dict]) -> dict:
rsp = await self.llm.ChatCompletion.acreate(**self._cons_kwargs(messages))
self._update_costs(rsp.get("usage"))
@ -220,6 +231,14 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
# messages = self.messages_to_dict(messages)
return self._chat_completion(messages)
def _nonchat_completion(self, prompt: str) -> dict:
rsp = self.llm.Completion.create(**self._nonchat_cons_kwargs(prompt))
self._update_costs(rsp.get("usage"))
return rsp
def nonchat_completion(self, prompt: str) -> dict:
return self._nonchat_completion(prompt)
async def acompletion(self, messages: list[dict]) -> dict:
# if isinstance(messages[0], Message):
# messages = self.messages_to_dict(messages)
@ -249,7 +268,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
usage["completion_tokens"] = completion_tokens
return usage
except Exception as e:
logger.error("usage calculation failed! {e}")
logger.error(f"usage calculation failed! {e}")
else:
return usage
@ -286,7 +305,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
completion_tokens = int(usage["completion_tokens"])
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error("updating costs failed! {e}")
logger.error(f"updating costs failed! {e}")
def get_costs(self) -> Costs:
return self._cost_manager.get_costs()

View file

@ -11,6 +11,7 @@ ref3: https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/o
import tiktoken
TOKEN_COSTS = {
"gpt-3.5-turbo-instruct": {"prompt": 0.0015, "completion": 0.002},
"gpt-3.5-turbo": {"prompt": 0.0015, "completion": 0.002},
"gpt-3.5-turbo-0301": {"prompt": 0.0015, "completion": 0.002},
"gpt-3.5-turbo-0613": {"prompt": 0.0015, "completion": 0.002},
@ -26,6 +27,7 @@ TOKEN_COSTS = {
TOKEN_MAX = {
"gpt-3.5-turbo-instruct": 4096,
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-0301": 4096,
"gpt-3.5-turbo-0613": 4096,
@ -48,6 +50,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in {
"gpt-3.5-turbo-instruct",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",