remove sync api in openai

This commit is contained in:
geekan 2023-12-26 16:33:15 +08:00
parent bb1b9823d0
commit 4007fc87d6
4 changed files with 16 additions and 75 deletions

View file

@ -90,7 +90,6 @@ class BaseGPTAPI(BaseChatbot):
rsp_text = await self.aask_batch(msgs, timeout=timeout)
return rsp_text
@abstractmethod
def completion(self, messages: list[dict], timeout=3):
"""All GPTAPIs are required to provide the standard OpenAI completion interface
[
@ -166,8 +165,3 @@ class BaseGPTAPI(BaseChatbot):
def messages_to_dict(self, messages):
"""objects to [{"role": "user", "content": msg}] etc."""
return [i.to_dict() for i in messages]
@abstractmethod
async def close(self):
"""Close connection"""
pass

View file

@ -84,20 +84,21 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
def __init__(self):
self.config: Config = CONFIG
self._init_openai()
self._init_client()
self.auto_max_tokens = False
RateLimiter.__init__(self, rpm=self.rpm)
super().__init__()
def _init_openai(self):
self.rpm = int(self.config.openai_api_rpm)
self._init_client()
def _init_client(self):
kwargs = self._make_client_kwargs()
# https://github.com/openai/openai-python#async-usage
self.aclient = AsyncOpenAI(**kwargs)
self.model = self.config.OPENAI_API_MODEL # Used in _calc_usage & _cons_kwargs
def _make_client_kwargs(self) -> (dict, dict):
def _init_client(self):
"""https://github.com/openai/openai-python#async-usage"""
kwargs = self._make_client_kwargs()
self.aclient = AsyncOpenAI(**kwargs)
def _make_client_kwargs(self) -> dict:
kwargs = {"api_key": self.config.OPENAI_API_KEY, "base_url": self.config.OPENAI_BASE_URL}
# to use proxy, openai v1 needs http_client
@ -124,19 +125,18 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
yield chunk_message
def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict:
def _cons_kwargs(self, messages: list[dict], timeout=3, **extra_kwargs) -> dict:
kwargs = {
"messages": messages,
"max_tokens": self.get_max_tokens(messages),
"max_tokens": self._get_max_tokens(messages),
"n": 1,
"stop": None,
"temperature": 0.3,
"model": self.model,
"timeout": max(CONFIG.timeout, timeout),
}
if configs:
kwargs.update(configs)
kwargs["timeout"] = max(CONFIG.timeout, timeout)
if extra_kwargs:
kwargs.update(extra_kwargs)
return kwargs
async def _achat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
@ -242,36 +242,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
usage.prompt_tokens = count_message_tokens(messages, self.model)
usage.completion_tokens = count_string_tokens(rsp, self.model)
except Exception as e:
logger.error(f"usage calculation failed!: {e}")
logger.error(f"usage calculation failed: {e}")
return usage
async def acompletion_batch(self, batch: list[list[dict]], timeout=3) -> list[ChatCompletion]:
"""Return full JSON"""
split_batches = self.split_batches(batch)
all_results = []
for small_batch in split_batches:
logger.info(small_batch)
await self.wait_if_needed(len(small_batch))
future = [self.acompletion(prompt, timeout=timeout) for prompt in small_batch]
results = await asyncio.gather(*future)
logger.info(results)
all_results.extend(results)
return all_results
async def acompletion_batch_text(self, batch: list[list[dict]], timeout=3) -> list[str]:
"""Only return plain text"""
raw_results = await self.acompletion_batch(batch, timeout=timeout)
results = []
for idx, raw_result in enumerate(raw_results, start=1):
result = self.get_choice_text(raw_result)
results.append(result)
logger.info(f"Result of task {idx}: {result}")
return results
@handle_exception
def _update_costs(self, usage: CompletionUsage):
if CONFIG.calc_usage and usage:
@ -280,11 +254,12 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
def get_costs(self) -> Costs:
return CONFIG.cost_manager.get_costs()
def get_max_tokens(self, messages: list[dict]):
def _get_max_tokens(self, messages: list[dict]):
if not self.auto_max_tokens:
return CONFIG.max_tokens_rsp
return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp)
@handle_exception
async def amoderation(self, content: Union[str, list[str]]):
"""Moderate content."""
return await self.aclient.moderations.create(input=content)

View file

@ -23,18 +23,11 @@ async def test_llm_aask(llm):
assert len(rsp) > 0
@pytest.mark.asyncio
async def test_llm_aask_batch(llm):
assert len(await llm.aask_batch(["hi", "write python hello world."])) > 0
@pytest.mark.asyncio
async def test_llm_acompletion(llm):
hello_msg = [{"role": "user", "content": "hello"}]
rsp = await llm.acompletion(hello_msg)
assert len(rsp.choices[0].message.content) > 0
assert len(await llm.acompletion_batch([hello_msg])) > 0
assert len(await llm.acompletion_batch_text([hello_msg])) > 0
if __name__ == "__main__":

View file

@ -1,21 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/7 17:23
@Author : alexanderwu
@File : test_custom_aio_session.py
"""
from metagpt.logs import logger
from metagpt.provider.openai_api import OpenAIGPTAPI
async def try_hello(api):
batch = [[{"role": "user", "content": "hello"}]]
results = await api.acompletion_batch_text(batch)
return results
async def aask_batch(api: OpenAIGPTAPI):
results = await api.aask_batch(["hi", "write python hello world."])
logger.info(results)
return results