feat: +common talk role

This commit is contained in:
莘权 马 2023-08-26 16:59:12 +08:00
parent 4fe3d6e879
commit 2c593bedea
13 changed files with 261 additions and 46 deletions

View file

@ -5,3 +5,11 @@
@Author : alexanderwu
@File : __init__.py
"""
from metagpt.learn.text_to_image import text_to_image
from metagpt.learn.text_to_speech import text_to_speech
__all__ = [
"text_to_image",
"text_to_speech",
]

View file

@ -1,14 +1,26 @@
from pathlib import Path
from typing import List, Dict
from typing import List, Dict, Optional
import yaml
from pydantic import BaseModel
class Example(BaseModel):
ask: str
answer: str
class Returns(BaseModel):
type: str
format: Optional[str] = None
class Skill(BaseModel):
name: str
description: str
id: str
requisite: List[str]
arguments: Dict
examples: List[Example]
returns: Returns
class EntitySkills(BaseModel):
@ -26,13 +38,26 @@ class SkillLoader:
skills = yaml.safe_load(file)
self._skills = SkillsDeclaration(**skills)
def get_skill_list(self, entity_name: str = "Assistant"):
if not self._skills or entity_name not in self._skills.entities:
def get_skill_list(self, entity_name: str = "Assistant") -> Dict:
entity_skills = self.get_entity(entity_name)
if not entity_skills:
return {}
entity_skills = self._skills.entities.get(entity_name)
description_to_name_mappings = {}
for s in entity_skills.skills:
description_to_name_mappings[s.description] = s.name
return description_to_name_mappings
def get_skill(self, name, entity_name: str = "Assistant") -> Skill:
entity = self.get_entity(entity_name)
if not entity:
return None
for sk in entity.skills:
if sk.name == name:
return sk
def get_entity(self, name) -> EntitySkills:
if not self._skills:
return None
return self._skills.entities.get(name)

View file

@ -16,7 +16,7 @@ from metagpt.utils.common import initialize_environment
@skill_metadata(name="Text to Embedding",
description="Convert the text into embeddings.",
requisite="`OPENAI_API_KEY`")
def text_to_embedding(text, model="text-embedding-ada-002", openai_api_key=""):
def text_to_embedding(text, model="text-embedding-ada-002", openai_api_key="", **kwargs):
"""Text to embedding
:param text: The text used for embedding.

View file

@ -17,7 +17,7 @@ from metagpt.utils.common import initialize_environment
@skill_metadata(name="Text to image",
description="Create a drawing based on the text.",
requisite="`OPENAI_API_KEY` or `METAGPT_TEXT_TO_IMAGE_MODEL`")
def text_to_image(text, size_type: str = "512x512", openai_api_key="", model_url=""):
def text_to_image(text, size_type: str = "512x512", openai_api_key="", model_url="", **kwargs):
"""Text to image
:param text: The text used for image conversion.
@ -27,8 +27,14 @@ def text_to_image(text, size_type: str = "512x512", openai_api_key="", model_url
:return: The image data is returned in Base64 encoding.
"""
initialize_environment()
image_declaration = "data:image/png;base64,"
if os.environ.get("METAGPT_TEXT_TO_IMAGE_MODEL") or model_url:
return oas3_metagpt_text_to_image(text, size_type, model_url)
data = oas3_metagpt_text_to_image(text, size_type, model_url)
return image_declaration + data if data else ""
if os.environ.get("OPENAI_API_KEY") or openai_api_key:
return oas3_openai_text_to_image(text, size_type, openai_api_key)
data = oas3_openai_text_to_image(text, size_type, openai_api_key)
return image_declaration + data if data else ""
raise EnvironmentError

View file

@ -17,7 +17,7 @@ from metagpt.utils.common import initialize_environment
description="Text-to-speech",
requisite="`AZURE_TTS_SUBSCRIPTION_KEY` and `AZURE_TTS_REGION`")
def text_to_speech(text, lang="zh-CN", voice="zh-CN-XiaomoNeural", style="affectionate", role="Girl",
subscription_key="", region=""):
subscription_key="", region="", **kwargs):
"""Text to speech
For more details, check out:`https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
@ -32,8 +32,10 @@ def text_to_speech(text, lang="zh-CN", voice="zh-CN-XiaomoNeural", style="affect
"""
initialize_environment()
audio_declaration = "data:audio/wav;base64,"
if (os.environ.get("AZURE_TTS_SUBSCRIPTION_KEY") and os.environ.get("AZURE_TTS_REGION")) or \
(subscription_key and region):
return oas3_azsure_tts(text, lang, voice, style, role, subscription_key, region)
data = oas3_azsure_tts(text, lang, voice, style, role, subscription_key, region)
return audio_declaration + data if data else data
raise EnvironmentError