triple checked the translations

This commit is contained in:
brucemeek 2023-07-27 08:40:32 -05:00
parent 5e1dcd8757
commit 0ff252886d
9 changed files with 315 additions and 376 deletions

View file

@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Provide configuration as a singleton.
Provide configuration, singleton
"""
import os
import openai
@ -28,7 +28,7 @@ class NotConfiguredException(Exception):
class Config(metaclass=Singleton):
"""
Typical usage:
Common usage:
config = Config("config.yaml")
secret_key = config.get_key("MY_SECRET_KEY")
print("Secret key:", secret_key)
@ -40,7 +40,7 @@ class Config(metaclass=Singleton):
def __init__(self, yaml_file=default_yaml_file):
self._configs = {}
self._initialize_with_config_files_and_environment(self._configs, yaml_file)
self._init_with_config_files_and_env(self._configs, yaml_file)
logger.info("Config loading done.")
self.global_proxy = self._get("GLOBAL_PROXY")
self.openai_api_key = self._get("OPENAI_API_KEY")
@ -67,26 +67,26 @@ class Config(metaclass=Singleton):
self.google_api_key = self._get("GOOGLE_API_KEY")
self.google_cse_id = self._get("GOOGLE_CSE_ID")
self.search_engine = self._get("SEARCH_ENGINE", SearchEngineType.SERPAPI_GOOGLE)
self.web_browser_engine = WebBrowserEngineType(self._get("WEB_BROWSER_ENGINE", "playwright"))
self.playwright_browser_type = self._get("PLAYWRIGHT_BROWSER_TYPE", "chromium")
self.selenium_browser_type = self._get("SELENIUM_BROWSER_TYPE", "chrome")
self.long_term_memory = self._get('LONG_TERM_MEMORY', False)
if self.long_term_memory:
logger.warning("LONG_TERM_MEMORY is True")
self.max_budget = self._get("MAX_BUDGET", 10.0)
self.total_cost = 0.0
def _initialize_with_config_files_and_environment(self, configs: dict, yaml_file):
"""Load configurations from config/key.yaml, config/config.yaml, and the environment, in decreasing order of priority."""
def _init_with_config_files_and_env(self, configs: dict, yaml_file):
"""Load from config/key.yaml / config/config.yaml / env in decreasing priority"""
configs.update(os.environ)
for _yaml_file in [yaml_file, self.key_yaml_file]:
if not _yaml_file.exists():
continue
# Load local YAML files.
# Load local YAML file
with open(_yaml_file, "r", encoding="utf-8") as file:
yaml_data = yaml.safe_load(file)
if not yaml_data:
@ -98,7 +98,7 @@ class Config(metaclass=Singleton):
return self._configs.get(*args, **kwargs)
def get(self, key, *args, **kwargs):
"""Fetch a value from config/key.yaml, config/config.yaml, or the environment. Raises an error if not found."""
"""Find values from config/key.yaml / config/config.yaml / env, report an error if not found"""
value = self._get(key, *args, **kwargs)
if value is None:
raise ValueError(f"Key '{key}' not found in environment variables or in the YAML file")

View file

@ -9,7 +9,7 @@ from pathlib import Path
def get_project_root():
"""Search upwards level by level for the project root directory."""
"""Search upwards to find the project root directory."""
current_path = Path.cwd()
while True:
if (current_path / '.git').exists() or \

View file

@ -28,7 +28,7 @@ class FaissStore(LocalStore):
def _load(self) -> Optional["FaissStore"]:
index_file, store_file = self._get_index_and_store_fname()
if not (index_file.exists() and store_file.exists()):
logger.info("Missing at least one of index_file/store_file, load failed and return None")
logger.info("At least one of the index_file/store_file is missing. Loading failed and returns None.")
return None
index = faiss.read_index(str(index_file))
with open(str(store_file), "rb") as f:

View file

