implement streaming

This commit is contained in:
usamimeri_renko 2024-05-22 22:05:15 +08:00
parent 1b13a28a77
commit bb8ea2eaf9
4 changed files with 29 additions and 2 deletions

View file

@ -10,7 +10,10 @@ from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
# from sparkai.schema import LLMResult, HumanMessage, AIMessage 由于其使用Pydantic V1,导入会报错
# from sparkai.schema import LLMResult, HumanMessage, AIMessage # 由于其使用Pydantic V1,导入会报错
from metagpt.utils.common import any_to_str
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.token_counter import SPARK_TOKENS
@register_provider(LLMType.SPARK)
@ -21,6 +24,8 @@ class SparkLLM(BaseLLM):
def __init__(self, config: LLMConfig):
self.config = config
self.cost_manager = CostManager(token_costs=SPARK_TOKENS)
self.model = self.config.domain
self._init_client()
def _init_client(self):
@ -60,7 +65,7 @@ class SparkLLM(BaseLLM):
return response
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
return self._achat_completion(messages, timeout)
return await self._achat_completion(messages, timeout)
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
response = self.client.astream(messages)
@ -76,3 +81,6 @@ class SparkLLM(BaseLLM):
self._update_costs(usage)
full_content = "".join(collected_content)
return full_content
def _extract_assistant_rsp(self, context):
return "\n".join([i.content for i in context if "AIMessage" in any_to_str(i)])

View file

@ -258,6 +258,14 @@ BEDROCK_TOKEN_COSTS = {
"ai21.j2-ultra-v1": {"prompt": 0.0188, "completion": 0.0188},
}
# https://xinghuo.xfyun.cn/sparkapi?scr=price
SPARK_TOKENS = {
"general": {"prompt": 0.0, "completion": 0.0}, # Spark-Lite
"generalV2": {"prompt": 0.0188, "completion": 0.0188}, # Spark V2.0
"generalV3": {"prompt": 0.0035, "completion": 0.0035}, # Spark Pro
"generalV3.5": {"prompt": 0.0035, "completion": 0.0035}, # Spark3.5 Max
}
def count_input_tokens(messages, model="gpt-3.5-turbo-0125"):
"""Return the number of tokens used by a list of messages."""