implement streaming

This commit is contained in:
usamimeri_renko 2024-05-22 22:05:15 +08:00
parent 1b13a28a77
commit bb8ea2eaf9
4 changed files with 29 additions and 2 deletions

View file

@ -0,0 +1,10 @@
# 适用于讯飞星火的spark-lite 参考 https://www.xfyun.cn/doc/spark/Web.html#_2-function-call%E8%AF%B4%E6%98%8E
llm:
api_type: 'spark'
# 对应模型的url 参考 https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E
base_url: "ws(s)://spark-api.xf-yun.com/v1.1/chat"
app_id: ""
api_key: ""
api_secret: ""
domain: "general" # 取值为 [general,generalv2,generalv3,generalv3.5] 和url一一对应

View file

@ -10,7 +10,10 @@ from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
# from sparkai.schema import LLMResult, HumanMessage, AIMessage 由于其使用Pydantic V1,导入会报错
# from sparkai.schema import LLMResult, HumanMessage, AIMessage # 由于其使用Pydantic V1,导入会报错
from metagpt.utils.common import any_to_str
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.token_counter import SPARK_TOKENS
@register_provider(LLMType.SPARK)
@ -21,6 +24,8 @@ class SparkLLM(BaseLLM):
def __init__(self, config: LLMConfig):
self.config = config
self.cost_manager = CostManager(token_costs=SPARK_TOKENS)
self.model = self.config.domain
self._init_client()
def _init_client(self):
@ -60,7 +65,7 @@ class SparkLLM(BaseLLM):
return response
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
return self._achat_completion(messages, timeout)
return await self._achat_completion(messages, timeout)
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
response = self.client.astream(messages)
@ -76,3 +81,6 @@ class SparkLLM(BaseLLM):
self._update_costs(usage)
full_content = "".join(collected_content)
return full_content
def _extract_assistant_rsp(self, context):
return "\n".join([i.content for i in context if "AIMessage" in any_to_str(i)])

View file

@ -258,6 +258,14 @@ BEDROCK_TOKEN_COSTS = {
"ai21.j2-ultra-v1": {"prompt": 0.0188, "completion": 0.0188},
}
# https://xinghuo.xfyun.cn/sparkapi?scr=price
SPARK_TOKENS = {
"general": {"prompt": 0.0, "completion": 0.0}, # Spark-Lite
"generalV2": {"prompt": 0.0188, "completion": 0.0188}, # Spark V2.0
"generalV3": {"prompt": 0.0035, "completion": 0.0035}, # Spark Pro
"generalV3.5": {"prompt": 0.0035, "completion": 0.0035}, # Spark3.5 Max
}
def count_input_tokens(messages, model="gpt-3.5-turbo-0125"):
"""Return the number of tokens used by a list of messages."""

View file

@ -71,3 +71,4 @@ dashscope==1.14.1
rank-bm25==0.2.2 # for tool recommendation
gymnasium==0.29.1
boto3~=1.34.69
spark_ai_python~=0.3.30