提交baseline例子;修改context-fill 格式识别方式

This commit is contained in:
didi 2024-09-09 17:17:15 +08:00
parent ca560a844f
commit 4e0a896bdc
13 changed files with 254 additions and 182 deletions

View file

@ -486,11 +486,18 @@ class ActionNode:
def get_field_names(self):
"""
Get the field names from the Pydantic model associated with this ActionNode.
获取与此ActionNode关联的Pydantic模型的字段名称
"""
model_class = self.create_class()
return model_class.model_fields.keys()
def get_field_types(self):
"""
获取与此ActionNode关联的Pydantic模型的字段类型
"""
model_class = self.create_class()
return {field_name: field.annotation for field_name, field in model_class.model_fields.items()}
def xml_compile(self, context):
# TODO 再来一版
@ -529,20 +536,44 @@ class ActionNode:
async def context_fill(self, context):
"""
Fill Context with XML TAG
使用XML标签填充上下文并根据字段类型进行转换包括字符串整数布尔值列表和字典类型
"""
field_names = self.get_field_names()
field_types = self.get_field_types()
extracted_data = {}
content = await self.llm.aask(context)
# TODO 自动解析类型标注的功能
for field_name in field_names:
# Use regex to find content within XML tags matching the field name
pattern = rf"<{field_name}>(.*?)</{field_name}>"
match = re.search(pattern, content, re.DOTALL)
if match:
extracted_data[field_name] = match.group(1).strip()
raw_value = match.group(1).strip()
field_type = field_types.get(field_name)
if field_type == str:
extracted_data[field_name] = raw_value
elif field_type == int:
try:
extracted_data[field_name] = int(raw_value)
except ValueError:
extracted_data[field_name] = 0 # 或者其他默认值
elif field_type == bool:
extracted_data[field_name] = raw_value.lower() in ('true', 'yes', '1', 'on', 'True')
elif field_type == list:
try:
extracted_data[field_name] = eval(raw_value)
if not isinstance(extracted_data[field_name], list):
raise ValueError
except:
extracted_data[field_name] = [] # 默认空列表
elif field_type == dict:
try:
extracted_data[field_name] = eval(raw_value)
if not isinstance(extracted_data[field_name], dict):
raise ValueError
except:
extracted_data[field_name] = {} # 默认空字典
return extracted_data