mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-11 15:15:18 +02:00
Add files via upload
This commit is contained in:
parent
fbd5d65e6b
commit
8f267b33d2
2 changed files with 36 additions and 23 deletions
|
|
@ -21,11 +21,15 @@ class LLMProviderRegistry:
|
|||
return self.providers[enum]
|
||||
|
||||
|
||||
def register_provider(key):
|
||||
def register_provider(keys):
|
||||
"""register provider to registry"""
|
||||
|
||||
def decorator(cls):
|
||||
LLM_REGISTRY.register(key, cls)
|
||||
if isinstance(keys,list):
|
||||
for key in keys:
|
||||
LLM_REGISTRY.register(key, cls)
|
||||
else:
|
||||
LLM_REGISTRY.register(keys, cls)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
|
|||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import CodeParser, decode_image
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.cost_manager import CostManager, Costs, TokenCostManager
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.token_counter import (
|
||||
count_message_tokens,
|
||||
|
|
@ -50,7 +50,7 @@ See FAQ 5.8
|
|||
raise retry_state.outcome.exception()
|
||||
|
||||
|
||||
@register_provider(LLMType.OPENAI)
|
||||
@register_provider([LLMType.OPENAI, LLMType.FIREWORKS, LLMType.OPEN_LLM, LLMType.MOONSHOT])
|
||||
class OpenAILLM(BaseLLM):
|
||||
"""Check https://platform.openai.com/examples for examples"""
|
||||
|
||||
|
|
@ -84,14 +84,33 @@ class OpenAILLM(BaseLLM):
|
|||
|
||||
return params
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> AsyncIterator[str]:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
|
||||
**self._cons_kwargs(messages, timeout=timeout), stream=True
|
||||
)
|
||||
|
||||
usage = None
|
||||
collected_messages = []
|
||||
async for chunk in response:
|
||||
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
|
||||
yield chunk_message
|
||||
finish_reason = chunk.choices[0].finish_reason if hasattr(chunk.choices[0], "finish_reason") else None
|
||||
log_llm_stream(chunk_message)
|
||||
collected_messages.append(chunk_message)
|
||||
if finish_reason:
|
||||
if hasattr(chunk, "usage"):
|
||||
# Some services have usage as an attribute of the chunk, such as Fireworks
|
||||
usage = CompletionUsage(**chunk.usage)
|
||||
elif hasattr(chunk.choices[0], "usage"):
|
||||
# The usage of some services is an attribute of chunk.choices[0], such as Moonshot
|
||||
usage = CompletionUsage(**chunk.choices[0].usage)
|
||||
|
||||
log_llm_stream("\n")
|
||||
full_reply_content = "".join(collected_messages)
|
||||
if not usage:
|
||||
# Some services do not provide the usage attribute, such as OpenAI or OpenLLM
|
||||
usage = self._calc_usage(messages, full_reply_content)
|
||||
|
||||
self._update_costs(usage)
|
||||
return full_reply_content
|
||||
|
||||
def _cons_kwargs(self, messages: list[dict], timeout=3, **extra_kwargs) -> dict:
|
||||
kwargs = {
|
||||
|
|
@ -99,7 +118,7 @@ class OpenAILLM(BaseLLM):
|
|||
"max_tokens": self._get_max_tokens(messages),
|
||||
"n": 1,
|
||||
# "stop": None, # default it's None and gpt4-v can't have this one
|
||||
"temperature": self.config.temperature,
|
||||
"temperature": 0.3,
|
||||
"model": self.model,
|
||||
"timeout": max(self.config.timeout, timeout),
|
||||
}
|
||||
|
|
@ -126,18 +145,7 @@ class OpenAILLM(BaseLLM):
|
|||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""when streaming, print each token in place."""
|
||||
if stream:
|
||||
resp = self._achat_completion_stream(messages, timeout=timeout)
|
||||
|
||||
collected_messages = []
|
||||
async for i in resp:
|
||||
log_llm_stream(i)
|
||||
collected_messages.append(i)
|
||||
log_llm_stream("\n")
|
||||
|
||||
full_reply_content = "".join(collected_messages)
|
||||
usage = self._calc_usage(messages, full_reply_content)
|
||||
self._update_costs(usage)
|
||||
return full_reply_content
|
||||
await self._achat_completion_stream(messages, timeout=timeout)
|
||||
|
||||
rsp = await self._achat_completion(messages, timeout=timeout)
|
||||
return self.get_choice_text(rsp)
|
||||
|
|
@ -261,11 +269,12 @@ class OpenAILLM(BaseLLM):
|
|||
if not self.config.calc_usage:
|
||||
return usage
|
||||
|
||||
model = self.model if not isinstance(self.cost_manager, TokenCostManager) else "open-llm-model"
|
||||
try:
|
||||
usage.prompt_tokens = count_message_tokens(messages, self.model)
|
||||
usage.completion_tokens = count_string_tokens(rsp, self.model)
|
||||
usage.prompt_tokens = count_message_tokens(messages, model)
|
||||
usage.completion_tokens = count_string_tokens(rsp, model)
|
||||
except Exception as e:
|
||||
logger.warning(f"usage calculation failed: {e}")
|
||||
logger.error(f"usage calculation failed: {e}")
|
||||
|
||||
return usage
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue