修改内容:

调整llm引用模式
添加对星火大模型的支持
在模型报错时可以人工输入
可以通过config直接全人工输入
This commit is contained in:
ziming 2023-09-05 21:52:13 +08:00
parent 206bc252de
commit 6046f9c942
18 changed files with 277 additions and 29 deletions

View file

@ -11,15 +11,16 @@ from typing import Optional
from tenacity import retry, stop_after_attempt, wait_fixed
from metagpt.actions.action_output import ActionOutput
from metagpt.llm import LLM
import metagpt.llm as LLM
from metagpt.utils.common import OutputParser
from metagpt.logs import logger
from metagpt.config import CONFIG
class Action(ABC):
def __init__(self, name: str = '', context=None, llm: LLM = None):
self.name: str = name
if llm is None:
llm = LLM()
llm=LLM.DEFAULT_LLM
self.llm = llm
self.context = context
self.prefix = ""
@ -54,13 +55,42 @@ class Action(ABC):
if not system_msgs:
system_msgs = []
system_msgs.append(self.prefix)
content = await self.llm.aask(prompt, system_msgs)
logger.debug(content)
output_class = ActionOutput.create_model_class(output_class_name, output_data_mapping)
parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping)
logger.debug(parsed_data)
instruct_content = output_class(**parsed_data)
return ActionOutput(content, instruct_content)
if not CONFIG.no_api_mode:
content = await self.llm.aask(prompt, system_msgs)
logger.debug(content)
output_class = ActionOutput.create_model_class(output_class_name, output_data_mapping)
parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping)
logger.debug(parsed_data)
try:
instruct_content = output_class(**parsed_data)
return ActionOutput(content, instruct_content)
except Exception as e:
print('Error:',e)
print('自动运行出错,切换为手动运行')
print('prompt为')
print('\n'.join( system_msgs)+prompt)
print('输入格式:')
print(output_data_mapping)
print('请准备输入,输入完成按ctrl+Z')
while True:
try:
lines=[]
while True:
try:
lines.append(input())
except:
break
content ='\n'.join(lines)
output_class = ActionOutput.create_model_class(output_class_name, output_data_mapping)
parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping)
logger.debug(parsed_data)
instruct_content = output_class(**parsed_data)
return ActionOutput(content, instruct_content)
except Exception as e:
print('Error:',e)
print('输入错误,请重试')
async def run(self, *args, **kwargs):
"""Run action"""

View file

@ -143,4 +143,5 @@ class WritePRD(Action):
format_example=FORMAT_EXAMPLE)
logger.debug(prompt)
prd = await self._aask_v1(prompt, "prd", OUTPUT_MAPPING)
return prd

View file

@ -45,8 +45,18 @@ class Config(metaclass=Singleton):
self.global_proxy = self._get("GLOBAL_PROXY")
self.openai_api_key = self._get("OPENAI_API_KEY")
self.anthropic_api_key = self._get("Anthropic_API_KEY")
#星火大模型相关
self.xinghuo_appid = self._get("xinghuo_appid")
self.xinghuo_api_secret = self._get("xinghuo_api_secret")
self.xinghuo_api_key = self._get("xinghuo_api_key")
self.domain=self._get("domain")
self.Spark_url=self._get("Spark_url")
self.no_api_mode=self._get("no_api_mode")
if (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) and (
not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key
)and (
not self.xinghuo_api_key or "APIKey" == self.xinghuo_api_key
):
raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY first")
self.openai_api_base = self._get("OPENAI_API_BASE")

View file

