diff --git a/metagpt/tools/moderation.py b/metagpt/tools/moderation.py new file mode 100644 index 000000000..c56a6afc4 --- /dev/null +++ b/metagpt/tools/moderation.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/9/26 14:27 +@Author : zhanglei +@File : moderation.py +""" +from typing import Union + +from metagpt.llm import LLM + + +class Moderation: + def __init__(self): + self.llm = LLM() + + def moderation(self, content: Union[str, list[str]]): + resp = [] + if content: + moderation_results = self.llm.moderation(content=content) + results = moderation_results.results + for item in results: + resp.append(item.flagged) + + return resp + + async def amoderation(self, content: Union[str, list[str]]): + resp = [] + if content: + moderation_results = await self.llm.amoderation(content=content) + results = moderation_results.results + for item in results: + resp.append(item.flagged) + + return resp + + +if __name__ == "__main__": + moderation = Moderation() + print(moderation.moderation(content=["I will kill you", "The weather is really nice today", "I want to hit you"])) diff --git a/tests/metagpt/tools/test_moderation.py b/tests/metagpt/tools/test_moderation.py new file mode 100644 index 000000000..225acff75 --- /dev/null +++ b/tests/metagpt/tools/test_moderation.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/9/26 14:46 +@Author : zhanglei +@File : test_translate.py +""" + +import pytest + +from metagpt.tools.moderation import Moderation + + +@pytest.mark.parametrize( + ("content",), + [ + [ + ["I will kill you", "The weather is really nice today", "I want to hit you"], + ] + ], +) +def test_moderation(content): + moderation = Moderation() + results = moderation.moderation(content=content) + assert isinstance(results, list) + assert len(results) == len(content) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("content",), + [ + [ + ["I will kill you", "The weather is really nice today", "I want to hit you"], + ] + ], +) +async def test_amoderation(content): + moderation = Moderation() + results = await moderation.amoderation(content=content) + assert isinstance(results, list) + assert len(results) == len(content)