mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-05 14:55:18 +02:00
remove sync api in openai
This commit is contained in:
parent
bb1b9823d0
commit
4007fc87d6
4 changed files with 16 additions and 75 deletions
|
|
@ -90,7 +90,6 @@ class BaseGPTAPI(BaseChatbot):
|
|||
rsp_text = await self.aask_batch(msgs, timeout=timeout)
|
||||
return rsp_text
|
||||
|
||||
@abstractmethod
|
||||
def completion(self, messages: list[dict], timeout=3):
|
||||
"""All GPTAPIs are required to provide the standard OpenAI completion interface
|
||||
[
|
||||
|
|
@ -166,8 +165,3 @@ class BaseGPTAPI(BaseChatbot):
|
|||
def messages_to_dict(self, messages):
|
||||
"""objects to [{"role": "user", "content": msg}] etc."""
|
||||
return [i.to_dict() for i in messages]
|
||||
|
||||
@abstractmethod
|
||||
async def close(self):
|
||||
"""Close connection"""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -84,20 +84,21 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
def __init__(self):
|
||||
self.config: Config = CONFIG
|
||||
self._init_openai()
|
||||
self._init_client()
|
||||
self.auto_max_tokens = False
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
super().__init__()
|
||||
|
||||
def _init_openai(self):
|
||||
self.rpm = int(self.config.openai_api_rpm)
|
||||
self._init_client()
|
||||
|
||||
def _init_client(self):
|
||||
kwargs = self._make_client_kwargs()
|
||||
# https://github.com/openai/openai-python#async-usage
|
||||
self.aclient = AsyncOpenAI(**kwargs)
|
||||
self.model = self.config.OPENAI_API_MODEL # Used in _calc_usage & _cons_kwargs
|
||||
|
||||
def _make_client_kwargs(self) -> (dict, dict):
|
||||
def _init_client(self):
|
||||
"""https://github.com/openai/openai-python#async-usage"""
|
||||
kwargs = self._make_client_kwargs()
|
||||
self.aclient = AsyncOpenAI(**kwargs)
|
||||
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = {"api_key": self.config.OPENAI_API_KEY, "base_url": self.config.OPENAI_BASE_URL}
|
||||
|
||||
# to use proxy, openai v1 needs http_client
|
||||
|
|
@ -124,19 +125,18 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
|
||||
yield chunk_message
|
||||
|
||||
def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict:
|
||||
def _cons_kwargs(self, messages: list[dict], timeout=3, **extra_kwargs) -> dict:
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"max_tokens": self.get_max_tokens(messages),
|
||||
"max_tokens": self._get_max_tokens(messages),
|
||||
"n": 1,
|
||||
"stop": None,
|
||||
"temperature": 0.3,
|
||||
"model": self.model,
|
||||
"timeout": max(CONFIG.timeout, timeout),
|
||||
}
|
||||
if configs:
|
||||
kwargs.update(configs)
|
||||
kwargs["timeout"] = max(CONFIG.timeout, timeout)
|
||||
|
||||
if extra_kwargs:
|
||||
kwargs.update(extra_kwargs)
|
||||
return kwargs
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
|
||||
|
|
@ -242,36 +242,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
usage.prompt_tokens = count_message_tokens(messages, self.model)
|
||||
usage.completion_tokens = count_string_tokens(rsp, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"usage calculation failed!: {e}")
|
||||
logger.error(f"usage calculation failed: {e}")
|
||||
|
||||
return usage
|
||||
|
||||
async def acompletion_batch(self, batch: list[list[dict]], timeout=3) -> list[ChatCompletion]:
|
||||
"""Return full JSON"""
|
||||
split_batches = self.split_batches(batch)
|
||||
all_results = []
|
||||
|
||||
for small_batch in split_batches:
|
||||
logger.info(small_batch)
|
||||
await self.wait_if_needed(len(small_batch))
|
||||
|
||||
future = [self.acompletion(prompt, timeout=timeout) for prompt in small_batch]
|
||||
results = await asyncio.gather(*future)
|
||||
logger.info(results)
|
||||
all_results.extend(results)
|
||||
|
||||
return all_results
|
||||
|
||||
async def acompletion_batch_text(self, batch: list[list[dict]], timeout=3) -> list[str]:
|
||||
"""Only return plain text"""
|
||||
raw_results = await self.acompletion_batch(batch, timeout=timeout)
|
||||
results = []
|
||||
for idx, raw_result in enumerate(raw_results, start=1):
|
||||
result = self.get_choice_text(raw_result)
|
||||
results.append(result)
|
||||
logger.info(f"Result of task {idx}: {result}")
|
||||
return results
|
||||
|
||||
@handle_exception
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if CONFIG.calc_usage and usage:
|
||||
|
|
@ -280,11 +254,12 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
def get_costs(self) -> Costs:
|
||||
return CONFIG.cost_manager.get_costs()
|
||||
|
||||
def get_max_tokens(self, messages: list[dict]):
|
||||
def _get_max_tokens(self, messages: list[dict]):
|
||||
if not self.auto_max_tokens:
|
||||
return CONFIG.max_tokens_rsp
|
||||
return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp)
|
||||
|
||||
@handle_exception
|
||||
async def amoderation(self, content: Union[str, list[str]]):
|
||||
"""Moderate content."""
|
||||
return await self.aclient.moderations.create(input=content)
|
||||
|
|
|
|||
|
|
@ -23,18 +23,11 @@ async def test_llm_aask(llm):
|
|||
assert len(rsp) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_aask_batch(llm):
|
||||
assert len(await llm.aask_batch(["hi", "write python hello world."])) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_acompletion(llm):
|
||||
hello_msg = [{"role": "user", "content": "hello"}]
|
||||
rsp = await llm.acompletion(hello_msg)
|
||||
assert len(rsp.choices[0].message.content) > 0
|
||||
assert len(await llm.acompletion_batch([hello_msg])) > 0
|
||||
assert len(await llm.acompletion_batch_text([hello_msg])) > 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,21 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/7 17:23
|
||||
@Author : alexanderwu
|
||||
@File : test_custom_aio_session.py
|
||||
"""
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI
|
||||
|
||||
|
||||
async def try_hello(api):
|
||||
batch = [[{"role": "user", "content": "hello"}]]
|
||||
results = await api.acompletion_batch_text(batch)
|
||||
return results
|
||||
|
||||
|
||||
async def aask_batch(api: OpenAIGPTAPI):
|
||||
results = await api.aask_batch(["hi", "write python hello world."])
|
||||
logger.info(results)
|
||||
return results
|
||||
Loading…
Add table
Add a link
Reference in a new issue