@ -8,10 +8,11 @@
from metagpt.provider.anthropic_api import Claude2 as Claude
from metagpt.provider.openai_api import OpenAIGPTAPI as LLM
from metagpt.provider.spark_api import Spark
DEFAULT_LLM = LLM()
DEFAULT_LLM = Spark()
CLAUDE_LLM = Claude()
SPARK_LLM = Spark()
async def ai_func(prompt):
"""使用LLM进行QA

View file

@ -8,7 +8,7 @@
from metagpt.actions import Action
from metagpt.const import PROMPT_PATH
from metagpt.document_store.chromadb_store import ChromaStore
from metagpt.llm import LLM
import metagpt.llm as LLM
from metagpt.logs import logger
Skill = Action
@ -18,7 +18,7 @@ class SkillManager:
"""用来管理所有技能"""
def __init__(self):
self._llm = LLM()
self._llm=LLM.DEFAULT_LLM
self._store = ChromaStore('skill_manager')
self._skills: dict[str: Skill] = {}

View file

@ -5,13 +5,13 @@
@Author : alexanderwu
@File : manager.py
"""
from metagpt.llm import LLM
import metagpt.llm as LLM
from metagpt.logs import logger
from metagpt.schema import Message
class Manager:
def __init__(self, llm: LLM = LLM()):
def __init__(self, llm: llm=LLM.DEFAULT_LLM):
self.llm = llm # Large Language Model
self.role_directions = {
"BOSS": "Product Manager",

View file

@ -10,10 +10,10 @@
```python
from typing import Optional
from abc import ABC
from metagpt.llm import LLM # 大语言模型类似GPT
import metagpt.llm as LLM # 大语言模型类似GPT
class Action(ABC):
def __init__(self, name='', context=None, llm: LLM = LLM()):
def __init__(self, name='', context=None, llm: llm=LLM.DEFAULT_LLM):
self.name = name
self.llm = llm
self.context = context

View file

@ -0,0 +1,137 @@
import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
from urllib.parse import urlparse
import ssl
from datetime import datetime
from time import mktime
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
import websocket # 使用websocket_client
answer = ""
class Ws_Param(object):
# 初始化
def __init__(self, APPID, APIKey, APISecret, Spark_url):
self.APPID = APPID
self.APIKey = APIKey
self.APISecret = APISecret
self.host = urlparse(Spark_url).netloc
self.path = urlparse(Spark_url).path
self.Spark_url = Spark_url
# 生成url
def create_url(self):
# 生成RFC1123格式的时间戳
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
# 拼接字符串
signature_origin = "host: " + self.host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + self.path + " HTTP/1.1"
# 进行hmac-sha256进行加密
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
# 将请求的鉴权参数组合为字典
v = {
"authorization": authorization,
"date": date,
"host": self.host
}
# 拼接鉴权参数生成url
url = self.Spark_url + '?' + urlencode(v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释比对相同参数时生成的url与自己代码生成的url是否一致
return url
# 收到websocket错误的处理
def on_error(ws, error):
print("### error:", error)
# 收到websocket关闭的处理
def on_close(ws,one,two):
print(" ")
# 收到websocket连接建立的处理
def on_open(ws):
thread.start_new_thread(run, (ws,))
def run(ws, *args):
data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question))
ws.send(data)
# 收到websocket消息的处理
def on_message(ws, message):
# print(message)
data = json.loads(message)
code = data['header']['code']
if code != 0:
print(f'请求错误: {code}, {data}')
ws.close()
else:
choices = data["payload"]["choices"]
status = choices["status"]
content = choices["text"][0]["content"]
print(content,end ="")
global answer
answer += content
# print(1)
if status == 2:
ws.close()
def gen_params(appid, domain,question):
"""
通过appid和用户的提问来生成请参数
"""
data = {
"header": {
"app_id": appid,
"uid": "1234"
},
"parameter": {
"chat": {
"domain": domain,
"random_threshold": 0.5,
"max_tokens": 2048,
"auditing": "default"
}
},
"payload": {
"message": {
"text": question
}
}
}
return data
def main(appid, api_key, api_secret, Spark_url,domain, question):
# print("星火:")
wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
websocket.enableTrace(False)
wsUrl = wsParam.create_url()
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
ws.appid = appid
ws.question = question
ws.domain = domain
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})

View file

@ -0,0 +1,57 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/7/21 11:15
@Author : Leo Xiao
@File : anthropic_api.py
"""
from typing import Optional
from metagpt.provider import SparkApi
from metagpt.config import CONFIG
def getlength(text):
length = 0
for content in text:
temp = content["content"]
leng = len(temp)
length += leng
return length
def checklen(text):
while (getlength(text) > 8000):
del text[0]
return text
class Spark:
system_prompt = 'You are a helpful assistant.'
def _user_msg(self, msg: str) -> dict[str, str]:
return {"role": "user", "content": msg}
def _assistant_msg(self, msg: str) -> dict[str, str]:
return {"role": "assistant", "content": msg}
def _system_msg(self, msg: str) -> dict[str, str]:
return {"role": "system", "content": msg}
def _system_msgs(self, msgs: list[str]) -> list[dict[str, str]]:
return [self._system_msg(msg) for msg in msgs]
def _default_system_msg(self):
return self._system_msg(self.system_prompt)
def ask(self, msg: str):
message = [self._user_msg(msg)]
SparkApi.main(CONFIG.xinghuo_appid,CONFIG.xinghuo_api_key,CONFIG.xinghuo_api_secret,"ws://spark-api.xf-yun.com/v2.1/chat","generalv2",message)
rsp = SparkApi.answer
return rsp
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None) -> str:
if system_msgs:
message = self._system_msgs(system_msgs) + [self._user_msg(msg)]
else:
message = [self._user_msg(msg)]
SparkApi.main(CONFIG.xinghuo_appid,CONFIG.xinghuo_api_key,CONFIG.xinghuo_api_secret,"ws://spark-api.xf-yun.com/v2.1/chat","generalv2",message)
rsp = SparkApi.answer
return rsp

View file

@ -14,7 +14,7 @@ from pydantic import BaseModel, Field
# from metagpt.environment import Environment
from metagpt.config import CONFIG
from metagpt.actions import Action, ActionOutput
from metagpt.llm import LLM
from metagpt import llm as LLM
from metagpt.logs import logger
from metagpt.memory import Memory, LongTermMemory
from metagpt.schema import Message
@ -94,7 +94,7 @@ class Role:
"""角色/代理"""
def __init__(self, name="", profile="", goal="", constraints="", desc=""):
self._llm = LLM()
self._llm=LLM.DEFAULT_LLM
self._setting = RoleSetting(name=name, profile=profile, goal=goal, constraints=constraints, desc=desc)
self._states = []
self._actions = []