mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-15 11:02:36 +02:00
Merge branch 'code_intepreter' into code_intepreter_add_vision
# Conflicts: # metagpt/tools/__init__.py
This commit is contained in:
commit
23fdf90d21
19 changed files with 568 additions and 245 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -177,4 +177,3 @@ htmlcov.*
|
|||
*.pkl
|
||||
*-structure.csv
|
||||
*-structure.json
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from metagpt.roles.code_interpreter import CodeInterpreter
|
|||
|
||||
|
||||
async def main():
|
||||
prompt = """This is a URL of webpage: https://cn.bing.com/
|
||||
prompt = """This is a URL of webpage: 'https://www.baidu.com/' .
|
||||
Firstly, utilize Selenium and WebDriver for rendering.
|
||||
Secondly, convert image to a webpage including HTML, CSS and JS in one go.
|
||||
Finally, save webpage in a text file.
|
||||
|
|
|
|||
21
examples/sd_tool_usage.py
Normal file
21
examples/sd_tool_usage.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 1/11/2024 7:06 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import asyncio
|
||||
|
||||
from metagpt.roles.code_interpreter import CodeInterpreter
|
||||
|
||||
|
||||
async def main(requirement: str = ""):
|
||||
code_interpreter = CodeInterpreter(use_tools=True, goal=requirement)
|
||||
await code_interpreter.run(requirement)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sd_url = "http://your.sd.service.ip:port"
|
||||
requirement = (
|
||||
f"I want to generate an image of a beautiful girl using the stable diffusion text2image tool, sd_url={sd_url}"
|
||||
)
|
||||
|
||||
asyncio.run(main(requirement))
|
||||
|
|
@ -85,20 +85,14 @@ class DebugCode(BaseWriteAnalysisCode):
|
|||
|
||||
async def run_reflection(
|
||||
self,
|
||||
# goal,
|
||||
# finished_code,
|
||||
# finished_code_result,
|
||||
context: List[Message],
|
||||
code,
|
||||
runtime_result,
|
||||
) -> dict:
|
||||
info = []
|
||||
# finished_code_and_result = finished_code + "\n [finished results]\n\n" + finished_code_result
|
||||
reflection_prompt = REFLECTION_PROMPT.format(
|
||||
debug_example=DEBUG_REFLECTION_EXAMPLE,
|
||||
context=context,
|
||||
# goal=goal,
|
||||
# finished_code=finished_code_and_result,
|
||||
code=code,
|
||||
runtime_result=runtime_result,
|
||||
)
|
||||
|
|
@ -106,33 +100,13 @@ class DebugCode(BaseWriteAnalysisCode):
|
|||
info.append(Message(role="system", content=system_prompt))
|
||||
info.append(Message(role="user", content=reflection_prompt))
|
||||
|
||||
# msg = messages_to_str(info)
|
||||
# resp = await self.llm.aask(msg=msg)
|
||||
resp = await self.llm.aask_code(messages=info, **create_func_config(CODE_REFLECTION))
|
||||
logger.info(f"reflection is {resp}")
|
||||
return resp
|
||||
|
||||
# async def rewrite_code(self, reflection: str = "", context: List[Message] = None) -> str:
|
||||
# """
|
||||
# 根据reflection重写代码
|
||||
# """
|
||||
# info = context
|
||||
# # info.append(Message(role="assistant", content=f"[code context]:{code_context}"
|
||||
# # f"finished code are executable, and you should based on the code to continue your current code debug and improvement"
|
||||
# # f"[reflection]: \n {reflection}"))
|
||||
# info.append(Message(role="assistant", content=f"[reflection]: \n {reflection}"))
|
||||
# info.append(Message(role="user", content=f"[improved impl]:\n Return in Python block"))
|
||||
# msg = messages_to_str(info)
|
||||
# resp = await self.llm.aask(msg=msg)
|
||||
# improv_code = CodeParser.parse_code(block=None, text=resp)
|
||||
# return improv_code
|
||||
|
||||
async def run(
|
||||
self,
|
||||
context: List[Message] = None,
|
||||
plan: str = "",
|
||||
# finished_code: str = "",
|
||||
# finished_code_result: str = "",
|
||||
code: str = "",
|
||||
runtime_result: str = "",
|
||||
) -> str:
|
||||
|
|
@ -140,14 +114,10 @@ class DebugCode(BaseWriteAnalysisCode):
|
|||
根据当前运行代码和报错信息进行reflection和纠错
|
||||
"""
|
||||
reflection = await self.run_reflection(
|
||||
# plan,
|
||||
# finished_code=finished_code,
|
||||
# finished_code_result=finished_code_result,
|
||||
code=code,
|
||||
context=context,
|
||||
runtime_result=runtime_result,
|
||||
)
|
||||
# 根据reflection结果重写代码
|
||||
# improv_code = await self.rewrite_code(reflection, context=context)
|
||||
improv_code = reflection["improved_impl"]
|
||||
return improv_code
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
@Author : orange-crow
|
||||
@File : code_executor.py
|
||||
"""
|
||||
import asyncio
|
||||
import re
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
|
|
@ -81,6 +82,9 @@ class ExecutePyCode(ExecuteCode, Action):
|
|||
async def reset(self):
|
||||
"""reset NotebookClient"""
|
||||
await self.terminate()
|
||||
|
||||
# sleep 1s to wait for the kernel to be cleaned up completely
|
||||
await asyncio.sleep(1)
|
||||
await self.build()
|
||||
self.nb_client = NotebookClient(self.nb, timeout=self.timeout)
|
||||
|
||||
|
|
@ -181,7 +185,11 @@ class ExecutePyCode(ExecuteCode, Action):
|
|||
await self.nb_client.async_execute_cell(cell, cell_index)
|
||||
return True, ""
|
||||
except CellTimeoutError:
|
||||
return False, "TimeoutError"
|
||||
assert self.nb_client.km is not None
|
||||
await self.nb_client.km.interrupt_kernel()
|
||||
await asyncio.sleep(1)
|
||||
error_msg = "Cell execution timed out: Execution exceeded the time limit and was stopped; consider optimizing your code for better performance."
|
||||
return False, error_msg
|
||||
except DeadKernelError:
|
||||
await self.reset()
|
||||
return False, "DeadKernelError"
|
||||
|
|
|
|||
|
|
@ -60,7 +60,6 @@ class MLEngineer(CodeInterpreter):
|
|||
if code_execution_count > 0:
|
||||
logger.warning("We got a bug code, now start to debug...")
|
||||
code = await DebugCode().run(
|
||||
plan=self.planner.current_task.instruction,
|
||||
code=self.latest_code,
|
||||
runtime_result=self.working_memory.get(),
|
||||
context=self.debug_context,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
@File : __init__.py
|
||||
"""
|
||||
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -72,6 +71,12 @@ TOOL_TYPE_MAPPINGS = {
|
|||
desc="Only for evaluating model.",
|
||||
usage_prompt=MODEL_EVALUATE_PROMPT,
|
||||
),
|
||||
"stable_diffusion": ToolType(
|
||||
name="stable_diffusion",
|
||||
module="metagpt.tools.sd_engine",
|
||||
desc="Related to text2image, image2image using stable diffusion model.",
|
||||
usage_prompt="",
|
||||
),
|
||||
"vision": ToolType(
|
||||
name="vision",
|
||||
module=str(TOOL_LIBS_PATH / "vision"),
|
||||
|
|
|
|||
|
|
@ -37,8 +37,9 @@ class FillMissingValue(MLProcess):
|
|||
def transform(self, df: pd.DataFrame):
|
||||
if len(self.features) == 0:
|
||||
return df
|
||||
df[self.features] = self.si.transform(df[self.features])
|
||||
return df
|
||||
new_df = df.copy()
|
||||
new_df[self.features] = self.si.transform(new_df[self.features])
|
||||
return new_df
|
||||
|
||||
|
||||
class MinMaxScale(MLProcess):
|
||||
|
|
@ -54,8 +55,9 @@ class MinMaxScale(MLProcess):
|
|||
self.mms.fit(df[self.features])
|
||||
|
||||
def transform(self, df: pd.DataFrame):
|
||||
df[self.features] = self.mms.transform(df[self.features])
|
||||
return df
|
||||
new_df = df.copy()
|
||||
new_df[self.features] = self.mms.transform(new_df[self.features])
|
||||
return new_df
|
||||
|
||||
|
||||
class StandardScale(MLProcess):
|
||||
|
|
@ -71,8 +73,9 @@ class StandardScale(MLProcess):
|
|||
self.ss.fit(df[self.features])
|
||||
|
||||
def transform(self, df: pd.DataFrame):
|
||||
df[self.features] = self.ss.transform(df[self.features])
|
||||
return df
|
||||
new_df = df.copy()
|
||||
new_df[self.features] = self.ss.transform(new_df[self.features])
|
||||
return new_df
|
||||
|
||||
|
||||
class MaxAbsScale(MLProcess):
|
||||
|
|
@ -88,8 +91,9 @@ class MaxAbsScale(MLProcess):
|
|||
self.mas.fit(df[self.features])
|
||||
|
||||
def transform(self, df: pd.DataFrame):
|
||||
df[self.features] = self.mas.transform(df[self.features])
|
||||
return df
|
||||
new_df = df.copy()
|
||||
new_df[self.features] = self.mas.transform(new_df[self.features])
|
||||
return new_df
|
||||
|
||||
|
||||
class RobustScale(MLProcess):
|
||||
|
|
@ -105,8 +109,9 @@ class RobustScale(MLProcess):
|
|||
self.rs.fit(df[self.features])
|
||||
|
||||
def transform(self, df: pd.DataFrame):
|
||||
df[self.features] = self.rs.transform(df[self.features])
|
||||
return df
|
||||
new_df = df.copy()
|
||||
new_df[self.features] = self.rs.transform(new_df[self.features])
|
||||
return new_df
|
||||
|
||||
|
||||
class OrdinalEncode(MLProcess):
|
||||
|
|
@ -122,8 +127,9 @@ class OrdinalEncode(MLProcess):
|
|||
self.oe.fit(df[self.features])
|
||||
|
||||
def transform(self, df: pd.DataFrame):
|
||||
df[self.features] = self.oe.transform(df[self.features])
|
||||
return df
|
||||
new_df = df.copy()
|
||||
new_df[self.features] = self.oe.transform(new_df[self.features])
|
||||
return new_df
|
||||
|
||||
|
||||
class OneHotEncode(MLProcess):
|
||||
|
|
@ -142,9 +148,9 @@ class OneHotEncode(MLProcess):
|
|||
ts_data = self.ohe.transform(df[self.features])
|
||||
new_columns = self.ohe.get_feature_names_out(self.features)
|
||||
ts_data = pd.DataFrame(ts_data, columns=new_columns, index=df.index)
|
||||
df.drop(self.features, axis=1, inplace=True)
|
||||
df = pd.concat([df, ts_data], axis=1)
|
||||
return df
|
||||
new_df = df.drop(self.features, axis=1)
|
||||
new_df = pd.concat([new_df, ts_data], axis=1)
|
||||
return new_df
|
||||
|
||||
|
||||
class LabelEncode(MLProcess):
|
||||
|
|
@ -165,13 +171,14 @@ class LabelEncode(MLProcess):
|
|||
def transform(self, df: pd.DataFrame):
|
||||
if len(self.features) == 0:
|
||||
return df
|
||||
new_df = df.copy()
|
||||
for i in range(len(self.features)):
|
||||
data_list = df[self.features[i]].astype(str).tolist()
|
||||
for unique_item in np.unique(df[self.features[i]].astype(str)):
|
||||
if unique_item not in self.le_encoders[i].classes_:
|
||||
data_list = ["unknown" if x == unique_item else x for x in data_list]
|
||||
df[self.features[i]] = self.le_encoders[i].transform(data_list)
|
||||
return df
|
||||
new_df[self.features[i]] = self.le_encoders[i].transform(data_list)
|
||||
return new_df
|
||||
|
||||
|
||||
def get_column_info(df: pd.DataFrame) -> dict:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2023/11/17 10:33
|
||||
# @Author : lidanyang
|
||||
# @File : feature_engineering.py
|
||||
# @File : test_feature_engineering.py
|
||||
# @Desc : Feature Engineering Tools
|
||||
import itertools
|
||||
|
||||
|
|
@ -43,9 +43,9 @@ class PolynomialExpansion(MLProcess):
|
|||
ts_data = self.poly.transform(df[self.cols].fillna(0))
|
||||
column_name = self.poly.get_feature_names_out(self.cols)
|
||||
ts_data = pd.DataFrame(ts_data, index=df.index, columns=column_name)
|
||||
df.drop(self.cols, axis=1, inplace=True)
|
||||
df = pd.concat([df, ts_data], axis=1)
|
||||
return df
|
||||
new_df = df.drop(self.cols, axis=1)
|
||||
new_df = pd.concat([new_df, ts_data], axis=1)
|
||||
return new_df
|
||||
|
||||
|
||||
class CatCount(MLProcess):
|
||||
|
|
@ -57,8 +57,9 @@ class CatCount(MLProcess):
|
|||
self.encoder_dict = df[self.col].value_counts().to_dict()
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
df[f"{self.col}_cnt"] = df[self.col].map(self.encoder_dict)
|
||||
return df
|
||||
new_df = df.copy()
|
||||
new_df[f"{self.col}_cnt"] = new_df[self.col].map(self.encoder_dict)
|
||||
return new_df
|
||||
|
||||
|
||||
class TargetMeanEncoder(MLProcess):
|
||||
|
|
@ -71,8 +72,9 @@ class TargetMeanEncoder(MLProcess):
|
|||
self.encoder_dict = df.groupby(self.col)[self.label].mean().to_dict()
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
df[f"{self.col}_target_mean"] = df[self.col].map(self.encoder_dict)
|
||||
return df
|
||||
new_df = df.copy()
|
||||
new_df[f"{self.col}_target_mean"] = new_df[self.col].map(self.encoder_dict)
|
||||
return new_df
|
||||
|
||||
|
||||
class KFoldTargetMeanEncoder(MLProcess):
|
||||
|
|
@ -96,8 +98,9 @@ class KFoldTargetMeanEncoder(MLProcess):
|
|||
self.encoder_dict = tmp.groupby(self.col)[col_name].mean().to_dict()
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
df[f"{self.col}_kf_target_mean"] = df[self.col].map(self.encoder_dict)
|
||||
return df
|
||||
new_df = df.copy()
|
||||
new_df[f"{self.col}_kf_target_mean"] = new_df[self.col].map(self.encoder_dict)
|
||||
return new_df
|
||||
|
||||
|
||||
class CatCross(MLProcess):
|
||||
|
|
@ -124,14 +127,15 @@ class CatCross(MLProcess):
|
|||
self.combs_map = dict(res)
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
new_df = df.copy()
|
||||
for comb in self.combs:
|
||||
new_col = f"{comb[0]}_{comb[1]}"
|
||||
_map = self.combs_map[new_col]
|
||||
df[new_col] = pd.Series(zip(df[comb[0]], df[comb[1]])).map(_map)
|
||||
new_df[new_col] = pd.Series(zip(new_df[comb[0]], new_df[comb[1]])).map(_map)
|
||||
# set the unknown value to a new number
|
||||
df[new_col].fillna(max(_map.values()) + 1, inplace=True)
|
||||
df[new_col] = df[new_col].astype(int)
|
||||
return df
|
||||
new_df[new_col].fillna(max(_map.values()) + 1, inplace=True)
|
||||
new_df[new_col] = new_df[new_col].astype(int)
|
||||
return new_df
|
||||
|
||||
|
||||
class GroupStat(MLProcess):
|
||||
|
|
@ -149,12 +153,12 @@ class GroupStat(MLProcess):
|
|||
self.group_df = group_df
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
df = df.merge(self.group_df, on=self.group_col, how="left")
|
||||
return df
|
||||
new_df = df.merge(self.group_df, on=self.group_col, how="left")
|
||||
return new_df
|
||||
|
||||
|
||||
class SplitBins(MLProcess):
|
||||
def __init__(self, cols: str, strategy: str = "quantile"):
|
||||
def __init__(self, cols: list, strategy: str = "quantile"):
|
||||
self.cols = cols
|
||||
self.strategy = strategy
|
||||
self.encoder = None
|
||||
|
|
@ -164,8 +168,9 @@ class SplitBins(MLProcess):
|
|||
self.encoder.fit(df[self.cols].fillna(0))
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
df[self.cols] = self.encoder.transform(df[self.cols].fillna(0))
|
||||
return df
|
||||
new_df = df.copy()
|
||||
new_df[self.cols] = self.encoder.transform(new_df[self.cols].fillna(0))
|
||||
return new_df
|
||||
|
||||
|
||||
class ExtractTimeComps(MLProcess):
|
||||
|
|
@ -192,91 +197,8 @@ class ExtractTimeComps(MLProcess):
|
|||
time_comps_df["dayofweek"] = time_s.dt.dayofweek + 1
|
||||
if "is_weekend" in self.time_comps:
|
||||
time_comps_df["is_weekend"] = time_s.dt.dayofweek.isin([5, 6]).astype(int)
|
||||
df = pd.concat([df, time_comps_df], axis=1)
|
||||
return df
|
||||
|
||||
|
||||
# @registry.register("feature_engineering", FeShiftByTime)
|
||||
# def fe_shift_by_time(df, time_col, group_col, shift_col, periods, freq):
|
||||
# df[time_col] = pd.to_datetime(df[time_col])
|
||||
#
|
||||
# def shift_datetime(date, offset, unit):
|
||||
# if unit in ["year", "y", "Y"]:
|
||||
# return date + relativedelta(years=offset)
|
||||
# elif unit in ["month", "m", "M"]:
|
||||
# return date + relativedelta(months=offset)
|
||||
# elif unit in ["day", "d", "D"]:
|
||||
# return date + relativedelta(days=offset)
|
||||
# elif unit in ["week", "w", "W"]:
|
||||
# return date + relativedelta(weeks=offset)
|
||||
# elif unit in ["hour", "h", "H"]:
|
||||
# return date + relativedelta(hours=offset)
|
||||
# else:
|
||||
# return date
|
||||
#
|
||||
# def shift_by_time_on_key(
|
||||
# inner_df, time_col, group_col, shift_col, offset, unit, col_name
|
||||
# ):
|
||||
# inner_df = inner_df.drop_duplicates()
|
||||
# inner_df[time_col] = inner_df[time_col].map(
|
||||
# lambda x: shift_datetime(x, offset, unit)
|
||||
# )
|
||||
# inner_df = inner_df.groupby([time_col, group_col], as_index=False)[
|
||||
# shift_col
|
||||
# ].mean()
|
||||
# inner_df.rename(columns={shift_col: col_name}, inplace=True)
|
||||
# return inner_df
|
||||
#
|
||||
# shift_df = df[[time_col, group_col, shift_col]].copy()
|
||||
# for period in periods:
|
||||
# new_col_name = f"{group_col}_{shift_col}_lag_{period}_{freq}"
|
||||
# tmp = shift_by_time_on_key(
|
||||
# shift_df, time_col, group_col, shift_col, period, freq, new_col_name
|
||||
# )
|
||||
# df = df.merge(tmp, on=[time_col, group_col], how="left")
|
||||
#
|
||||
# return df
|
||||
#
|
||||
#
|
||||
# @registry.register("feature_engineering", FeRollingByTime)
|
||||
# def fe_rolling_by_time(df, time_col, group_col, rolling_col, periods, freq, agg_funcs):
|
||||
# df[time_col] = pd.to_datetime(df[time_col])
|
||||
#
|
||||
# def rolling_by_time_on_key(inner_df, offset, unit, agg_func, col_name):
|
||||
# time_freq = {
|
||||
# "Y": [365 * offset, "D"],
|
||||
# "M": [30 * offset, "D"],
|
||||
# "D": [offset, "D"],
|
||||
# "W": [7 * offset, "D"],
|
||||
# "H": [offset, "h"],
|
||||
# }
|
||||
#
|
||||
# if agg_func not in ["mean", "std", "max", "min", "median", "sum", "count"]:
|
||||
# raise ValueError(f"Invalid agg function: {agg_func}")
|
||||
#
|
||||
# rolling_feat = inner_df.rolling(
|
||||
# f"{time_freq[unit][0]}{time_freq[unit][1]}", closed="left"
|
||||
# )
|
||||
# rolling_feat = getattr(rolling_feat, agg_func)()
|
||||
# depth = df.columns.nlevels
|
||||
# rolling_feat = rolling_feat.stack(list(range(depth)))
|
||||
# rolling_feat.name = col_name
|
||||
# return rolling_feat
|
||||
#
|
||||
# rolling_df = df[[time_col, group_col, rolling_col]].copy()
|
||||
# for period in periods:
|
||||
# for func in agg_funcs:
|
||||
# new_col_name = f"{group_col}_{rolling_col}_rolling_{period}_{freq}_{func}"
|
||||
# tmp = pd.pivot_table(
|
||||
# rolling_df,
|
||||
# index=time_col,
|
||||
# values=rolling_col,
|
||||
# columns=group_col,
|
||||
# )
|
||||
# tmp = rolling_by_time_on_key(tmp, period, freq, func, new_col_name)
|
||||
# df = df.merge(tmp, on=[time_col, group_col], how="left")
|
||||
#
|
||||
# return df
|
||||
new_df = pd.concat([df, time_comps_df], axis=1)
|
||||
return new_df
|
||||
|
||||
|
||||
class GeneralSelection(MLProcess):
|
||||
|
|
@ -302,8 +224,8 @@ class GeneralSelection(MLProcess):
|
|||
self.feats = feats
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
df = df[self.feats + [self.label_col]]
|
||||
return df
|
||||
new_df = df[self.feats + [self.label_col]]
|
||||
return new_df
|
||||
|
||||
|
||||
class TreeBasedSelection(MLProcess):
|
||||
|
|
@ -344,8 +266,8 @@ class TreeBasedSelection(MLProcess):
|
|||
self.feats.append(self.label_col)
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
df = df[self.feats]
|
||||
return df
|
||||
new_df = df[self.feats]
|
||||
return new_df
|
||||
|
||||
|
||||
class VarianceBasedSelection(MLProcess):
|
||||
|
|
@ -364,5 +286,5 @@ class VarianceBasedSelection(MLProcess):
|
|||
self.feats.append(self.label_col)
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
df = df[self.feats]
|
||||
return df
|
||||
new_df = df[self.feats]
|
||||
return new_df
|
||||
|
|
|
|||
58
metagpt/tools/functions/schemas/stable_diffusion.yml
Normal file
58
metagpt/tools/functions/schemas/stable_diffusion.yml
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
SDEngine:
|
||||
type: class
|
||||
description: "Generate image using stable diffusion model"
|
||||
methods:
|
||||
__init__:
|
||||
description: "Initialize the SDEngine instance."
|
||||
parameters:
|
||||
properties:
|
||||
sd_url:
|
||||
type: str
|
||||
description: "URL of the stable diffusion service."
|
||||
simple_run_t2i:
|
||||
description: "Run the stable diffusion API for multiple prompts, calling the stable diffusion API to generate images."
|
||||
parameters:
|
||||
properties:
|
||||
payload:
|
||||
type: dict
|
||||
description: "Dictionary of input parameters for the stable diffusion API."
|
||||
auto_save:
|
||||
type: bool
|
||||
description: "Save generated images automatically."
|
||||
required:
|
||||
- prompts
|
||||
run_t2i:
|
||||
type: async function
|
||||
description: "Run the stable diffusion API for multiple prompts, calling the stable diffusion API to generate images."
|
||||
parameters:
|
||||
properties:
|
||||
payloads:
|
||||
type: list
|
||||
description: "List of payload, each payload is a dictionary of input parameters for the stable diffusion API."
|
||||
required:
|
||||
- payloads
|
||||
construct_payload:
|
||||
description: "Modify and set the API parameters for image generation."
|
||||
parameters:
|
||||
properties:
|
||||
prompt:
|
||||
type: str
|
||||
description: "Text input for image generation."
|
||||
required:
|
||||
- prompt
|
||||
returns:
|
||||
payload:
|
||||
type: dict
|
||||
description: "Updated parameters for the stable diffusion API."
|
||||
save:
|
||||
description: "Save generated images to the output directory."
|
||||
parameters:
|
||||
properties:
|
||||
imgs:
|
||||
type: str
|
||||
description: "Generated images."
|
||||
save_name:
|
||||
type: str
|
||||
description: "Output image name. Default is empty."
|
||||
required:
|
||||
- imgs
|
||||
|
|
@ -2,13 +2,14 @@
|
|||
# @Date : 2023/7/19 16:28
|
||||
# @Author : stellahong (stellahong@deepwisdom.ai)
|
||||
# @Desc :
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
from os.path import join
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
from aiohttp import ClientSession
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
|
|
@ -51,9 +52,9 @@ default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
|
|||
|
||||
|
||||
class SDEngine:
|
||||
def __init__(self):
|
||||
def __init__(self, sd_url=""):
|
||||
# Initialize the SDEngine with configuration
|
||||
self.sd_url = CONFIG.get("SD_URL")
|
||||
self.sd_url = sd_url if sd_url else CONFIG.get("SD_URL")
|
||||
self.sd_t2i_url = f"{self.sd_url}{CONFIG.get('SD_T2I_API')}"
|
||||
# Define default payload settings for SD API
|
||||
self.payload = payload
|
||||
|
|
@ -69,25 +70,36 @@ class SDEngine:
|
|||
):
|
||||
# Configure the payload with provided inputs
|
||||
self.payload["prompt"] = prompt
|
||||
self.payload["negtive_prompt"] = negtive_prompt
|
||||
self.payload["negative_prompt"] = negtive_prompt
|
||||
self.payload["width"] = width
|
||||
self.payload["height"] = height
|
||||
self.payload["override_settings"]["sd_model_checkpoint"] = sd_model
|
||||
logger.info(f"call sd payload is {self.payload}")
|
||||
return self.payload
|
||||
|
||||
def _save(self, imgs, save_name=""):
|
||||
def save(self, imgs, save_name=""):
|
||||
save_dir = CONFIG.workspace_path / SD_OUTPUT_FILE_REPO
|
||||
if not save_dir.exists():
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
batch_decode_base64_to_image(imgs, str(save_dir), save_name=save_name)
|
||||
|
||||
async def run_t2i(self, prompts: List):
|
||||
def simple_run_t2i(self, payload: dict, auto_save: bool = True):
|
||||
with requests.Session() as session:
|
||||
logger.debug(self.sd_t2i_url)
|
||||
rsp = session.post(self.sd_t2i_url, json=payload, timeout=600)
|
||||
|
||||
results = rsp.json()["images"]
|
||||
if auto_save:
|
||||
save_name = hashlib.sha256(payload["prompt"][:10].encode()).hexdigest()[:6]
|
||||
self.save(results, save_name=f"output_{save_name}")
|
||||
return results
|
||||
|
||||
async def run_t2i(self, payloads: List):
|
||||
# Asynchronously run the SD API for multiple prompts
|
||||
session = ClientSession()
|
||||
for payload_idx, payload in enumerate(prompts):
|
||||
for payload_idx, payload in enumerate(payloads):
|
||||
results = await self.run(url=self.sd_t2i_url, payload=payload, session=session)
|
||||
self._save(results, save_name=f"output_{payload_idx}")
|
||||
self.save(results, save_name=f"output_{payload_idx}")
|
||||
await session.close()
|
||||
|
||||
async def run(self, url, payload, session):
|
||||
|
|
@ -121,13 +133,3 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""):
|
|||
for idx, _img in enumerate(imgs):
|
||||
save_name = join(save_dir, save_name)
|
||||
decode_base64_to_image(_img, save_name=save_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
engine = SDEngine()
|
||||
prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary"
|
||||
|
||||
engine.construct_payload(prompt)
|
||||
|
||||
event_loop = asyncio.get_event_loop()
|
||||
event_loop.run_until_complete(engine.run_t2i(prompt))
|
||||
|
|
|
|||
57
tests/metagpt/actions/test_debug_code.py
Normal file
57
tests/metagpt/actions/test_debug_code.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 1/11/2024 8:51 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.debug_code import DebugCode, messages_to_str
|
||||
from metagpt.schema import Message
|
||||
|
||||
ErrorStr = """Tested passed:
|
||||
|
||||
Tests failed:
|
||||
assert sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5] # output: [1, 2, 4, 3, 5]
|
||||
"""
|
||||
|
||||
CODE = """
|
||||
def sort_array(arr):
|
||||
# Helper function to count the number of ones in the binary representation
|
||||
def count_ones(n):
|
||||
return bin(n).count('1')
|
||||
|
||||
# Sort the array using a custom key function
|
||||
# The key function returns a tuple (number of ones, value) for each element
|
||||
# This ensures that if two elements have the same number of ones, they are sorted by their value
|
||||
sorted_arr = sorted(arr, key=lambda x: (count_ones(x), x))
|
||||
|
||||
return sorted_arr
|
||||
```
|
||||
"""
|
||||
|
||||
DebugContext = '''Solve the problem in Python:
|
||||
def sort_array(arr):
|
||||
"""
|
||||
In this Kata, you have to sort an array of non-negative integers according to
|
||||
number of ones in their binary representation in ascending order.
|
||||
For similar number of ones, sort based on decimal value.
|
||||
|
||||
It must be implemented like this:
|
||||
>>> sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5]
|
||||
>>> sort_array([-2, -3, -4, -5, -6]) == [-6, -5, -4, -3, -2]
|
||||
>>> sort_array([1, 0, 2, 3, 4]) [0, 1, 2, 3, 4]
|
||||
"""
|
||||
'''
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debug_code():
|
||||
debug_context = Message(content=DebugContext)
|
||||
new_code = await DebugCode().run(context=debug_context, code=CODE, runtime_result=ErrorStr)
|
||||
assert "def sort_array(arr)" in new_code
|
||||
|
||||
|
||||
def test_messages_to_str():
|
||||
debug_context = Message(content=DebugContext)
|
||||
msg_str = messages_to_str([debug_context])
|
||||
assert "user: Solve the problem in Python" in msg_str
|
||||
|
|
@ -96,4 +96,4 @@ async def test_run_with_timeout():
|
|||
code = "import time; time.sleep(2)"
|
||||
message, success = await pi.run(code)
|
||||
assert not success
|
||||
assert message == "TimeoutError"
|
||||
assert message.startswith("Cell execution timed out")
|
||||
|
|
|
|||
|
|
@ -3,8 +3,13 @@ import asyncio
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.execute_code import ExecutePyCode
|
||||
from metagpt.actions.write_analysis_code import WriteCodeByGenerate, WriteCodeWithTools
|
||||
from metagpt.actions.write_analysis_code import (
|
||||
WriteCodeByGenerate,
|
||||
WriteCodeWithTools,
|
||||
WriteCodeWithToolsML,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.plan.planner import STRUCTURAL_CONTEXT
|
||||
from metagpt.schema import Message, Plan, Task
|
||||
|
||||
|
||||
|
|
@ -40,13 +45,15 @@ async def test_tool_recommendation():
|
|||
tools = await write_code._tool_recommendation(task, code_steps, available_tools)
|
||||
|
||||
assert len(tools) == 1
|
||||
assert tools[0] == ["fill_missing_value"]
|
||||
assert tools[0] == "fill_missing_value"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_with_tools():
|
||||
write_code = WriteCodeWithTools()
|
||||
messages = []
|
||||
write_code_ml = WriteCodeWithToolsML()
|
||||
|
||||
requirement = "构造数据集并进行数据清洗"
|
||||
task_map = {
|
||||
"1": Task(
|
||||
task_id="1",
|
||||
|
|
@ -69,10 +76,6 @@ async def test_write_code_with_tools():
|
|||
instruction="对数据集进行数据清洗",
|
||||
task_type="data_preprocess",
|
||||
dependent_task_ids=["1"],
|
||||
code_steps="""
|
||||
{"Step 1": "对数据集进行去重",
|
||||
"Step 2": "对数据集进行缺失值处理"}
|
||||
""",
|
||||
),
|
||||
}
|
||||
plan = Plan(
|
||||
|
|
@ -83,10 +86,22 @@ async def test_write_code_with_tools():
|
|||
)
|
||||
column_info = ""
|
||||
|
||||
code = await write_code.run(messages, plan, column_info)
|
||||
context = STRUCTURAL_CONTEXT.format(
|
||||
user_requirement=requirement,
|
||||
context=plan.context,
|
||||
tasks=list(task_map.values()),
|
||||
current_task=plan.current_task.model_dump_json(),
|
||||
)
|
||||
context_msg = [Message(content=context, role="user")]
|
||||
|
||||
code = await write_code.run(context_msg, plan)
|
||||
assert len(code) > 0
|
||||
print(code)
|
||||
|
||||
code_with_ml = await write_code_ml.run([], plan, column_info)
|
||||
assert len(code_with_ml) > 0
|
||||
print(code_with_ml)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_to_correct_error():
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2023/11/17 10:24
|
||||
# @Time : 2024/1/11 16:14
|
||||
# @Author : lidanyang
|
||||
# @File : __init__.py
|
||||
# @Desc :
|
||||
111
tests/metagpt/tools/functions/libs/test_data_preprocess.py
Normal file
111
tests/metagpt/tools/functions/libs/test_data_preprocess.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import numpy.testing as npt
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from metagpt.tools.functions.libs.data_preprocess import (
|
||||
FillMissingValue,
|
||||
LabelEncode,
|
||||
MaxAbsScale,
|
||||
MinMaxScale,
|
||||
OneHotEncode,
|
||||
OrdinalEncode,
|
||||
RobustScale,
|
||||
StandardScale,
|
||||
get_column_info,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_datasets():
|
||||
return pd.DataFrame(
|
||||
{
|
||||
"num1": [1, 2, np.nan, 4, 5],
|
||||
"cat1": ["A", "B", np.nan, "D", "A"],
|
||||
"date1": [
|
||||
datetime(2020, 1, 1),
|
||||
datetime(2020, 1, 2),
|
||||
datetime(2020, 1, 3),
|
||||
datetime(2020, 1, 4),
|
||||
datetime(2020, 1, 5),
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_fill_missing_value(mock_datasets):
|
||||
fm = FillMissingValue(features=["num1"], strategy="mean")
|
||||
transformed = fm.fit_transform(mock_datasets.copy())
|
||||
|
||||
assert transformed["num1"].isnull().sum() == 0
|
||||
|
||||
|
||||
def test_min_max_scale(mock_datasets):
|
||||
mms = MinMaxScale(features=["num1"])
|
||||
transformed = mms.fit_transform(mock_datasets.copy())
|
||||
|
||||
npt.assert_allclose(transformed["num1"].min(), 0)
|
||||
npt.assert_allclose(transformed["num1"].max(), 1)
|
||||
|
||||
|
||||
def test_standard_scale(mock_datasets):
|
||||
ss = StandardScale(features=["num1"])
|
||||
transformed = ss.fit_transform(mock_datasets.copy())
|
||||
|
||||
assert int(transformed["num1"].mean()) == 0
|
||||
assert int(transformed["num1"].std()) == 1
|
||||
|
||||
|
||||
def test_max_abs_scale(mock_datasets):
|
||||
mas = MaxAbsScale(features=["num1"])
|
||||
transformed = mas.fit_transform(mock_datasets.copy())
|
||||
|
||||
npt.assert_allclose(transformed["num1"].abs().max(), 1)
|
||||
|
||||
|
||||
def test_robust_scale(mock_datasets):
|
||||
rs = RobustScale(features=["num1"])
|
||||
transformed = rs.fit_transform(mock_datasets.copy())
|
||||
|
||||
assert int(transformed["num1"].median()) == 0
|
||||
|
||||
|
||||
def test_ordinal_encode(mock_datasets):
|
||||
oe = OrdinalEncode(features=["cat1"])
|
||||
transformed = oe.fit_transform(mock_datasets.copy())
|
||||
|
||||
assert transformed["cat1"].max() == 2
|
||||
|
||||
|
||||
def test_one_hot_encode(mock_datasets):
|
||||
ohe = OneHotEncode(features=["cat1"])
|
||||
transformed = ohe.fit_transform(mock_datasets.copy())
|
||||
|
||||
assert transformed["cat1_A"].max() == 1
|
||||
|
||||
|
||||
def test_label_encode(mock_datasets):
|
||||
le = LabelEncode(features=["cat1"])
|
||||
transformed = le.fit_transform(mock_datasets.copy())
|
||||
|
||||
assert transformed["cat1"].max() == 3
|
||||
|
||||
# test transform with unseen data
|
||||
test = mock_datasets.copy()
|
||||
test["cat1"] = ["A", "B", "C", "D", "E"]
|
||||
transformed = le.transform(test)
|
||||
assert transformed["cat1"].max() == 4
|
||||
|
||||
|
||||
def test_get_column_info(mock_datasets):
|
||||
df = mock_datasets
|
||||
column_info = get_column_info(df)
|
||||
|
||||
assert column_info == {
|
||||
"Category": ["cat1"],
|
||||
"Numeric": ["num1"],
|
||||
"Datetime": ["date1"],
|
||||
"Others": [],
|
||||
}
|
||||
174
tests/metagpt/tools/functions/libs/test_feature_engineering.py
Normal file
174
tests/metagpt/tools/functions/libs/test_feature_engineering.py
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from sklearn.datasets import fetch_california_housing, load_breast_cancer, load_iris
|
||||
|
||||
from metagpt.tools.functions.libs.feature_engineering import (
|
||||
CatCount,
|
||||
CatCross,
|
||||
ExtractTimeComps,
|
||||
GeneralSelection,
|
||||
GroupStat,
|
||||
KFoldTargetMeanEncoder,
|
||||
PolynomialExpansion,
|
||||
SplitBins,
|
||||
TargetMeanEncoder,
|
||||
TreeBasedSelection,
|
||||
VarianceBasedSelection,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset():
|
||||
return pd.DataFrame(
|
||||
{
|
||||
"num1": [1, 2, np.nan, 4, 5, 6, 7, 3],
|
||||
"num2": [1, 3, 2, 1, np.nan, 5, 6, 4],
|
||||
"num3": [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan],
|
||||
"cat1": ["A", "B", np.nan, "D", "E", "C", "B", "A"],
|
||||
"cat2": ["A", "A", "A", "A", "A", "A", "A", "A"],
|
||||
"date1": [
|
||||
"2020-01-01",
|
||||
"2020-01-02",
|
||||
"2020-01-03",
|
||||
"2020-01-04",
|
||||
"2020-01-05",
|
||||
"2020-01-06",
|
||||
"2020-01-07",
|
||||
"2020-01-08",
|
||||
],
|
||||
"label": [0, 1, 0, 1, 0, 1, 0, 1],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def load_sklearn_data(data_name):
|
||||
if data_name == "iris":
|
||||
data = load_iris()
|
||||
elif data_name == "breast_cancer":
|
||||
data = load_breast_cancer()
|
||||
elif data_name == "housing":
|
||||
data = fetch_california_housing()
|
||||
else:
|
||||
raise ValueError("data_name not supported")
|
||||
|
||||
X, y, feature_names = data.data, data.target, data.feature_names
|
||||
data = pd.DataFrame(X, columns=feature_names)
|
||||
data["label"] = y
|
||||
return data
|
||||
|
||||
|
||||
def test_polynomial_expansion(mock_dataset):
|
||||
pe = PolynomialExpansion(cols=["num1", "num2", "label"], degree=2, label_col="label")
|
||||
transformed = pe.fit_transform(mock_dataset)
|
||||
|
||||
assert len(transformed.columns) == len(mock_dataset.columns) + 3
|
||||
|
||||
# when too many columns
|
||||
data = load_sklearn_data("breast_cancer")
|
||||
cols = [c for c in data.columns if c != "label"]
|
||||
pe = PolynomialExpansion(cols=cols, degree=2, label_col="label")
|
||||
transformed = pe.fit_transform(data)
|
||||
|
||||
assert len(transformed.columns) == len(data.columns) + 55
|
||||
|
||||
|
||||
def test_cat_count(mock_dataset):
|
||||
cc = CatCount(col="cat1")
|
||||
transformed = cc.fit_transform(mock_dataset)
|
||||
|
||||
assert "cat1_cnt" in transformed.columns
|
||||
assert transformed["cat1_cnt"][0] == 2
|
||||
|
||||
|
||||
def test_target_mean_encoder(mock_dataset):
|
||||
tme = TargetMeanEncoder(col="cat1", label="label")
|
||||
transformed = tme.fit_transform(mock_dataset)
|
||||
|
||||
assert "cat1_target_mean" in transformed.columns
|
||||
assert transformed["cat1_target_mean"][0] == 0.5
|
||||
|
||||
|
||||
def test_kfold_target_mean_encoder(mock_dataset):
|
||||
kfme = KFoldTargetMeanEncoder(col="cat1", label="label")
|
||||
transformed = kfme.fit_transform(mock_dataset)
|
||||
|
||||
assert "cat1_kf_target_mean" in transformed.columns
|
||||
|
||||
|
||||
def test_cat_cross(mock_dataset):
|
||||
cc = CatCross(cols=["cat1", "cat2"])
|
||||
transformed = cc.fit_transform(mock_dataset)
|
||||
|
||||
assert "cat1_cat2" in transformed.columns
|
||||
|
||||
cc = CatCross(cols=["cat1", "cat2"], max_cat_num=3)
|
||||
transformed = cc.fit_transform(mock_dataset)
|
||||
|
||||
assert "cat1_cat2" not in transformed.columns
|
||||
|
||||
|
||||
def test_group_stat(mock_dataset):
|
||||
gs = GroupStat(group_col="cat1", agg_col="num1", agg_funcs=["mean", "sum"])
|
||||
transformed = gs.fit_transform(mock_dataset)
|
||||
|
||||
assert "num1_mean_by_cat1" in transformed.columns
|
||||
assert "num1_sum_by_cat1" in transformed.columns
|
||||
|
||||
|
||||
def test_split_bins(mock_dataset):
|
||||
sb = SplitBins(cols=["num1"])
|
||||
transformed = sb.fit_transform(mock_dataset)
|
||||
|
||||
assert transformed["num1"].nunique() <= 5
|
||||
assert all(0 <= x < 5 for x in transformed["num1"])
|
||||
|
||||
|
||||
def test_extract_time_comps(mock_dataset):
|
||||
time_comps = ["year", "month", "day", "hour", "dayofweek", "is_weekend"]
|
||||
etc = ExtractTimeComps(time_col="date1", time_comps=time_comps)
|
||||
transformed = etc.fit_transform(mock_dataset.copy())
|
||||
|
||||
for comp in time_comps:
|
||||
assert comp in transformed.columns
|
||||
assert transformed["year"][0] == 2020
|
||||
assert transformed["month"][0] == 1
|
||||
assert transformed["day"][0] == 1
|
||||
assert transformed["hour"][0] == 0
|
||||
assert transformed["dayofweek"][0] == 3
|
||||
assert transformed["is_weekend"][0] == 0
|
||||
|
||||
|
||||
def test_general_selection(mock_dataset):
|
||||
gs = GeneralSelection(label_col="label")
|
||||
transformed = gs.fit_transform(mock_dataset.copy())
|
||||
|
||||
assert "num3" not in transformed.columns
|
||||
assert "cat2" not in transformed.columns
|
||||
|
||||
|
||||
def test_tree_based_selection(mock_dataset):
|
||||
# regression
|
||||
data = load_sklearn_data("housing")
|
||||
tbs = TreeBasedSelection(label_col="label", task_type="reg")
|
||||
transformed = tbs.fit_transform(data)
|
||||
assert len(transformed.columns) > 1
|
||||
|
||||
# classification
|
||||
data = load_sklearn_data("breast_cancer")
|
||||
tbs = TreeBasedSelection(label_col="label", task_type="cls")
|
||||
transformed = tbs.fit_transform(data)
|
||||
assert len(transformed.columns) > 1
|
||||
|
||||
# multi-classification
|
||||
data = load_sklearn_data("iris")
|
||||
tbs = TreeBasedSelection(label_col="label", task_type="mcls")
|
||||
transformed = tbs.fit_transform(data)
|
||||
assert len(transformed.columns) > 1
|
||||
|
||||
|
||||
def test_variance_based_selection(mock_dataset):
|
||||
vbs = VarianceBasedSelection(label_col="label")
|
||||
transformed = vbs.fit_transform(mock_dataset.copy())
|
||||
|
||||
assert "num3" not in transformed.columns
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2023/11/17 10:24
|
||||
# @Author : lidanyang
|
||||
# @File : test_register.py
|
||||
# @Desc :
|
||||
import pytest
|
||||
|
||||
from metagpt.tools.functions.register.register import FunctionRegistry
|
||||
from metagpt.tools.functions.schemas.base import ToolSchema, tool_field
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def registry():
|
||||
return FunctionRegistry()
|
||||
|
||||
|
||||
class AddNumbers(ToolSchema):
|
||||
"""Add two numbers"""
|
||||
|
||||
num1: int = tool_field(description="First number")
|
||||
num2: int = tool_field(description="Second number")
|
||||
|
||||
|
||||
def test_register(registry):
|
||||
@registry.register("module1", AddNumbers)
|
||||
def add_numbers(num1, num2):
|
||||
return num1 + num2
|
||||
|
||||
assert len(registry.functions["module1"]) == 1
|
||||
assert "add_numbers" in registry.functions["module1"]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
@registry.register("module1", AddNumbers)
|
||||
def add_numbers(num1, num2):
|
||||
return num1 + num2
|
||||
|
||||
func = registry.get("module1", "add_numbers")
|
||||
assert func["func"](1, 2) == 3
|
||||
assert func["schema"] == {
|
||||
"name": "add_numbers",
|
||||
"description": "Add two numbers",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"num1": {"description": "First number", "type": "int"},
|
||||
"num2": {"description": "Second number", "type": "int"},
|
||||
},
|
||||
"required": ["num1", "num2"],
|
||||
},
|
||||
}
|
||||
|
||||
module1_funcs = registry.get_all_by_module("module1")
|
||||
assert len(module1_funcs) == 1
|
||||
30
tests/metagpt/tools/functions/test_sd.py
Normal file
30
tests/metagpt/tools/functions/test_sd.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 1/10/2024 10:07 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import pytest
|
||||
|
||||
from metagpt.tools.sd_engine import SDEngine
|
||||
|
||||
|
||||
def test_sd_tools():
|
||||
engine = SDEngine()
|
||||
prompt = "1boy, hansom"
|
||||
engine.construct_payload(prompt)
|
||||
engine.simple_run_t2i(engine.payload)
|
||||
|
||||
|
||||
def test_sd_construct_payload():
|
||||
engine = SDEngine()
|
||||
prompt = "1boy, hansom"
|
||||
engine.construct_payload(prompt)
|
||||
assert "negative_prompt" in engine.payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sd_asyn_t2i():
|
||||
engine = SDEngine()
|
||||
prompt = "1boy, hansom"
|
||||
engine.construct_payload(prompt)
|
||||
await engine.run_t2i([engine.payload])
|
||||
assert "negative_prompt" in engine.payload
|
||||
Loading…
Add table
Add a link
Reference in a new issue