Merge branch 'dev' into update-unit-test

This commit is contained in:
Stitch-z 2023-12-26 16:11:31 +08:00 committed by GitHub
commit bf0f6bd272
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
148 changed files with 6195 additions and 691 deletions

View file

@ -0,0 +1,9 @@
role:
name: Teacher # Referenced the `Teacher` in `metagpt/roles/teacher.py`.
module: metagpt.roles.teacher # Referenced `metagpt/roles/teacher.py`.
skills: # Refer to the skill `name` of the published skill in `.well-known/skills.yaml`.
- name: text_to_speech
description: Text-to-speech
- name: text_to_image
description: Create a drawing based on the text.

6
.gitignore vendored
View file

@ -159,3 +159,9 @@ workspace/*
tmp
metagpt/roles/idea_agent.py
.aider*
*.bak
# output folder
output
tmp.png

View file

@ -1,8 +1,9 @@
default_stages: [ commit ]
# Install
# 1. pip install pre-commit
# 1. pip install metagpt[dev]
# 2. pre-commit install
# 3. pre-commit run --all-files # make sure all files are clean
repos:
- repo: https://github.com/pycqa/isort
rev: 5.11.5
@ -19,9 +20,10 @@ repos:
rev: v0.0.284
hooks:
- id: ruff
args: [ --fix ]
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
args: ['--line-length', '120']
args: ['--line-length', '120']

View file

@ -0,0 +1,18 @@
{
"schema_version": "v1",
"name_for_model": "text processing tools",
"name_for_human": "MetaGPT Text Plugin",
"description_for_model": "Plugins for text processing, including text-to-speech, text-to-image, text-to-embedding, text summarization, text-to-code, vector similarity calculation, web content crawling, and more.",
"description_for_human": "Plugins for text processing, including text-to-speech, text-to-image, text-to-embedding, text summarization, text-to-code, vector similarity calculation, web content crawling, and more.",
"auth": {
"type": "none"
},
"api": {
"type": "openapi",
"url": "https://github.com/iorisa/MetaGPT/blob/feature/assistant_role/.well-known/metagpt_oas3_api.yaml",
"has_user_authentication": false
},
"logo_url": "https://github.com/geekan/MetaGPT/blob/main/docs/resources/MetaGPT-logo.png",
"contact_email": "mashenquan@fuzhi.cn",
"legal_info_url": "https://github.com/geekan/MetaGPT/blob/main/docs/README_CN.md"
}

View file

@ -0,0 +1,338 @@
openapi: "3.0.0"
info:
title: "MetaGPT Export OpenAPIs"
version: "1.0"
servers:
- url: "/oas3"
variables:
port:
default: '8080'
description: HTTP service port
paths:
/tts/azsure:
x-prerequisite:
configurations:
AZURE_TTS_SUBSCRIPTION_KEY:
type: string
description: "For more details, check out: [Azure Text-to_Speech](https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts)"
AZURE_TTS_REGION:
type: string
description: "For more details, check out: [Azure Text-to_Speech](https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts)"
required:
allOf:
- AZURE_TTS_SUBSCRIPTION_KEY
- AZURE_TTS_REGION
post:
summary: "Convert Text to Base64-encoded .wav File Stream"
description: "For more details, check out: [Azure Text-to_Speech](https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts)"
operationId: azure_tts.oas3_azsure_tts
requestBody:
required: true
content:
application/json:
schema:
type: object
required:
- text
properties:
text:
type: string
description: Text to convert
lang:
type: string
description: The language code or locale, e.g., en-US (English - United States)
default: "zh-CN"
voice:
type: string
description: "Voice style, see: [Azure Text-to_Speech](https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts), [Voice Gallery](https://speech.microsoft.com/portal/voicegallery)"
default: "zh-CN-XiaomoNeural"
style:
type: string
description: "Speaking style to express different emotions. For more details, checkout: [Azure Text-to_Speech](https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts)"
default: "affectionate"
role:
type: string
description: "Role to specify age and gender. For more details, checkout: [Azure Text-to_Speech](https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts)"
default: "Girl"
subscription_key:
type: string
description: "Key used to access Azure AI service API, see: [Azure Portal](https://portal.azure.com/) > `Resource Management` > `Keys and Endpoint`"
default: ""
region:
type: string
description: "Location (or region) of your resource, see: [Azure Portal](https://portal.azure.com/) > `Resource Management` > `Keys and Endpoint`"
default: ""
responses:
'200':
description: "Base64-encoded .wav file data if successful, otherwise an empty string."
content:
application/json:
schema:
type: object
properties:
wav_data:
type: string
format: base64
'400':
description: "Bad Request"
'500':
description: "Internal Server Error"
/tts/iflytek:
x-prerequisite:
configurations:
IFLYTEK_APP_ID:
type: string
description: "Application ID is used to access your iFlyTek service API, see: `https://console.xfyun.cn/services/tts`"
IFLYTEK_API_KEY:
type: string
description: "WebAPI argument, see: `https://console.xfyun.cn/services/tts`"
IFLYTEK_API_SECRET:
type: string
description: "WebAPI argument, see: `https://console.xfyun.cn/services/tts`"
required:
allOf:
- IFLYTEK_APP_ID
- IFLYTEK_API_KEY
- IFLYTEK_API_SECRET
post:
summary: "Convert Text to Base64-encoded .mp3 File Stream"
description: "For more details, check out: [iFlyTek](https://console.xfyun.cn/services/tts)"
operationId: iflytek_tts.oas3_iflytek_tts
requestBody:
required: true
content:
application/json:
schema:
type: object
required:
- text
properties:
text:
type: string
description: Text to convert
voice:
type: string
description: "Voice style, see: [iFlyTek Text-to_Speech](https://www.xfyun.cn/doc/tts/online_tts/API.html#%E6%8E%A5%E5%8F%A3%E8%B0%83%E7%94%A8%E6%B5%81%E7%A8%8B)"
default: "xiaoyan"
app_id:
type: string
description: "Application ID is used to access your iFlyTek service API, see: `https://console.xfyun.cn/services/tts`"
default: ""
api_key:
type: string
description: "WebAPI argument, see: `https://console.xfyun.cn/services/tts`"
default: ""
api_secret:
type: string
description: "WebAPI argument, see: `https://console.xfyun.cn/services/tts`"
default: ""
responses:
'200':
description: "Base64-encoded .mp3 file data if successful, otherwise an empty string."
content:
application/json:
schema:
type: object
properties:
wav_data:
type: string
format: base64
'400':
description: "Bad Request"
'500':
description: "Internal Server Error"
/txt2img/openai:
x-prerequisite:
configurations:
OPENAI_API_KEY:
type: string
description: "OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`"
required:
allOf:
- OPENAI_API_KEY
post:
summary: "Convert Text to Base64-encoded Image Data Stream"
operationId: openai_text_to_image.oas3_openai_text_to_image
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
text:
type: string
description: "The text used for image conversion."
size_type:
type: string
enum: ["256x256", "512x512", "1024x1024"]
default: "1024x1024"
description: "Size of the generated image."
openai_api_key:
type: string
default: ""
description: "OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`"
responses:
'200':
description: "Base64-encoded image data."
content:
application/json:
schema:
type: object
properties:
image_data:
type: string
format: base64
'400':
description: "Bad Request"
'500':
description: "Internal Server Error"
/txt2embedding/openai:
x-prerequisite:
configurations:
OPENAI_API_KEY:
type: string
description: "OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`"
required:
allOf:
- OPENAI_API_KEY
post:
summary: Text to embedding
operationId: openai_text_to_embedding.oas3_openai_text_to_embedding
description: Retrieve an embedding for the provided text using the OpenAI API.
requestBody:
content:
application/json:
schema:
type: object
properties:
input:
type: string
description: The text used for embedding.
model:
type: string
description: "ID of the model to use. For more details, checkout: [models](https://api.openai.com/v1/models)"
enum:
- text-embedding-ada-002
responses:
"200":
description: Successful response
content:
application/json:
schema:
$ref: "#/components/schemas/ResultEmbedding"
"4XX":
description: Client error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"5XX":
description: Server error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
/txt2image/metagpt:
x-prerequisite:
configurations:
METAGPT_TEXT_TO_IMAGE_MODEL_URL:
type: string
description: "Model url."
required:
allOf:
- METAGPT_TEXT_TO_IMAGE_MODEL_URL
post:
summary: "Text to Image"
description: "Generate an image from the provided text using the MetaGPT Text-to-Image API."
operationId: metagpt_text_to_image.oas3_metagpt_text_to_image
requestBody:
required: true
content:
application/json:
schema:
type: object
required:
- text
properties:
text:
type: string
description: "The text used for image conversion."
size_type:
type: string
enum: ["512x512", "512x768"]
default: "512x512"
description: "Size of the generated image."
model_url:
type: string
description: "Model reset API URL for text-to-image."
default: ""
responses:
'200':
description: "Base64-encoded image data."
content:
application/json:
schema:
type: object
properties:
image_data:
type: string
format: base64
'400':
description: "Bad Request"
'500':
description: "Internal Server Error"
components:
schemas:
Embedding:
type: object
description: Represents an embedding vector returned by the embedding endpoint.
properties:
object:
type: string
example: embedding
embedding:
type: array
items:
type: number
example: [0.0023064255, -0.009327292, ...]
index:
type: integer
example: 0
Usage:
type: object
properties:
prompt_tokens:
type: integer
example: 8
total_tokens:
type: integer
example: 8
ResultEmbedding:
type: object
properties:
object:
type: string
example: result_embedding
data:
type: array
items:
$ref: "#/components/schemas/Embedding"
model:
type: string
example: text-embedding-ada-002
usage:
$ref: "#/components/schemas/Usage"
Error:
type: object
properties:
error:
type: string
example: An error occurred

35
.well-known/openapi.yaml Normal file
View file

@ -0,0 +1,35 @@
openapi: "3.0.0"
info:
title: Hello World
version: "1.0"
servers:
- url: /openapi
paths:
/greeting/{name}:
post:
summary: Generate greeting
description: Generates a greeting message.
operationId: hello.post_greeting
responses:
200:
description: greeting response
content:
text/plain:
schema:
type: string
example: "hello dave!"
parameters:
- name: name
in: path
description: Name of the person to greet.
required: true
schema:
type: string
example: "dave"
requestBody:
content:
application/json:
schema:
type: object

161
.well-known/skills.yaml Normal file
View file

@ -0,0 +1,161 @@
skillapi: "0.1.0"
info:
title: "Agent Skill Specification"
version: "1.0"
entities:
Assistant:
summary: assistant
description: assistant
skills:
- name: text_to_speech
description: Generate a voice file from the input text, text-to-speech
id: text_to_speech.text_to_speech
x-prerequisite:
configurations:
AZURE_TTS_SUBSCRIPTION_KEY:
type: string
description: "For more details, check out: [Azure Text-to_Speech](https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts)"
AZURE_TTS_REGION:
type: string
description: "For more details, check out: [Azure Text-to_Speech](https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts)"
IFLYTEK_APP_ID:
type: string
description: "Application ID is used to access your iFlyTek service API, see: `https://console.xfyun.cn/services/tts`"
IFLYTEK_API_KEY:
type: string
description: "WebAPI argument, see: `https://console.xfyun.cn/services/tts`"
IFLYTEK_API_SECRET:
type: string
description: "WebAPI argument, see: `https://console.xfyun.cn/services/tts`"
required:
oneOf:
- allOf:
- AZURE_TTS_SUBSCRIPTION_KEY
- AZURE_TTS_REGION
- allOf:
- IFLYTEK_APP_ID
- IFLYTEK_API_KEY
- IFLYTEK_API_SECRET
parameters:
text:
description: 'The text used for voice conversion.'
required: true
type: string
lang:
description: 'The value can contain a language code such as en (English), or a locale such as en-US (English - United States).'
type: string
enum:
- English
- Chinese
default: Chinese
voice:
description: Name of voice styles
type: string
default: zh-CN-XiaomoNeural
style:
type: string
description: Speaking style to express different emotions like cheerfulness, empathy, and calm.
enum:
- affectionate
- angry
- calm
- cheerful
- depressed
- disgruntled
- embarrassed
- envious
- fearful
- gentle
- sad
- serious
default: affectionate
role:
type: string
description: With roles, the same voice can act as a different age and gender.
enum:
- Girl
- Boy
- OlderAdultFemale
- OlderAdultMale
- SeniorFemale
- SeniorMale
- YoungAdultFemale
- YoungAdultMale
default: Girl
examples:
- ask: 'A girl says "hello world"'
answer: 'text_to_speech(text="hello world", role="Girl")'
- ask: 'A boy affectionate says "hello world"'
answer: 'text_to_speech(text="hello world", role="Boy", style="affectionate")'
- ask: 'A boy says "你好"'
answer: 'text_to_speech(text="你好", role="Boy", lang="Chinese")'
returns:
type: string
format: base64
- name: text_to_image
description: Create a drawing based on the text.
id: text_to_image.text_to_image
x-prerequisite:
configurations:
OPENAI_API_KEY:
type: string
description: "OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`"
METAGPT_TEXT_TO_IMAGE_MODEL_URL:
type: string
description: "Model url."
required:
oneOf:
- OPENAI_API_KEY
- METAGPT_TEXT_TO_IMAGE_MODEL_URL
parameters:
text:
description: 'The text used for image conversion.'
type: string
required: true
size_type:
description: size type
type: string
default: "512x512"
examples:
- ask: 'Draw a girl'
answer: 'text_to_image(text="Draw a girl", size_type="512x512")'
- ask: 'Draw an apple'
answer: 'text_to_image(text="Draw an apple", size_type="512x512")'
returns:
type: string
format: base64
- name: web_search
description: Perform Google searches to provide real-time information.
id: web_search.web_search
x-prerequisite:
configurations:
SEARCH_ENGINE:
type: string
description: "Supported values: serpapi/google/serper/ddg"
SERPER_API_KEY:
type: string
description: "SERPER API KEY, For more details, checkout: `https://serper.dev/api-key`"
required:
allOf:
- SEARCH_ENGINE
- SERPER_API_KEY
parameters:
query:
type: string
description: 'The search query.'
required: true
max_results:
type: number
default: 6
description: 'The number of search results to retrieve.'
examples:
- ask: 'Search for information about artificial intelligence'
answer: 'web_search(query="Search for information about artificial intelligence", max_results=6)'
- ask: 'Find news articles about climate change'
answer: 'web_search(query="Find news articles about climate change", max_results=6)'
returns:
type: string

View file

@ -14,6 +14,8 @@ OPENAI_BASE_URL: "https://api.openai.com/v1"
OPENAI_API_MODEL: "gpt-4-1106-preview"
MAX_TOKENS: 4096
RPM: 10
LLM_TYPE: OpenAI # Except for these three major models OpenAI, MetaGPT LLM, and Azure other large models can be distinguished based on the validity of the key.
TIMEOUT: 60 # Timeout for llm invocation
#### if Spark
#SPARK_APPID : "YOUR_APPID"
@ -119,4 +121,24 @@ RPM: 10
# PROMPT_FORMAT: json #json or markdown
### Agent configurations
# RAISE_NOT_CONFIG_ERROR: true # "true" if the LLM key is not configured, throw a NotConfiguredException, else "false".
# WORKSPACE_PATH_WITH_UID: false # "true" if using `{workspace}/{uid}` as the workspace path; "false" use `{workspace}`.
### Meta Models
#METAGPT_TEXT_TO_IMAGE_MODEL: MODEL_URL
### S3 config
#S3_ACCESS_KEY: "YOUR_S3_ACCESS_KEY"
#S3_SECRET_KEY: "YOUR_S3_SECRET_KEY"
#S3_ENDPOINT_URL: "YOUR_S3_ENDPOINT_URL"
#S3_SECURE: true # true/false
#S3_BUCKET: "YOUR_S3_BUCKET"
### Redis config
#REDIS_HOST: "YOUR_REDIS_HOST"
#REDIS_PORT: "YOUR_REDIS_PORT"
#REDIS_PASSWORD: "YOUR_REDIS_PASSWORD"
#REDIS_DB: "YOUR_REDIS_DB_INDEX, str, 0-based"
# DISABLE_LLM_PROVIDER_CHECK: false

BIN
examples/example.faiss Normal file

Binary file not shown.

10
examples/example.json Normal file
View file

@ -0,0 +1,10 @@
[
{
"source": "Which facial cleanser is good for oily skin?",
"output": "ABC cleanser is preferred by many with oily skin."
},
{
"source": "Is L'Oreal good to use?",
"output": "L'Oreal is a popular brand with many positive reviews."
}
]

BIN
examples/example.pkl Normal file

Binary file not shown.

BIN
examples/faq.xlsx Normal file

Binary file not shown.

View file

@ -2,30 +2,18 @@
# -*- coding: utf-8 -*-
"""
@File : search_kb.py
@Modified By: mashenquan, 2023-12-22. Delete useless codes.
"""
import asyncio
from langchain.embeddings import OpenAIEmbeddings
from metagpt.config import CONFIG
from metagpt.const import DATA_PATH
from metagpt.const import DATA_PATH, EXAMPLE_PATH
from metagpt.document_store import FaissStore
from metagpt.logs import logger
from metagpt.roles import Sales
""" example.json, e.g.
[
{
"source": "Which facial cleanser is good for oily skin?",
"output": "ABC cleanser is preferred by many with oily skin."
},
{
"source": "Is L'Oreal good to use?",
"output": "L'Oreal is a popular brand with many positive reviews."
}
]
"""
def get_store():
embedding = OpenAIEmbeddings(openai_api_key=CONFIG.openai_api_key, openai_api_base=CONFIG.openai_base_url)
@ -33,13 +21,11 @@ def get_store():
async def search():
role = Sales(profile="Sales", store=get_store())
queries = ["Which facial cleanser is good for oily skin?", "Is L'Oreal good to use?"]
for query in queries:
logger.info(f"User: {query}")
result = await role.run(query)
logger.info(result)
store = FaissStore(EXAMPLE_PATH / "example.json")
role = Sales(profile="Sales", store=store)
query = "Which facial cleanser is good for oily skin?"
result = await role.run(query)
logger.info(result)
if __name__ == "__main__":

View file

@ -1,3 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
"""
import asyncio
from metagpt.roles import Searcher

View file

@ -14,6 +14,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
from pydantic import BaseModel, create_model, root_validator, validator
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.config import CONFIG
from metagpt.llm import BaseGPTAPI
from metagpt.logs import logger
from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess
@ -260,9 +261,10 @@ class ActionNode:
output_data_mapping: dict,
system_msgs: Optional[list[str]] = None,
schema="markdown", # compatible to original format
timeout=CONFIG.timeout,
) -> (str, BaseModel):
"""Use ActionOutput to wrap the output of aask"""
content = await self.llm.aask(prompt, system_msgs)
content = await self.llm.aask(prompt, system_msgs, timeout=timeout)
logger.debug(f"llm raw output:\n{content}")
output_class = self.create_model_class(output_class_name, output_data_mapping)
@ -289,13 +291,13 @@ class ActionNode:
def set_context(self, context):
self.set_recursive("context", context)
async def simple_fill(self, schema, mode):
async def simple_fill(self, schema, mode, timeout=CONFIG.timeout):
prompt = self.compile(context=self.context, schema=schema, mode=mode)
if schema != "raw":
mapping = self.get_mapping(mode)
class_name = f"{self.key}_AN"
content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema)
content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema, timeout=timeout)
self.content = content
self.instruct_content = scontent
else:
@ -304,7 +306,7 @@ class ActionNode:
return self
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple"):
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout):
"""Fill the node(s) with mode.
:param context: Everything we should know when filling node.
@ -320,6 +322,7 @@ class ActionNode:
:param strgy: simple/complex
- simple: run only once
- complex: run each node
:param timeout: Timeout for llm invocation.
:return: self
"""
self.set_llm(llm)
@ -328,12 +331,12 @@ class ActionNode:
schema = self.schema
if strgy == "simple":
return await self.simple_fill(schema=schema, mode=mode)
return await self.simple_fill(schema=schema, mode=mode, timeout=timeout)
elif strgy == "complex":
# 这里隐式假设了拥有children
tmp = {}
for _, i in self.children.items():
child = await i.simple_fill(schema=schema, mode=mode)
child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout)
tmp.update(child.instruct_content.dict())
cls = self.create_children_class()
self.instruct_content = cls(**tmp)

View file

@ -1,4 +1,3 @@
import traceback
from pathlib import Path
from pydantic import Field
@ -8,6 +7,7 @@ from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.schema import Message
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.highlight import highlight
CLONE_PROMPT = """
@ -39,7 +39,7 @@ class CloneFunction(WriteCode):
if isinstance(code_path, str):
code_path = Path(code_path)
code_path.parent.mkdir(parents=True, exist_ok=True)
code_path.write_text(code)
code_path.write_text(code, encoding="utf-8")
logger.info(f"Saving Code to {code_path}")
async def run(self, template_func: str, source_code: str) -> str:
@ -51,20 +51,17 @@ class CloneFunction(WriteCode):
return code
@handle_exception
def run_function_code(func_code: str, func_name: str, *args, **kwargs):
"""Run function code from string code."""
try:
locals_ = {}
exec(func_code, locals_)
func = locals_[func_name]
return func(*args, **kwargs), ""
except Exception:
return "", traceback.format_exc()
locals_ = {}
exec(func_code, locals_)
func = locals_[func_name]
return func(*args, **kwargs), ""
def run_function_script(code_script_path: str, func_name: str, *args, **kwargs):
"""Run function code from script."""
if isinstance(code_script_path, str):
code_path = Path(code_script_path)
code_path = Path(code_script_path)
code = code_path.read_text(encoding="utf-8")
return run_function_code(code, func_name, *args, **kwargs)

View file

@ -19,5 +19,5 @@ class ExecuteTask(Action):
context: list[Message] = []
llm: BaseGPTAPI = Field(default_factory=LLM)
def run(self, *args, **kwargs):
async def run(self, *args, **kwargs):
pass

View file

@ -11,6 +11,3 @@ class FixBug(Action):
"""Fix bug action without any implementation details"""
name: str = "FixBug"
async def run(self, *args, **kwargs):
raise NotImplementedError

View file

@ -0,0 +1,68 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/19
@Author : mashenquan
@File : rebuild_class_view.py
@Desc : Rebuild class view info
"""
import re
from pathlib import Path
from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.const import CLASS_VIEW_FILE_REPO, GRAPH_REPO_FILE_REPO
from metagpt.repo_parser import RepoParser
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
class RebuildClassView(Action):
def __init__(self, name="", context=None, llm=None):
super().__init__(name=name, context=context, llm=llm)
async def run(self, with_messages=None, format=CONFIG.prompt_schema):
graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
repo_parser = RepoParser(base_directory=self.context)
class_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint
await GraphRepository.update_graph_db_with_class_views(graph_db, class_views)
symbols = repo_parser.generate_symbols() # use ast
for file_info in symbols:
await GraphRepository.update_graph_db_with_file_info(graph_db, file_info)
await self._create_mermaid_class_view(graph_db=graph_db)
await self._save(graph_db=graph_db)
async def _create_mermaid_class_view(self, graph_db):
pass
# dataset = await graph_db.select(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_PAGE_INFO)
# if not dataset:
# logger.warning(f"No page info for {concat_namespace(filename, class_name)}")
# return
# code_block_info = CodeBlockInfo.parse_raw(dataset[0].object_)
# src_code = await read_file_block(filename=Path(self.context) / filename, lineno=code_block_info.lineno, end_lineno=code_block_info.end_lineno)
# code_type = ""
# dataset = await graph_db.select(subject=filename, predicate=GraphKeyword.IS)
# for spo in dataset:
# if spo.object_ in ["javascript", "python"]:
# code_type = spo.object_
# break
# try:
# node = await REBUILD_CLASS_VIEW_NODE.fill(context=f"```{code_type}\n{src_code}\n```", llm=self.llm, to=format)
# class_view = node.instruct_content.dict()["Class View"]
# except Exception as e:
# class_view = RepoParser.rebuild_class_view(src_code, code_type)
# await graph_db.insert(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_CLASS_VIEW, object_=class_view)
# logger.info(f"{concat_namespace(filename, class_name)} {GraphKeyword.HAS_CLASS_VIEW} {class_view}")
async def _save(self, graph_db):
class_view_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CLASS_VIEW_FILE_REPO)
dataset = await graph_db.select(predicate=GraphKeyword.HAS_CLASS_VIEW)
all_class_view = []
for spo in dataset:
title = f"---\ntitle: {spo.subject}\n---\n"
filename = re.sub(r"[/:]", "_", spo.subject) + ".mmd"
await class_view_file_repo.save(filename=filename, content=title + spo.object_)
all_class_view.append(spo.object_)
await class_view_file_repo.save(filename="all.mmd", content="\n".join(all_class_view))

View file

@ -0,0 +1,33 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/19
@Author : mashenquan
@File : rebuild_class_view_an.py
@Desc : Defines `ActionNode` objects used by rebuild_class_view.py
"""
from metagpt.actions.action_node import ActionNode
CLASS_SOURCE_CODE_BLOCK = ActionNode(
key="Class View",
expected_type=str,
instruction='Generate the mermaid class diagram corresponding to source code in "context."',
example="""
classDiagram
class A {
-int x
+int y
-int speed
-int direction
+__init__(x: int, y: int, speed: int, direction: int)
+change_direction(new_direction: int) None
+move() None
}
""",
)
REBUILD_CLASS_VIEW_NODES = [
CLASS_SOURCE_CODE_BLOCK,
]
REBUILD_CLASS_VIEW_NODE = ActionNode.from_children("RebuildClassView", REBUILD_CLASS_VIEW_NODES)

View file

@ -105,6 +105,7 @@ You are a member of a professional butler team and will provide helpful suggesti
"""
# TOTEST
class SearchAndSummarize(Action):
name: str = ""
content: Optional[str] = None

View file

@ -0,0 +1,111 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/8/28
@Author : mashenquan
@File : skill_action.py
@Desc : Call learned skill
"""
from __future__ import annotations
import ast
import importlib
import traceback
from copy import deepcopy
from typing import Dict, Optional
from metagpt.actions import Action
from metagpt.learn.skill_loader import Skill
from metagpt.logs import logger
from metagpt.schema import Message
# TOTEST
class ArgumentsParingAction(Action):
skill: Skill
ask: str
rsp: Optional[Message] = None
args: Optional[Dict] = None
@property
def prompt(self):
prompt = "You are a function parser. You can convert spoken words into function parameters.\n"
prompt += "\n---\n"
prompt += f"{self.skill.name} function parameters description:\n"
for k, v in self.skill.arguments.items():
prompt += f"parameter `{k}`: {v}\n"
prompt += "\n---\n"
prompt += "Examples:\n"
for e in self.skill.examples:
prompt += f"If want you to do `{e.ask}`, return `{e.answer}` brief and clear.\n"
prompt += "\n---\n"
prompt += (
f"\nRefer to the `{self.skill.name}` function description, and fill in the function parameters according "
'to the example "I want you to do xx" in the Examples section.'
f"\nNow I want you to do `{self.ask}`, return function parameters in Examples format above, brief and "
"clear."
)
return prompt
async def run(self, with_message=None, **kwargs) -> Message:
prompt = self.prompt
rsp = await self.llm.aask(msg=prompt, system_msgs=[])
logger.debug(f"SKILL:{prompt}\n, RESULT:{rsp}")
self.args = ArgumentsParingAction.parse_arguments(skill_name=self.skill.name, txt=rsp)
self.rsp = Message(content=rsp, role="assistant", instruct_content=self.args, cause_by=self)
return self.rsp
@staticmethod
def parse_arguments(skill_name, txt) -> dict:
prefix = skill_name + "("
if prefix not in txt:
logger.error(f"{skill_name} not in {txt}")
return None
if ")" not in txt:
logger.error(f"')' not in {txt}")
return None
begin_ix = txt.find(prefix)
end_ix = txt.rfind(")")
args_txt = txt[begin_ix + len(prefix) : end_ix]
logger.info(args_txt)
fake_expression = f"dict({args_txt})"
parsed_expression = ast.parse(fake_expression, mode="eval")
args = {}
for keyword in parsed_expression.body.keywords:
key = keyword.arg
value = ast.literal_eval(keyword.value)
args[key] = value
return args
class SkillAction(Action):
skill: Skill
args: Dict
rsp: Optional[Message] = None
async def run(self, with_message=None, **kwargs) -> Message:
"""Run action"""
options = deepcopy(kwargs)
if self.args:
for k in self.args.keys():
if k in options:
options.pop(k)
try:
rsp = await self.find_and_call_function(self.skill.name, args=self.args, **options)
self.rsp = Message(content=rsp, role="assistant", cause_by=self)
except Exception as e:
logger.exception(f"{e}, traceback:{traceback.format_exc()}")
self.rsp = Message(content=f"Error: {e}", role="assistant", cause_by=self)
return self.rsp
@staticmethod
async def find_and_call_function(function_name, args, **kwargs) -> str:
try:
module = importlib.import_module("metagpt.learn")
function = getattr(module, function_name)
# Invoke function and return result
result = await function(**args, **kwargs)
return result
except (ModuleNotFoundError, AttributeError):
logger.error(f"{function_name} not found")
raise ValueError(f"{function_name} not found")

View file

@ -91,6 +91,7 @@ flowchart TB
"""
# TOTEST
class SummarizeCode(Action):
name: str = "SummarizeCode"
context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext)

View file

@ -0,0 +1,163 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/8/28
@Author : mashenquan
@File : talk_action.py
@Desc : Act as its a talk
"""
from typing import Optional
from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.const import DEFAULT_LANGUAGE
from metagpt.logs import logger
from metagpt.schema import Message
# TOTEST
class TalkAction(Action):
context: str
history_summary: str = ""
knowledge: str = ""
rsp: Optional[Message] = None
@property
def prompt(self):
prompt = ""
if CONFIG.agent_description:
prompt = (
f"You are {CONFIG.agent_description}. Your responses should align with the role-play agreement, "
f"maintaining the character's persona and habits. When faced with unrelated questions, playfully "
f"decline to answer without revealing your AI nature to preserve the character's image.\n\n"
)
prompt += f"Knowledge:\n{self.knowledge}\n\n" if self.knowledge else ""
prompt += f"{self.history_summary}\n\n"
prompt += (
"If the information is insufficient, you can search in the historical conversation or knowledge above.\n"
)
language = CONFIG.language or DEFAULT_LANGUAGE
prompt += (
f"Answer the following questions strictly in {language}, and the answers must follow the Markdown format.\n "
f"{self.context}"
)
logger.debug(f"PROMPT: {prompt}")
return prompt
@property
def prompt_gpt4(self):
kvs = {
"{role}": CONFIG.agent_description or "",
"{history}": self.history_summary or "",
"{knowledge}": self.knowledge or "",
"{language}": CONFIG.language or DEFAULT_LANGUAGE,
"{ask}": self.context,
}
prompt = TalkActionPrompt.FORMATION_LOOSE
for k, v in kvs.items():
prompt = prompt.replace(k, v)
logger.info(f"PROMPT: {prompt}")
return prompt
# async def run_old(self, *args, **kwargs) -> ActionOutput:
# prompt = self.prompt
# rsp = await self.llm.aask(msg=prompt, system_msgs=[])
# logger.debug(f"PROMPT:{prompt}\nRESULT:{rsp}\n")
# self._rsp = ActionOutput(content=rsp)
# return self._rsp
@property
def aask_args(self):
language = CONFIG.language or DEFAULT_LANGUAGE
system_msgs = [
f"You are {CONFIG.agent_description}.",
"Your responses should align with the role-play agreement, "
"maintaining the character's persona and habits. When faced with unrelated questions, playfully "
"decline to answer without revealing your AI nature to preserve the character's image.",
"If the information is insufficient, you can search in the context or knowledge.",
f"Answer the following questions strictly in {language}, and the answers must follow the Markdown format.",
]
format_msgs = []
if self.knowledge:
format_msgs.append({"role": "assistant", "content": self.knowledge})
if self.history_summary:
format_msgs.append({"role": "assistant", "content": self.history_summary})
return self.context, format_msgs, system_msgs
async def run(self, with_message=None, **kwargs) -> Message:
msg, format_msgs, system_msgs = self.aask_args
rsp = await self.llm.aask(msg=msg, format_msgs=format_msgs, system_msgs=system_msgs)
self.rsp = Message(content=rsp, role="assistant", cause_by=self)
return self.rsp
class TalkActionPrompt:
FORMATION = """Formation: "Capacity and role" defines the role you are currently playing;
"[HISTORY_BEGIN]" and "[HISTORY_END]" tags enclose the historical conversation;
"[KNOWLEDGE_BEGIN]" and "[KNOWLEDGE_END]" tags enclose the knowledge may help for your responses;
"Statement" defines the work detail you need to complete at this stage;
"[ASK_BEGIN]" and [ASK_END] tags enclose the questions;
"Constraint" defines the conditions that your responses must comply with.
"Personality" defines your language style
"Insight" provides a deeper understanding of the characters' inner traits.
"Initial" defines the initial setup of a character.
Capacity and role: {role}
Statement: Your responses should align with the role-play agreement, maintaining the
character's persona and habits. When faced with unrelated questions, playfully decline to answer without revealing
your AI nature to preserve the character's image.
[HISTORY_BEGIN]
{history}
[HISTORY_END]
[KNOWLEDGE_BEGIN]
{knowledge}
[KNOWLEDGE_END]
Statement: If the information is insufficient, you can search in the historical conversation or knowledge.
Statement: Unless you are a language professional, answer the following questions strictly in {language}
, and the answers must follow the Markdown format. Strictly excluding any tag likes "[HISTORY_BEGIN]"
, "[HISTORY_END]", "[KNOWLEDGE_BEGIN]", "[KNOWLEDGE_END]" in responses.
{ask}
"""
FORMATION_LOOSE = """Formation: "Capacity and role" defines the role you are currently playing;
"[HISTORY_BEGIN]" and "[HISTORY_END]" tags enclose the historical conversation;
"[KNOWLEDGE_BEGIN]" and "[KNOWLEDGE_END]" tags enclose the knowledge may help for your responses;
"Statement" defines the work detail you need to complete at this stage;
"Constraint" defines the conditions that your responses must comply with.
"Personality" defines your language style
"Insight" provides a deeper understanding of the characters' inner traits.
"Initial" defines the initial setup of a character.
Capacity and role: {role}
Statement: Your responses should maintaining the character's persona and habits. When faced with unrelated questions
, playfully decline to answer without revealing your AI nature to preserve the character's image.
[HISTORY_BEGIN]
{history}
[HISTORY_END]
[KNOWLEDGE_BEGIN]
{knowledge}
[KNOWLEDGE_END]
Statement: If the information is insufficient, you can search in the historical conversation or knowledge.
Statement: Unless you are a language professional, answer the following questions strictly in {language}
, and the answers must follow the Markdown format. Strictly excluding any tag likes "[HISTORY_BEGIN]"
, "[HISTORY_END]", "[KNOWLEDGE_BEGIN]", "[KNOWLEDGE_END]" in responses.
{ask}
"""

View file

@ -123,7 +123,7 @@ class WritePRD(Action):
# logger.info(rsp)
project_name = CONFIG.project_name if CONFIG.project_name else ""
context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name)
node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, schema=schema)
node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm) # schema=schema
await self._rename_workspace(node)
return node

View file

@ -0,0 +1,193 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/7/27
@Author : mashenquan
@File : write_teaching_plan.py
"""
from typing import Optional
from pydantic import Field
from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
class WriteTeachingPlanPart(Action):
"""Write Teaching Plan Part"""
context: Optional[str] = None
llm: BaseGPTAPI = Field(default_factory=LLM)
topic: str = ""
language: str = "Chinese"
rsp: Optional[str] = None
async def run(self, with_message=None, **kwargs):
statement_patterns = TeachingPlanBlock.TOPIC_STATEMENTS.get(self.topic, [])
statements = []
for p in statement_patterns:
s = self.format_value(p)
statements.append(s)
formatter = (
TeachingPlanBlock.PROMPT_TITLE_TEMPLATE
if self.topic == TeachingPlanBlock.COURSE_TITLE
else TeachingPlanBlock.PROMPT_TEMPLATE
)
prompt = formatter.format(
formation=TeachingPlanBlock.FORMATION,
role=self.prefix,
statements="\n".join(statements),
lesson=self.context,
topic=self.topic,
language=self.language,
)
logger.debug(prompt)
rsp = await self._aask(prompt=prompt)
logger.debug(rsp)
self._set_result(rsp)
return self.rsp
def _set_result(self, rsp):
if TeachingPlanBlock.DATA_BEGIN_TAG in rsp:
ix = rsp.index(TeachingPlanBlock.DATA_BEGIN_TAG)
rsp = rsp[ix + len(TeachingPlanBlock.DATA_BEGIN_TAG) :]
if TeachingPlanBlock.DATA_END_TAG in rsp:
ix = rsp.index(TeachingPlanBlock.DATA_END_TAG)
rsp = rsp[0:ix]
self.rsp = rsp.strip()
if self.topic != TeachingPlanBlock.COURSE_TITLE:
return
if "#" not in self.rsp or self.rsp.index("#") != 0:
self.rsp = "# " + self.rsp
def __str__(self):
"""Return `topic` value when str()"""
return self.topic
def __repr__(self):
"""Show `topic` value when debug"""
return self.topic
@staticmethod
def format_value(value):
"""Fill parameters inside `value` with `options`."""
if not isinstance(value, str):
return value
if "{" not in value:
return value
merged_opts = CONFIG.options or {}
try:
return value.format(**merged_opts)
except KeyError as e:
logger.warning(f"Parameter is missing:{e}")
for k, v in merged_opts.items():
value = value.replace("{" + f"{k}" + "}", str(v))
return value
class TeachingPlanBlock:
FORMATION = (
'"Capacity and role" defines the role you are currently playing;\n'
'\t"[LESSON_BEGIN]" and "[LESSON_END]" tags enclose the content of textbook;\n'
'\t"Statement" defines the work detail you need to complete at this stage;\n'
'\t"Answer options" defines the format requirements for your responses;\n'
'\t"Constraint" defines the conditions that your responses must comply with.'
)
COURSE_TITLE = "Title"
TOPICS = [
COURSE_TITLE,
"Teaching Hours",
"Teaching Objectives",
"Teaching Content",
"Teaching Methods and Strategies",
"Learning Activities",
"Teaching Time Allocation",
"Assessment and Feedback",
"Teaching Summary and Improvement",
"Vocabulary Cloze",
"Choice Questions",
"Grammar Questions",
"Translation Questions",
]
TOPIC_STATEMENTS = {
COURSE_TITLE: [
"Statement: Find and return the title of the lesson only in markdown first-level header format, "
"without anything else."
],
"Teaching Content": [
'Statement: "Teaching Content" must include vocabulary, analysis, and examples of various grammar '
"structures that appear in the textbook, as well as the listening materials and key points.",
'Statement: "Teaching Content" must include more examples.',
],
"Teaching Time Allocation": [
'Statement: "Teaching Time Allocation" must include how much time is allocated to each '
"part of the textbook content."
],
"Teaching Methods and Strategies": [
'Statement: "Teaching Methods and Strategies" must include teaching focus, difficulties, materials, '
"procedures, in detail."
],
"Vocabulary Cloze": [
'Statement: Based on the content of the textbook enclosed by "[LESSON_BEGIN]" and "[LESSON_END]", '
"create vocabulary cloze. The cloze should include 10 {language} questions with {teaching_language} "
"answers, and it should also include 10 {teaching_language} questions with {language} answers. "
"The key-related vocabulary and phrases in the textbook content must all be included in the exercises.",
],
"Grammar Questions": [
'Statement: Based on the content of the textbook enclosed by "[LESSON_BEGIN]" and "[LESSON_END]", '
"create grammar questions. 10 questions."
],
"Choice Questions": [
'Statement: Based on the content of the textbook enclosed by "[LESSON_BEGIN]" and "[LESSON_END]", '
"create choice questions. 10 questions."
],
"Translation Questions": [
'Statement: Based on the content of the textbook enclosed by "[LESSON_BEGIN]" and "[LESSON_END]", '
"create translation questions. The translation should include 10 {language} questions with "
"{teaching_language} answers, and it should also include 10 {teaching_language} questions with "
"{language} answers."
],
}
# Teaching plan title
PROMPT_TITLE_TEMPLATE = (
"Do not refer to the context of the previous conversation records, "
"start the conversation anew.\n\n"
"Formation: {formation}\n\n"
"{statements}\n"
"Constraint: Writing in {language}.\n"
'Answer options: Encloses the lesson title with "[TEACHING_PLAN_BEGIN]" '
'and "[TEACHING_PLAN_END]" tags.\n'
"[LESSON_BEGIN]\n"
"{lesson}\n"
"[LESSON_END]"
)
# Teaching plan parts:
PROMPT_TEMPLATE = (
"Do not refer to the context of the previous conversation records, "
"start the conversation anew.\n\n"
"Formation: {formation}\n\n"
"Capacity and role: {role}\n"
'Statement: Write the "{topic}" part of teaching plan, '
'WITHOUT ANY content unrelated to "{topic}"!!\n'
"{statements}\n"
'Answer options: Enclose the teaching plan content with "[TEACHING_PLAN_BEGIN]" '
'and "[TEACHING_PLAN_END]" tags.\n'
"Answer options: Using proper markdown format from second-level header format.\n"
"Constraint: Writing in {language}.\n"
"[LESSON_BEGIN]\n"
"{lesson}\n"
"[LESSON_END]"
)
DATA_BEGIN_TAG = "[TEACHING_PLAN_BEGIN]"
DATA_END_TAG = "[TEACHING_PLAN_END]"

View file

@ -44,7 +44,7 @@ you should correctly import the necessary classes based on these file locations!
class WriteTest(Action):
name: str = "WriteTest"
context: Optional[str] = None
context: Optional[TestingContext] = None
llm: BaseGPTAPI = Field(default_factory=LLM)
async def write_code(self, prompt):

View file

@ -6,12 +6,15 @@ Provide configuration, singleton
1. According to Section 2.2.3.11 of RFC 135, add git repository support.
2. Add the parameter `src_workspace` for the old version project path.
"""
import datetime
import json
import os
import warnings
from copy import deepcopy
from enum import Enum
from pathlib import Path
from typing import Any
from uuid import uuid4
import yaml
@ -19,6 +22,7 @@ from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT, OPTIONS
from metagpt.logs import logger
from metagpt.tools import SearchEngineType, WebBrowserEngineType
from metagpt.utils.common import require_python_version
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.singleton import Singleton
@ -42,6 +46,8 @@ class LLMProviderEnum(Enum):
FIREWORKS = "fireworks"
OPEN_LLM = "open_llm"
GEMINI = "gemini"
METAGPT = "metagpt"
AZURE_OPENAI = "azure_openai"
OLLAMA = "ollama"
@ -58,7 +64,7 @@ class Config(metaclass=Singleton):
key_yaml_file = METAGPT_ROOT / "config/key.yaml"
default_yaml_file = METAGPT_ROOT / "config/config.yaml"
def __init__(self, yaml_file=default_yaml_file):
def __init__(self, yaml_file=default_yaml_file, cost_data=""):
global_options = OPTIONS.get()
# cli paras
self.project_path = ""
@ -68,32 +74,57 @@ class Config(metaclass=Singleton):
self.max_auto_summarize_code = 0
self._init_with_config_files_and_env(yaml_file)
# The agent needs to be billed per user, so billing information cannot be destroyed when the session ends.
self.cost_manager = CostManager(**json.loads(cost_data)) if cost_data else CostManager()
self._update()
global_options.update(OPTIONS.get())
logger.debug("Config loading done.")
def get_default_llm_provider_enum(self) -> LLMProviderEnum:
for k, v in [
(self.openai_api_key, LLMProviderEnum.OPENAI),
(self.anthropic_api_key, LLMProviderEnum.ANTHROPIC),
(self.zhipuai_api_key, LLMProviderEnum.ZHIPUAI),
(self.fireworks_api_key, LLMProviderEnum.FIREWORKS),
(self.open_llm_api_base, LLMProviderEnum.OPEN_LLM),
(self.gemini_api_key, LLMProviderEnum.GEMINI),
(self.ollama_api_base, LLMProviderEnum.OLLAMA), # reuse logic. but not a key
]:
if self._is_valid_llm_key(k):
# logger.debug(f"Use LLMProvider: {v.value}")
if v == LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)):
warnings.warn("Use Gemini requires Python >= 3.10")
if self.openai_api_key and self.openai_api_model:
logger.info(f"OpenAI API Model: {self.openai_api_model}")
return v
"""Get first valid LLM provider enum"""
mappings = {
LLMProviderEnum.OPENAI: bool(
self._is_valid_llm_key(self.OPENAI_API_KEY) and not self.OPENAI_API_TYPE and self.OPENAI_API_MODEL
),
LLMProviderEnum.ANTHROPIC: self._is_valid_llm_key(self.ANTHROPIC_API_KEY),
LLMProviderEnum.ZHIPUAI: self._is_valid_llm_key(self.ZHIPUAI_API_KEY),
LLMProviderEnum.FIREWORKS: self._is_valid_llm_key(self.FIREWORKS_API_KEY),
LLMProviderEnum.OPEN_LLM: self._is_valid_llm_key(self.OPEN_LLM_API_BASE),
LLMProviderEnum.GEMINI: self._is_valid_llm_key(self.GEMINI_API_KEY),
LLMProviderEnum.METAGPT: bool(
self._is_valid_llm_key(self.OPENAI_API_KEY) and self.OPENAI_API_TYPE == "metagpt"
),
LLMProviderEnum.AZURE_OPENAI: bool(
self._is_valid_llm_key(self.OPENAI_API_KEY)
and self.OPENAI_API_TYPE == "azure"
and self.DEPLOYMENT_NAME
and self.OPENAI_API_VERSION
),
LLMProviderEnum.OLLAMA: self._is_valid_llm_key(self.OLLAMA_API_BASE),
}
provider = None
for k, v in mappings.items():
if v:
provider = k
break
if provider is LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)):
warnings.warn("Use Gemini requires Python >= 3.10")
model_mappings = {
LLMProviderEnum.OPENAI: self.OPENAI_API_MODEL,
LLMProviderEnum.AZURE_OPENAI: self.DEPLOYMENT_NAME,
}
model_name = model_mappings.get(provider)
if model_name:
logger.info(f"{provider} Model: {model_name}")
if provider:
logger.info(f"API: {provider}")
return provider
raise NotConfiguredException("You should config a LLM configuration first")
@staticmethod
def _is_valid_llm_key(k: str) -> bool:
return k and k != "YOUR_API_KEY"
return bool(k and k != "YOUR_API_KEY")
def _update(self):
self.global_proxy = self._get("GLOBAL_PROXY")
@ -111,7 +142,7 @@ class Config(metaclass=Singleton):
if not self._get("DISABLE_LLM_PROVIDER_CHECK"):
_ = self.get_default_llm_provider_enum()
self.openai_base_url = self._get("OPENAI_BASE_URL")
# self.openai_base_url = self._get("OPENAI_BASE_URL")
self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
self.openai_api_type = self._get("OPENAI_API_TYPE")
self.openai_api_version = self._get("OPENAI_API_VERSION")
@ -142,8 +173,7 @@ class Config(metaclass=Singleton):
self.long_term_memory = self._get("LONG_TERM_MEMORY", False)
if self.long_term_memory:
logger.warning("LONG_TERM_MEMORY is True")
self.max_budget = self._get("MAX_BUDGET", 10.0)
self.total_cost = 0.0
self.cost_manager.max_budget = self._get("MAX_BUDGET", 10.0)
self.code_review_k_times = 2
self.puppeteer_config = self._get("PUPPETEER_CONFIG", "")
@ -154,10 +184,18 @@ class Config(metaclass=Singleton):
self.mermaid_engine = self._get("MERMAID_ENGINE", "nodejs")
self.pyppeteer_executable_path = self._get("PYPPETEER_EXECUTABLE_PATH", "")
workspace_uid = (
self._get("WORKSPACE_UID") or f"{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}-{uuid4().hex[-8:]}"
)
self.repair_llm_output = self._get("REPAIR_LLM_OUTPUT", False)
self.prompt_schema = self._get("PROMPT_FORMAT", "json")
self.workspace_path = Path(self._get("WORKSPACE_PATH", DEFAULT_WORKSPACE_ROOT))
val = self._get("WORKSPACE_PATH_WITH_UID")
if val and val.lower() == "true": # for agent
self.workspace_path = self.workspace_path / workspace_uid
self._ensure_workspace_exists()
self.max_auto_summarize_code = self.max_auto_summarize_code or self._get("MAX_AUTO_SUMMARIZE_CODE", 1)
self.timeout = int(self._get("TIMEOUT", 3))
def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code):
"""update config via cli"""
@ -198,7 +236,8 @@ class Config(metaclass=Singleton):
return i.get(*args, **kwargs)
def get(self, key, *args, **kwargs):
"""Search for a value in config/key.yaml, config/config.yaml, and env; raise an error if not found"""
"""Retrieve values from config/key.yaml, config/config.yaml, and environment variables.
Throw an error if not found."""
value = self._get(key, *args, **kwargs)
if value is None:
raise ValueError(f"Key '{key}' not found in environment variables or in the YAML file")

View file

@ -48,9 +48,10 @@ def get_metagpt_root():
# METAGPT PROJECT ROOT AND VARS
METAGPT_ROOT = get_metagpt_root()
METAGPT_ROOT = get_metagpt_root() # Dependent on METAGPT_PROJECT_ROOT
DEFAULT_WORKSPACE_ROOT = METAGPT_ROOT / "workspace"
EXAMPLE_PATH = METAGPT_ROOT / "examples"
DATA_PATH = METAGPT_ROOT / "data"
RESEARCH_PATH = DATA_PATH / "research"
TUTORIAL_PATH = DATA_PATH / "tutorial_docx"
@ -100,7 +101,27 @@ TEST_CODES_FILE_REPO = "tests"
TEST_OUTPUTS_FILE_REPO = "test_outputs"
CODE_SUMMARIES_FILE_REPO = "docs/code_summaries"
CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summaries"
RESOURCES_FILE_REPO = "resources"
SD_OUTPUT_FILE_REPO = "resources/SD_Output"
GRAPH_REPO_FILE_REPO = "docs/graph_repo"
CLASS_VIEW_FILE_REPO = "docs/class_views"
YAPI_URL = "http://yapi.deepwisdomai.com/"
DEFAULT_LANGUAGE = "English"
DEFAULT_MAX_TOKENS = 1500
COMMAND_TOKENS = 500
BRAIN_MEMORY = "BRAIN_MEMORY"
SKILL_PATH = "SKILL_PATH"
SERPER_API_KEY = "SERPER_API_KEY"
DEFAULT_TOKEN_SIZE = 500
# format
BASE64_FORMAT = "base64"
# REDIS
REDIS_KEY = "REDIS_KEY"
LLM_API_TIMEOUT = 300
# Message id
IGNORED_MESSAGE_ID = "0"

View file

@ -13,6 +13,7 @@ from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain_core.embeddings import Embeddings
from metagpt.config import CONFIG
from metagpt.const import DATA_PATH
from metagpt.document import IndexableDocument
from metagpt.document_store.base_store import LocalStore
@ -25,7 +26,9 @@ class FaissStore(LocalStore):
):
self.meta_col = meta_col
self.content_col = content_col
self.embedding = embedding or OpenAIEmbeddings()
self.embedding = embedding or OpenAIEmbeddings(
openai_api_key=CONFIG.openai_api_key, openai_api_base=CONFIG.openai_base_url
)
super().__init__(raw_data, cache_dir)
def _load(self) -> Optional["FaissStore"]:

View file

@ -17,6 +17,7 @@ from typing import Iterable, Set
from pydantic import BaseModel, Field
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.roles.role import Role, role_subclass_registry
from metagpt.schema import Message
@ -108,7 +109,7 @@ class Environment(BaseModel):
for role in roles: # setup system message with roles
role.set_env(self)
def publish_message(self, message: Message) -> bool:
def publish_message(self, message: Message, peekable: bool = True) -> bool:
"""
Distribute the message to the recipients.
In accordance with the Message routing structure design in Chapter 2.2.1 of RFC 116, as already planned
@ -173,3 +174,8 @@ class Environment(BaseModel):
def set_subscription(self, obj, tags):
"""Set the labels for message to be consumed by the object"""
self.members[obj] = tags
@staticmethod
def archive(auto_archive=True):
if auto_archive and CONFIG.git_repo:
CONFIG.git_repo.archive()

View file

@ -5,3 +5,9 @@
@Author : alexanderwu
@File : __init__.py
"""
from metagpt.learn.text_to_image import text_to_image
from metagpt.learn.text_to_speech import text_to_speech
from metagpt.learn.google_search import google_search
__all__ = ["text_to_image", "text_to_speech", "google_search"]

View file

@ -0,0 +1,12 @@
from metagpt.tools.search_engine import SearchEngine
async def google_search(query: str, max_results: int = 6, **kwargs):
"""Perform a web search and retrieve search results.
:param query: The search query.
:param max_results: The number of search results to retrieve
:return: The web search results in markdown format.
"""
resluts = await SearchEngine().run(query, max_results=max_results, as_string=False)
return "\n".join(f"{i}. [{j['title']}]({j['link']}): {j['snippet']}" for i, j in enumerate(resluts, 1))

View file

@ -0,0 +1,100 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/8/18
@Author : mashenquan
@File : skill_loader.py
@Desc : Skill YAML Configuration Loader.
"""
from pathlib import Path
from typing import Dict, List, Optional
import aiofiles
import yaml
from pydantic import BaseModel, Field
from metagpt.config import CONFIG
class Example(BaseModel):
ask: str
answer: str
class Returns(BaseModel):
type: str
format: Optional[str] = None
class Parameter(BaseModel):
type: str
description: str = None
class Skill(BaseModel):
name: str
description: str = None
id: str = None
x_prerequisite: Dict = Field(default=None, alias="x-prerequisite")
parameters: Dict[str, Parameter] = None
examples: List[Example]
returns: Returns
@property
def arguments(self) -> Dict:
if not self.parameters:
return {}
ret = {}
for k, v in self.parameters.items():
ret[k] = v.description if v.description else ""
return ret
class Entity(BaseModel):
name: str = None
skills: List[Skill]
class Components(BaseModel):
pass
class SkillsDeclaration(BaseModel):
skillapi: str
entities: Dict[str, Entity]
components: Components = None
@staticmethod
async def load(skill_yaml_file_name: Path = None) -> "SkillsDeclaration":
if not skill_yaml_file_name:
skill_yaml_file_name = Path(__file__).parent.parent.parent / ".well-known/skills.yaml"
async with aiofiles.open(str(skill_yaml_file_name), mode="r") as reader:
data = await reader.read(-1)
skill_data = yaml.safe_load(data)
return SkillsDeclaration(**skill_data)
def get_skill_list(self, entity_name: str = "Assistant") -> Dict:
"""Return the skill name based on the skill description."""
entity = self.entities.get(entity_name)
if not entity:
return {}
# List of skills that the agent chooses to activate.
agent_skills = CONFIG.agent_skills
if not agent_skills:
return {}
class _AgentSkill(BaseModel):
name: str
names = [_AgentSkill(**i).name for i in agent_skills]
return {s.description: s.name for s in entity.skills if s.name in names}
def get_skill(self, name, entity_name: str = "Assistant") -> Skill:
"""Return a skill by name."""
entity = self.entities.get(entity_name)
if not entity:
return None
for sk in entity.skills:
if sk.name == name:
return sk

View file

@ -0,0 +1,24 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/8/18
@Author : mashenquan
@File : text_to_embedding.py
@Desc : Text-to-Embedding skill, which provides text-to-embedding functionality.
"""
from metagpt.config import CONFIG
from metagpt.tools.openai_text_to_embedding import oas3_openai_text_to_embedding
async def text_to_embedding(text, model="text-embedding-ada-002", openai_api_key="", **kwargs):
"""Text to embedding
:param text: The text used for embedding.
:param model: One of ['text-embedding-ada-002'], ID of the model to use. For more details, checkout: `https://api.openai.com/v1/models`.
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
:return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`.
"""
if CONFIG.OPENAI_API_KEY or openai_api_key:
return await oas3_openai_text_to_embedding(text, model=model, openai_api_key=openai_api_key)
raise EnvironmentError

View file

@ -0,0 +1,38 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/8/18
@Author : mashenquan
@File : text_to_image.py
@Desc : Text-to-Image skill, which provides text-to-image functionality.
"""
from metagpt.config import CONFIG
from metagpt.const import BASE64_FORMAT
from metagpt.tools.metagpt_text_to_image import oas3_metagpt_text_to_image
from metagpt.tools.openai_text_to_image import oas3_openai_text_to_image
from metagpt.utils.s3 import S3
async def text_to_image(text, size_type: str = "512x512", openai_api_key="", model_url="", **kwargs):
"""Text to image
:param text: The text used for image conversion.
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
:param size_type: If using OPENAI, the available size options are ['256x256', '512x512', '1024x1024'], while for MetaGPT, the options are ['512x512', '512x768'].
:param model_url: MetaGPT model url
:return: The image data is returned in Base64 encoding.
"""
image_declaration = "data:image/png;base64,"
if CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL or model_url:
base64_data = await oas3_metagpt_text_to_image(text, size_type, model_url)
elif CONFIG.OPENAI_API_KEY or openai_api_key:
base64_data = await oas3_openai_text_to_image(text, size_type)
else:
raise ValueError("Missing necessary parameters.")
s3 = S3()
url = await s3.cache(data=base64_data, file_ext=".png", format=BASE64_FORMAT) if s3.is_valid else ""
if url:
return f"![{text}]({url})"
return image_declaration + base64_data if base64_data else ""

View file

@ -0,0 +1,72 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/8/17
@Author : mashenquan
@File : text_to_speech.py
@Desc : Text-to-Speech skill, which provides text-to-speech functionality
"""
import openai
from metagpt.config import CONFIG
from metagpt.const import BASE64_FORMAT
from metagpt.tools.azure_tts import oas3_azsure_tts
from metagpt.tools.iflytek_tts import oas3_iflytek_tts
from metagpt.utils.s3 import S3
async def text_to_speech(
text,
lang="zh-CN",
voice="zh-CN-XiaomoNeural",
style="affectionate",
role="Girl",
subscription_key="",
region="",
iflytek_app_id="",
iflytek_api_key="",
iflytek_api_secret="",
**kwargs,
):
"""Text to speech
For more details, check out:`https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
:param lang: The value can contain a language code such as en (English), or a locale such as en-US (English - United States). For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
:param voice: For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`, `https://speech.microsoft.com/portal/voicegallery`
:param style: Speaking style to express different emotions like cheerfulness, empathy, and calm. For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
:param role: With roles, the same voice can act as a different age and gender. For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
:param text: The text used for voice conversion.
:param subscription_key: key is used to access your Azure AI service API, see: `https://portal.azure.com/` > `Resource Management` > `Keys and Endpoint`
:param region: This is the location (or region) of your resource. You may need to use this field when making calls to this API.
:param iflytek_app_id: Application ID is used to access your iFlyTek service API, see: `https://console.xfyun.cn/services/tts`
:param iflytek_api_key: WebAPI argument, see: `https://console.xfyun.cn/services/tts`
:param iflytek_api_secret: WebAPI argument, see: `https://console.xfyun.cn/services/tts`
:return: Returns the Base64-encoded .wav/.mp3 file data if successful, otherwise an empty string.
"""
if (CONFIG.AZURE_TTS_SUBSCRIPTION_KEY and CONFIG.AZURE_TTS_REGION) or (subscription_key and region):
audio_declaration = "data:audio/wav;base64,"
base64_data = await oas3_azsure_tts(text, lang, voice, style, role, subscription_key, region)
s3 = S3()
url = await s3.cache(data=base64_data, file_ext=".wav", format=BASE64_FORMAT) if s3.is_valid else ""
if url:
return f"[{text}]({url})"
return audio_declaration + base64_data if base64_data else base64_data
if (CONFIG.IFLYTEK_APP_ID and CONFIG.IFLYTEK_API_KEY and CONFIG.IFLYTEK_API_SECRET) or (
iflytek_app_id and iflytek_api_key and iflytek_api_secret
):
audio_declaration = "data:audio/mp3;base64,"
base64_data = await oas3_iflytek_tts(
text=text, app_id=iflytek_app_id, api_key=iflytek_api_key, api_secret=iflytek_api_secret
)
s3 = S3()
url = await s3.cache(data=base64_data, file_ext=".mp3", format=BASE64_FORMAT) if s3.is_valid else ""
if url:
return f"[{text}]({url})"
return audio_declaration + base64_data if base64_data else base64_data
raise openai.InvalidRequestError(
message="AZURE_TTS_SUBSCRIPTION_KEY, AZURE_TTS_REGION, IFLYTEK_APP_ID, IFLYTEK_API_KEY, IFLYTEK_API_SECRET error",
param={},
)

View file

@ -4,11 +4,11 @@
@Time : 2023/6/5 01:44
@Author : alexanderwu
@File : skill_manager.py
@Modified By: mashenquan, 2023/8/20. Remove useless `_llm`
"""
from metagpt.actions import Action
from metagpt.const import PROMPT_PATH
from metagpt.document_store.chromadb_store import ChromaStore
from metagpt.llm import LLM
from metagpt.logs import logger
Skill = Action
@ -18,7 +18,6 @@ class SkillManager:
"""Used to manage all skills"""
def __init__(self):
self._llm = LLM()
self._store = ChromaStore("skill_manager")
self._skills: dict[str:Skill] = {}

View file

@ -0,0 +1,253 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/8/18
@Author : mashenquan
@File : brain_memory.py
@Desc : Used by AgentStore. Used for long-term storage and automatic compression.
@Modified By: mashenquan, 2023/9/4. + redis memory cache.
@Modified By: mashenquan, 2023/12/25. Simplify Functionality.
"""
import json
import re
from typing import Dict, List
from pydantic import BaseModel, Field
from metagpt.config import CONFIG
from metagpt.const import DEFAULT_LANGUAGE
from metagpt.logs import logger
from metagpt.provider import MetaGPTAPI
from metagpt.schema import Message, SimpleMessage
from metagpt.utils.redis import Redis
class BrainMemory(BaseModel):
history: List[Message] = Field(default_factory=list)
knowledge: List[Message] = Field(default_factory=list)
historical_summary: str = ""
last_history_id: str = ""
is_dirty: bool = False
last_talk: str = None
cacheable: bool = True
def add_talk(self, msg: Message):
"""
Add message from user.
"""
msg.role = "user"
self.add_history(msg)
self.is_dirty = True
def add_answer(self, msg: Message):
"""Add message from LLM"""
msg.role = "assistant"
self.add_history(msg)
self.is_dirty = True
def get_knowledge(self) -> str:
texts = [m.content for m in self.knowledge]
return "\n".join(texts)
@staticmethod
async def loads(redis_key: str, redis_conf: Dict = None) -> "BrainMemory":
redis = Redis(conf=redis_conf)
if not redis.is_valid() or not redis_key:
return BrainMemory()
v = await redis.get(key=redis_key)
logger.debug(f"REDIS GET {redis_key} {v}")
if v:
bm = BrainMemory.parse_raw(v)
bm.is_dirty = False
return bm
return BrainMemory()
async def dumps(self, redis_key: str, timeout_sec: int = 30 * 60, redis_conf: Dict = None):
if not self.is_dirty:
return
redis = Redis(conf=redis_conf)
if not redis.is_valid() or not redis_key:
return False
v = self.json(ensure_ascii=False)
if self.cacheable:
await redis.set(key=redis_key, data=v, timeout_sec=timeout_sec)
logger.debug(f"REDIS SET {redis_key} {v}")
self.is_dirty = False
@staticmethod
def to_redis_key(prefix: str, user_id: str, chat_id: str):
return f"{prefix}:{user_id}:{chat_id}"
async def set_history_summary(self, history_summary, redis_key, redis_conf):
if self.historical_summary == history_summary:
if self.is_dirty:
await self.dumps(redis_key=redis_key, redis_conf=redis_conf)
self.is_dirty = False
return
self.historical_summary = history_summary
self.history = []
await self.dumps(redis_key=redis_key, redis_conf=redis_conf)
self.is_dirty = False
def add_history(self, msg: Message):
if msg.id:
if self.to_int(msg.id, 0) <= self.to_int(self.last_history_id, -1):
return
self.history.append(msg.dict())
self.last_history_id = str(msg.id)
self.is_dirty = True
def exists(self, text) -> bool:
for m in reversed(self.history):
if m.get("content") == text:
return True
return False
@staticmethod
def to_int(v, default_value):
try:
return int(v)
except:
return default_value
def pop_last_talk(self):
v = self.last_talk
self.last_talk = None
return v
async def summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1, **kwargs):
if isinstance(llm, MetaGPTAPI):
return await self._metagpt_summarize(max_words=max_words)
return await self._openai_summarize(llm=llm, max_words=max_words, keep_language=keep_language, limit=limit)
async def _openai_summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1):
texts = [self.historical_summary]
for m in self.history:
texts.append(m.content)
text = "\n".join(texts)
text_length = len(text)
if limit > 0 and text_length < limit:
return text
summary = await llm.summarize(text=text, max_words=max_words, keep_language=keep_language, limit=limit)
if summary:
await self.set_history_summary(history_summary=summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS)
return summary
raise ValueError(f"text too long:{text_length}")
async def _metagpt_summarize(self, max_words=200):
if not self.history:
return ""
total_length = 0
msgs = []
for m in reversed(self.history):
delta = len(m.content)
if total_length + delta > max_words:
left = max_words - total_length
if left == 0:
break
m.content = m.content[0:left]
msgs.append(m.dict())
break
msgs.append(m)
total_length += delta
msgs.reverse()
self.history = msgs
self.is_dirty = True
await self.dumps(redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS_CONF)
self.is_dirty = False
return BrainMemory.to_metagpt_history_format(self.history)
@staticmethod
def to_metagpt_history_format(history) -> str:
mmsg = [SimpleMessage(role=m.role, content=m.content) for m in history]
return json.dumps(mmsg)
async def get_title(self, llm, max_words=5, **kwargs) -> str:
"""Generate text title"""
if isinstance(llm, MetaGPTAPI):
return self.history[0].content if self.history else "New"
summary = await self.summarize(llm=llm, max_words=500)
language = CONFIG.language or DEFAULT_LANGUAGE
command = f"Translate the above summary into a {language} title of less than {max_words} words."
summaries = [summary, command]
msg = "\n".join(summaries)
logger.debug(f"title ask:{msg}")
response = await llm.aask(msg=msg, system_msgs=[])
logger.debug(f"title rsp: {response}")
return response
async def is_related(self, text1, text2, llm):
if isinstance(llm, MetaGPTAPI):
return await self._metagpt_is_related(text1=text1, text2=text2, llm=llm)
return await self._openai_is_related(text1=text1, text2=text2, llm=llm)
@staticmethod
async def _metagpt_is_related(**kwargs):
return False
@staticmethod
async def _openai_is_related(text1, text2, llm, **kwargs):
command = (
f"{text2}\n\nIs there any sentence above related to the following sentence: {text1}.\nIf is there "
"any relevance, return [TRUE] brief and clear. Otherwise, return [FALSE] brief and clear."
)
rsp = await llm.aask(msg=command, system_msgs=[])
result = True if "TRUE" in rsp else False
p2 = text2.replace("\n", "")
p1 = text1.replace("\n", "")
logger.info(f"IS_RELATED:\nParagraph 1: {p2}\nParagraph 2: {p1}\nRESULT: {result}\n")
return result
async def rewrite(self, sentence: str, context: str, llm):
if isinstance(llm, MetaGPTAPI):
return await self._metagpt_rewrite(sentence=sentence, context=context, llm=llm)
return await self._openai_rewrite(sentence=sentence, context=context, llm=llm)
@staticmethod
async def _metagpt_rewrite(sentence: str):
return sentence
@staticmethod
async def _openai_rewrite(sentence: str, context: str, llm):
command = (
f"{context}\n\nExtract relevant information from every preceding sentence and use it to succinctly "
f"supplement or rewrite the following text in brief and clear:\n{sentence}"
)
rsp = await llm.aask(msg=command, system_msgs=[])
logger.info(f"REWRITE:\nCommand: {command}\nRESULT: {rsp}\n")
return rsp
@staticmethod
def extract_info(input_string, pattern=r"\[([A-Z]+)\]:\s*(.+)"):
match = re.match(pattern, input_string)
if match:
return match.group(1), match.group(2)
else:
return None, input_string
@property
def is_history_available(self):
return bool(self.history or self.historical_summary)
@property
def history_text(self):
if len(self.history) == 0 and not self.historical_summary:
return ""
texts = [self.historical_summary] if self.historical_summary else []
for m in self.history[:-1]:
if isinstance(m, Dict):
t = Message(**m).content
elif isinstance(m, Message):
t = m.content
else:
continue
texts.append(t)
return "\n".join(texts)

View file

@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
"""
@Desc : the implement of Long-term memory
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
from typing import Optional

View file

@ -12,6 +12,7 @@ from typing import Iterable, Set
from pydantic import BaseModel, Field
from metagpt.const import IGNORED_MESSAGE_ID
from metagpt.schema import Message
from metagpt.utils.common import (
any_to_str,
@ -26,6 +27,7 @@ class Memory(BaseModel):
storage: list[Message] = []
index: dict[str, list[Message]] = Field(default_factory=defaultdict(list))
ignore_id: bool = False
def __init__(self, **kwargs):
index = kwargs.get("index", {})
@ -54,6 +56,8 @@ class Memory(BaseModel):
def add(self, message: Message):
"""Add a new message to storage, while updating the index"""
if self.ignore_id:
message.id = IGNORED_MESSAGE_ID
if message in self.storage:
return
self.storage.append(message)
@ -84,6 +88,8 @@ class Memory(BaseModel):
def delete(self, message: Message):
"""Delete the specified message from storage, while updating the index"""
if self.ignore_id:
message.id = IGNORED_MESSAGE_ID
self.storage.remove(message)
if message.cause_by and message in self.index[message.cause_by]:
self.index[message.cause_by].remove(message)

View file

@ -1,11 +1,16 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the implement of memory storage
"""
@Desc : the implement of memory storage
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
from pathlib import Path
from typing import List
from typing import Optional
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores.faiss import FAISS
from langchain_core.embeddings import Embeddings
from metagpt.const import DATA_PATH, MEM_TTL
from metagpt.document_store.faiss_store import FaissStore
@ -19,20 +24,30 @@ class MemoryStorage(FaissStore):
The memory storage with Faiss as ANN search engine
"""
def __init__(self, mem_ttl: int = MEM_TTL):
def __init__(self, mem_ttl: int = MEM_TTL, embedding: Embeddings = None):
self.role_id: str = None
self.role_mem_path: str = None
self.mem_ttl: int = mem_ttl # later use
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories
self._initialized: bool = False
self.embedding = embedding or OpenAIEmbeddings()
self.store: FAISS = None # Faiss engine
@property
def is_initialized(self) -> bool:
return self._initialized
def recover_memory(self, role_id: str) -> List[Message]:
def _load(self) -> Optional["FaissStore"]:
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss
if not (index_file.exists() and store_file.exists()):
logger.info("Missing at least one of index_file/store_file, load failed and return None")
return None
return FAISS.load_local(self.role_mem_path, self.embedding, self.role_id)
def recover_memory(self, role_id: str) -> list[Message]:
self.role_id = role_id
self.role_mem_path = Path(DATA_PATH / f"role_mem/{self.role_id}/")
self.role_mem_path.mkdir(parents=True, exist_ok=True)
@ -49,16 +64,16 @@ class MemoryStorage(FaissStore):
return messages
def _get_index_and_store_fname(self):
def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"):
if not self.role_mem_path:
logger.error(f"You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory")
return None, None
index_fpath = Path(self.role_mem_path / f"{self.role_id}.index")
storage_fpath = Path(self.role_mem_path / f"{self.role_id}.pkl")
index_fpath = Path(self.role_mem_path / f"{self.role_id}{index_ext}")
storage_fpath = Path(self.role_mem_path / f"{self.role_id}{pkl_ext}")
return index_fpath, storage_fpath
def persist(self):
super().persist()
self.store.save_local(self.role_mem_path, self.role_id)
logger.debug(f"Agent {self.role_id} persist memory into local")
def add(self, message: Message) -> bool:
@ -74,7 +89,7 @@ class MemoryStorage(FaissStore):
self.persist()
logger.info(f"Agent {self.role_id}'s memory_storage add a message")
def search_dissimilar(self, message: Message, k=4) -> List[Message]:
def search_dissimilar(self, message: Message, k=4) -> list[Message]:
"""search for dissimilar messages"""
if not self.store:
return []

View file

@ -12,5 +12,16 @@ from metagpt.provider.ollama_api import OllamaGPTAPI
from metagpt.provider.open_llm_api import OpenLLMGPTAPI
from metagpt.provider.openai_api import OpenAIGPTAPI
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
from metagpt.provider.azure_openai_api import AzureOpenAIGPTAPI
from metagpt.provider.metagpt_api import MetaGPTAPI
__all__ = ["FireWorksGPTAPI", "GeminiGPTAPI", "OpenLLMGPTAPI", "OpenAIGPTAPI", "ZhiPuAIGPTAPI", "OllamaGPTAPI"]
__all__ = [
"FireWorksGPTAPI",
"GeminiGPTAPI",
"OpenLLMGPTAPI",
"OpenAIGPTAPI",
"ZhiPuAIGPTAPI",
"AzureOpenAIGPTAPI",
"MetaGPTAPI",
"OllamaGPTAPI",
]

View file

@ -7,13 +7,13 @@
"""
import anthropic
from anthropic import Anthropic
from anthropic import Anthropic, AsyncAnthropic
from metagpt.config import CONFIG
class Claude2:
def ask(self, prompt):
def ask(self, prompt: str) -> str:
client = Anthropic(api_key=CONFIG.anthropic_api_key)
res = client.completions.create(
@ -23,10 +23,10 @@ class Claude2:
)
return res.completion
async def aask(self, prompt):
client = Anthropic(api_key=CONFIG.anthropic_api_key)
async def aask(self, prompt: str) -> str:
aclient = AsyncAnthropic(api_key=CONFIG.anthropic_api_key)
res = client.completions.create(
res = await aclient.completions.create(
model="claude-2",
prompt=f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}",
max_tokens_to_sample=1000,

View file

@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/5 23:08
@Author : alexanderwu
@File : openai.py
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation;
Change cost control from global to company level.
@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout.
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
"""
from openai import AsyncAzureOpenAI, AzureOpenAI
from openai._base_client import AsyncHttpxClientWrapper, SyncHttpxClientWrapper
from metagpt.config import CONFIG, Config, LLMProviderEnum
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter
@register_provider(LLMProviderEnum.AZURE_OPENAI)
class AzureOpenAIGPTAPI(OpenAIGPTAPI):
"""
Check https://platform.openai.com/examples for examples
"""
def __init__(self):
self.config: Config = CONFIG
self._init_openai()
self.auto_max_tokens = False
RateLimiter.__init__(self, rpm=self.rpm)
def _make_client(self):
kwargs, async_kwargs = self._make_client_kwargs()
# https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix
self.client = AzureOpenAI(**kwargs)
self.async_client = AsyncAzureOpenAI(**async_kwargs)
self.model = self.config.DEPLOYMENT_NAME # Used in _calc_usage & _cons_kwargs
def _make_client_kwargs(self) -> (dict, dict):
kwargs = dict(
api_key=self.config.OPENAI_API_KEY,
api_version=self.config.OPENAI_API_VERSION,
azure_endpoint=self.config.OPENAI_BASE_URL,
)
async_kwargs = kwargs.copy()
# to use proxy, openai v1 needs http_client
proxy_params = self._get_proxy_params()
if proxy_params:
kwargs["http_client"] = SyncHttpxClientWrapper(**proxy_params)
async_kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
return kwargs, async_kwargs
def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict:
kwargs = {
"messages": messages,
"max_tokens": self.get_max_tokens(messages),
"n": 1,
"stop": None,
"temperature": 0.3,
"model": self.model,
}
if configs:
kwargs.update(configs)
kwargs["timeout"] = max(CONFIG.timeout, timeout)
return kwargs

View file

@ -4,6 +4,7 @@
@Time : 2023/5/5 23:00
@Author : alexanderwu
@File : base_chatbot.py
@Modified By: mashenquan, 2023/11/21. Add `timeout`.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
@ -17,13 +18,13 @@ class BaseChatbot(ABC):
use_system_prompt: bool = True
@abstractmethod
def ask(self, msg: str) -> str:
def ask(self, msg: str, timeout=3) -> str:
"""Ask GPT a question and get an answer"""
@abstractmethod
def ask_batch(self, msgs: list) -> str:
def ask_batch(self, msgs: list, timeout=3) -> str:
"""Ask GPT multiple questions and get a series of answers"""
@abstractmethod
def ask_code(self, msgs: list) -> str:
def ask_code(self, msgs: list, timeout=3) -> str:
"""Ask GPT multiple questions and get a piece of code"""

View file

@ -4,12 +4,12 @@
@Time : 2023/5/5 23:04
@Author : alexanderwu
@File : base_gpt_api.py
@Desc : mashenquan, 2023/8/22. + try catch
"""
import json
from abc import abstractmethod
from typing import Optional
from metagpt.logs import logger
from metagpt.provider.base_chatbot import BaseChatbot
@ -33,62 +33,65 @@ class BaseGPTAPI(BaseChatbot):
def _default_system_msg(self):
return self._system_msg(self.system_prompt)
def ask(self, msg: str) -> str:
def ask(self, msg: str, timeout=3) -> str:
message = [self._default_system_msg(), self._user_msg(msg)] if self.use_system_prompt else [self._user_msg(msg)]
rsp = self.completion(message)
rsp = self.completion(message, timeout=timeout)
return self.get_choice_text(rsp)
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None, stream=True) -> str:
async def aask(
self,
msg: str,
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
timeout=3,
stream=True,
) -> str:
if system_msgs:
message = (
self._system_msgs(system_msgs) + [self._user_msg(msg)]
if self.use_system_prompt
else [self._user_msg(msg)]
)
message = self._system_msgs(system_msgs)
else:
message = (
[self._default_system_msg(), self._user_msg(msg)] if self.use_system_prompt else [self._user_msg(msg)]
)
logger.debug(message)
rsp = await self.acompletion_text(message, stream=stream)
message = [self._default_system_msg()]
if format_msgs:
message.extend(format_msgs)
message.append(self._user_msg(msg))
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
# logger.debug(rsp)
return rsp
def _extract_assistant_rsp(self, context):
return "\n".join([i["content"] for i in context if i["role"] == "assistant"])
def ask_batch(self, msgs: list) -> str:
def ask_batch(self, msgs: list, timeout=3) -> str:
context = []
for msg in msgs:
umsg = self._user_msg(msg)
context.append(umsg)
rsp = self.completion(context)
rsp = self.completion(context, timeout=timeout)
rsp_text = self.get_choice_text(rsp)
context.append(self._assistant_msg(rsp_text))
return self._extract_assistant_rsp(context)
async def aask_batch(self, msgs: list) -> str:
async def aask_batch(self, msgs: list, timeout=3) -> str:
"""Sequential questioning"""
context = []
for msg in msgs:
umsg = self._user_msg(msg)
context.append(umsg)
rsp_text = await self.acompletion_text(context)
rsp_text = await self.acompletion_text(context, timeout=timeout)
context.append(self._assistant_msg(rsp_text))
return self._extract_assistant_rsp(context)
def ask_code(self, msgs: list[str]) -> str:
def ask_code(self, msgs: list[str], timeout=3) -> str:
"""FIXME: No code segment filtering has been done here, and all results are actually displayed"""
rsp_text = self.ask_batch(msgs)
rsp_text = self.ask_batch(msgs, timeout=timeout)
return rsp_text
async def aask_code(self, msgs: list[str]) -> str:
async def aask_code(self, msgs: list[str], timeout=3) -> str:
"""FIXME: No code segment filtering has been done here, and all results are actually displayed"""
rsp_text = await self.aask_batch(msgs)
rsp_text = await self.aask_batch(msgs, timeout=timeout)
return rsp_text
@abstractmethod
def completion(self, messages: list[dict]):
def completion(self, messages: list[dict], timeout=3):
"""All GPTAPIs are required to provide the standard OpenAI completion interface
[
{"role": "system", "content": "You are a helpful assistant."},
@ -98,7 +101,7 @@ class BaseGPTAPI(BaseChatbot):
"""
@abstractmethod
async def acompletion(self, messages: list[dict]):
async def acompletion(self, messages: list[dict], timeout=3):
"""Asynchronous version of completion
All GPTAPIs are required to provide the standard OpenAI completion interface
[
@ -109,7 +112,7 @@ class BaseGPTAPI(BaseChatbot):
"""
@abstractmethod
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
"""Asynchronous version of completion. Return str. Support stream-print"""
def get_choice_text(self, rsp: dict) -> str:
@ -145,7 +148,7 @@ class BaseGPTAPI(BaseChatbot):
:return dict: return first function of choice, for exmaple,
{'name': 'execute', 'arguments': '{\n "language": "python",\n "code": "print(\'Hello, World!\')"\n}'}
"""
return rsp.get("choices")[0]["message"]["tool_calls"][0]["function"].to_dict()
return rsp.get("choices")[0]["message"]["tool_calls"][0]["function"]
def get_choice_function_arguments(self, rsp: dict) -> dict:
"""Required to provide the first function arguments of choice.
@ -158,8 +161,13 @@ class BaseGPTAPI(BaseChatbot):
def messages_to_prompt(self, messages: list[dict]):
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])
return "\n".join([f"{i.role}: {i.content}" for i in messages])
def messages_to_dict(self, messages):
"""objects to [{"role": "user", "content": msg}] etc."""
return [i.to_dict() for i in messages]
@abstractmethod
async def close(self):
"""Close connection"""
pass

View file

@ -2,24 +2,142 @@
# -*- coding: utf-8 -*-
# @Desc : fireworks.ai's api
import openai
import re
from metagpt.config import CONFIG, LLMProviderEnum
from openai import APIConnectionError, AsyncStream
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletionChunk
from tenacity import (
after_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from metagpt.config import CONFIG, Config, LLMProviderEnum
from metagpt.logs import logger
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import CostManager, OpenAIGPTAPI, RateLimiter
from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter, log_and_reraise
from metagpt.utils.cost_manager import CostManager, Costs
MODEL_GRADE_TOKEN_COSTS = {
"-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition
"16": {"prompt": 0.2, "completion": 0.8}, # 16 means model size <= 16B; 0.2 means $0.2/1M tokens
"80": {"prompt": 0.7, "completion": 2.8}, # 80 means 16B < model size <= 80B
"mixtral-8x7b": {"prompt": 0.4, "completion": 1.6},
}
class FireworksCostManager(CostManager):
def model_grade_token_costs(self, model: str) -> dict[str, float]:
def _get_model_size(model: str) -> float:
size = re.findall(".*-([0-9.]+)b", model)
size = float(size[0]) if len(size) > 0 else -1
return size
if "mixtral-8x7b" in model:
token_costs = MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"]
else:
model_size = _get_model_size(model)
if 0 < model_size <= 16:
token_costs = MODEL_GRADE_TOKEN_COSTS["16"]
elif 16 < model_size <= 80:
token_costs = MODEL_GRADE_TOKEN_COSTS["80"]
else:
token_costs = MODEL_GRADE_TOKEN_COSTS["-1"]
return token_costs
def update_cost(self, prompt_tokens: int, completion_tokens: int, model: str):
"""
Refs to `https://app.fireworks.ai/pricing` **Developer pricing**
Update the total cost, prompt tokens, and completion tokens.
Args:
prompt_tokens (int): The number of tokens used in the prompt.
completion_tokens (int): The number of tokens used in the completion.
model (str): The model used for the API call.
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
token_costs = self.model_grade_token_costs(model)
cost = (prompt_tokens * token_costs["prompt"] + completion_tokens * token_costs["completion"]) / 1000000
self.total_cost += cost
logger.info(
f"Total running cost: ${self.total_cost:.4f} | Max budget: ${CONFIG.max_budget:.3f} | "
f"Current cost: ${cost:.4f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
)
CONFIG.total_cost = self.total_cost
@register_provider(LLMProviderEnum.FIREWORKS)
class FireWorksGPTAPI(OpenAIGPTAPI):
def __init__(self):
self.__init_fireworks(CONFIG)
self.llm = openai
self.model = CONFIG.fireworks_api_model
self.config: Config = CONFIG
self.__init_fireworks()
self.auto_max_tokens = False
self._cost_manager = CostManager()
self._cost_manager = FireworksCostManager()
RateLimiter.__init__(self, rpm=self.rpm)
def __init_fireworks(self, config: "Config"):
openai.api_key = config.fireworks_api_key
openai.api_base = config.fireworks_api_base
self.rpm = int(config.get("RPM", 10))
def __init_fireworks(self):
self.is_azure = False
self.rpm = int(self.config.get("RPM", 10))
self._make_client()
self.model = self.config.fireworks_api_model # `self.model` should after `_make_client` to rewrite it
def _make_client_kwargs(self) -> (dict, dict):
kwargs = dict(api_key=self.config.fireworks_api_key, base_url=self.config.fireworks_api_base)
async_kwargs = kwargs.copy()
return kwargs, async_kwargs
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage:
try:
# use FireworksCostManager not CONFIG.cost_manager
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
except Exception as e:
logger.error(f"updating costs failed!, exp: {e}")
def get_costs(self) -> Costs:
return self._cost_manager.get_costs()
async def _achat_completion_stream(self, messages: list[dict]) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create(
**self._cons_kwargs(messages), stream=True
)
collected_content = []
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
# iterate through the stream of events
async for chunk in response:
if chunk.choices:
choice = chunk.choices[0]
choice_delta = choice.delta
finish_reason = choice.finish_reason if hasattr(choice, "finish_reason") else None
if choice_delta.content:
collected_content.append(choice_delta.content)
print(choice_delta.content, end="")
if finish_reason:
# fireworks api return usage when finish_reason is not None
usage = CompletionUsage(**chunk.usage)
full_content = "".join(collected_content)
self._update_costs(usage)
return full_content
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
after=after_log(logger, logger.level("WARNING").name),
retry=retry_if_exception_type(APIConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
) -> str:
"""when streaming, print each token in place."""
if stream:
return await self._achat_completion_stream(messages)
rsp = await self._achat_completion(messages)
return self.get_choice_text(rsp)

View file

@ -23,7 +23,7 @@ from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import CostManager, log_and_reraise
from metagpt.provider.openai_api import log_and_reraise
class GeminiGenerativeModel(GenerativeModel):
@ -53,7 +53,6 @@ class GeminiGPTAPI(BaseGPTAPI):
self.__init_gemini(CONFIG)
self.model = "gemini-pro" # so far only one model
self.llm = GeminiGenerativeModel(model_name=self.model)
self._cost_manager = CostManager()
def __init_gemini(self, config: CONFIG):
genai.configure(api_key=config.gemini_api_key)
@ -76,10 +75,13 @@ class GeminiGPTAPI(BaseGPTAPI):
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
CONFIG.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"google gemini updats costs failed! exp: {e}")
def close(self):
pass
def get_choice_text(self, resp: GenerateContentResponse) -> str:
return resp.text
@ -134,7 +136,9 @@ class GeminiGPTAPI(BaseGPTAPI):
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
async def acompletion_text(
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
) -> str:
"""response in async with stream or non-stream mode"""
if stream:
return await self._achat_completion_stream(messages)

View file

@ -14,24 +14,35 @@ class HumanProvider(BaseGPTAPI):
This enables replacing LLM anywhere in the framework with a human, thus introducing human interaction
"""
def ask(self, msg: str) -> str:
def ask(self, msg: str, timeout=3) -> str:
logger.info("It's your turn, please type in your response. You may also refer to the context below")
rsp = input(msg)
if rsp in ["exit", "quit"]:
exit()
return rsp
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None) -> str:
return self.ask(msg)
async def aask(
self,
msg: str,
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
generator: bool = False,
timeout=3,
) -> str:
return self.ask(msg, timeout=timeout)
def completion(self, messages: list[dict]):
def completion(self, messages: list[dict], timeout=3):
"""dummy implementation of abstract method in base"""
return []
async def acompletion(self, messages: list[dict]):
async def acompletion(self, messages: list[dict], timeout=3):
"""dummy implementation of abstract method in base"""
return []
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
"""dummy implementation of abstract method in base"""
return []
return ""
async def close(self):
"""Close connection"""
pass

View file

@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/5 23:08
@Author : alexanderwu
@File : metagpt_api.py
@Desc : MetaGPT LLM provider.
"""
from metagpt.config import LLMProviderEnum
from metagpt.provider import OpenAIGPTAPI
from metagpt.provider.llm_provider_registry import register_provider
@register_provider(LLMProviderEnum.METAGPT)
class MetaGPTAPI(OpenAIGPTAPI):
def __init__(self):
super().__init__()

View file

@ -19,7 +19,8 @@ from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import CostManager, log_and_reraise
from metagpt.provider.openai_api import log_and_reraise
from metagpt.utils.cost_manager import CostManager
class OllamaCostManager(CostManager):
@ -56,6 +57,9 @@ class OllamaGPTAPI(BaseGPTAPI):
self.model = config.ollama_api_model
def close(self):
pass
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
return kwargs
@ -143,7 +147,9 @@ class OllamaGPTAPI(BaseGPTAPI):
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
async def acompletion_text(
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
) -> str:
"""response in async with stream or non-stream mode"""
if stream:
return await self._achat_completion_stream(messages)

View file

@ -2,12 +2,14 @@
# -*- coding: utf-8 -*-
# @Desc : self-host open llm model with openai-compatible interface
import openai
from openai.types import CompletionUsage
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.config import CONFIG, Config, LLMProviderEnum
from metagpt.logs import logger
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import CostManager, OpenAIGPTAPI, RateLimiter
from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
class OpenLLMCostManager(CostManager):
@ -26,7 +28,7 @@ class OpenLLMCostManager(CostManager):
self.total_completion_tokens += completion_tokens
logger.info(
f"Max budget: ${CONFIG.max_budget:.3f} | "
f"Max budget: ${CONFIG.max_budget:.3f} | reference "
f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
)
CONFIG.total_cost = self.total_cost
@ -35,14 +37,43 @@ class OpenLLMCostManager(CostManager):
@register_provider(LLMProviderEnum.OPEN_LLM)
class OpenLLMGPTAPI(OpenAIGPTAPI):
def __init__(self):
self.__init_openllm(CONFIG)
self.llm = openai
self.model = CONFIG.open_llm_api_model
self.config: Config = CONFIG
self.__init_openllm()
self.auto_max_tokens = False
self._cost_manager = OpenLLMCostManager()
RateLimiter.__init__(self, rpm=self.rpm)
def __init_openllm(self, config: "Config"):
openai.api_key = "sk-xx" # self-host api doesn't need api-key, use the default value
openai.api_base = config.open_llm_api_base
self.rpm = int(config.get("RPM", 10))
def __init_openllm(self):
self.is_azure = False
self.rpm = int(self.config.get("RPM", 10))
self._make_client()
self.model = self.config.open_llm_api_model # `self.model` should after `_make_client` to rewrite it
def _make_client_kwargs(self) -> (dict, dict):
kwargs = dict(api_key="sk-xxx", base_url=self.config.open_llm_api_base)
async_kwargs = kwargs.copy()
return kwargs, async_kwargs
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
if not CONFIG.calc_usage:
return usage
try:
usage.prompt_tokens = count_message_tokens(messages, "open-llm-model")
usage.completion_tokens = count_string_tokens(rsp, "open-llm-model")
except Exception as e:
logger.error(f"usage calculation failed!: {e}")
return usage
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage:
try:
# use OpenLLMCostManager not CONFIG.cost_manager
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
except Exception as e:
logger.error(f"updating costs failed!, exp: {e}")
def get_costs(self) -> Costs:
return self._cost_manager.get_costs()

View file

@ -3,20 +3,19 @@
@Time : 2023/5/5 23:08
@Author : alexanderwu
@File : openai.py
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation;
Change cost control from global to company level.
@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout.
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
"""
import asyncio
import json
import time
from typing import NamedTuple, Union
from typing import AsyncIterator, List, Union
from openai import (
APIConnectionError,
AsyncAzureOpenAI,
AsyncOpenAI,
AsyncStream,
AzureOpenAI,
OpenAI,
)
import openai
from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI
from openai._base_client import AsyncHttpxClientWrapper, SyncHttpxClientWrapper
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion, ChatCompletionChunk
@ -29,15 +28,15 @@ from tenacity import (
)
from metagpt.config import CONFIG, Config, LLMProviderEnum
from metagpt.const import DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.schema import Message
from metagpt.utils.cost_manager import Costs
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.singleton import Singleton
from metagpt.utils.token_counter import (
TOKEN_COSTS,
count_message_tokens,
count_string_tokens,
get_max_completion_tokens,
@ -69,75 +68,6 @@ class RateLimiter:
self.last_call_time = time.time()
class Costs(NamedTuple):
total_prompt_tokens: int
total_completion_tokens: int
total_cost: float
total_budget: float
class CostManager(metaclass=Singleton):
"""计算使用接口的开销"""
def __init__(self):
self.total_prompt_tokens = 0
self.total_completion_tokens = 0
self.total_cost = 0
self.total_budget = 0
def update_cost(self, prompt_tokens, completion_tokens, model):
"""
Update the total cost, prompt tokens, and completion tokens.
Args:
prompt_tokens (int): The number of tokens used in the prompt.
completion_tokens (int): The number of tokens used in the completion.
model (str): The model used for the API call.
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
cost = (
prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]
) / 1000
self.total_cost += cost
logger.info(
f"Total running cost: ${self.total_cost:.3f} | Max budget: ${CONFIG.max_budget:.3f} | "
f"Current cost: ${cost:.3f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
)
CONFIG.total_cost = self.total_cost
def get_total_prompt_tokens(self):
"""
Get the total number of prompt tokens.
Returns:
int: The total number of prompt tokens.
"""
return self.total_prompt_tokens
def get_total_completion_tokens(self):
"""
Get the total number of completion tokens.
Returns:
int: The total number of completion tokens.
"""
return self.total_completion_tokens
def get_total_cost(self):
"""
Get the total cost of API calls.
Returns:
float: The total cost of API calls.
"""
return self.total_cost
def get_costs(self) -> Costs:
"""Get all costs"""
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)
def log_and_reraise(retry_state):
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
logger.warning(
@ -157,37 +87,23 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
def __init__(self):
self.config: Config = CONFIG
self.__init_openai()
self._init_openai()
self.auto_max_tokens = False
self._cost_manager = CostManager()
RateLimiter.__init__(self, rpm=self.rpm)
def __init_openai(self):
self.is_azure = self.config.openai_api_type == "azure"
self.model = self.config.deployment_name if self.is_azure else self.config.openai_api_model
self.rpm = int(self.config.get("RPM", 10))
def _init_openai(self):
self.rpm = int(self.config.RPM or 10)
self._make_client()
def _make_client(self):
kwargs, async_kwargs = self._make_client_kwargs()
if self.is_azure:
self.client = AzureOpenAI(**kwargs)
self.async_client = AsyncAzureOpenAI(**async_kwargs)
else:
self.client = OpenAI(**kwargs)
self.async_client = AsyncOpenAI(**async_kwargs)
# https://github.com/openai/openai-python#async-usage
self.client = OpenAI(**kwargs)
self.async_client = AsyncOpenAI(**async_kwargs)
self.model = self.config.OPENAI_API_MODEL # Used in _calc_usage & _cons_kwargs
def _make_client_kwargs(self) -> (dict, dict):
if self.is_azure:
kwargs = dict(
api_key=self.config.openai_api_key,
api_version=self.config.openai_api_version,
azure_endpoint=self.config.openai_base_url,
)
else:
kwargs = dict(api_key=self.config.openai_api_key, base_url=self.config.openai_base_url)
kwargs = dict(api_key=self.config.OPENAI_API_KEY, base_url=self.config.OPENAI_BASE_URL)
async_kwargs = kwargs.copy()
# to use proxy, openai v1 needs http_client
@ -202,64 +118,51 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
params = {}
if self.config.openai_proxy:
params = {"proxies": self.config.openai_proxy}
if self.config.openai_base_url:
params["base_url"] = self.config.openai_base_url
if self.config.OPENAI_BASE_URL:
params["base_url"] = self.config.OPENAI_BASE_URL
return params
async def _achat_completion_stream(self, messages: list[dict]) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> AsyncIterator[str]:
response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create(
**self._cons_kwargs(messages), stream=True
**self._cons_kwargs(messages, timeout=timeout), stream=True
)
# create variables to collect the stream of chunks
collected_chunks = []
collected_messages = []
# iterate through the stream of events
async for chunk in response:
collected_chunks.append(chunk) # save the event response
if chunk.choices:
chunk_message = chunk.choices[0].delta # extract the message
collected_messages.append(chunk_message) # save the message
if chunk_message.content:
log_llm_stream(chunk_message.content)
print()
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
yield chunk_message
full_reply_content = "".join([m.content for m in collected_messages if m.content])
usage = self._calc_usage(messages, full_reply_content)
self._update_costs(usage)
return full_reply_content
def _cons_kwargs(self, messages: list[dict], **configs) -> dict:
def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict:
kwargs = {
"messages": messages,
"max_tokens": self.get_max_tokens(messages),
"n": 1,
"stop": None,
"temperature": 0.3,
"timeout": 3,
"model": self.model,
}
if configs:
kwargs.update(configs)
kwargs["timeout"] = max(CONFIG.timeout, timeout)
return kwargs
async def _achat_completion(self, messages: list[dict]) -> ChatCompletion:
rsp: ChatCompletion = await self.async_client.chat.completions.create(**self._cons_kwargs(messages))
async def _achat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
kwargs = self._cons_kwargs(messages, timeout=timeout)
rsp: ChatCompletion = await self.async_client.chat.completions.create(**kwargs)
self._update_costs(rsp.usage)
return rsp
def _chat_completion(self, messages: list[dict]) -> ChatCompletion:
rsp: ChatCompletion = self.client.chat.completions.create(**self._cons_kwargs(messages))
def _chat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
rsp: ChatCompletion = self.client.chat.completions.create(**self._cons_kwargs(messages, timeout=timeout))
self._update_costs(rsp.usage)
return rsp
def completion(self, messages: list[dict]) -> ChatCompletion:
return self._chat_completion(messages)
def completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
return self._chat_completion(messages, timeout=timeout)
async def acompletion(self, messages: list[dict]) -> ChatCompletion:
return await self._achat_completion(messages)
async def acompletion(self, messages: list[dict], timeout=3) -> ChatCompletion:
return await self._achat_completion(messages, timeout=timeout)
@retry(
wait=wait_random_exponential(min=1, max=60),
@ -268,14 +171,25 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
retry=retry_if_exception_type(APIConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
"""when streaming, print each token in place."""
if stream:
return await self._achat_completion_stream(messages)
rsp = await self._achat_completion(messages)
resp = self._achat_completion_stream(messages, timeout=timeout)
collected_messages = []
async for i in resp:
log_llm_stream(i)
collected_messages.append(i)
full_reply_content = "".join(collected_messages)
usage = self._calc_usage(messages, full_reply_content)
self._update_costs(usage)
return full_reply_content
rsp = await self._achat_completion(messages, timeout=timeout)
return self.get_choice_text(rsp)
def _func_configs(self, messages: list[dict], **kwargs) -> dict:
def _func_configs(self, messages: list[dict], timeout=3, **kwargs) -> dict:
"""
Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create
"""
@ -286,17 +200,16 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
}
kwargs.update(configs)
return self._cons_kwargs(messages, **kwargs)
return self._cons_kwargs(messages=messages, timeout=timeout, **kwargs)
def _chat_completion_function(self, messages: list[dict], **kwargs) -> ChatCompletion:
def _chat_completion_function(self, messages: list[dict], timeout=3, **kwargs) -> ChatCompletion:
rsp: ChatCompletion = self.client.chat.completions.create(**self._func_configs(messages, **kwargs))
self._update_costs(rsp.usage)
return rsp
async def _achat_completion_function(self, messages: list[dict], **chat_configs) -> ChatCompletion:
rsp: ChatCompletion = await self.async_client.chat.completions.create(
**self._func_configs(messages, **chat_configs)
)
async def _achat_completion_function(self, messages: list[dict], timeout=3, **chat_configs) -> ChatCompletion:
kwargs = self._func_configs(messages=messages, timeout=timeout, **chat_configs)
rsp: ChatCompletion = await self.async_client.chat.completions.create(**kwargs)
self._update_costs(rsp.usage)
return rsp
@ -349,8 +262,12 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
>>> rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
"""
messages = self._process_message(messages)
rsp = await self._achat_completion_function(messages, **kwargs)
return self.get_choice_function_arguments(rsp)
try:
rsp = await self._achat_completion_function(messages, **kwargs)
return self.get_choice_function_arguments(rsp)
except openai.BadRequestError as e:
logger.error(f"API TYPE:{CONFIG.OPENAI_API_TYPE}, err:{e}")
raise e
def get_choice_function_arguments(self, rsp: ChatCompletion) -> dict:
"""Required to provide the first function arguments of choice.
@ -380,7 +297,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
return usage
async def acompletion_batch(self, batch: list[list[dict]]) -> list[ChatCompletion]:
async def acompletion_batch(self, batch: list[list[dict]], timeout=3) -> list[ChatCompletion]:
"""Return full JSON"""
split_batches = self.split_batches(batch)
all_results = []
@ -389,16 +306,16 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
logger.info(small_batch)
await self.wait_if_needed(len(small_batch))
future = [self.acompletion(prompt) for prompt in small_batch]
future = [self.acompletion(prompt, timeout=timeout) for prompt in small_batch]
results = await asyncio.gather(*future)
logger.info(results)
all_results.extend(results)
return all_results
async def acompletion_batch_text(self, batch: list[list[dict]]) -> list[str]:
async def acompletion_batch_text(self, batch: list[list[dict]], timeout=3) -> list[str]:
"""Only return plain text"""
raw_results = await self.acompletion_batch(batch)
raw_results = await self.acompletion_batch(batch, timeout=timeout)
results = []
for idx, raw_result in enumerate(raw_results, start=1):
result = self.get_choice_text(raw_result)
@ -409,18 +326,101 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
def _update_costs(self, usage: CompletionUsage):
if CONFIG.calc_usage and usage:
try:
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
CONFIG.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
except Exception as e:
logger.error("updating costs failed!", e)
logger.error(f"updating costs failed!, exp: {e}")
def get_costs(self) -> Costs:
return self._cost_manager.get_costs()
return CONFIG.cost_manager.get_costs()
def get_max_tokens(self, messages: list[dict]):
if not self.auto_max_tokens:
return CONFIG.max_tokens_rsp
return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp)
def moderation(self, content: Union[str, list[str]]):
return self.client.moderations.create(input=content)
@handle_exception
async def amoderation(self, content: Union[str, list[str]]):
return await self.async_client.moderations.create(input=content)
async def close(self):
"""Close connection"""
if self.client:
self.client.close()
self.client = None
if self.async_client:
await self.async_client.close()
self.async_client = None
async def summarize(self, text: str, max_words=200, keep_language: bool = False, limit: int = -1) -> str:
max_token_count = DEFAULT_MAX_TOKENS
max_count = 100
text_length = len(text)
if limit > 0 and text_length < limit:
return text
summary = ""
while max_count > 0:
if text_length < max_token_count:
summary = await self._get_summary(text=text, max_words=max_words, keep_language=keep_language)
break
padding_size = 20 if max_token_count > 20 else 0
text_windows = self.split_texts(text, window_size=max_token_count - padding_size)
part_max_words = min(int(max_words / len(text_windows)) + 1, 100)
summaries = []
for ws in text_windows:
response = await self._get_summary(text=ws, max_words=part_max_words, keep_language=keep_language)
summaries.append(response)
if len(summaries) == 1:
summary = summaries[0]
break
# Merged and retry
text = "\n".join(summaries)
text_length = len(text)
max_count -= 1 # safeguard
return summary
async def _get_summary(self, text: str, max_words=20, keep_language: bool = False):
"""Generate text summary"""
if len(text) < max_words:
return text
if keep_language:
command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly."
else:
command = f"Translate the above content into a summary of less than {max_words} words."
msg = text + "\n\n" + command
logger.debug(f"summary ask:{msg}")
response = await self.aask(msg=msg, system_msgs=[])
logger.debug(f"summary rsp: {response}")
return response
@staticmethod
def split_texts(text: str, window_size) -> List[str]:
"""Splitting long text into sliding windows text"""
if window_size <= 0:
window_size = DEFAULT_TOKEN_SIZE
total_len = len(text)
if total_len <= window_size:
return [text]
padding_size = 20 if window_size > 20 else 0
windows = []
idx = 0
data_len = window_size - padding_size
while idx < total_len:
if window_size + idx > total_len: # 不足一个滑窗
windows.append(text[idx:])
break
# 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....]
# window_size=3, padding_size=1
# [1, 2, 3], [3, 4, 5], [5, 6, 7], ....
# idx=2, | idx=5 | idx=8 | ...
w = text[idx : idx + window_size]
windows.append(w)
idx += data_len
return windows

View file

@ -26,16 +26,19 @@ from metagpt.provider.llm_provider_registry import register_provider
@register_provider(LLMProviderEnum.SPARK)
class SparkAPI(BaseGPTAPI):
class SparkGPTAPI(BaseGPTAPI):
def __init__(self):
logger.warning("当前方法无法支持异步运行。当你使用acompletion时并不能并行访问。")
def close(self):
pass
def ask(self, msg: str) -> str:
message = [self._default_system_msg(), self._user_msg(msg)]
rsp = self.completion(message)
return rsp
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None) -> str:
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None, stream: bool = True) -> str:
if system_msgs:
message = self._system_msgs(system_msgs) + [self._user_msg(msg)]
else:
@ -47,7 +50,9 @@ class SparkAPI(BaseGPTAPI):
def get_choice_text(self, rsp: dict) -> str:
return rsp["payload"]["choices"]["text"][-1]["content"]
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
async def acompletion_text(
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
) -> str:
# 不支持
logger.error("该功能禁用。")
w = GetMessageFromWeb(messages)

View file

@ -5,7 +5,6 @@
import json
from enum import Enum
import openai
import zhipuai
from requests import ConnectionError
from tenacity import (
@ -20,7 +19,7 @@ from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import CostManager, log_and_reraise
from metagpt.provider.openai_api import log_and_reraise
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
@ -44,12 +43,12 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
self.__init_zhipuai(CONFIG)
self.llm = ZhiPuModelAPI
self.model = "chatglm_turbo" # so far only one model, just use it
self._cost_manager = CostManager()
def __init_zhipuai(self, config: CONFIG):
assert config.zhipuai_api_key
zhipuai.api_key = config.zhipuai_api_key
openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used.
# due to use openai sdk, set the api_key but it will't be used.
# openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used.
def _const_kwargs(self, messages: list[dict]) -> dict:
kwargs = {"model": self.model, "prompt": messages, "temperature": 0.3}
@ -61,32 +60,35 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
CONFIG.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"zhipuai updats costs failed! exp: {e}")
def close(self):
pass
def get_choice_text(self, resp: dict) -> str:
"""get the first text of choice from llm response"""
assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1]
assert assist_msg["role"] == "assistant"
return assist_msg.get("content")
def completion(self, messages: list[dict]) -> dict:
def completion(self, messages: list[dict], timeout=3) -> dict:
resp = self.llm.invoke(**self._const_kwargs(messages))
usage = resp.get("data").get("usage")
self._update_costs(usage)
return resp
async def _achat_completion(self, messages: list[dict]) -> dict:
async def _achat_completion(self, messages: list[dict], timeout=3) -> dict:
resp = await self.llm.ainvoke(**self._const_kwargs(messages))
usage = resp.get("data").get("usage")
self._update_costs(usage)
return resp
async def acompletion(self, messages: list[dict]) -> dict:
return await self._achat_completion(messages)
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
return await self._achat_completion(messages, timeout=timeout)
async def _achat_completion_stream(self, messages: list[dict]) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
response = await self.llm.asse_invoke(**self._const_kwargs(messages))
collected_content = []
usage = {}
@ -129,7 +131,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
"""response in async with stream or non-stream mode"""
if stream:
return await self._achat_completion_stream(messages)

View file

@ -5,19 +5,49 @@
@Author : alexanderwu
@File : repo_parser.py
"""
from __future__ import annotations
import ast
import json
import re
import subprocess
from pathlib import Path
from pprint import pformat
from typing import Dict, List, Optional, Tuple
import aiofiles
import pandas as pd
from pydantic import BaseModel, Field
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.utils.common import any_to_str
from metagpt.utils.exceptions import handle_exception
class RepoFileInfo(BaseModel):
file: str
classes: List = Field(default_factory=list)
functions: List = Field(default_factory=list)
globals: List = Field(default_factory=list)
page_info: List = Field(default_factory=list)
class CodeBlockInfo(BaseModel):
lineno: int
end_lineno: int
type_name: str
tokens: List = Field(default_factory=list)
properties: Dict = Field(default_factory=dict)
class ClassInfo(BaseModel):
name: str
package: Optional[str] = None
attributes: Dict[str, str] = Field(default_factory=dict)
methods: Dict[str, str] = Field(default_factory=dict)
class RepoParser(BaseModel):
base_directory: Path = Field(default=None)
@ -27,31 +57,32 @@ class RepoParser(BaseModel):
"""Parse a Python file in the repository."""
return ast.parse(file_path.read_text()).body
def extract_class_and_function_info(self, tree, file_path):
def extract_class_and_function_info(self, tree, file_path) -> RepoFileInfo:
"""Extract class, function, and global variable information from the AST."""
file_info = {
"file": str(file_path.relative_to(self.base_directory)),
"classes": [],
"functions": [],
"globals": [],
}
file_info = RepoFileInfo(file=str(file_path.relative_to(self.base_directory)))
for node in tree:
info = RepoParser.node_to_str(node)
file_info.page_info.append(info)
if isinstance(node, ast.ClassDef):
class_methods = [m.name for m in node.body if is_func(m)]
file_info["classes"].append({"name": node.name, "methods": class_methods})
file_info.classes.append({"name": node.name, "methods": class_methods})
elif is_func(node):
file_info["functions"].append(node.name)
file_info.functions.append(node.name)
elif isinstance(node, (ast.Assign, ast.AnnAssign)):
for target in node.targets if isinstance(node, ast.Assign) else [node.target]:
if isinstance(target, ast.Name):
file_info["globals"].append(target.id)
file_info.globals.append(target.id)
return file_info
def generate_symbols(self):
def generate_symbols(self) -> List[RepoFileInfo]:
files_classes = []
directory = self.base_directory
for path in directory.rglob("*.py"):
matching_files = []
extensions = ["*.py", "*.js"]
for ext in extensions:
matching_files += directory.rglob(ext)
for path in matching_files:
tree = self._parse_file(path)
file_info = self.extract_class_and_function_info(tree, path)
files_classes.append(file_info)
@ -79,6 +110,215 @@ class RepoParser(BaseModel):
elif mode == "csv":
self.generate_dataframe_structure(output_path)
@staticmethod
def node_to_str(node) -> (int, int, str, str | Tuple):
if any_to_str(node) == any_to_str(ast.Expr):
return CodeBlockInfo(
lineno=node.lineno,
end_lineno=node.end_lineno,
type_name=any_to_str(node),
tokens=RepoParser._parse_expr(node),
)
mappings = {
any_to_str(ast.Import): lambda x: [RepoParser._parse_name(n) for n in x.names],
any_to_str(ast.Assign): RepoParser._parse_assign,
any_to_str(ast.ClassDef): lambda x: x.name,
any_to_str(ast.FunctionDef): lambda x: x.name,
any_to_str(ast.ImportFrom): lambda x: {
"module": x.module,
"names": [RepoParser._parse_name(n) for n in x.names],
},
any_to_str(ast.If): RepoParser._parse_if,
any_to_str(ast.AsyncFunctionDef): lambda x: x.name,
}
func = mappings.get(any_to_str(node))
if func:
code_block = CodeBlockInfo(lineno=node.lineno, end_lineno=node.end_lineno, type_name=any_to_str(node))
val = func(node)
if isinstance(val, dict):
code_block.properties = val
elif isinstance(val, list):
code_block.tokens = val
elif isinstance(val, str):
code_block.tokens = [val]
else:
raise NotImplementedError(f"Not implement:{val}")
return code_block
raise NotImplementedError(f"Not implement code block:{node.lineno}, {node.end_lineno}, {any_to_str(node)}")
@staticmethod
def _parse_expr(node) -> List:
funcs = {
any_to_str(ast.Constant): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value)],
any_to_str(ast.Call): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value.func)],
}
func = funcs.get(any_to_str(node.value))
if func:
return func(node)
raise NotImplementedError(f"Not implement: {node.value}")
@staticmethod
def _parse_name(n):
if n.asname:
return f"{n.name} as {n.asname}"
return n.name
@staticmethod
def _parse_if(n):
tokens = [RepoParser._parse_variable(n.test.left)]
for item in n.test.comparators:
tokens.append(RepoParser._parse_variable(item))
return tokens
@staticmethod
def _parse_variable(node):
funcs = {
any_to_str(ast.Constant): lambda x: x.value,
any_to_str(ast.Name): lambda x: x.id,
any_to_str(ast.Attribute): lambda x: f"{x.value.id}.{x.attr}",
}
func = funcs.get(any_to_str(node))
if not func:
raise NotImplementedError(f"Not implement:{node}")
return func(node)
@staticmethod
def _parse_assign(node):
return [RepoParser._parse_variable(t) for t in node.targets]
async def rebuild_class_views(self, path: str | Path = None):
if not path:
path = self.base_directory
path = Path(path)
if not path.exists():
return
command = f"pyreverse {str(path)} -o dot"
result = subprocess.run(command, shell=True, check=True, cwd=str(path))
if result.returncode != 0:
raise ValueError(f"{result}")
class_view_pathname = path / "classes.dot"
class_views = await self._parse_classes(class_view_pathname)
packages_pathname = path / "packages.dot"
class_views = RepoParser._repair_namespaces(class_views=class_views, path=path)
class_view_pathname.unlink(missing_ok=True)
packages_pathname.unlink(missing_ok=True)
return class_views
async def _parse_classes(self, class_view_pathname):
class_views = []
if not class_view_pathname.exists():
return class_views
async with aiofiles.open(str(class_view_pathname), mode="r") as reader:
lines = await reader.readlines()
for line in lines:
package_name, info = RepoParser._split_class_line(line)
if not package_name:
continue
class_name, members, functions = re.split(r"(?<!\\)\|", info)
class_info = ClassInfo(name=class_name)
class_info.package = package_name
for m in members.split("\n"):
if not m:
continue
member_name = m.split(":", 1)[0].strip() if ":" in m else m.strip()
class_info.attributes[member_name] = m
for f in functions.split("\n"):
if not f:
continue
function_name, _ = f.split("(", 1)
class_info.methods[function_name] = f
class_views.append(class_info)
return class_views
@staticmethod
def _split_class_line(line):
part_splitor = '" ['
if part_splitor not in line:
return None, None
ix = line.find(part_splitor)
class_name = line[0:ix].replace('"', "")
left = line[ix:]
begin_flag = "label=<{"
end_flag = "}>"
if begin_flag not in left or end_flag not in left:
return None, None
bix = left.find(begin_flag)
eix = left.rfind(end_flag)
info = left[bix + len(begin_flag) : eix]
info = re.sub(r"<br[^>]*>", "\n", info)
return class_name, info
@staticmethod
def _create_path_mapping(path: str | Path) -> Dict[str, str]:
mappings = {
str(path).replace("/", "."): str(path),
}
files = []
try:
directory_path = Path(path)
if not directory_path.exists():
return mappings
for file_path in directory_path.iterdir():
if file_path.is_file():
files.append(str(file_path))
else:
subfolder_files = RepoParser._create_path_mapping(path=file_path)
mappings.update(subfolder_files)
except Exception as e:
logger.error(f"Error: {e}")
for f in files:
mappings[str(Path(f).with_suffix("")).replace("/", ".")] = str(f)
return mappings
@staticmethod
def _repair_namespaces(class_views: List[ClassInfo], path: str | Path) -> List[ClassInfo]:
if not class_views:
return []
c = class_views[0]
full_key = str(path).lstrip("/").replace("/", ".")
root_namespace = RepoParser._find_root(full_key, c.package)
root_path = root_namespace.replace(".", "/")
mappings = RepoParser._create_path_mapping(path=path)
new_mappings = {}
ix_root_namespace = len(root_namespace)
ix_root_path = len(root_path)
for k, v in mappings.items():
nk = k[ix_root_namespace:]
nv = v[ix_root_path:]
new_mappings[nk] = nv
for c in class_views:
c.package = RepoParser._repair_ns(c.package, new_mappings)
return class_views
@staticmethod
def _repair_ns(package, mappings):
file_ns = package
while file_ns != "":
if file_ns not in mappings:
ix = file_ns.rfind(".")
file_ns = file_ns[0:ix]
continue
break
internal_ns = package[ix + 1 :]
ns = mappings[file_ns] + ":" + internal_ns.replace(".", ":")
return ns
@staticmethod
def _find_root(full_key, package) -> str:
left = full_key
while left != "":
if left in package:
break
if "." not in left:
break
ix = left.find(".")
left = left[ix + 1 :]
ix = full_key.rfind(left)
return "." + full_key[0:ix]
def is_func(node):
return isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))

143
metagpt/roles/assistant.py Normal file
View file

@ -0,0 +1,143 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/8/7
@Author : mashenquan
@File : assistant.py
@Desc : I am attempting to incorporate certain symbol concepts from UML into MetaGPT, enabling it to have the
ability to freely construct flows through symbol concatenation. Simultaneously, I am also striving to
make these symbols configurable and standardized, making the process of building flows more convenient.
For more about `fork` node in activity diagrams, see: `https://www.uml-diagrams.org/activity-diagrams.html`
This file defines a `fork` style meta role capable of generating arbitrary roles at runtime based on a
configuration file.
@Modified By: mashenquan, 2023/8/22. A definition has been provided for the return value of _think: returning false
indicates that further reasoning cannot continue.
"""
from enum import Enum
from pathlib import Path
from typing import Optional
from pydantic import Field
from metagpt.actions.skill_action import ArgumentsParingAction, SkillAction
from metagpt.actions.talk_action import TalkAction
from metagpt.config import CONFIG
from metagpt.learn.skill_loader import SkillsDeclaration
from metagpt.logs import logger
from metagpt.memory.brain_memory import BrainMemory
from metagpt.roles import Role
from metagpt.schema import Message
class MessageType(Enum):
Talk = "TALK"
Skill = "SKILL"
class Assistant(Role):
"""Assistant for solving common issues."""
name: str = "Lily"
profile: str = "An assistant"
goal: str = "Help to solve problem"
constraints: str = "Talk in {language}"
desc: str = ""
memory: BrainMemory = Field(default_factory=BrainMemory)
skills: Optional[SkillsDeclaration] = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.constraints = self.constraints.format(language=kwargs.get("language") or CONFIG.language or "Chinese")
async def think(self) -> bool:
"""Everything will be done part by part."""
last_talk = await self.refine_memory()
if not last_talk:
return False
if not self.skills:
skill_path = Path(CONFIG.SKILL_PATH) if CONFIG.SKILL_PATH else None
self.skills = await SkillsDeclaration.load(skill_yaml_file_name=skill_path)
prompt = ""
skills = self.skills.get_skill_list()
for desc, name in skills.items():
prompt += f"If the text explicitly want you to {desc}, return `[SKILL]: {name}` brief and clear. For instance: [SKILL]: {name}\n"
prompt += 'Otherwise, return `[TALK]: {talk}` brief and clear. For instance: if {talk} is "xxxx" return [TALK]: xxxx\n\n'
prompt += f"Now what specific action is explicitly mentioned in the text: {last_talk}\n"
rsp = await self._llm.aask(prompt, [])
logger.info(f"THINK: {prompt}\n, THINK RESULT: {rsp}\n")
return await self._plan(rsp, last_talk=last_talk)
async def act(self) -> Message:
result = await self._rc.todo.run()
if not result:
return None
if isinstance(result, str):
msg = Message(content=result, role="assistant", cause_by=self._rc.todo)
elif isinstance(result, Message):
msg = result
else:
msg = Message(
content=result.content, instruct_content=result.instruct_content, cause_by=type(self._rc.todo)
)
self.memory.add_answer(msg)
return msg
async def talk(self, text):
self.memory.add_talk(Message(content=text))
async def _plan(self, rsp: str, **kwargs) -> bool:
skill, text = BrainMemory.extract_info(input_string=rsp)
handlers = {
MessageType.Talk.value: self.talk_handler,
MessageType.Skill.value: self.skill_handler,
}
handler = handlers.get(skill, self.talk_handler)
return await handler(text, **kwargs)
async def talk_handler(self, text, **kwargs) -> bool:
history = self.memory.history_text
text = kwargs.get("last_talk") or text
self._rc.todo = TalkAction(
context=text, knowledge=self.memory.get_knowledge(), history_summary=history, llm=self._llm, **kwargs
)
return True
async def skill_handler(self, text, **kwargs) -> bool:
last_talk = kwargs.get("last_talk")
skill = self.skills.get_skill(text)
if not skill:
logger.info(f"skill not found: {text}")
return await self.talk_handler(text=last_talk, **kwargs)
action = ArgumentsParingAction(skill=skill, llm=self._llm, ask=last_talk, **kwargs)
await action.run(**kwargs)
if action.args is None:
return await self.talk_handler(text=last_talk, **kwargs)
self._rc.todo = SkillAction(
skill=skill, args=action.args, llm=self._llm, name=skill.name, desc=skill.description
)
return True
async def refine_memory(self) -> str:
last_talk = self.memory.pop_last_talk()
if last_talk is None: # No user feedback, unsure if past conversation is finished.
return None
if not self.memory.is_history_available:
return last_talk
history_summary = await self.memory.summarize(max_words=800, keep_language=True, llm=self._llm)
if last_talk and await self.memory.is_related(text1=last_talk, text2=history_summary, llm=self._llm):
# Merge relevant content.
merged = await self.memory.rewrite(sentence=last_talk, context=history_summary, llm=self._llm)
return f"{merged} {last_talk}"
return last_talk
def get_memory(self) -> str:
return self.memory.json()
def load_memory(self, jsn):
try:
self.memory = BrainMemory(**jsn)
except Exception as e:
logger.exception(f"load error:{e}, data:{jsn}")

View file

@ -43,7 +43,7 @@ from metagpt.schema import (
Documents,
Message,
)
from metagpt.utils.common import any_to_str, any_to_str_set
from metagpt.utils.common import any_to_name, any_to_str, any_to_str_set
IS_PASS_PROMPT = """
{context}
@ -78,13 +78,17 @@ class Engineer(Role):
n_borg: int = 1
use_code_review: bool = False
code_todos: list = []
summarize_todos = []
summarize_todos: list = []
next_todo_action: str = ""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self._init_actions([WriteCode])
self._watch([WriteTasks, SummarizeCode, WriteCode, WriteCodeReview, FixBug])
self.code_todos = []
self.summarize_todos = []
self.next_todo_action = any_to_name(WriteCode)
@staticmethod
def _parse_tasks(task_msg: Document) -> list[str]:
@ -128,8 +132,10 @@ class Engineer(Role):
if self._rc.todo is None:
return None
if isinstance(self._rc.todo, WriteCode):
self.next_todo_action = any_to_name(SummarizeCode)
return await self._act_write_code()
if isinstance(self._rc.todo, SummarizeCode):
self.next_todo_action = any_to_name(WriteCode)
return await self._act_summarize()
return None
@ -301,3 +307,8 @@ class Engineer(Role):
self.summarize_todos.append(SummarizeCode(context=ctx, llm=self._llm))
if self.summarize_todos:
self._rc.todo = self.summarize_todos[0]
@property
def todo(self) -> str:
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
return self.next_todo_action

View file

@ -7,11 +7,11 @@
@Modified By: mashenquan, 2023/11/27. Add `PrepareDocuments` action according to Section 2.2.3.5.1 of RFC 135.
"""
from metagpt.actions import UserRequirement, WritePRD
from metagpt.actions.prepare_documents import PrepareDocuments
from metagpt.config import CONFIG
from metagpt.roles.role import Role
from metagpt.utils.common import any_to_name
class ProductManager(Role):
@ -29,20 +29,28 @@ class ProductManager(Role):
profile: str = "Product Manager"
goal: str = "efficiently create a successful product that meets market demands and user expectations"
constraints: str = "utilize the same language as the user requirements for seamless communication"
todo_action: str = ""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self._init_actions([PrepareDocuments, WritePRD])
self._watch([UserRequirement, PrepareDocuments])
self.todo_action = any_to_name(PrepareDocuments)
async def _think(self) -> None:
async def _think(self) -> bool:
"""Decide what to do"""
if CONFIG.git_repo:
self._set_state(1)
else:
self._set_state(0)
return self._rc.todo
self.todo_action = any_to_name(WritePRD)
return bool(self._rc.todo)
async def _observe(self, ignore_memory=False) -> int:
return await super()._observe(ignore_memory=True)
@property
def todo(self) -> str:
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
return self.todo_action

View file

@ -1,5 +1,6 @@
#!/usr/bin/env python
"""
@Modified By: mashenquan, 2023/8/22. A definition has been provided for the return value of _think: returning false indicates that further reasoning cannot continue.
@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, change the data type of
the `cause_by` value in the `Message` to a string to support the new message distribution feature.
"""
@ -39,6 +40,17 @@ class Researcher(Role):
if self.language not in ("en-us", "zh-cn"):
logger.warning(f"The language `{self.language}` has not been tested, it may not work.")
async def _think(self) -> bool:
if self._rc.todo is None:
self._set_state(0)
return True
if self._rc.state + 1 < len(self._states):
self._set_state(self._rc.state + 1)
else:
self._rc.todo = None
return False
async def _act(self) -> Message:
logger.info(f"{self._setting}: to do {self._rc.todo}({self._rc.todo.name})")
todo = self._rc.todo

View file

@ -4,6 +4,7 @@
@Time : 2023/5/11 14:42
@Author : alexanderwu
@File : role.py
@Modified By: mashenquan, 2023/8/22. A definition has been provided for the return value of _think: returning false indicates that further reasoning cannot continue.
@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116:
1. Merge the `recv` functionality into the `_observe` function. Future message reading operations will be
consolidated within the `_observe` function.
@ -38,6 +39,7 @@ from metagpt.memory import Memory
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.schema import Message, MessageQueue
from metagpt.utils.common import (
any_to_name,
any_to_str,
import_class,
read_json_file,
@ -118,7 +120,7 @@ class RoleContext(BaseModel):
@property
def important_memory(self) -> list[Message]:
"""Get the information corresponding to the watched actions"""
"""Retrieve information corresponding to the attention action."""
return self.memory.get_by_actions(self.watch)
@property
@ -317,6 +319,9 @@ class Role(BaseModel):
# check RoleContext after adding watch actions
self._rc.check(self._role_id)
def is_watch(self, caused_by: str):
return caused_by in self._rc.watch
def subscribe(self, tags: Set[str]):
"""Used to receive Messages with certain tags from the environment. Message will be put into personal message
buffer to be further processed in _observe. By default, a Role subscribes Messages with a tag of its own name
@ -340,6 +345,11 @@ class Role(BaseModel):
env.set_subscription(self, self.subscription)
self.refresh_system_message() # add env message to system message
@property
def action_count(self):
"""Return number of action"""
return len(self._actions)
def _get_prefix(self):
"""Get the role prefix"""
if self.desc:
@ -356,16 +366,18 @@ class Role(BaseModel):
prefix += env_desc
return prefix
async def _think(self) -> None:
"""Think about what to do and decide on the next action"""
async def _think(self) -> bool:
"""Consider what to do and decide on the next course of action. Return false if nothing can be done."""
if len(self._actions) == 1:
# If there is only one action, then only this one can be performed
self._set_state(0)
return
return True
if self.recovered and self._rc.state >= 0:
self._set_state(self._rc.state) # action to run from recovered state
self.recovered = False # avoid max_react_loop out of work
return
self.set_recovered(False) # avoid max_react_loop out of work
return True
prompt = self._get_prefix()
prompt += STATE_TEMPLATE.format(
@ -387,6 +399,7 @@ class Role(BaseModel):
if next_state == -1:
logger.info(f"End actions with {next_state=}")
self._set_state(next_state)
return True
async def _act(self) -> Message:
logger.info(f"{self._setting}: to do {self._rc.todo}({self._rc.todo.name})")
@ -420,17 +433,17 @@ class Role(BaseModel):
async def _observe(self, ignore_memory=False) -> int:
"""Prepare new messages for processing from the message buffer and other sources."""
# Read unprocessed messages from the msg buffer.
news = self._rc.msg_buffer.pop_all()
news = []
if self.recovered:
news = [self.latest_observed_msg] if self.latest_observed_msg else []
else:
self.latest_observed_msg = news[-1] if len(news) > 0 else None # record the latest observed msg
if not news:
news = self._rc.msg_buffer.pop_all()
# Store the read messages in your own memory to prevent duplicate processing.
old_messages = [] if ignore_memory else self._rc.memory.get()
self._rc.memory.add_batch(news)
# Filter out messages of interest.
self._rc.news = self._find_news(news, old_messages)
self._rc.news = [n for n in news if n.cause_by in self._rc.watch and n not in old_messages]
self.latest_observed_msg = self._rc.news[-1] if self._rc.news else None # record the latest observed msg
# Design Rules:
# If you need to further categorize Message objects, you can do so using the Message.set_meta function.
@ -440,6 +453,29 @@ class Role(BaseModel):
logger.debug(f"{self._setting} observed: {news_text}")
return len(self._rc.news)
# async def _observe(self, ignore_memory=False) -> int:
# """Prepare new messages for processing from the message buffer and other sources."""
# # Read unprocessed messages from the msg buffer.
# news = self._rc.msg_buffer.pop_all()
# if self.recovered:
# news = [self.latest_observed_msg] if self.latest_observed_msg else []
# else:
# self.latest_observed_msg = news[-1] if len(news) > 0 else None # record the latest observed msg
#
# # Store the read messages in your own memory to prevent duplicate processing.
# old_messages = [] if ignore_memory else self._rc.memory.get()
# self._rc.memory.add_batch(news)
# # Filter out messages of interest.
# self._rc.news = self._find_news(news, old_messages)
#
# # Design Rules:
# # If you need to further categorize Message objects, you can do so using the Message.set_meta function.
# # msg_buffer is a receiving buffer, avoid adding message data and operations to msg_buffer.
# news_text = [f"{i.role}: {i.content[:20]}..." for i in self._rc.news]
# if news_text:
# logger.debug(f"{self._setting} observed: {news_text}")
# return len(self._rc.news)
def publish_message(self, msg):
"""If the role belongs to env, then the role's messages will be broadcast to env"""
if not msg:
@ -498,23 +534,6 @@ class Role(BaseModel):
self._set_state(state=-1) # current reaction is complete, reset state to -1 and todo back to None
return rsp
# # Replaced by run()
# def recv(self, message: Message) -> None:
# """add message to history."""
# # self._history += f"\n{message}"
# # self._context = self._history
# if message in self._rc.memory.get():
# return
# self._rc.memory.add(message)
# # Replaced by run()
# async def handle(self, message: Message) -> Message:
# """Receive information and reply with actions"""
# # logger.debug(f"{self.name=}, {self.profile=}, {message.role=}")
# self.recv(message)
#
# return await self._react()
def get_memories(self, k=0) -> list[Message]:
"""A wrapper to return the most recent k memories of this role, return all when k=0"""
return self._rc.memory.get(k=k)
@ -551,3 +570,20 @@ class Role(BaseModel):
def is_idle(self) -> bool:
"""If true, all actions have been executed."""
return not self._rc.news and not self._rc.todo and self._rc.msg_buffer.empty()
async def think(self) -> Action:
"""The exported `think` function"""
await self._think()
return self._rc.todo
async def act(self) -> ActionOutput:
"""The exported `act` function"""
msg = await self._act()
return ActionOutput(content=msg.content, instruct_content=msg.instruct_content)
@property
def todo(self) -> str:
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
if self._actions:
return any_to_name(self._actions[0])
return ""

View file

@ -15,14 +15,15 @@ from metagpt.tools import SearchEngineType
class Sales(Role):
name: str = "Xiaomei"
profile: str = "Retail sales guide"
desc: str = "I am a sales guide in retail. My name is Xiaomei. I will answer some customer questions next, and I "
"will answer questions only based on the information in the knowledge base."
"If I feel that you can't get the answer from the reference material, then I will directly reply that"
" I don't know, and I won't tell you that this is from the knowledge base,"
"but pretend to be what I know. Note that each of my replies will be replied in the tone of a "
"professional guide"
name: str = "John Smith"
profile: str = "Retail Sales Guide"
desc: str = (
"As a Retail Sales Guide, my name is John Smith. I specialize in addressing customer inquiries with "
"expertise and precision. My responses are based solely on the information available in our knowledge"
" base. In instances where your query extends beyond this scope, I'll honestly indicate my inability "
"to provide an answer, rather than speculate or assume. Please note, each of my replies will be "
"delivered with the professionalism and courtesy expected of a seasoned sales guide."
)
store: Optional[BaseStore] = None

118
metagpt/roles/teacher.py Normal file
View file

@ -0,0 +1,118 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/7/27
@Author : mashenquan
@File : teacher.py
@Desc : Used by Agent Store
@Modified By: mashenquan, 2023/8/22. A definition has been provided for the return value of _think: returning false indicates that further reasoning cannot continue.
"""
import re
import aiofiles
from metagpt.actions import UserRequirement
from metagpt.actions.write_teaching_plan import TeachingPlanBlock, WriteTeachingPlanPart
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.utils.common import any_to_str
class Teacher(Role):
"""Support configurable teacher roles,
with native and teaching languages being replaceable through configurations."""
name: str = "Lily"
profile: str = "{teaching_language} Teacher"
goal: str = "writing a {language} teaching plan part by part"
constraints: str = "writing in {language}"
desc: str = ""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.name = WriteTeachingPlanPart.format_value(self.name)
self.profile = WriteTeachingPlanPart.format_value(self.profile)
self.goal = WriteTeachingPlanPart.format_value(self.goal)
self.constraints = WriteTeachingPlanPart.format_value(self.constraints)
self.desc = WriteTeachingPlanPart.format_value(self.desc)
async def _think(self) -> bool:
"""Everything will be done part by part."""
if not self._actions:
if not self._rc.news or self._rc.news[0].cause_by != any_to_str(UserRequirement):
raise ValueError("Lesson content invalid.")
actions = []
print(TeachingPlanBlock.TOPICS)
for topic in TeachingPlanBlock.TOPICS:
act = WriteTeachingPlanPart(context=self._rc.news[0].content, topic=topic, llm=self._llm)
actions.append(act)
self._init_actions(actions)
if self._rc.todo is None:
self._set_state(0)
return True
if self._rc.state + 1 < len(self._states):
self._set_state(self._rc.state + 1)
return True
self._rc.todo = None
return False
async def _react(self) -> Message:
ret = Message(content="")
while True:
await self._think()
if self._rc.todo is None:
break
logger.debug(f"{self._setting}: {self._rc.state=}, will do {self._rc.todo}")
msg = await self._act()
if ret.content != "":
ret.content += "\n\n\n"
ret.content += msg.content
logger.info(ret.content)
await self.save(ret.content)
return ret
async def save(self, content):
"""Save teaching plan"""
filename = Teacher.new_file_name(self.course_title)
pathname = CONFIG.workspace_path / "teaching_plan"
pathname.mkdir(exist_ok=True)
pathname = pathname / filename
try:
async with aiofiles.open(str(pathname), mode="w", encoding="utf-8") as writer:
await writer.write(content)
except Exception as e:
logger.error(f"Save failed{e}")
logger.info(f"Save to:{pathname}")
@staticmethod
def new_file_name(lesson_title, ext=".md"):
"""Create a related file name based on `lesson_title` and `ext`."""
# Define the special characters that need to be replaced.
illegal_chars = r'[#@$%!*&\\/:*?"<>|\n\t \']'
# Replace the special characters with underscores.
filename = re.sub(illegal_chars, "_", lesson_title) + ext
return re.sub(r"_+", "_", filename)
@property
def course_title(self):
"""Return course title of teaching plan"""
default_title = "teaching_plan"
for act in self._actions:
if act.topic != TeachingPlanBlock.COURSE_TITLE:
continue
if act.rsp is None:
return default_title
title = act.rsp.lstrip("# \n")
if "\n" in title:
ix = title.index("\n")
title = title[0:ix]
return title
return default_title

View file

@ -23,7 +23,7 @@ from abc import ABC
from asyncio import Queue, QueueEmpty, wait_for
from json import JSONDecodeError
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Type, TypedDict, TypeVar
from typing import Any, Dict, List, Optional, Set, Type, TypeVar
from pydantic import BaseModel, Field
@ -46,7 +46,7 @@ from metagpt.utils.serialize import (
)
class RawMessage(TypedDict):
class SimpleMessage(BaseModel):
content: str
role: str
@ -162,8 +162,7 @@ class Message(BaseModel):
# prefix = '-'.join([self.role, str(self.cause_by)])
if self.instruct_content:
return f"{self.role}: {self.instruct_content.dict()}"
else:
return f"{self.role}: {self.content}"
return f"{self.role}: {self.content}"
def __repr__(self):
return self.__str__()
@ -180,8 +179,19 @@ class Message(BaseModel):
@handle_exception(exception_type=JSONDecodeError, default_return=None)
def load(val):
"""Convert the json string to object."""
i = json.loads(val)
return Message(**i)
try:
m = json.loads(val)
id = m.get("id")
if "id" in m:
del m["id"]
msg = Message(**m)
if id:
msg.id = id
return msg
except JSONDecodeError as err:
logger.error(f"parse json failed: {val}, error:{err}")
return None
class UserMessage(Message):

View file

@ -90,9 +90,12 @@ class Team(BaseModel):
CONFIG.max_budget = investment
logger.info(f"Investment: ${investment}.")
def _check_balance(self):
if CONFIG.total_cost > CONFIG.max_budget:
raise NoMoneyException(CONFIG.total_cost, f"Insufficient funds: {CONFIG.max_budget}")
@staticmethod
def _check_balance():
if CONFIG.cost_manager.total_cost > CONFIG.cost_manager.max_budget:
raise NoMoneyException(
CONFIG.cost_manager.total_cost, f"Insufficient funds: {CONFIG.cost_manager.max_budget}"
)
def run_project(self, idea, send_to: str = ""):
"""Run a project from publishing user requirement."""
@ -100,7 +103,8 @@ class Team(BaseModel):
# Human requirement.
self.env.publish_message(
Message(role="Human", content=idea, cause_by=UserRequirement, send_to=send_to or MESSAGE_ROUTE_TO_ALL)
Message(role="Human", content=idea, cause_by=UserRequirement, send_to=send_to or MESSAGE_ROUTE_TO_ALL),
peekable=False,
)
def start_project(self, idea, send_to: str = ""):
@ -120,7 +124,7 @@ class Team(BaseModel):
logger.info(self.json(ensure_ascii=False))
@serialize_decorator
async def run(self, n_round=3, idea="", send_to=""):
async def run(self, n_round=3, idea="", send_to="", auto_archive=True):
"""Run company until target round or no money"""
if idea:
self.run_project(idea=idea, send_to=send_to)
@ -132,6 +136,5 @@ class Team(BaseModel):
self._check_balance()
await self.env.run()
if CONFIG.git_repo:
CONFIG.git_repo.archive()
self.env.archive(auto_archive)
return self.env.history

View file

@ -22,3 +22,8 @@ class WebBrowserEngineType(Enum):
PLAYWRIGHT = "playwright"
SELENIUM = "selenium"
CUSTOM = "custom"
@classmethod
def __missing__(cls, key):
"""Default type conversion"""
return cls.CUSTOM

View file

@ -4,39 +4,110 @@
@Time : 2023/6/9 22:22
@Author : Leo Xiao
@File : azure_tts.py
@Modified by: mashenquan, 2023/8/17. Azure TTS OAS3 api, which provides text-to-speech functionality
"""
import asyncio
import base64
from pathlib import Path
from uuid import uuid4
import aiofiles
from azure.cognitiveservices.speech import AudioConfig, SpeechConfig, SpeechSynthesizer
from metagpt.config import CONFIG
from metagpt.config import CONFIG, Config
from metagpt.logs import logger
class AzureTTS:
"""https://learn.microsoft.com/zh-cn/azure/cognitive-services/speech-service/language-support?tabs=tts#voice-styles-and-roles"""
"""Azure Text-to-Speech"""
@classmethod
def synthesize_speech(cls, lang, voice, role, text, output_file):
subscription_key = CONFIG.get("AZURE_TTS_SUBSCRIPTION_KEY")
region = CONFIG.get("AZURE_TTS_REGION")
speech_config = SpeechConfig(subscription=subscription_key, region=region)
def __init__(self, subscription_key, region):
"""
:param subscription_key: key is used to access your Azure AI service API, see: `https://portal.azure.com/` > `Resource Management` > `Keys and Endpoint`
:param region: This is the location (or region) of your resource. You may need to use this field when making calls to this API.
"""
self.subscription_key = subscription_key if subscription_key else CONFIG.AZURE_TTS_SUBSCRIPTION_KEY
self.region = region if region else CONFIG.AZURE_TTS_REGION
# 参数参考https://learn.microsoft.com/zh-cn/azure/cognitive-services/speech-service/language-support?tabs=tts#voice-styles-and-roles
async def synthesize_speech(self, lang, voice, text, output_file):
speech_config = SpeechConfig(subscription=self.subscription_key, region=self.region)
speech_config.speech_synthesis_voice_name = voice
audio_config = AudioConfig(filename=output_file)
synthesizer = SpeechSynthesizer(speech_config=speech_config, audio_config=audio_config)
# if voice=="zh-CN-YunxiNeural":
ssml_string = f"""
<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xml:lang='{lang}' xmlns:mstts='http://www.w3.org/2001/mstts'>
<voice name='{voice}'>
<mstts:express-as style='affectionate' role='{role}'>
{text}
</mstts:express-as>
</voice>
</speak>
"""
# More detail: https://learn.microsoft.com/en-us/azure/ai-services/speech-service/speech-synthesis-markup-voice
ssml_string = (
"<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' "
f"xml:lang='{lang}' xmlns:mstts='http://www.w3.org/2001/mstts'>"
f"<voice name='{voice}'>{text}</voice></speak>"
)
synthesizer.speak_ssml_async(ssml_string).get()
return synthesizer.speak_ssml_async(ssml_string).get()
@staticmethod
def role_style_text(role, style, text):
return f'<mstts:express-as role="{role}" style="{style}">{text}</mstts:express-as>'
@staticmethod
def role_text(role, text):
return f'<mstts:express-as role="{role}">{text}</mstts:express-as>'
@staticmethod
def style_text(style, text):
return f'<mstts:express-as style="{style}">{text}</mstts:express-as>'
# Export
async def oas3_azsure_tts(text, lang="", voice="", style="", role="", subscription_key="", region=""):
"""Text to speech
For more details, check out:`https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
:param lang: The value can contain a language code such as en (English), or a locale such as en-US (English - United States). For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
:param voice: For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`, `https://speech.microsoft.com/portal/voicegallery`
:param style: Speaking style to express different emotions like cheerfulness, empathy, and calm. For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
:param role: With roles, the same voice can act as a different age and gender. For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
:param text: The text used for voice conversion.
:param subscription_key: key is used to access your Azure AI service API, see: `https://portal.azure.com/` > `Resource Management` > `Keys and Endpoint`
:param region: This is the location (or region) of your resource. You may need to use this field when making calls to this API.
:return: Returns the Base64-encoded .wav file data if successful, otherwise an empty string.
"""
if not text:
return ""
if not lang:
lang = "zh-CN"
if not voice:
voice = "zh-CN-XiaomoNeural"
if not role:
role = "Girl"
if not style:
style = "affectionate"
if not subscription_key:
subscription_key = CONFIG.AZURE_TTS_SUBSCRIPTION_KEY
if not region:
region = CONFIG.AZURE_TTS_REGION
xml_value = AzureTTS.role_style_text(role=role, style=style, text=text)
tts = AzureTTS(subscription_key=subscription_key, region=region)
filename = Path(__file__).resolve().parent / (str(uuid4()).replace("-", "") + ".wav")
try:
await tts.synthesize_speech(lang=lang, voice=voice, text=xml_value, output_file=str(filename))
async with aiofiles.open(filename, mode="rb") as reader:
data = await reader.read()
base64_string = base64.b64encode(data).decode("utf-8")
filename.unlink()
except Exception as e:
logger.error(f"text:{text}, error:{e}")
return ""
return base64_string
if __name__ == "__main__":
azure_tts = AzureTTS()
azure_tts.synthesize_speech("zh-CN", "zh-CN-YunxiNeural", "Boy", "Hello, I am Kaka", "output.wav")
Config()
loop = asyncio.new_event_loop()
v = loop.create_task(oas3_azsure_tts("测试test"))
loop.run_until_complete(v)
print(v)

27
metagpt/tools/hello.py Normal file
View file

@ -0,0 +1,27 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/2 16:03
@Author : mashenquan
@File : hello.py
@Desc : Implement the OpenAPI Specification 3.0 demo and use the following command to test the HTTP service:
curl -X 'POST' \
'http://localhost:8080/openapi/greeting/dave' \
-H 'accept: text/plain' \
-H 'Content-Type: application/json' \
-d '{}'
"""
import connexion
# openapi implement
async def post_greeting(name: str) -> str:
return f"Hello {name}\n"
if __name__ == "__main__":
app = connexion.AioHttpApp(__name__, specification_dir="../../.well-known/")
app.add_api("openapi.yaml", arguments={"title": "Hello World Example"})
app.run(port=8080)

View file

@ -0,0 +1,162 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/8/17
@Author : mashenquan
@File : iflytek_tts.py
@Desc : iFLYTEK TTS OAS3 api, which provides text-to-speech functionality
"""
import asyncio
import base64
import hashlib
import hmac
import json
import uuid
from datetime import datetime
from enum import Enum
from pathlib import Path
from time import mktime
from typing import Optional
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
import aiofiles
import websockets as websockets
from pydantic import BaseModel
from metagpt.config import CONFIG
from metagpt.logs import logger
class IFlyTekTTSStatus(Enum):
STATUS_FIRST_FRAME = 0 # The first frame
STATUS_CONTINUE_FRAME = 1 # The intermediate frame
STATUS_LAST_FRAME = 2 # The last frame
class AudioData(BaseModel):
audio: str
status: int
ced: str
class IFlyTekTTSResponse(BaseModel):
code: int
message: str
data: Optional[AudioData] = None
sid: str
DEFAULT_IFLYTEK_VOICE = "xiaoyan"
class IFlyTekTTS(object):
def __init__(self, app_id: str, api_key: str, api_secret: str):
"""
:param app_id: Application ID is used to access your iFlyTek service API, see: `https://console.xfyun.cn/services/tts`
:param api_key: WebAPI argument, see: `https://console.xfyun.cn/services/tts`
:param api_secret: WebAPI argument, see: `https://console.xfyun.cn/services/tts`
"""
self.app_id = app_id or CONFIG.IFLYTEK_APP_ID
self.api_key = api_key or CONFIG.IFLYTEK_API_KEY
self.api_secret = api_secret or CONFIG.API_SECRET
async def synthesize_speech(self, text, output_file: str, voice=DEFAULT_IFLYTEK_VOICE):
url = self._create_url()
data = {
"common": {"app_id": self.app_id},
"business": {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": voice, "tte": "utf8"},
"data": {"status": 2, "text": str(base64.b64encode(text.encode("utf-8")), "UTF8")},
}
req = json.dumps(data)
async with websockets.connect(url) as websocket:
# send request
await websocket.send(req)
# receive frames
async with aiofiles.open(str(output_file), "w") as writer:
while True:
v = await websocket.recv()
rsp = IFlyTekTTSResponse(**json.loads(v))
if rsp.data:
await writer.write(rsp.data.audio)
if rsp.data.status != IFlyTekTTSStatus.STATUS_LAST_FRAME.value:
continue
break
def _create_url(self):
"""Create request url"""
url = "wss://tts-api.xfyun.cn/v2/tts"
# Generate a timestamp in RFC1123 format
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
# Perform HMAC-SHA256 encryption
signature_sha = hmac.new(
self.api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256
).digest()
signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
authorization_origin = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (
self.api_key,
"hmac-sha256",
"host date request-line",
signature_sha,
)
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
# Combine the authentication parameters of the request into a dictionary.
v = {"authorization": authorization, "date": date, "host": "ws-api.xfyun.cn"}
# Concatenate the authentication parameters to generate the URL.
url = url + "?" + urlencode(v)
return url
# Export
async def oas3_iflytek_tts(text: str, voice: str = "", app_id: str = "", api_key: str = "", api_secret: str = ""):
"""Text to speech
For more details, check out:`https://www.xfyun.cn/doc/tts/online_tts/API.html`
:param voice: Default `xiaoyan`. For more details, checkout: `https://www.xfyun.cn/doc/tts/online_tts/API.html#%E6%8E%A5%E5%8F%A3%E8%B0%83%E7%94%A8%E6%B5%81%E7%A8%8B`
:param text: The text used for voice conversion.
:param app_id: Application ID is used to access your iFlyTek service API, see: `https://console.xfyun.cn/services/tts`
:param api_key: WebAPI argument, see: `https://console.xfyun.cn/services/tts`
:param api_secret: WebAPI argument, see: `https://console.xfyun.cn/services/tts`
:return: Returns the Base64-encoded .mp3 file data if successful, otherwise an empty string.
"""
if not app_id:
app_id = CONFIG.IFLYTEK_APP_ID
if not api_key:
api_key = CONFIG.IFLYTEK_API_KEY
if not api_secret:
api_secret = CONFIG.IFLYTEK_API_SECRET
if not voice:
voice = CONFIG.IFLYTEK_VOICE or DEFAULT_IFLYTEK_VOICE
filename = Path(__file__).parent / (uuid.uuid4().hex + ".mp3")
try:
tts = IFlyTekTTS(app_id=app_id, api_key=api_key, api_secret=api_secret)
await tts.synthesize_speech(text=text, output_file=str(filename), voice=voice)
async with aiofiles.open(str(filename), mode="r") as reader:
base64_string = await reader.read()
except Exception as e:
logger.error(f"text:{text}, error:{e}")
base64_string = ""
finally:
filename.unlink()
return base64_string
if __name__ == "__main__":
asyncio.get_event_loop().run_until_complete(
oas3_iflytek_tts(
text="你好hello",
app_id="f7acef62",
api_key="fda72e3aa286042a492525816a5efa08",
api_secret="ZDk3NjdiMDBkODJlOWQ1NjRjMGI2NDY4",
)
)

View file

@ -0,0 +1,44 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/8/17
@Author : mashenquan
@File : metagpt_oas3_api_svc.py
@Desc : MetaGPT OpenAPI Specification 3.0 REST API service
"""
import asyncio
import sys
from pathlib import Path
import connexion
sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) # fix-bug: No module named 'metagpt'
def oas_http_svc():
"""Start the OAS 3.0 OpenAPI HTTP service"""
app = connexion.AioHttpApp(__name__, specification_dir="../../.well-known/")
app.add_api("metagpt_oas3_api.yaml")
app.add_api("openapi.yaml")
app.run(port=8080)
async def async_main():
"""Start the OAS 3.0 OpenAPI HTTP service in the background."""
loop = asyncio.get_event_loop()
loop.run_in_executor(None, oas_http_svc)
# TODO: replace following codes:
while True:
await asyncio.sleep(1)
print("sleep")
def main():
print("http://localhost:8080/oas3/ui/")
oas_http_svc()
if __name__ == "__main__":
# asyncio.run(async_main())
main()

View file

@ -0,0 +1,110 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/8/18
@Author : mashenquan
@File : metagpt_text_to_image.py
@Desc : MetaGPT Text-to-Image OAS3 api, which provides text-to-image functionality.
"""
import asyncio
import base64
from typing import Dict, List
import aiohttp
import requests
from pydantic import BaseModel
from metagpt.config import CONFIG, Config
from metagpt.logs import logger
class MetaGPTText2Image:
def __init__(self, model_url):
"""
:param model_url: Model reset api url
"""
self.model_url = model_url if model_url else CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL
async def text_2_image(self, text, size_type="512x512"):
"""Text to image
:param text: The text used for image conversion.
:param size_type: One of ['512x512', '512x768']
:return: The image data is returned in Base64 encoding.
"""
headers = {"Content-Type": "application/json"}
dims = size_type.split("x")
data = {
"prompt": text,
"negative_prompt": "(easynegative:0.8),black, dark,Low resolution",
"override_settings": {"sd_model_checkpoint": "galaxytimemachinesGTM_photoV20"},
"seed": -1,
"batch_size": 1,
"n_iter": 1,
"steps": 20,
"cfg_scale": 11,
"width": int(dims[0]),
"height": int(dims[1]), # 768,
"restore_faces": False,
"tiling": False,
"do_not_save_samples": False,
"do_not_save_grid": False,
"enable_hr": False,
"hr_scale": 2,
"hr_upscaler": "Latent",
"hr_second_pass_steps": 0,
"hr_resize_x": 0,
"hr_resize_y": 0,
"hr_upscale_to_x": 0,
"hr_upscale_to_y": 0,
"truncate_x": 0,
"truncate_y": 0,
"applied_old_hires_behavior_to": None,
"eta": None,
"sampler_index": "DPM++ SDE Karras",
"alwayson_scripts": {},
}
class ImageResult(BaseModel):
images: List
parameters: Dict
try:
async with aiohttp.ClientSession() as session:
async with session.post(self.model_url, headers=headers, json=data) as response:
result = ImageResult(**await response.json())
if len(result.images) == 0:
return ""
return result.images[0]
except requests.exceptions.RequestException as e:
logger.error(f"An error occurred:{e}")
return ""
# Export
async def oas3_metagpt_text_to_image(text, size_type: str = "512x512", model_url=""):
"""Text to image
:param text: The text used for image conversion.
:param model_url: Model reset api
:param size_type: One of ['512x512', '512x768']
:return: The image data is returned in Base64 encoding.
"""
if not text:
return ""
if not model_url:
model_url = CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
return await MetaGPTText2Image(model_url).text_2_image(text, size_type=size_type)
if __name__ == "__main__":
Config()
loop = asyncio.new_event_loop()
task = loop.create_task(oas3_metagpt_text_to_image("Panda emoji"))
v = loop.run_until_complete(task)
print(v)
data = base64.b64decode(v)
with open("tmp.png", mode="wb") as writer:
writer.write(data)
print(v)

View file

@ -0,0 +1,90 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/8/18
@Author : mashenquan
@File : openai_text_to_embedding.py
@Desc : OpenAI Text-to-Embedding OAS3 api, which provides text-to-embedding functionality.
For more details, checkout: `https://platform.openai.com/docs/api-reference/embeddings/object`
"""
import asyncio
from typing import List
import aiohttp
import requests
from pydantic import BaseModel
from metagpt.config import CONFIG, Config
from metagpt.logs import logger
class Embedding(BaseModel):
"""Represents an embedding vector returned by embedding endpoint."""
object: str # The object type, which is always "embedding".
embedding: List[
float
] # The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the embedding guide.
index: int # The index of the embedding in the list of embeddings.
class Usage(BaseModel):
prompt_tokens: int
total_tokens: int
class ResultEmbedding(BaseModel):
object: str
data: List[Embedding]
model: str
usage: Usage
class OpenAIText2Embedding:
def __init__(self, openai_api_key):
"""
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
"""
self.openai_api_key = openai_api_key if openai_api_key else CONFIG.OPENAI_API_KEY
async def text_2_embedding(self, text, model="text-embedding-ada-002"):
"""Text to embedding
:param text: The text used for embedding.
:param model: One of ['text-embedding-ada-002'], ID of the model to use. For more details, checkout: `https://api.openai.com/v1/models`.
:return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`.
"""
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.openai_api_key}"}
data = {"input": text, "model": model}
try:
async with aiohttp.ClientSession() as session:
async with session.post("https://api.openai.com/v1/embeddings", headers=headers, json=data) as response:
return await response.json()
except requests.exceptions.RequestException as e:
logger.error(f"An error occurred:{e}")
return {}
# Export
async def oas3_openai_text_to_embedding(text, model="text-embedding-ada-002", openai_api_key=""):
"""Text to embedding
:param text: The text used for embedding.
:param model: One of ['text-embedding-ada-002'], ID of the model to use. For more details, checkout: `https://api.openai.com/v1/models`.
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
:return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`.
"""
if not text:
return ""
if not openai_api_key:
openai_api_key = CONFIG.OPENAI_API_KEY
return await OpenAIText2Embedding(openai_api_key).text_2_embedding(text, model=model)
if __name__ == "__main__":
Config()
loop = asyncio.new_event_loop()
task = loop.create_task(oas3_openai_text_to_embedding("Panda emoji"))
v = loop.run_until_complete(task)
print(v)

View file

@ -0,0 +1,86 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/8/17
@Author : mashenquan
@File : openai_text_to_image.py
@Desc : OpenAI Text-to-Image OAS3 api, which provides text-to-image functionality.
"""
import asyncio
import base64
import aiohttp
import requests
from metagpt.config import Config
from metagpt.llm import LLM
from metagpt.logs import logger
class OpenAIText2Image:
def __init__(self):
"""
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
"""
self._llm = LLM()
self._client = self._llm.async_client
def __del__(self):
if self._llm:
self._llm.close()
async def text_2_image(self, text, size_type="1024x1024"):
"""Text to image
:param text: The text used for image conversion.
:param size_type: One of ['256x256', '512x512', '1024x1024']
:return: The image data is returned in Base64 encoding.
"""
try:
result = await self._client.images.generate(prompt=text, n=1, size=size_type)
except Exception as e:
logger.error(f"An error occurred:{e}")
return ""
if result and len(result.data) > 0:
return await OpenAIText2Image.get_image_data(result.data[0].url)
return ""
@staticmethod
async def get_image_data(url):
"""Fetch image data from a URL and encode it as Base64
:param url: Image url
:return: Base64-encoded image data.
"""
try:
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
response.raise_for_status() # 如果是 4xx 或 5xx 响应,会引发异常
image_data = await response.read()
base64_image = base64.b64encode(image_data).decode("utf-8")
return base64_image
except requests.exceptions.RequestException as e:
logger.error(f"An error occurred:{e}")
return ""
# Export
async def oas3_openai_text_to_image(text, size_type: str = "1024x1024"):
"""Text to image
:param text: The text used for image conversion.
:param size_type: One of ['256x256', '512x512', '1024x1024']
:return: The image data is returned in Base64 encoding.
"""
if not text:
return ""
return await OpenAIText2Image().text_2_image(text, size_type=size_type)
if __name__ == "__main__":
Config()
loop = asyncio.new_event_loop()
task = loop.create_task(oas3_openai_text_to_image("Panda emoji"))
v = loop.run_until_complete(task)
print(v)

View file

@ -6,7 +6,6 @@ import asyncio
import base64
import io
import json
import os
from os.path import join
from typing import List
@ -14,8 +13,7 @@ from aiohttp import ClientSession
from PIL import Image, PngImagePlugin
from metagpt.config import CONFIG
# from metagpt.const import WORKSPACE_ROOT
from metagpt.const import SD_OUTPUT_FILE_REPO
from metagpt.logs import logger
payload = {
@ -79,10 +77,10 @@ class SDEngine:
return self.payload
def _save(self, imgs, save_name=""):
save_dir = CONFIG.workspace_path / "resources" / "SD_Output"
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
batch_decode_base64_to_image(imgs, save_dir, save_name=save_name)
save_dir = CONFIG.workspace_path / SD_OUTPUT_FILE_REPO
if not save_dir.exists():
save_dir.mkdir(parents=True, exist_ok=True)
batch_decode_base64_to_image(imgs, str(save_dir), save_name=save_name)
async def run_t2i(self, prompts: List):
# Asynchronously run the SD API for multiple prompts

View file

@ -1,4 +1,7 @@
#!/usr/bin/env python
"""
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
from __future__ import annotations
@ -17,14 +20,16 @@ class WebBrowserEngine:
run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None,
):
engine = engine or CONFIG.web_browser_engine
if engine is None:
raise NotImplementedError
if engine == WebBrowserEngineType.PLAYWRIGHT:
if WebBrowserEngineType(engine) is WebBrowserEngineType.PLAYWRIGHT:
module = "metagpt.tools.web_browser_engine_playwright"
run_func = importlib.import_module(module).PlaywrightWrapper().run
elif engine == WebBrowserEngineType.SELENIUM:
elif WebBrowserEngineType(engine) is WebBrowserEngineType.SELENIUM:
module = "metagpt.tools.web_browser_engine_selenium"
run_func = importlib.import_module(module).SeleniumWrapper().run
elif engine == WebBrowserEngineType.CUSTOM:
elif WebBrowserEngineType(engine) is WebBrowserEngineType.CUSTOM:
run_func = run_func
else:
raise NotImplementedError
@ -47,6 +52,6 @@ if __name__ == "__main__":
import fire
async def main(url: str, *urls: str, engine_type: Literal["playwright", "selenium"] = "playwright", **kwargs):
return await WebBrowserEngine(WebBrowserEngineType(engine_type), **kwargs).run(url, *urls)
return await WebBrowserEngine(engine=WebBrowserEngineType(engine_type), **kwargs).run(url, *urls)
fire.Fire(main)

View file

@ -1,4 +1,8 @@
#!/usr/bin/env python
"""
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
from __future__ import annotations
import asyncio
@ -144,6 +148,6 @@ if __name__ == "__main__":
import fire
async def main(url: str, *urls: str, browser_type: str = "chromium", **kwargs):
return await PlaywrightWrapper(browser_type, **kwargs).run(url, *urls)
return await PlaywrightWrapper(browser_type=browser_type, **kwargs).run(url, *urls)
fire.Fire(main)

View file

@ -1,11 +1,15 @@
#!/usr/bin/env python
"""
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
from __future__ import annotations
import asyncio
import importlib
from concurrent import futures
from copy import deepcopy
from typing import Literal
from typing import Dict, Literal
from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as EC
@ -29,6 +33,7 @@ class SeleniumWrapper:
def __init__(
self,
options: Dict,
browser_type: Literal["chrome", "firefox", "edge", "ie"] | None = None,
launch_kwargs: dict | None = None,
*,
@ -120,6 +125,6 @@ if __name__ == "__main__":
import fire
async def main(url: str, *urls: str, browser_type: str = "chrome", **kwargs):
return await SeleniumWrapper(browser_type, **kwargs).run(url, *urls)
return await SeleniumWrapper(browser_type=browser_type, **kwargs).run(url, *urls)
fire.Fire(main)

View file

@ -23,7 +23,7 @@ import sys
import traceback
import typing
from pathlib import Path
from typing import Any, List, Tuple, Union, get_args, get_origin
from typing import Any, Callable, List, Tuple, Union, get_args, get_origin
import aiofiles
import loguru
@ -48,7 +48,7 @@ def check_cmd_exists(command) -> int:
return result
def require_python_version(req_version: tuple[int]) -> bool:
def require_python_version(req_version: Tuple) -> bool:
if not (2 <= len(req_version) <= 3):
raise ValueError("req_version should be (3, 9) or (3, 10, 13)")
return True if sys.version_info > req_version else False
@ -367,7 +367,7 @@ def get_class_name(cls) -> str:
return f"{cls.__module__}.{cls.__name__}"
def any_to_str(val: str | typing.Callable) -> str:
def any_to_str(val: str | Callable) -> str:
"""Return the class name or the class name of the object, or 'val' if it's a string type."""
if isinstance(val, str):
return val
@ -406,6 +406,21 @@ def is_subscribed(message: "Message", tags: set):
return False
def any_to_name(val):
"""
Convert a value to its name by extracting the last part of the dotted path.
:param val: The value to convert.
:return: The name of the value.
"""
return any_to_str(val).split(".")[-1]
def concat_namespace(*args) -> str:
return ":".join(str(value) for value in args)
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
"""
Generates a logging function to be used after a call is retried.
@ -520,3 +535,20 @@ async def aread(file_path: str) -> str:
async with aiofiles.open(str(file_path), mode="r") as reader:
content = await reader.read()
return content
async def read_file_block(filename: str | Path, lineno: int, end_lineno: int):
if not Path(filename).exists():
return ""
lines = []
async with aiofiles.open(str(filename), mode="r") as reader:
ix = 0
while ix < end_lineno:
ix += 1
line = await reader.readline()
if ix < lineno:
continue
if ix > end_lineno:
break
lines.append(line)
return "".join(lines)

View file

@ -0,0 +1,82 @@
# -*- coding: utf-8 -*-
"""
@Time : 2023/8/28
@Author : mashenquan
@File : openai.py
@Desc : mashenquan, 2023/8/28. Separate the `CostManager` class to support user-level cost accounting.
"""
from typing import NamedTuple
from pydantic import BaseModel
from metagpt.logs import logger
from metagpt.utils.token_counter import TOKEN_COSTS
class Costs(NamedTuple):
total_prompt_tokens: int
total_completion_tokens: int
total_cost: float
total_budget: float
class CostManager(BaseModel):
"""Calculate the overhead of using the interface."""
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
total_budget: float = 0
max_budget: float = 10.0
total_cost: float = 0
def update_cost(self, prompt_tokens, completion_tokens, model):
"""
Update the total cost, prompt tokens, and completion tokens.
Args:
prompt_tokens (int): The number of tokens used in the prompt.
completion_tokens (int): The number of tokens used in the completion.
model (str): The model used for the API call.
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
cost = (
prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]
) / 1000
self.total_cost += cost
logger.info(
f"Total running cost: ${self.total_cost:.3f} | Max budget: ${self.max_budget:.3f} | "
f"Current cost: ${cost:.3f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
)
def get_total_prompt_tokens(self):
"""
Get the total number of prompt tokens.
Returns:
int: The total number of prompt tokens.
"""
return self.total_prompt_tokens
def get_total_completion_tokens(self):
"""
Get the total number of completion tokens.
Returns:
int: The total number of completion tokens.
"""
return self.total_completion_tokens
def get_total_cost(self):
"""
Get the total cost of API calls.
Returns:
float: The total cost of API calls.
"""
return self.total_cost
def get_costs(self) -> Costs:
"""Get all costs"""
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)

View file

@ -0,0 +1,84 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/19
@Author : mashenquan
@File : di_graph_repository.py
@Desc : Graph repository based on DiGraph
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import List
import aiofiles
import networkx
from metagpt.utils.graph_repository import SPO, GraphRepository
class DiGraphRepository(GraphRepository):
def __init__(self, name: str, **kwargs):
super().__init__(name=name, **kwargs)
self._repo = networkx.DiGraph()
async def insert(self, subject: str, predicate: str, object_: str):
self._repo.add_edge(subject, object_, predicate=predicate)
async def upsert(self, subject: str, predicate: str, object_: str):
pass
async def update(self, subject: str, predicate: str, object_: str):
pass
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
result = []
for s, o, p in self._repo.edges(data="predicate"):
if subject and subject != s:
continue
if predicate and predicate != p:
continue
if object_ and object_ != o:
continue
result.append(SPO(subject=s, predicate=p, object_=o))
return result
def json(self) -> str:
m = networkx.node_link_data(self._repo)
data = json.dumps(m)
return data
async def save(self, path: str | Path = None):
data = self.json()
path = path or self._kwargs.get("root")
if not path.exists():
path.mkdir(parents=True, exist_ok=True)
pathname = Path(path) / self.name
async with aiofiles.open(str(pathname.with_suffix(".json")), mode="w", encoding="utf-8") as writer:
await writer.write(data)
async def load(self, pathname: str | Path):
async with aiofiles.open(str(pathname), mode="r", encoding="utf-8") as reader:
data = await reader.read(-1)
m = json.loads(data)
self._repo = networkx.node_link_graph(m)
@staticmethod
async def load_from(pathname: str | Path) -> GraphRepository:
pathname = Path(pathname)
name = pathname.with_suffix("").name
root = pathname.parent
graph = DiGraphRepository(name=name, root=root)
if pathname.exists():
await graph.load(pathname=pathname)
return graph
@property
def root(self) -> str:
return self._kwargs.get("root")
@property
def pathname(self) -> Path:
p = Path(self.root) / self.name
return p.with_suffix(".json")

View file

@ -0,0 +1,150 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/19
@Author : mashenquan
@File : graph_repository.py
@Desc : Superclass for graph repository.
"""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List
from pydantic import BaseModel
from metagpt.repo_parser import ClassInfo, RepoFileInfo
from metagpt.utils.common import concat_namespace
class GraphKeyword:
IS = "is"
CLASS = "class"
FUNCTION = "function"
SOURCE_CODE = "source_code"
NULL = "<null>"
GLOBAL_VARIABLE = "global_variable"
CLASS_FUNCTION = "class_function"
CLASS_PROPERTY = "class_property"
HAS_CLASS = "has_class"
HAS_PAGE_INFO = "has_page_info"
HAS_CLASS_VIEW = "has_class_view"
HAS_SEQUENCE_VIEW = "has_sequence_view"
HAS_ARGS_DESC = "has_args_desc"
HAS_TYPE_DESC = "has_type_desc"
class SPO(BaseModel):
subject: str
predicate: str
object_: str
class GraphRepository(ABC):
def __init__(self, name: str, **kwargs):
self._repo_name = name
self._kwargs = kwargs
@abstractmethod
async def insert(self, subject: str, predicate: str, object_: str):
pass
@abstractmethod
async def upsert(self, subject: str, predicate: str, object_: str):
pass
@abstractmethod
async def update(self, subject: str, predicate: str, object_: str):
pass
@abstractmethod
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
pass
@property
def name(self) -> str:
return self._repo_name
@staticmethod
async def update_graph_db_with_file_info(graph_db: "GraphRepository", file_info: RepoFileInfo):
await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
file_types = {".py": "python", ".js": "javascript"}
file_type = file_types.get(Path(file_info.file).suffix, GraphKeyword.NULL)
await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=file_type)
for c in file_info.classes:
class_name = c.get("name", "")
await graph_db.insert(
subject=file_info.file,
predicate=GraphKeyword.HAS_CLASS,
object_=concat_namespace(file_info.file, class_name),
)
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS,
)
methods = c.get("methods", [])
for fn in methods:
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name, fn),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS_FUNCTION,
)
for f in file_info.functions:
await graph_db.insert(
subject=concat_namespace(file_info.file, f), predicate=GraphKeyword.IS, object_=GraphKeyword.FUNCTION
)
for g in file_info.globals:
await graph_db.insert(
subject=concat_namespace(file_info.file, g),
predicate=GraphKeyword.IS,
object_=GraphKeyword.GLOBAL_VARIABLE,
)
for code_block in file_info.page_info:
if code_block.tokens:
await graph_db.insert(
subject=concat_namespace(file_info.file, *code_block.tokens),
predicate=GraphKeyword.HAS_PAGE_INFO,
object_=code_block.json(ensure_ascii=False),
)
for k, v in code_block.properties.items():
await graph_db.insert(
subject=concat_namespace(file_info.file, k, v),
predicate=GraphKeyword.HAS_PAGE_INFO,
object_=code_block.json(ensure_ascii=False),
)
@staticmethod
async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[ClassInfo]):
for c in class_views:
filename, class_name = c.package.split(":", 1)
await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
file_types = {".py": "python", ".js": "javascript"}
file_type = file_types.get(Path(filename).suffix, GraphKeyword.NULL)
await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=file_type)
await graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_CLASS, object_=class_name)
await graph_db.insert(
subject=c.package,
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS,
)
for vn, vt in c.attributes.items():
await graph_db.insert(
subject=concat_namespace(c.package, vn),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS_PROPERTY,
)
await graph_db.insert(
subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt
)
for fn, desc in c.methods.items():
await graph_db.insert(
subject=concat_namespace(c.package, fn),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS_FUNCTION,
)
await graph_db.insert(
subject=concat_namespace(c.package, fn),
predicate=GraphKeyword.HAS_ARGS_DESC,
object_=desc,
)

View file

@ -18,17 +18,15 @@ from metagpt.config import CONFIG
def make_sk_kernel():
kernel = sk.Kernel()
if CONFIG.openai_api_type == "azure":
if CONFIG.OPENAI_API_TYPE == "azure":
kernel.add_chat_service(
"chat_completion",
AzureChatCompletion(
deployment_name=CONFIG.deployment_name, endpoint=CONFIG.openai_base_url, api_key=CONFIG.openai_api_key
),
AzureChatCompletion(CONFIG.DEPLOYMENT_NAME, CONFIG.OPENAI_BASE_URL, CONFIG.OPENAI_API_KEY),
)
else:
kernel.add_chat_service(
"chat_completion",
OpenAIChatCompletion(model_id=CONFIG.openai_api_model, api_key=CONFIG.openai_api_key),
OpenAIChatCompletion(CONFIG.OPENAI_API_MODEL, CONFIG.OPENAI_API_KEY),
)
return kernel

View file

@ -4,11 +4,14 @@
@Time : 2023/7/4 10:53
@Author : alexanderwu alitrack
@File : mermaid.py
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
import asyncio
import os
from pathlib import Path
import aiofiles
from metagpt.config import CONFIG
from metagpt.const import METAGPT_ROOT
from metagpt.logs import logger
@ -29,7 +32,9 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name)
tmp = Path(f"{output_file_without_suffix}.mmd")
tmp.write_text(mermaid_code, encoding="utf-8")
async with aiofiles.open(tmp, "w", encoding="utf-8") as f:
await f.write(mermaid_code)
# tmp.write_text(mermaid_code, encoding="utf-8")
engine = CONFIG.mermaid_engine.lower()
if engine == "nodejs":
@ -88,7 +93,8 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
return 0
MMC1 = """classDiagram
MMC1 = """
classDiagram
class Main {
-SearchEngine search_engine
+main() str
@ -118,9 +124,11 @@ MMC1 = """classDiagram
SearchEngine --> Index
SearchEngine --> Ranking
SearchEngine --> Summary
Index --> KnowledgeBase"""
Index --> KnowledgeBase
"""
MMC2 = """sequenceDiagram
MMC2 = """
sequenceDiagram
participant M as Main
participant SE as SearchEngine
participant I as Index
@ -136,11 +144,11 @@ MMC2 = """sequenceDiagram
R-->>SE: return ranked_results
SE->>S: summarize_results(ranked_results)
S-->>SE: return summary
SE-->>M: return summary"""
SE-->>M: return summary
"""
if __name__ == "__main__":
loop = asyncio.new_event_loop()
result = loop.run_until_complete(mermaid_to_file(MMC1, METAGPT_ROOT / f"{CONFIG.mermaid_engine}/1"))
result = loop.run_until_complete(mermaid_to_file(MMC2, METAGPT_ROOT / f"{CONFIG.mermaid_engine}/1"))
result = loop.run_until_complete(mermaid_to_file(MMC2, METAGPT_ROOT / f"{CONFIG.mermaid_engine}/2"))
loop.close()

219
metagpt/utils/redis.py Normal file
View file

@ -0,0 +1,219 @@
# !/usr/bin/python3
# -*- coding: utf-8 -*-
# @Author: Hui
# @Desc: { redis client }
# @Date: 2022/11/28 10:12
import json
import traceback
from datetime import timedelta
from enum import Enum
from typing import Awaitable, Callable, Dict, Optional, Union
from redis import asyncio as aioredis
from metagpt.config import CONFIG
from metagpt.logs import logger
class RedisTypeEnum(Enum):
"""Redis 数据类型"""
String = "String"
List = "List"
Hash = "Hash"
Set = "Set"
ZSet = "ZSet"
def make_url(
dialect: str,
*,
user: Optional[str] = None,
password: Optional[str] = None,
host: Optional[str] = None,
port: Optional[Union[str, int]] = None,
name: Optional[Union[str, int]] = None,
) -> str:
url_parts = [f"{dialect}://"]
if user or password:
if user:
url_parts.append(user)
if password:
url_parts.append(f":{password}")
url_parts.append("@")
if not host and not dialect.startswith("sqlite"):
host = "127.0.0.1"
if host:
url_parts.append(f"{host}")
if port:
url_parts.append(f":{port}")
# 比如redis可能传入0
if name is not None:
url_parts.append(f"/{name}")
return "".join(url_parts)
class RedisAsyncClient(aioredis.Redis):
"""异步的客户端
例子::
rdb = RedisAsyncClient()
print(rdb.url)
Args:
host: 服务器地址
port: 服务器端口
user: 用户名
db: 数据库
password: 密码
decode_responses: 字符串输入被编码成utf8存储在Redis里了而取出来的时候还是被编码后的bytes需要显示的decode才能变成字符串
health_check_interval: 定时检测连接防止出现ConnectionErrors (104, Connection reset by peer)
"""
def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: int = 0,
password: str = None,
decode_responses=True,
health_check_interval=10,
socket_connect_timeout=5,
retry_on_timeout=True,
socket_keepalive=True,
**kwargs,
):
super().__init__(
host=host,
port=port,
db=db,
password=password,
decode_responses=decode_responses,
health_check_interval=health_check_interval,
socket_connect_timeout=socket_connect_timeout,
retry_on_timeout=retry_on_timeout,
socket_keepalive=socket_keepalive,
**kwargs,
)
self.url = make_url("redis", host=host, port=port, name=db, password=password)
class RedisCacheInfo(object):
"""统一缓存信息类"""
def __init__(self, key, timeout: Union[int, timedelta] = timedelta(seconds=60), data_type=RedisTypeEnum.String):
"""
缓存信息类初始化
Args:
key: 缓存的key
timeout: 缓存过期时间, 单位秒
data_type: 缓存采用的数据结构 (不传并不影响用于标记业务采用的是什么数据结构)
"""
self.key = key
self.timeout = timeout
self.data_type = data_type
def __str__(self):
return f"cache key {self.key} timeout {self.timeout}s"
class RedisManager:
client: RedisAsyncClient = None
@classmethod
def init_redis_conn(cls, host, port, password, db):
"""初始化redis 连接"""
if cls.client is None:
cls.client = RedisAsyncClient(host=host, port=port, password=password, db=db)
@classmethod
async def set_with_cache_info(cls, redis_cache_info: RedisCacheInfo, value):
"""
根据 RedisCacheInfo 设置 Redis 缓存
:param redis_cache_info: RedisCacheInfo缓存信息对象
:param value: 缓存的值
:return:
"""
await cls.client.setex(redis_cache_info.key, redis_cache_info.timeout, value)
@classmethod
async def get_with_cache_info(cls, redis_cache_info: RedisCacheInfo):
"""
根据 RedisCacheInfo 获取 Redis 缓存
:param redis_cache_info: RedisCacheInfo 缓存信息对象
:return:
"""
cache_info = await cls.client.get(redis_cache_info.key)
return cache_info
@classmethod
async def del_with_cache_info(cls, redis_cache_info: RedisCacheInfo):
"""
根据 RedisCacheInfo 删除 Redis 缓存
:param redis_cache_info: RedisCacheInfo缓存信息对象
:return:
"""
await cls.client.delete(redis_cache_info.key)
@staticmethod
async def get_or_set_cache(cache_info: RedisCacheInfo, fetch_data_func: Callable[[], Awaitable[dict]]) -> dict:
"""
获取缓存数据如果缓存不存在则从提供的函数中获取并设置缓存
当前版本仅支持 json 形式的 string 格式数据
"""
serialized_data = await RedisManager.get_with_cache_info(cache_info)
if serialized_data:
return json.loads(serialized_data)
data = await fetch_data_func()
try:
serialized_data = json.dumps(data)
await RedisManager.set_with_cache_info(cache_info, serialized_data)
except Exception as e:
logger.warning(f"数据 {data} 通过 json 进行序列化缓存失败:{e}")
return data
@classmethod
def is_valid(cls):
return cls.client is not None
class Redis:
def __init__(self, conf: Dict = None):
try:
host = CONFIG.REDIS_HOST
port = int(CONFIG.REDIS_PORT)
pwd = CONFIG.REDIS_PASSWORD
db = CONFIG.REDIS_DB
RedisManager.init_redis_conn(host=host, port=port, password=pwd, db=db)
except Exception as e:
logger.warning(f"Redis initialization has failed:{e}")
def is_valid(self):
return RedisManager.is_valid()
async def get(self, key: str) -> str:
if not self.is_valid() or not key:
return None
try:
v = await RedisManager.get_with_cache_info(redis_cache_info=RedisCacheInfo(key=key))
return v
except Exception as e:
logger.exception(f"{e}, stack:{traceback.format_exc()}")
return None
async def set(self, key: str, data: str, timeout_sec: int):
if not self.is_valid() or not key:
return
try:
await RedisManager.set_with_cache_info(
redis_cache_info=RedisCacheInfo(key=key, timeout=timeout_sec), value=data
)
except Exception as e:
logger.exception(f"{e}, stack:{traceback.format_exc()}")

View file

@ -230,9 +230,11 @@ def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["R
elif retry_state.kwargs:
func_param_output = retry_state.kwargs.get("output", "")
exp_str = str(retry_state.outcome.exception())
fix_str = "try to fix it, " if CONFIG.repair_llm_output else ""
logger.warning(
f"parse json from content inside [CONTENT][/CONTENT] failed at retry "
f"{retry_state.attempt_number}, try to fix it, exp: {exp_str}"
f"{retry_state.attempt_number}, {fix_str}exp: {exp_str}"
)
repaired_output = repair_invalid_json(func_param_output, exp_str)

170
metagpt/utils/s3.py Normal file
View file

@ -0,0 +1,170 @@
import base64
import os.path
import traceback
import uuid
from pathlib import Path
from typing import Optional
import aioboto3
import aiofiles
from metagpt.config import CONFIG
from metagpt.const import BASE64_FORMAT
from metagpt.logs import logger
class S3:
"""A class for interacting with Amazon S3 storage."""
def __init__(self):
self.session = aioboto3.Session()
self.auth_config = {
"service_name": "s3",
"aws_access_key_id": CONFIG.S3_ACCESS_KEY,
"aws_secret_access_key": CONFIG.S3_SECRET_KEY,
"endpoint_url": CONFIG.S3_ENDPOINT_URL,
}
async def upload_file(
self,
bucket: str,
local_path: str,
object_name: str,
) -> None:
"""Upload a file from the local path to the specified path of the storage bucket specified in s3.
Args:
bucket: The name of the S3 storage bucket.
local_path: The local file path, including the file name.
object_name: The complete path of the uploaded file to be stored in S3, including the file name.
Raises:
Exception: If an error occurs during the upload process, an exception is raised.
"""
try:
async with self.session.client(**self.auth_config) as client:
async with aiofiles.open(local_path, mode="rb") as reader:
body = await reader.read()
await client.put_object(Body=body, Bucket=bucket, Key=object_name)
logger.info(f"Successfully uploaded the file to path {object_name} in bucket {bucket} of s3.")
except Exception as e:
logger.error(f"Failed to upload the file to path {object_name} in bucket {bucket} of s3: {e}")
raise e
async def get_object_url(
self,
bucket: str,
object_name: str,
) -> str:
"""Get the URL for a downloadable or preview file stored in the specified S3 bucket.
Args:
bucket: The name of the S3 storage bucket.
object_name: The complete path of the file stored in S3, including the file name.
Returns:
The URL for the downloadable or preview file.
Raises:
Exception: If an error occurs while retrieving the URL, an exception is raised.
"""
try:
async with self.session.client(**self.auth_config) as client:
file = await client.get_object(Bucket=bucket, Key=object_name)
return str(file["Body"].url)
except Exception as e:
logger.error(f"Failed to get the url for a downloadable or preview file: {e}")
raise e
async def get_object(
self,
bucket: str,
object_name: str,
) -> bytes:
"""Get the binary data of a file stored in the specified S3 bucket.
Args:
bucket: The name of the S3 storage bucket.
object_name: The complete path of the file stored in S3, including the file name.
Returns:
The binary data of the requested file.
Raises:
Exception: If an error occurs while retrieving the file data, an exception is raised.
"""
try:
async with self.session.client(**self.auth_config) as client:
s3_object = await client.get_object(Bucket=bucket, Key=object_name)
return await s3_object["Body"].read()
except Exception as e:
logger.error(f"Failed to get the binary data of the file: {e}")
raise e
async def download_file(
self, bucket: str, object_name: str, local_path: str, chunk_size: Optional[int] = 128 * 1024
) -> None:
"""Download an S3 object to a local file.
Args:
bucket: The name of the S3 storage bucket.
object_name: The complete path of the file stored in S3, including the file name.
local_path: The local file path where the S3 object will be downloaded.
chunk_size: The size of data chunks to read and write at a time. Default is 128 KB.
Raises:
Exception: If an error occurs during the download process, an exception is raised.
"""
try:
async with self.session.client(**self.auth_config) as client:
s3_object = await client.get_object(Bucket=bucket, Key=object_name)
stream = s3_object["Body"]
async with aiofiles.open(local_path, mode="wb") as writer:
while True:
file_data = await stream.read(chunk_size)
if not file_data:
break
await writer.write(file_data)
except Exception as e:
logger.error(f"Failed to download the file from S3: {e}")
raise e
async def cache(self, data: str, file_ext: str, format: str = "") -> str:
"""Save data to remote S3 and return url"""
object_name = uuid.uuid4().hex + file_ext
path = Path(__file__).parent
pathname = path / object_name
try:
async with aiofiles.open(str(pathname), mode="wb") as file:
if format == BASE64_FORMAT:
data = base64.b64decode(data)
await file.write(data)
bucket = CONFIG.S3_BUCKET
object_pathname = CONFIG.S3_BUCKET or "system"
object_pathname += f"/{object_name}"
object_pathname = os.path.normpath(object_pathname)
await self.upload_file(bucket=bucket, local_path=str(pathname), object_name=object_pathname)
pathname.unlink(missing_ok=True)
return await self.get_object_url(bucket=bucket, object_name=object_pathname)
except Exception as e:
logger.exception(f"{e}, stack:{traceback.format_exc()}")
pathname.unlink(missing_ok=True)
return None
@property
def is_valid(self):
is_invalid = (
not CONFIG.S3_ACCESS_KEY
or CONFIG.S3_ACCESS_KEY == "YOUR_S3_ACCESS_KEY"
or not CONFIG.S3_SECRET_KEY
or CONFIG.S3_SECRET_KEY == "YOUR_S3_SECRET_KEY"
or not CONFIG.S3_ENDPOINT_URL
or CONFIG.S3_ENDPOINT_URL == "YOUR_S3_ENDPOINT_URL"
or not CONFIG.S3_BUCKET
or CONFIG.S3_BUCKET == "YOUR_S3_BUCKET"
)
if is_invalid:
logger.info("S3 is invalid")
return not is_invalid

View file

@ -84,6 +84,13 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
elif "gpt-4" == model:
print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return count_message_tokens(messages, model="gpt-4-0613")
elif "open-llm-model" == model:
"""
For self-hosted open_llm api, they include lots of different models. The message tokens calculation is
inaccurate. It's a reference result.
"""
tokens_per_message = 0 # ignore conversation message template prefix
tokens_per_name = 0
else:
raise NotImplementedError(
f"num_tokens_from_messages() is not implemented for model {model}. "
@ -112,7 +119,11 @@ def count_string_tokens(string: str, model_name: str) -> int:
Returns:
int: The number of tokens in the text string.
"""
encoding = tiktoken.encoding_for_model(model_name)
try:
encoding = tiktoken.encoding_for_model(model_name)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
return len(encoding.encode(string))

View file

@ -31,11 +31,11 @@ tenacity==8.2.2
tiktoken==0.5.2
tqdm==4.64.0
#unstructured[local-inference]
# playwright
# selenium>4
# webdriver_manager<3.9
anthropic==0.3.6
typing-inspect==0.8.0
aiofiles
typing_extensions==4.7.0
libcst==1.0.1
qdrant-client==1.4.0
@ -44,9 +44,18 @@ pytest-mock==3.11.1
ta==0.10.2
semantic-kernel==0.4.0.dev0
wrapt==1.15.0
websocket-client==0.58.0
#aiohttp_jinja2
#azure-cognitiveservices-speech~=1.31.0
#aioboto3~=11.3.0
#redis==4.3.5
websocket-client==1.6.2
aiofiles==23.2.1
gitpython==3.1.40
zhipuai==1.0.7
socksio~=1.0.0
gitignore-parser==0.1.9
# connexion[swagger-ui]
websockets~=12.0
networkx~=3.2.1
google-generativeai==0.3.1
playwright==1.40.0

View file

@ -1,8 +1,6 @@
"""wutils: handy tools
"""
"""Setup script for MetaGPT."""
import subprocess
from codecs import open
from os import path
from pathlib import Path
from setuptools import Command, find_packages, setup
@ -20,13 +18,9 @@ class InstallMermaidCLI(Command):
print(f"Error occurred: {e.output}")
here = path.abspath(path.dirname(__file__))
with open(path.join(here, "README.md"), encoding="utf-8") as f:
long_description = f.read()
with open(path.join(here, "requirements.txt"), encoding="utf-8") as f:
requirements = [line.strip() for line in f if line]
here = Path(__file__).resolve().parent
long_description = (here / "README.md").read_text(encoding="utf-8")
requirements = (here / "requirements.txt").read_text(encoding="utf-8").splitlines()
setup(
name="metagpt",
@ -49,6 +43,8 @@ setup(
"search-ddg": ["duckduckgo-search==3.8.5"],
"pyppeteer": ["pyppeteer>=1.0.2"],
"ocr": ["paddlepaddle==2.4.2", "paddleocr>=2.0.1", "tabulate==0.9.0"],
"dev": ["pylint~=3.0.3", "black~=23.3.0", "isort~=5.12.0", "pre-commit~=3.6.0"],
"test": ["pytest", "pytest-cov", "pytest-asyncio", "pytest-mock"],
},
cmdclass={
"install_mermaid": InstallMermaidCLI,

View file

@ -13,17 +13,17 @@ from unittest.mock import Mock
import pytest
from metagpt.config import CONFIG
from metagpt.config import CONFIG, Config
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.openai_api import OpenAIGPTAPI as GPTAPI
from metagpt.utils.git_repository import GitRepository
class Context:
def __init__(self):
self._llm_ui = None
self._llm_api = GPTAPI()
self._llm_api = LLM(provider=CONFIG.get_default_llm_provider_enum())
@property
def llm_api(self):
@ -96,3 +96,8 @@ def setup_and_teardown_git_repo(request):
# Register the function for destroying the environment.
request.addfinalizer(fin)
@pytest.fixture(scope="session", autouse=True)
def init_config():
Config()

View file

@ -0,0 +1,143 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/24 20:32
@Author : alexanderwu
@File : mock_json.py
"""
PRD = {
"Language": "zh_cn",
"Programming Language": "Python",
"Original Requirements": "写一个简单的cli贪吃蛇",
"Project Name": "cli_snake",
"Product Goals": ["创建一个简单易用的贪吃蛇游戏", "提供良好的用户体验", "支持不同难度级别"],
"User Stories": [
"作为玩家,我希望能够选择不同的难度级别",
"作为玩家,我希望在每局游戏结束后能够看到我的得分",
"作为玩家,我希望在输掉游戏后能够重新开始",
"作为玩家,我希望看到简洁美观的界面",
"作为玩家,我希望能够在手机上玩游戏",
],
"Competitive Analysis": ["贪吃蛇游戏A界面简单缺乏响应式特性", "贪吃蛇游戏B美观且响应式的界面显示最高得分", "贪吃蛇游戏C响应式界面显示最高得分但有很多广告"],
"Competitive Quadrant Chart": 'quadrantChart\n title "Reach and engagement of campaigns"\n x-axis "Low Reach" --> "High Reach"\n y-axis "Low Engagement" --> "High Engagement"\n quadrant-1 "We should expand"\n quadrant-2 "Need to promote"\n quadrant-3 "Re-evaluate"\n quadrant-4 "May be improved"\n "Game A": [0.3, 0.6]\n "Game B": [0.45, 0.23]\n "Game C": [0.57, 0.69]\n "Game D": [0.78, 0.34]\n "Game E": [0.40, 0.34]\n "Game F": [0.35, 0.78]\n "Our Target Product": [0.5, 0.6]',
"Requirement Analysis": "",
"Requirement Pool": [["P0", "主要代码..."], ["P0", "游戏算法..."]],
"UI Design draft": "基本功能描述,简单的风格和布局。",
"Anything UNCLEAR": "",
}
DESIGN = {
"Implementation approach": "我们将使用Python编程语言并选择合适的开源框架来实现贪吃蛇游戏。我们将分析需求中的难点并选择合适的开源框架来简化开发流程。",
"File list": ["main.py", "game.py"],
"Data structures and interfaces": "\nclassDiagram\n class Game {\n -int width\n -int height\n -int score\n -int speed\n -List<Point> snake\n -Point food\n +__init__(width: int, height: int, speed: int)\n +start_game()\n +change_direction(direction: str)\n +game_over()\n +update_snake()\n +update_food()\n +check_collision()\n }\n class Point {\n -int x\n -int y\n +__init__(x: int, y: int)\n }\n Game --> Point\n",
"Program call flow": "\nsequenceDiagram\n participant M as Main\n participant G as Game\n M->>G: start_game()\n M->>G: change_direction(direction)\n G->>G: update_snake()\n G->>G: update_food()\n G->>G: check_collision()\n G-->>G: game_over()\n",
"Anything UNCLEAR": "",
}
TASKS = {
"Required Python packages": ["pygame==2.0.1"],
"Required Other language third-party packages": ["No third-party dependencies required"],
"Logic Analysis": [
["game.py", "Contains Game class and related functions for game logic"],
["main.py", "Contains the main function, imports Game class from game.py"],
],
"Task list": ["game.py", "main.py"],
"Full API spec": "",
"Shared Knowledge": "'game.py' contains functions shared across the project.",
"Anything UNCLEAR": "",
}
FILE_GAME = """## game.py
import pygame
import random
class Point:
def __init__(self, x: int, y: int):
self.x = x
self.y = y
class Game:
def __init__(self, width: int, height: int, speed: int):
self.width = width
self.height = height
self.score = 0
self.speed = speed
self.snake = [Point(width // 2, height // 2)]
self.food = self._create_food()
def start_game(self):
pygame.init()
self._display = pygame.display.set_mode((self.width, self.height))
pygame.display.set_caption('Snake Game')
self._clock = pygame.time.Clock()
self._running = True
while self._running:
self._handle_events()
self._update_snake()
self._update_food()
self._check_collision()
self._draw_screen()
self._clock.tick(self.speed)
def change_direction(self, direction: str):
# Update the direction of the snake based on user input
pass
def game_over(self):
# Display game over message and handle game over logic
pass
def _create_food(self) -> Point:
# Create and return a new food Point
return Point(random.randint(0, self.width - 1), random.randint(0, self.height - 1))
def _handle_events(self):
for event in pygame.event.get():
if event.type == pygame.QUIT:
self._running = False
def _update_snake(self):
# Update the position of the snake based on its direction
pass
def _update_food(self):
# Update the position of the food if the snake eats it
pass
def _check_collision(self):
# Check for collision between the snake and the walls or itself
pass
def _draw_screen(self):
self._display.fill((0, 0, 0)) # Clear the screen
# Draw the snake and food on the screen
pygame.display.update()
if __name__ == "__main__":
game = Game(800, 600, 15)
game.start_game()
"""
FILE_GAME_CR_1 = """## Code Review: game.py
1. Yes, the code is implemented as per the requirements. It initializes the game with the specified width, height, and speed, and starts the game loop.
2. No, the logic for handling events and updating the snake, food, and collision is not implemented. To correct this, we need to implement the logic for handling events, updating the snake and food positions, and checking for collisions.
3. Yes, the existing code follows the "Data structures and interfaces" by defining the Game and Point classes with the specified attributes and methods.
4. No, several functions such as change_direction, game_over, _update_snake, _update_food, and _check_collision are not implemented. These functions need to be implemented to complete the game logic.
5. Yes, all necessary pre-dependencies have been imported. The required pygame package is imported at the beginning of the file.
6. No, methods from other files are not being reused as there are no other files being imported or referenced in the current code.
## Actions
1. Implement the logic for handling events, updating the snake and food positions, and checking for collisions within the Game class.
2. Implement the change_direction and game_over methods to handle user input and game over logic.
3. Implement the _update_snake method to update the position of the snake based on its direction.
4. Implement the _update_food method to update the position of the food if the snake eats it.
5. Implement the _check_collision method to check for collision between the snake and the walls or itself.
## Code Review Result
LBTM"""

View file

@ -3,7 +3,7 @@
"""
@Time : 2023/5/18 23:51
@Author : alexanderwu
@File : mock.py
@File : mock_markdown.py
"""
PRD_SAMPLE = """## Original Requirements

View file

@ -5,9 +5,16 @@
@Author : alexanderwu
@File : test_action.py
"""
from metagpt.actions import Action, WritePRD, WriteTest
from metagpt.actions import Action, ActionType, WritePRD, WriteTest
def test_action_repr():
actions = [Action(), WriteTest(), WritePRD()]
assert "WriteTest" in str(actions)
def test_action_type():
assert ActionType.WRITE_PRD.value == WritePRD
assert ActionType.WRITE_TEST.value == WriteTest
assert ActionType.WRITE_PRD.name == "WRITE_PRD"
assert ActionType.WRITE_TEST.name == "WRITE_TEST"

View file

@ -5,6 +5,8 @@
@Author : alexanderwu
@File : test_action_node.py
"""
from typing import List, Tuple
import pytest
from metagpt.actions import Action
@ -29,7 +31,7 @@ async def test_debate_two_roles():
team = Team(investment=10.0, env=env, roles=[biden, trump])
history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Biden", n_round=3)
assert "BidenSay" in history
assert "Biden" in history
@pytest.mark.asyncio
@ -39,7 +41,7 @@ async def test_debate_one_role_in_env():
env = Environment(desc="US election live broadcast")
team = Team(investment=10.0, env=env, roles=[biden])
history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Biden", n_round=3)
assert "Debate" in history
assert "Biden" in history
@pytest.mark.asyncio
@ -86,3 +88,47 @@ async def test_action_node_two_layer():
assert node_b in root.children.values()
json_template = root.compile(context="123", schema="json", mode="auto")
assert "i-a" in json_template
t_dict = {
"Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n',
"Required Other language third-party packages": '"""\nNo third-party packages required for other languages.\n"""\n',
"Full API spec": '"""\nopenapi: 3.0.0\ninfo:\n title: Web Snake Game API\n version: 1.0.0\npaths:\n /game:\n get:\n summary: Get the current game state\n responses:\n \'200\':\n description: A JSON object of the game state\n post:\n summary: Send a command to the game\n requestBody:\n required: true\n content:\n application/json:\n schema:\n type: object\n properties:\n command:\n type: string\n responses:\n \'200\':\n description: A JSON object of the updated game state\n"""\n',
"Logic Analysis": [
["app.py", "Main entry point for the Flask application. Handles HTTP requests and responses."],
["game.py", "Contains the Game and Snake classes. Handles the game logic."],
["static/js/script.js", "Handles user interactions and updates the game UI."],
["static/css/styles.css", "Defines the styles for the game UI."],
["templates/index.html", "The main page of the web application. Displays the game UI."],
],
"Task list": ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"],
"Shared Knowledge": "\"\"\"\n'game.py' contains the Game and Snake classes which are responsible for the game logic. The Game class uses an instance of the Snake class.\n\n'app.py' is the main entry point for the Flask application. It creates an instance of the Game class and handles HTTP requests and responses.\n\n'static/js/script.js' is responsible for handling user interactions and updating the game UI based on the game state returned by 'app.py'.\n\n'static/css/styles.css' defines the styles for the game UI.\n\n'templates/index.html' is the main page of the web application. It displays the game UI and loads 'static/js/script.js' and 'static/css/styles.css'.\n\"\"\"\n",
"Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?",
}
WRITE_TASKS_OUTPUT_MAPPING = {
"Required Python third-party packages": (str, ...),
"Required Other language third-party packages": (str, ...),
"Full API spec": (str, ...),
"Logic Analysis": (List[Tuple[str, str]], ...),
"Task list": (List[str], ...),
"Shared Knowledge": (str, ...),
"Anything UNCLEAR": (str, ...),
}
def test_create_model_class():
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING)
assert test_class.__name__ == "test_class"
def test_create_model_class_with_mapping():
t = ActionNode.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING)
t1 = t(**t_dict)
value = t1.dict()["Task list"]
assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"]
if __name__ == "__main__":
test_create_model_class()
test_create_model_class_with_mapping()

View file

@ -1,53 +0,0 @@
#!/usr/bin/env python
# coding: utf-8
"""
@Time : 2023/7/11 10:49
@Author : chengmaoyu
@File : test_action_output
"""
from typing import List, Tuple
from metagpt.actions.action_node import ActionNode
t_dict = {
"Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n',
"Required Other language third-party packages": '"""\nNo third-party packages required for other languages.\n"""\n',
"Full API spec": '"""\nopenapi: 3.0.0\ninfo:\n title: Web Snake Game API\n version: 1.0.0\npaths:\n /game:\n get:\n summary: Get the current game state\n responses:\n \'200\':\n description: A JSON object of the game state\n post:\n summary: Send a command to the game\n requestBody:\n required: true\n content:\n application/json:\n schema:\n type: object\n properties:\n command:\n type: string\n responses:\n \'200\':\n description: A JSON object of the updated game state\n"""\n',
"Logic Analysis": [
["app.py", "Main entry point for the Flask application. Handles HTTP requests and responses."],
["game.py", "Contains the Game and Snake classes. Handles the game logic."],
["static/js/script.js", "Handles user interactions and updates the game UI."],
["static/css/styles.css", "Defines the styles for the game UI."],
["templates/index.html", "The main page of the web application. Displays the game UI."],
],
"Task list": ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"],
"Shared Knowledge": "\"\"\"\n'game.py' contains the Game and Snake classes which are responsible for the game logic. The Game class uses an instance of the Snake class.\n\n'app.py' is the main entry point for the Flask application. It creates an instance of the Game class and handles HTTP requests and responses.\n\n'static/js/script.js' is responsible for handling user interactions and updating the game UI based on the game state returned by 'app.py'.\n\n'static/css/styles.css' defines the styles for the game UI.\n\n'templates/index.html' is the main page of the web application. It displays the game UI and loads 'static/js/script.js' and 'static/css/styles.css'.\n\"\"\"\n",
"Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?",
}
WRITE_TASKS_OUTPUT_MAPPING = {
"Required Python third-party packages": (str, ...),
"Required Other language third-party packages": (str, ...),
"Full API spec": (str, ...),
"Logic Analysis": (List[Tuple[str, str]], ...),
"Task list": (List[str], ...),
"Shared Knowledge": (str, ...),
"Anything UNCLEAR": (str, ...),
}
def test_create_model_class():
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING)
assert test_class.__name__ == "test_class"
def test_create_model_class_with_mapping():
t = ActionNode.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING)
t1 = t(**t_dict)
value = t1.dict()["Task list"]
assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"]
if __name__ == "__main__":
test_create_model_class()
test_create_model_class_with_mapping()

View file

@ -1,6 +1,13 @@
import os
import tempfile
import pytest
from metagpt.actions.clone_function import CloneFunction, run_function_code
from metagpt.actions.clone_function import (
CloneFunction,
run_function_code,
run_function_script,
)
source_code = """
import pandas as pd
@ -55,3 +62,40 @@ async def test_clone_function():
assert not msg
expected_df = get_expected_res()
assert df.equals(expected_df)
def test_run_function_script():
# 创建一个临时文件并写入脚本内容
script_content = """def valid_function(arg1, arg2):\n return arg1 + arg2\n"""
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py", delete=False) as temp_file:
temp_file.write(script_content)
temp_file_path = temp_file.name
invalid_script_content = """def valid_function(arg1, arg2)\n return arg1 + arg2\n"""
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py", delete=False) as error_temp_file:
error_temp_file.write(invalid_script_content)
error_temp_file_path = error_temp_file.name
try:
# 正常情况下运行脚本
result, _ = run_function_script(temp_file_path, "valid_function", 1, arg2=2)
assert result == 3
# 不存在的脚本路径
with pytest.raises(FileNotFoundError):
run_function_script("nonexistent/path/script.py", "valid_function", 1, arg2=2)
# 无效的脚本内容
result, traceback = run_function_script(error_temp_file_path, "invalid_function", 1, arg2=2)
assert not result
assert "SyntaxError" in traceback
# 函数调用失败的情况
result, traceback = run_function_script(temp_file_path, "function_that_raises_exception", 1, arg2=2)
assert not result
assert "KeyError" in traceback
finally:
# 删除临时文件
if os.path.exists(temp_file_path):
os.remove(temp_file_path)

View file

@ -13,7 +13,7 @@ from metagpt.const import PRDS_FILE_REPO
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.file_repository import FileRepository
from tests.metagpt.actions.mock import PRD_SAMPLE
from tests.metagpt.actions.mock_markdown import PRD_SAMPLE
@pytest.mark.asyncio
@ -22,9 +22,9 @@ async def test_design_api():
for prd in inputs:
await FileRepository.save_file("new_prd.txt", content=prd, relative_path=PRDS_FILE_REPO)
design_api = WriteDesign("design_api")
design_api = WriteDesign()
result = await design_api.run([Message(content=prd, instruct_content=None)])
result = await design_api.run(Message(content=prd, instruct_content=None))
logger.info(result)
assert result

View file

@ -26,7 +26,7 @@ API列表:
"""
_ = "API设计看起来非常合理满足了PRD中的所有需求。"
design_api_review = DesignReview("design_api_review")
design_api_review = DesignReview()
result = await design_api_review.run(prd, api_design)

View file

@ -0,0 +1,17 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/25 22:38
@Author : alexanderwu
@File : test_fix_bug.py
"""
import pytest
from metagpt.actions.fix_bug import FixBug
@pytest.mark.asyncio
async def test_fix_bug():
fix_bug = FixBug()
assert fix_bug.name == "FixBug"

Some files were not shown because too many files have changed in this diff Show more