Update ConText Fill

This commit is contained in:
didi 2024-08-23 20:43:29 +08:00
parent 02c7c4ea47
commit a3ff25430e
5 changed files with 245 additions and 210 deletions

View file

@ -9,6 +9,7 @@ NOTE: You should use typing.List instead of list to do type annotation. Because
we can use typing to extract the type of the node, but we cannot use built-in list to extract.
"""
import json
import re
import typing
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type, Union
@ -482,14 +483,34 @@ class ActionNode:
# If there are multiple fields, we might want to use self.key to find the right one
return self.key
def get_field_names(self):
"""
Get the field names from the Pydantic model associated with this ActionNode.
"""
model_class = self.create_class()
return model_class.model_fields.keys()
def xml_compile(self, context):
pass
field_names = self.get_field_names()
# Construct the example using the field names
examples = []
for field_name in field_names:
examples.append(f"<{field_name}>content</{field_name}>")
# Join all examples into a single string
example_str = "\n".join(examples)
# Add the example to the context
context += f"""
### format example (must be strictly followed) (do not include any other formats except for the given XML format)
{example_str}
"""
print(context)
return context
async def code_fill(self, context, function_name=None, timeout=USE_CONFIG_TIMEOUT):
"""
fill CodeBlock Node
Fill CodeBlock Using ``` ```
"""
field_name = self.get_field_name()
prompt = context
content = await self.llm.aask(prompt, timeout=timeout)
@ -499,9 +520,19 @@ class ActionNode:
async def context_fill(self, context):
"""
这个地方的代码实现的目的是
Fill Context with XML TAG
"""
pass
field_names = self.get_field_names()
extracted_data = {}
content = await self.llm.aask(context)
for field_name in field_names:
# Use regex to find content within XML tags matching the field name
pattern = rf"<{field_name}>(.*?)</{field_name}>"
match = re.search(pattern, content, re.DOTALL)
if match:
extracted_data[field_name] = match.group(1).strip()
return extracted_data
async def fill(
self,
@ -550,7 +581,7 @@ class ActionNode:
使用xml_compile但是这个版本没有办法实现system message temperature
"""
context = self.xml_compile(context=self.context)
result = await self.context_fill(context, timeout)
result = await self.context_fill(context)
self.instruct_content = self.create_class()(**result)
return self