@ -19,9 +19,8 @@ type_mapping = {
np.ndarray: DataType.FLOAT_VECTOR
}
def columns_to_milvus_schema(columns: dict, primary_col_name: str = "", desc: str = ""):
"""Assuming the structure of columns is str: regular type."""
"""Assuming the structure of columns is str: standard type"""
fields = []
for col, ctype in columns.items():
if ctype == str:
@ -34,13 +33,11 @@ def columns_to_milvus_schema(columns: dict, primary_col_name: str = "", desc: st
schema = CollectionSchema(fields, description=desc)
return schema
class MilvusConnection(TypedDict):
alias: str
host: str
port: str
class MilvusStore(BaseStore):
"""
FIXME: ADD TESTS
@ -79,8 +76,8 @@ class MilvusStore(BaseStore):
"""
FIXME: ADD TESTS
https://milvus.io/docs/v2.0.x/search.md
All search and query operations within Milvus are executed in memory. Load the collection to memory before conducting a vector similarity search.
Noting the above description, is this logic serious? The time taken for this should be long, right?
All search and query operations within Milvus are executed in memory. Load the collection into memory before conducting a vector similarity search.
Noting the above description, is this logic serious? This should be time-consuming, right?
"""
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
results = self.collection.search(
@ -91,7 +88,7 @@ class MilvusStore(BaseStore):
expr=None,
consistency_level="Strong"
)
# FIXME: results contains ids, but to get the actual values from the ids, the query interface still needs to be called.
# FIXME: results contain an id, but to get the actual value for the id, you still need to call the query interface
return results
def write(self, name, schema, *args, **kwargs):

View file

@ -20,71 +20,72 @@ summary. Pick a suitable emoji for every bullet point. Your response should be i
a YouTube video, use the following text: {{CONTENT}}.
"""
# From GCP-VertexAI-Text Summarization
# From GCP-VertexAI-Text Summarization (SUMMARIZE_PROMPT_2-5 are all from this)
# https://github.com/GoogleCloudPlatform/generative-ai/blob/main/language/examples/prompt-design/text_summarization.ipynb
# For longer documents, a map-reduce process is needed, see the following notebook
# Long documents need a map-reduce process, see the following notebook
# https://github.com/GoogleCloudPlatform/generative-ai/blob/main/language/examples/document-summarization/summarization_large_documents.ipynb
SUMMARIZE_PROMPT_2 = """
Provide a very short summary, no more than three sentences, for the following article:
Our quantum computers work by manipulating qubits in a manner we call quantum algorithms.
The challenge is that qubits are extremely sensitive, to the extent that even stray light can introduce calculation errors a problem that intensifies as quantum computers scale.
This has notable ramifications since the most effective quantum algorithms we know for executing valuable applications necessitate that our qubits' error rates be significantly lower than current levels.
To address this discrepancy, quantum error correction is essential.
Quantum error correction safeguards information by distributing it over several physical qubits, forming a logical qubit. This is believed to be the sole method to create a large-scale quantum computer with sufficiently low error rates for practical calculations.
Rather than computing on individual qubits, we will utilize logical qubits. By transforming a greater number of physical qubits on our quantum processor into a single logical qubit, we aim to reduce error rates, enabling viable quantum algorithms.
Our quantum computers work by manipulating qubits in an orchestrated fashion that we call quantum algorithms.
The challenge is that qubits are so sensitive that even stray light can cause calculation errors and the problem worsens as quantum computers grow.
This has significant consequences, since the best quantum algorithms that we know for running useful applications require the error rates of our qubits to be far lower than we have today.
To bridge this gap, we will need quantum error correction.
Quantum error correction protects information by encoding it across multiple physical qubits to form a logical qubit, and is believed to be the only way to produce a large-scale quantum computer with error rates low enough for useful calculations.
Instead of computing on the individual qubits themselves, we will then compute on logical qubits. By encoding larger numbers of physical qubits on our quantum processor into one logical qubit, we hope to reduce the error rates to enable useful quantum algorithms.
Summary:
"""
SUMMARIZE_PROMPT_3 = """
Provide a TL;DR for the following article:
Our quantum computers operate by controlling qubits in a method termed quantum algorithms.
The problem is that qubits are incredibly delicate, so much so that even minimal light interference can introduce computational errors and this issue becomes more pronounced as quantum computers expand.
This is consequential because the most potent quantum algorithms we are aware of, for practical applications, demand that our qubits' error rates be substantially below current standards.
To mitigate this, quantum error correction is pivotal.
Quantum error correction secures data by distributing it across numerous physical qubits, generating a logical qubit. It's believed to be the exclusive approach to develop a large-scale quantum computer with error rates low enough for practical operations.
Instead of operations on individual qubits, we'll focus on logical qubits. By encoding a greater number of physical qubits on our quantum device into a single logical qubit, we aspire to diminish error rates and enable efficient quantum algorithms.
Our quantum computers work by manipulating qubits in an orchestrated fashion that we call quantum algorithms.
The challenge is that qubits are so sensitive that even stray light can cause calculation errors and the problem worsens as quantum computers grow.
This has significant consequences, since the best quantum algorithms that we know for running useful applications require the error rates of our qubits to be far lower than we have today.
To bridge this gap, we will need quantum error correction.
Quantum error correction protects information by encoding it across multiple physical qubits to form a logical qubit, and is believed to be the only way to produce a large-scale quantum computer with error rates low enough for useful calculations.
Instead of computing on the individual qubits themselves, we will then compute on logical qubits. By encoding larger numbers of physical qubits on our quantum processor into one logical qubit, we hope to reduce the error rates to enable useful quantum algorithms.
TL;DR:
"""
SUMMARIZE_PROMPT_4 = """
Provide a very short summary in four bullet points for the following article:
Our quantum computers function by manipulating qubits through a method known as quantum algorithms.
The dilemma is that qubits are exceedingly fragile, so much so that even minimal light can lead to computational inaccuracies and this problem amplifies as quantum computers become larger.
This is significant because the most proficient quantum algorithms known to us, suitable for real-world applications, necessitate that our qubits' error rates be significantly below what we currently observe.
To bridge this disparity, quantum error correction becomes indispensable.
Quantum error correction secures data by spreading it across multiple physical qubits, resulting in a logical qubit. It's perceived as the only technique to manufacture a large-scale quantum computer with error rates sufficiently low for practical tasks.
Instead of operating on individual qubits directly, we'll be utilizing logical qubits. By converting more physical qubits on our quantum machine into a single logical qubit, we intend to lower error rates, facilitating effective quantum algorithms.
Our quantum computers work by manipulating qubits in an orchestrated fashion that we call quantum algorithms.
The challenge is that qubits are so sensitive that even stray light can cause calculation errors and the problem worsens as quantum computers grow.
This has significant consequences, since the best quantum algorithms that we know for running useful applications require the error rates of our qubits to be far lower than we have today.
To bridge this gap, we will need quantum error correction.
Quantum error correction protects information by encoding it across multiple physical qubits to form a logical qubit, and is believed to be the only way to produce a large-scale quantum computer with error rates low enough for useful calculations.
Instead of computing on the individual qubits themselves, we will then compute on logical qubits. By encoding larger numbers of physical qubits on our quantum processor into one logical qubit, we hope to reduce the error rates to enable useful quantum algorithms.
Bulletpoints:
"""
SUMMARIZE_PROMPT_5 = """
Please generate a summary of the following conversation and at the end summarize the to-do's for the support Agent:
Customer: Hi, I'm Larry, and I received the wrong item.
Support Agent: Hi, Larry. How would you like this issue to be resolved?
Support Agent: Hi, Larry. How would you like to see this resolved?
Customer: That's alright. I'd like to return the item and receive a refund, please.
Customer: That's alright. I want to return the item and get a refund, please.
Support Agent: Certainly. I can process the refund for you right now. Could I have your order number, please?
Support Agent: Of course. I can process the refund for you now. Can I have your order number, please?
Customer: It's [ORDER NUMBER].
Support Agent: Thanks. I've processed the refund, and you should receive your funds within 14 days.
Support Agent: Thank you. I've processed the refund, and you will receive your money back within 14 days.
Customer: I appreciate it.
Customer: Thank you very much.
Support Agent: You're welcome, Larry. Have a great day!
Support Agent: You're welcome, Larry. Have a good day!
Summary:
"""
# - def summarize(doc: str) -> str # Input a document and receive a summary.
"""

View file

@ -1,243 +1,187 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/5 23:08
@Time : 2023/5/11 14:43
@Author : alexanderwu
@File : openai.py
@File : engineer.py
"""
import asyncio
import time
from functools import wraps
from typing import NamedTuple
import shutil
from collections import OrderedDict
from pathlib import Path
import openai
from metagpt.config import CONFIG
from metagpt.const import WORKSPACE_ROOT
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.utils.singleton import Singleton
from metagpt.utils.token_counter import (
TOKEN_COSTS,
count_message_tokens,
count_string_tokens,
)
from metagpt.roles import Role
from metagpt.actions import WriteCode, WriteCodeReview, WriteTasks, WriteDesign
from metagpt.schema import Message
from metagpt.utils.common import CodeParser
def retry(max_retries):
def decorator(f):
@wraps(f)
async def wrapper(*args, **kwargs):
for i in range(max_retries):
async def gather_ordered_k(coros, k) -> list:
tasks = OrderedDict()
results = [None] * len(coros)
done_queue = asyncio.Queue()
for i, coro in enumerate(coros):
if len(tasks) >= k:
done, _ = await asyncio.wait(tasks.keys(), return_when=asyncio.FIRST_COMPLETED)
for task in done:
index = tasks.pop(task)
await done_queue.put((index, task.result()))
task = asyncio.create_task(coro)
tasks[task] = i
if tasks:
done, _ = await asyncio.wait(tasks.keys())
for task in done:
index = tasks[task]
await done_queue.put((index, task.result()))
while not done_queue.empty():
index, result = await done_queue.get()
results[index] = result
return results
class Engineer(Role):
def __init__(self, name="Alex", profile="Engineer", goal="Write elegant, readable, extensible, efficient code",
constraints="The code you write should conform to code standard like PEP8, be modular, easy to read and maintain",
n_borg=1, use_code_review=False):
super().__init__(name, profile, goal, constraints)
self._init_actions([WriteCode])
self.use_code_review = use_code_review
if self.use_code_review:
self._init_actions([WriteCode, WriteCodeReview])
self._watch([WriteTasks])
self.todos = []
self.n_borg = n_borg
@classmethod
def parse_tasks(cls, task_msg: Message) -> list[str]:
if not task_msg.instruct_content:
return task_msg.instruct_content.dict().get("Task list")
return CodeParser.parse_file_list(block="Task list", text=task_msg.content)
@classmethod
def parse_code(cls, code_text: str) -> str:
return CodeParser.parse_code(block="", text=code_text)
@classmethod
def parse_workspace(cls, system_design_msg: Message) -> str:
if not system_design_msg.instruct_content:
return system_design_msg.instruct_content.dict().get("Python package name")
return CodeParser.parse_str(block="Python package name", text=system_design_msg.content)
def get_workspace(self) -> Path:
msg = self._rc.memory.get_by_action(WriteDesign)[-1]
if not msg:
return WORKSPACE_ROOT / 'src'
workspace = self.parse_workspace(msg)
# Codes are written in workspace/{package_name}/{package_name}
return WORKSPACE_ROOT / workspace / workspace
def recreate_workspace(self):
workspace = self.get_workspace()
try:
shutil.rmtree(workspace)
except FileNotFoundError:
pass # Directory does not exist, but we don't care
workspace.mkdir(parents=True, exist_ok=True)
def write_file(self, filename: str, code: str):
workspace = self.get_workspace()
file = workspace / filename
file.parent.mkdir(parents=True, exist_ok=True)
file.write_text(code)
def recv(self, message: Message) -> None:
self._rc.memory.add(message)
if message in self._rc.important_memory:
self.todos = self.parse_tasks(message)
async def _act_mp(self) -> Message:
# self.recreate_workspace()
todo_coros = []
for todo in self.todos:
todo_coro = WriteCode().run(
context=self._rc.memory.get_by_actions([WriteTasks, WriteDesign]),
filename=todo
)
todo_coros.append(todo_coro)
rsps = await gather_ordered_k(todo_coros, self.n_borg)
for todo, code_rsp in zip(self.todos, rsps):
_ = self.parse_code(code_rsp)
logger.info(todo)
logger.info(code_rsp)
# self.write_file(todo, code)
msg = Message(content=code_rsp, role=self.profile, cause_by=type(self._rc.todo))
self._rc.memory.add(msg)
del self.todos[0]
logger.info(f'Done {self.get_workspace()} generating.')
msg = Message(content="all done.", role=self.profile, cause_by=type(self._rc.todo))
return msg
async def _act_sp(self) -> Message:
for todo in self.todos:
code_rsp = await WriteCode().run(
context=self._rc.history,
filename=todo
)
# logger.info(todo)
# logger.info(code_rsp)
# code = self.parse_code(code_rsp)
self.write_file(todo, code_rsp)
msg = Message(content=code_rsp, role=self.profile, cause_by=type(self._rc.todo))
self._rc.memory.add(msg)
logger.info(f'Done {self.get_workspace()} generating.')
msg = Message(content="all done.", role=self.profile, cause_by=type(self._rc.todo))
return msg
async def _act_sp_precision(self) -> Message:
for todo in self.todos:
"""
# Select essential information from historical information to reduce prompt length (summarized from human experience)
1. All from Architect
2. All from ProjectManager
3. Do you need other codes (currently needed)?
TODO: The goal is not to need it. Once tasks are split clearly, according to the design idea, the code can be written clearly for each file without other codes. If it can't, it means that it still needs to be defined more clearly, this is the key to write long code.
"""
context = []
msg = self._rc.memory.get_by_actions([WriteDesign, WriteTasks, WriteCode])
for m in msg:
context.append(m.content)
context_str = "\n".join(context)
# Writing code
code = await WriteCode().run(
context=context_str,
filename=todo
)
# Code review
if self.use_code_review:
try:
return await f(*args, **kwargs)
except Exception:
if i == max_retries - 1:
raise
await asyncio.sleep(2 ** i)
return wrapper
return decorator
rewrite_code = await WriteCodeReview().run(
context=context_str,
code=code,
filename=todo
)
code = rewrite_code
except Exception as e:
logger.error("code review failed!", e)
pass
self.write_file(todo, code)
msg = Message(content=code, role=self.profile, cause_by=WriteCode)
self._rc.memory.add(msg)
logger.info(f'Done {self.get_workspace()} generating.')
msg = Message(content="all done.", role=self.profile, cause_by=WriteCode)
return msg
class RateLimiter:
"""Rate control class, each call goes through wait_if_needed, sleep if rate control is needed."""
def __init__(self, rpm):
self.last_call_time = 0
self.interval = 1.1 * 60 / rpm # Using 1.1 since strict adherence to time can still lead to QoS issues; consider simple error retry later.
self.rpm = rpm
def split_batches(self, batch):
return [batch[i:i + self.rpm] for i in range(0, len(batch), self.rpm)]
async def wait_if_needed(self, num_requests):
current_time = time.time()
elapsed_time = current_time - self.last_call_time
if elapsed_time < self.interval * num_requests:
remaining_time = self.interval * num_requests - elapsed_time
logger.info(f"sleep {remaining_time}")
await asyncio.sleep(remaining_time)
self.last_call_time = time.time()
class Costs(NamedTuple):
total_prompt_tokens: int
total_completion_tokens: int
total_cost: float
total_budget: float
class CostManager(metaclass=Singleton):
"""Calculate the cost of using the API."""
def __init__(self):
self.total_prompt_tokens = 0
self.total_completion_tokens = 0
self.total_cost = 0
self.total_budget = 0
def update_cost(self, prompt_tokens, completion_tokens, model):
"""
Update the total cost, prompt tokens, and completion tokens.
Args:
prompt_tokens (int): The number of tokens used in the prompt.
completion_tokens (int): The number of tokens used in the completion.
model (str): The model used for the API call.
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
cost = (
prompt_tokens * TOKEN_COSTS[model]["prompt"]
+ completion_tokens * TOKEN_COSTS[model]["completion"]
) / 1000
self.total_cost += cost
logger.info(f"Total running cost: ${self.total_cost:.3f} | Max budget: ${CONFIG.max_budget:.3f} | "
f"Current cost: ${cost:.3f}, {prompt_tokens=}, {completion_tokens=}")
CONFIG.total_cost = self.total_cost
def get_total_prompt_tokens(self):
"""Get the total number of prompt tokens."""
return self.total_prompt_tokens
def get_total_completion_tokens(self):
"""Get the total number of completion tokens."""
return self.total_completion_tokens
def get_total_cost(self):
"""Get the total cost of API calls."""
return self.total_cost
def get_costs(self) -> Costs:
"""Get all costs."""
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)
class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
"""
Check https://platform.openai.com/examples for examples.
"""
def __init__(self):
self.__init_openai(CONFIG)
self.llm = openai
self.model = CONFIG.openai_api_model
self._cost_manager = CostManager()
RateLimiter.__init__(self, rpm=self.rpm)
def __init_openai(self, config):
openai.api_key = config.openai_api_key
if config.openai_api_base:
openai.api_base = config.openai_api_base
if config.openai_api_type:
openai.api_type = config.openai_api_type
openai.api_version = config.openai_api_version
self.rpm = int(config.get("RPM", 10))
async def _achat_completion_stream(self, messages: list[dict]) -> str:
response = await openai.ChatCompletion.acreate(
**self._cons_kwargs(messages),
stream=True
)
# create variables to collect the stream of chunks
collected_chunks = []
collected_messages = []
# iterate through the stream of events
async for chunk in response:
collected_chunks.append(chunk) # save the event response
chunk_message = chunk['choices'][0]['delta'] # extract the message
collected_messages.append(chunk_message) # save the message
if "content" in chunk_message:
print(chunk_message["content"], end="")
print()
full_reply_content = ''.join([m.get('content', '') for m in collected_messages])
usage = self._calc_usage(messages, full_reply_content)
self._update_costs(usage)
return full_reply_content
def _cons_kwargs(self, messages: list[dict]) -> dict:
if CONFIG.openai_api_type == 'azure':
kwargs = {
"deployment_id": CONFIG.deployment_id,
"messages": messages,
"max_tokens": CONFIG.max_tokens_rsp,
"n": 1,
"stop": None,
"temperature": 0.3
}
else:
kwargs = {
"model": self.model,
"messages": messages,
"max_tokens": CONFIG.max_tokens_rsp,
"n": 1,
"stop": None,
"temperature": 0.3
}
return kwargs
async def _achat_completion(self, messages: list[dict]) -> dict:
rsp = await self.llm.ChatCompletion.acreate(**self._cons_kwargs(messages))
self._update_costs(rsp.get('usage'))
return rsp
def _chat_completion(self, messages: list[dict]) -> dict:
rsp = self.llm.ChatCompletion.create(**self._cons_kwargs(messages))
self._update_costs(rsp)
return rsp
def completion(self, messages: list[dict]) -> dict:
return self._chat_completion(messages)
async def acompletion(self, messages: list[dict]) -> dict:
return await self._achat_completion(messages)
@retry(max_retries=6)
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
"""When streaming, print each token in place."""
if stream:
return await self._achat_completion_stream(messages)
rsp = await self._achat_completion(messages)
return self.get_choice_text(rsp)
def _calc_usage(self, messages: list[dict], rsp: str) -> dict:
usage = {}
prompt_tokens = count_message_tokens(messages, self.model)
completion_tokens = count_string_tokens(rsp, self.model)
usage['prompt_tokens'] = prompt_tokens
usage['completion_tokens'] = completion_tokens
return usage
async def acompletion_batch(self, batch: list[list[dict]]) -> list[dict]:
"""Returns the full JSON."""
split_batches = self.split_batches(batch)
all_results = []
for small_batch in split_batches:
logger.info(small_batch)
await self.wait_if_needed(len(small_batch))
future = [self.acompletion(prompt) for prompt in small_batch]
results = await asyncio.gather(*future)
logger.info(results)
all_results.extend(results)
return all_results
async def acompletion_batch_text(self, batch: list[list[dict]]) -> list[str]:
"""Returns only plain text."""
raw_results = await self.acompletion_batch(batch)
results = []
for idx, raw_result in enumerate(raw_results, start=1):
result = self.get_choice_text(raw_result)
results.append(result)
logger.info(f"Result of task {idx}: {result}")
return results
def _update_costs(self, usage: dict):
prompt_tokens = int(usage['prompt_tokens'])
completion_tokens = int(usage['completion_tokens'])
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
def get_costs(self) -> Costs:
return self._cost_manager.get_costs()
async def _act(self) -> Message:
if self.use_code_review:
return await self._act_sp_precision()
return await self._act_sp()

View file

@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
# @Date : 2023/7/19 16:28
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
# @Description :
import os
import asyncio
from os.path import join
@ -55,6 +56,7 @@ payload = {
default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
class SDEngine:
def __init__(self):
# Initialize the SDEngine with configuration
@ -65,7 +67,7 @@ class SDEngine:
self.payload = payload
logger.info(self.sd_t2i_url)
def construct_payload(self, prompt, negtive_prompt=default_negative_prompt, width=512, height=512,
def construct_payload(self, prompt, negative_prompt=default_negative_prompt, width=512, height=512,
sd_model="galaxytimemachinesGTM_photoV20"):
# Configure the payload with provided inputs
self.payload["prompt"] = prompt
@ -101,11 +103,11 @@ class SDEngine:
return imgs
async def run_i2i(self):
# TODO: Add image-to-image API call
# TODO: Add a method to call the image-to-image interface
raise NotImplementedError
async def run_sam(self):
# TODO: Add SAM API call
# TODO: Add a method to call the SAM interface
raise NotImplementedError
def decode_base64_to_image(img, save_name):

View file

@ -6,18 +6,18 @@ from pathlib import Path
from metagpt.provider.openai_api import OpenAIGPTAPI as GPTAPI
ICL_SAMPLE = '''API Definition:
ICL_SAMPLE = '''Interface definition:
```text
API Name: Tag Elements
API Path: /projects/{project_key}/node-tags
Method: POST
Interface Name: Tag Elements
Interface Path: /projects/{project_key}/node-tags
Method: POST
Request Parameters:
Path Parameters:
project_key
Body Parameters:
Name Type Required Default Value Description
Name Type Required Default Value Description
nodes array Yes Nodes
node_key string No Node key
tags array No Original node tag list
@ -26,11 +26,11 @@ operations array Yes
tags array No Operation tag list
mode string No Operation type ADD / DELETE
Response Data:
Name Type Required Default Value Description
Return Data:
Name Type Required Default Value Description
code integer Yes Status code
msg string Yes Message
data object Yes Response data
data object Yes Return data
list array No Node list true / false
node_type string No Node type DATASET / RECIPE
node_key string No Node key
@ -43,7 +43,7 @@ Unit Test:
[
("project_key", [{"node_key": "dataset_001", "tags": ["tag1", "tag2"], "node_type": "DATASET"}], [{"tags": ["new_tag1"], "mode": "ADD"}], "success"),
("project_key", [{"node_key": "dataset_002", "tags": ["tag1", "tag2"], "node_type": "DATASET"}], [{"tags": ["tag1"], "mode": "DELETE"}], "success"),
("", [{"node_key": "dataset_001", "tags": ["tag1", "tag2"], "node_type": "DATASET"}], [{"tags": ["new_tag1"], "mode": "ADD"}], "Missing required parameter project_key"),
("", [{"node_key": "dataset_001", "tags": ["tag1", "tag2"], "node_type": "DATASET"}], [{"tags": ["new_tag1"], "mode": "ADD"}], "Missing necessary parameter project_key"),
(123, [{"node_key": "dataset_001", "tags": ["tag1", "tag2"], "node_type": "DATASET"}], [{"tags": ["new_tag1"], "mode": "ADD"}], "Incorrect parameter type"),
("project_key", [{"node_key": "a"*201, "tags": ["tag1", "tag2"], "node_type": "DATASET"}], [{"tags": ["new_tag1"], "mode": "ADD"}], "Request parameter exceeds field boundary")
]
@ -51,64 +51,67 @@ Unit Test:
def test_node_tags(project_key, nodes, operations, expected_msg):
pass
```
Above is an example of an API definition and a unit test sample.
Next, please play the role of a test manager from Google with 20 years of experience. After I provide the API definition, reply with the unit test. There are a few requirements:
1. Only output one '@pytest.mark.parametrize' and its corresponding 'test_<API_name>' function (with only a 'pass' statement inside, no implementation).
-- The function parameters should include 'expected_msg' for result validation.
2. The generated test cases should use shorter text or numbers and be as compact as possible.
The above is an example of interface definition and unit test.
Next, please act as an expert test manager with 20 years of experience at Google.
After I provide the interface definition, please reply with the unit test.
There are a few requirements:
1. Only output one `@pytest.mark.parametrize` and the corresponding test_<interface name> function
(with a pass inside, not implemented).
-- The function parameters should include expected_msg for result validation.
2. The generated test cases should use shorter text or numbers and be as concise as possible.
3. If comments are needed, use Chinese.
If you understand, please wait for me to provide the API definition and only reply with "Understood" to save tokens.
If you understand, please wait for me to provide the interface definition
and only reply with "Understood" to save tokens.
'''
ACT_PROMPT_PREFIX = '''Reference test types: such as missing request parameters, field boundary validation, incorrect field type.
ACT_PROMPT_PREFIX = '''Reference test types: such as missing request parameters, field boundary checks, incorrect field types.
Please output 10 test cases within a `@pytest.mark.parametrize` scope.
```text
'''
YFT_PROMPT_PREFIX = '''Reference test types: such as SQL injection, cross-site scripting (XSS), illegal access and unauthorized access, authentication and authorization, parameter validation, exception handling, file upload and download.
YFT_PROMPT_PREFIX = '''Reference test types: such as SQL injection, cross-site scripting (XSS), illegal access and unauthorized access, authentication and authorization, parameter verification, exception handling, file upload and download.
Please output 10 test cases within a `@pytest.mark.parametrize` scope.
```text
'''
OCR_API_DOC = '''```text
API Name: OCR Recognition
API Path: /api/v1/contract/treaty/task/ocr
Method: POST
API Name: OCR Recognition
API Path: /api/v1/contract/treaty/task/ocr
Method: POST
Request Parameters:
Path Parameters:
Body Parameters:
Name Type Required Default Value Remarks
file_id string Yes
box array Yes
contract_id number Yes Contract ID
start_time string No yyyy-mm-dd
end_time string No yyyy-mm-dd
extract_type number No Recognition type 1-During import 2-After import, default is 1
Name Type Mandatory Default Value Remarks
file_id string Yes
box array Yes
contract_id number Yes Contract ID
start_time string No yyyy-mm-dd
end_time string No yyyy-mm-dd
extract_type number No Recognition Type 1- During Import 2- After Import, Default is 1
Response Data:
Name Type Required Default Value Remarks
code integer Yes
message string Yes
data object Yes
Return Data:
Name Type Mandatory Default Value Remarks
code integer Yes
message string Yes
data object Yes
'''
class UTGenerator:
"""UT Generator: Construct UT through API documentation."""
"""UT Generator: Constructs UT from API documentation."""
def __init__(self, swagger_file: str, ut_py_path: str, questions_path: str,
chatgpt_method: str = "API", template_prefix=YFT_PROMPT_PREFIX) -> None:
"""Initialize the UT Generator.
"""Initializes the UT generator.
Args:
swagger_file: Path to the swagger file.
swagger_file: Path to the swagger.
ut_py_path: Path to store test cases.
questions_path: Path to store templates for future investigation.
chatgpt_method: API
template_prefix: Use template, default is YFT_UT_PROMPT.
questions_path: Path to store templates for further investigation.
chatgpt_method: API.
template_prefix: Use template, defaults to YFT_UT_PROMPT.
"""
self.swagger_file = swagger_file
self.ut_py_path = ut_py_path
@ -116,56 +119,56 @@ class UTGenerator:
assert chatgpt_method in ["API"], "Invalid chatgpt_method"
self.chatgpt_method = chatgpt_method
# ICL: In-Context Learning. Provide an example here for GPT to mimic.
# ICL: In-Context Learning. Here, an example is provided for GPT to follow.
self.icl_sample = ICL_SAMPLE
self.template_prefix = template_prefix
def get_swagger_json(self) -> dict:
"""Load Swagger JSON from a local file."""
"""Loads Swagger JSON from a local file."""
with open(self.swagger_file, "r", encoding="utf-8") as file:
swagger_json = json.load(file)
return swagger_json
def __parameter_to_string(self, prop, required, name=""):
def __para_to_str(self, prop, required, name=""):
name = name or prop["name"]
ptype = prop["type"]
title = prop.get("title", "")
desc = prop.get("description", "")
return f'{name}\t{ptype}\t{"Yes" if required else "No"}\t{title}\t{desc}'
def _parameter_to_string(self, prop):
def _para_to_str(self, prop):
required = prop.get("required", False)
return self.__parameter_to_string(prop, required)
return self.__para_to_str(prop, required)
def parameter_to_string(self, name, prop, prop_object_required):
def para_to_str(self, name, prop, prop_object_required):
required = name in prop_object_required
return self.__parameter_to_string(prop, required, name)
return self.__para_to_str(prop, required, name)
def build_object_properties(self, node, prop_object_required, level: int = 0) -> str:
"""Recursively output properties of object and array[object] types.
"""Recursively outputs properties of object and array[object] types.
Args:
node (_type_): Value of the sub-item.
prop_object_required (_type_): Indicates if it's a required item.
node: Value of the sub-item.
prop_object_required: Whether it's a required item.
level: Current recursion depth.
"""
doc = ""
def dive_into_object(node):
"""If it's an object type, recursively output its properties."""
"""If it's an object type, recursively outputs its properties."""
if node.get("type") == "object":
sub_properties = node.get("properties", {})
return self.build_object_properties(sub_properties, prop_object_required, level=level + 1)
return ""
if node.get("in", "") in ["query", "header", "formData"]:
doc += f'{"\t" * level}{self._parameter_to_string(node)}\n'
doc += f'{"\t" * level}{self._para_to_str(node)}\n'
doc += dive_into_object(node)
return doc
for name, prop in node.items():
doc += f'{"\t" * level}{self.parameter_to_string(name, prop, prop_object_required)}\n'
doc += f'{"\t" * level}{self.para_to_str(name, prop, prop_object_required)}\n'
doc += dive_into_object(prop)
if prop["type"] == "array":
items = prop.get("items", {})
@ -173,10 +176,10 @@ class UTGenerator:
return doc
def get_tags_mapping(self) -> dict:
"""Process tag and path.
"""Handles tag and path mapping.
Returns:
Dictionary: Correspondence of tag to path.
Dict: Mapping of tag to path.
"""
swagger_data = self.get_swagger_json()
paths = swagger_data["paths"]
@ -194,7 +197,7 @@ class UTGenerator:
return tags
def generate_ut(self, include_tags) -> bool:
"""Generate test case files."""
"""Generates test case files."""
tags = self.get_tags_mapping()
for tag, paths in tags.items():
if include_tags is None or tag in include_tags:
@ -204,19 +207,19 @@ class UTGenerator:
def build_api_doc(self, node: dict, path: str, method: str) -> str:
summary = node["summary"]
doc = f"API Name: {summary}\nAPI Path: {path}\nMethod: {method.upper()}\n"
doc += "\nRequest Parameters:\n"
doc = f"Interface name: {summary}\nInterface path: {path}\nMethod: {method.upper()}\n"
doc += "\nRequest parameters:\n"
if "parameters" in node:
parameters = node["parameters"]
doc += "Path Parameters:\n"
doc += "Path parameters:\n"
# param["in"]: path / formData / body / query / header
for param in parameters:
if param["in"] == "path":
doc += f'{param["name"]}\n'
doc += f'{param["name"]} \n'
doc += "\nBody Parameters:\n"
doc += "Name\tType\tRequired\tDefault Value\tRemarks\n"
doc += "\nBody parameters:\n"
doc += "Name\tType\tRequired\tDefault\tNotes\n"
for param in parameters:
if param["in"] == "body":
schema = param.get("schema", {})
@ -227,8 +230,8 @@ class UTGenerator:
doc += self.build_object_properties(param, [])
# Output return data information
doc += "\nReturn Data:\n"
doc += "Name\tType\tRequired\tDefault Value\tRemarks\n"
doc += "\nReturn data:\n"
doc += "Name\tType\tRequired\tDefault\tNotes\n"
responses = node["responses"]
response = responses.get("200", {})
schema = response.get("schema", {})
@ -242,12 +245,13 @@ class UTGenerator:
return doc
def _store(self, data, base, folder, fname):
"""Store data in a file."""
file_path = self.get_file_path(Path(base) / folder, fname)
with open(file_path, "w", encoding="utf-8") as file:
file.write(data)
def ask_gpt_and_save(self, question: str, tag: str, fname: str):
"""Generate a question and store both question and answer."""
"""Generate a question and save both the question and answer."""
messages = [self.icl_sample, question]
result = self.gpt_msgs_to_code(messages=messages)
@ -255,11 +259,11 @@ class UTGenerator:
self._store(result, self.ut_py_path, tag, f"{fname}.py")
def _generate_ut(self, tag, paths):
"""Handle structure under the data path.
"""Process the structure under the data path.
Args:
tag (_type_): Module name.
paths (_type_): Path Object.
tag: Module name.
paths: Path Object.
"""
for path, path_obj in paths.items():
for method, node in path_obj.items():
@ -269,7 +273,7 @@ class UTGenerator:
self.ask_gpt_and_save(question, tag, summary)
def gpt_msgs_to_code(self, messages: list) -> str:
"""Choose based on different invocation methods."""
"""Choose based on different call methods."""
result = ''
if self.chatgpt_method == "API":
result = GPTAPI().ask_code(msgs=messages)

View file

@ -14,10 +14,11 @@ from typing import List, Tuple
from metagpt.logs import logger
def check_command_exists(command) -> int:
""" Check if a command exists.
def check_cmd_exists(command) -> int:
"""Check if a command exists.
:param command: Command to check.
:return: Returns 0 if the command exists, else returns non-zero.
:return: Returns 0 if the command exists, otherwise returns a non-zero value.
"""
check_command = 'command -v ' + command + ' >/dev/null 2>&1 || { echo >&2 "no mermaid"; exit 1; }'
result = os.system(check_command)
@ -28,19 +29,19 @@ class OutputParser:
@classmethod
def parse_blocks(cls, text: str):
# Firstly, split the text into different blocks based on "##".
# First, split the text into different blocks based on "##".
blocks = text.split("##")
# Create a dictionary to store the title and content of each block.
block_dict = {}
# Loop through all blocks.
# Iterate through all blocks.
for block in blocks:
# If block is not empty, continue processing.
# If the block is not empty, continue processing.
if block.strip() != "":
# Split block's title and content and trim them.
# Split the block's title and content and trim whitespace.
block_title, block_content = block.split("\n", 1)
# There may be errors in LLM, correct it here.
# LLM might make mistakes; correct it here.
if block_title[-1] == ":":
block_title = block_title[:-1]
block_dict[block_title.strip()] = block_content.strip()
@ -84,7 +85,7 @@ class OutputParser:
block_dict = cls.parse_blocks(data)
parsed_data = {}
for block, content in block_dict.items():
# Try to remove code markers.
# Try to remove the code marker.
try:
content = cls.parse_code(text=content)
except Exception:
@ -103,7 +104,7 @@ class OutputParser:
block_dict = cls.parse_blocks(data)
parsed_data = {}
for block, content in block_dict.items():
# Try to remove code markers.
# Try to remove the code marker.
try:
content = cls.parse_code(text=content)
except Exception:
@ -135,17 +136,17 @@ class CodeParser:
@classmethod
def parse_blocks(cls, text: str):
# Firstly, split the text into different blocks based on "##".
# First, split the text into different blocks based on "##".
blocks = text.split("##")
# Create a dictionary to store the title and content of each block.
block_dict = {}
# Loop through all blocks.
# Iterate through all blocks.
for block in blocks:
# If block is not empty, continue processing.
# If the block is not empty, continue processing.
if block.strip() != "":
# Split block's title and content and trim them.
# Split the block's title and content and trim whitespace.
block_title, block_content = block.split("\n", 1)
block_dict[block_title.strip()] = block_content.strip()
@ -160,7 +161,7 @@ class CodeParser:
if match:
code = match.group(1)
else:
logger.error(f"{pattern} did not match the following text:")
logger.error(f"{pattern} not match following text:")
logger.error(text)
raise Exception
return code
@ -213,14 +214,4 @@ def print_members(module, indent=0):
prefix = ' ' * indent
for name, obj in inspect.getmembers(module):
print(name, obj)
if inspect.isclass(obj):
print(f'{prefix}Class: {name}')
# print the methods within the class
if name in ['__class__', '__base__']:
continue
print_members(obj, indent + 2)
elif inspect.isfunction(obj):
print(f'{prefix}Function: {name}')
elif inspect.ismethod(obj):
print(f'{prefix}Method: {name}')
if inspect