diff --git a/apps/tools_webhook/__init__.py b/apps/tools_webhook/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/apps/tools_webhook/app.py b/apps/tools_webhook/app.py new file mode 100644 index 00000000..1d172d6f --- /dev/null +++ b/apps/tools_webhook/app.py @@ -0,0 +1,127 @@ +# app.py + +import hashlib +import json +import logging +import os +from functools import wraps + +import jwt +from flask import Flask, jsonify, request +from jwt import InvalidTokenError + +from tools_webhook.function_map import FUNCTIONS_MAP +from tools_webhook.tool_caller import call_tool + +app = Flask(__name__) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +def require_signed_request(f): + """ + If SIGNING_SECRET is set, verifies the request content's SHA256 hash + matches 'bodyHash' in the 'X-Signature-Jwt' header using HS256. + If no SIGNING_SECRET is configured, skip the validation entirely. + """ + @wraps(f) + def decorated(*args, **kwargs): + signing_secret = os.environ.get("SIGNING_SECRET", "").strip() + + # 1) If no signing secret is set, skip validation + if not signing_secret: + return f(*args, **kwargs) + + # 2) Attempt to retrieve the JWT from the header + signature_jwt = request.headers.get("X-Signature-Jwt") + if not signature_jwt: + logger.error("Missing X-Signature-Jwt header") + return jsonify({"error": "Missing X-Signature-Jwt header"}), 401 + + # 3) Decode/verify the token with PyJWT, ignoring audience/issuer + try: + decoded = jwt.decode( + signature_jwt, + signing_secret, + algorithms=["HS256"], + options={ + "require": ["bodyHash"], # must have bodyHash + "verify_aud": False, # disable audience check + "verify_iss": False, # disable issuer check + } + ) + except InvalidTokenError as e: + logger.error("Invalid token: %s", e) + return jsonify({"error": f"Invalid token: {str(e)}"}), 401 + + # 4) Compare bodyHash to SHA256(content) + request_data = request.get_json() or {} + content_str = request_data.get("content", "") + actual_hash = hashlib.sha256(content_str.encode("utf-8")).hexdigest() + + if decoded["bodyHash"] != actual_hash: + logger.error("bodyHash mismatch") + return jsonify({"error": "bodyHash mismatch"}), 403 + + return f(*args, **kwargs) + return decorated + +@app.route("/tool_call", methods=["POST"]) +@require_signed_request +def tool_call(): + """ + 1) Parse the incoming JSON (including 'content' as a JSON string). + 2) Extract function name and arguments. + 3) Use call_tool(...) to invoke the function. + 4) Return JSON response with result or error. + """ + req_data = request.get_json() + if not req_data: + logger.warning("No JSON data provided in request body.") + return jsonify({"error": "No JSON data provided"}), 400 + + content_str = req_data.get("content") + if not content_str: + logger.warning("Missing 'content' in request data.") + return jsonify({"error": "Missing 'content' in request data"}), 400 + + # Parse the JSON string in "content" + try: + parsed_content = json.loads(content_str) + except json.JSONDecodeError as e: + logger.error("Unable to parse 'content' as JSON: %s", e) + return jsonify({"error": f"Unable to parse 'content' as JSON: {str(e)}"}), 400 + + # Extract function info + tool_call_data = parsed_content.get("toolCall", {}) + function_data = tool_call_data.get("function", {}) + + function_name = function_data.get("name") + arguments_str = function_data.get("arguments") + + if not function_name: + logger.warning("No function name provided.") + return jsonify({"error": "No function name provided"}), 400 + if not arguments_str: + logger.warning("No arguments string provided.") + return jsonify({"error": "No arguments string provided"}), 400 + + # Parse the arguments, which is also a JSON string + try: + parameters = json.loads(arguments_str) + except json.JSONDecodeError as e: + logger.error("Unable to parse 'arguments' as JSON: %s", e) + return jsonify({"error": f"Unable to parse 'arguments' as JSON: {str(e)}"}), 400 + + try: + result = call_tool(function_name, parameters, FUNCTIONS_MAP) + return jsonify({"result": result}), 200 + except ValueError as val_err: + logger.warning("ValueError in call_tool: %s", val_err) + return jsonify({"error": str(val_err)}), 400 + except Exception as e: + logger.exception("Unexpected error in /tool_call route") + return jsonify({"error": str(e)}), 500 + +if __name__ == "__main__": + app.run(debug=True) diff --git a/apps/tools_webhook/function_map.py b/apps/tools_webhook/function_map.py new file mode 100644 index 00000000..e69b8362 --- /dev/null +++ b/apps/tools_webhook/function_map.py @@ -0,0 +1,26 @@ + +""" +function_map.py + +Defines all the callable functions and a mapping from +string names to these functions. +""" + +def greet(name: str, message: str): + """Return a greeting string.""" + return f"{message}, {name}!" + +def add(a: int, b: int): + """Return the sum of two integers.""" + return a + b + +def get_account_balance(user_id: str): + """Return a mock account balance for the given user_id.""" + return f"User {user_id} has a balance of $123.45." + +# A configurable mapping from function identifiers to actual Python functions +FUNCTIONS_MAP = { + "greet": greet, + "add": add, + "get_account_balance": get_account_balance +} diff --git a/apps/tools_webhook/requirements.txt b/apps/tools_webhook/requirements.txt new file mode 100644 index 00000000..3819ffdc --- /dev/null +++ b/apps/tools_webhook/requirements.txt @@ -0,0 +1,12 @@ +blinker==1.9.0 +click==8.1.8 +Flask==3.1.0 +iniconfig==2.0.0 +itsdangerous==2.2.0 +Jinja2==3.1.5 +MarkupSafe==3.0.2 +packaging==24.2 +pluggy==1.5.0 +PyJWT==2.10.1 +pytest==8.3.4 +Werkzeug==3.1.3 diff --git a/apps/tools_webhook/tests/__init__.py b/apps/tools_webhook/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/apps/tools_webhook/tests/test_app.py b/apps/tools_webhook/tests/test_app.py new file mode 100644 index 00000000..e1458052 --- /dev/null +++ b/apps/tools_webhook/tests/test_app.py @@ -0,0 +1,95 @@ +# tests/test_app.py + +import json +import pytest +from tools_webhook.app import app # If "sidecar" is recognized as a package + +@pytest.fixture +def client(): + """ + A pytest fixture that provides a Flask test client. + The `app.test_client()` allows us to make requests to our Flask app + without running the server. + """ + with app.test_client() as client: + yield client + + +def test_tool_call_greet(client): + # This matches the structure of the request in our code: + # { + # "content": "...a JSON string..." + # } + + # The content we pass is another JSON, so we have to double-escape quotes. + request_data = { + "content": json.dumps({ + "toolCall": { + "function": { + "name": "greet", + "arguments": json.dumps({ + "name": "Alice", + "message": "Hello" + }) + } + } + }) + } + + response = client.post( + "/tool_call", + data=json.dumps(request_data), + content_type="application/json" + ) + + assert response.status_code == 200 + data = response.get_json() + assert data["result"] == "Hello, Alice!" + + +def test_tool_call_missing_params(client): + request_data = { + "content": json.dumps({ + "toolCall": { + "function": { + "name": "greet", + "arguments": json.dumps({ + "name": "Alice" + # Missing "message" + }) + } + } + }) + } + + response = client.post( + "/tool_call", + data=json.dumps(request_data), + content_type="application/json" + ) + assert response.status_code == 400 + data = response.get_json() + assert "Missing required parameter: message" in data["error"] + + +def test_tool_call_invalid_func(client): + request_data = { + "content": json.dumps({ + "toolCall": { + "function": { + "name": "does_not_exist", + "arguments": json.dumps({}) + } + } + }) + } + + response = client.post( + "/tool_call", + data=json.dumps(request_data), + content_type="application/json" + ) + assert response.status_code == 400 + data = response.get_json() + assert "Function 'does_not_exist' not found" in data["error"] + diff --git a/apps/tools_webhook/tests/test_tool_caller.py b/apps/tools_webhook/tests/test_tool_caller.py new file mode 100644 index 00000000..4394049c --- /dev/null +++ b/apps/tools_webhook/tests/test_tool_caller.py @@ -0,0 +1,40 @@ +# tests/test_tool_caller.py + +import pytest +from tools_webhook.tool_caller import call_tool +from tools_webhook.function_map import FUNCTIONS_MAP + +def test_call_tool_greet(): + # Normal case + result = call_tool("greet", {"name": "Alice", "message": "Hello"}, FUNCTIONS_MAP) + assert result == "Hello, Alice!" + +def test_call_tool_add(): + # Normal case + result = call_tool("add", {"a": 2, "b": 5}, FUNCTIONS_MAP) + assert result == 7 + +def test_call_tool_missing_func(): + # Should raise ValueError if function is not in FUNCTIONS_MAP + with pytest.raises(ValueError) as exc_info: + call_tool("non_existent_func", {}, FUNCTIONS_MAP) + assert "Function 'non_existent_func' not found" in str(exc_info.value) + +def test_call_tool_missing_param(): + # greet requires `name` and `message` + with pytest.raises(ValueError) as exc_info: + call_tool("greet", {"name": "Alice"}, FUNCTIONS_MAP) + assert "Missing required parameter: message" in str(exc_info.value) + +def test_call_tool_unexpected_param(): + # `greet` only expects name and message + with pytest.raises(ValueError) as exc_info: + call_tool("greet", {"name": "Alice", "message": "Hello", "extra": "???"}, + FUNCTIONS_MAP) + assert "Unexpected parameter: extra" in str(exc_info.value) + +def test_call_tool_type_conversion_error(): + # `add` expects integers `a` and `b`, so passing a string should fail + with pytest.raises(ValueError) as exc_info: + call_tool("add", {"a": "not_an_int", "b": 3}, FUNCTIONS_MAP) + assert "Parameter 'a' must be of type int" in str(exc_info.value) diff --git a/apps/tools_webhook/tool_caller.py b/apps/tools_webhook/tool_caller.py new file mode 100644 index 00000000..aacf29a0 --- /dev/null +++ b/apps/tools_webhook/tool_caller.py @@ -0,0 +1,69 @@ +# tool_caller.py + +import inspect +import logging + +logger = logging.getLogger(__name__) + +def call_tool(function_name: str, parameters: dict, functions_map: dict): + """ + 1) Lookup a function in functions_map by name. + 2) Validate parameters against the function signature. + 3) Call the function with converted parameters. + 4) Return the result or raise an Exception on error. + """ + + logger.debug("call_tool invoked with function_name=%s, parameters=%s", function_name, parameters) + + # 1) Check if function exists + if function_name not in functions_map: + error_msg = f"Function '{function_name}' not found." + logger.error(error_msg) + raise ValueError(error_msg) + + func = functions_map[function_name] + signature = inspect.signature(func) + + # 2) Identify required parameters + required_params = [ + pname for pname, p in signature.parameters.items() + if p.default == inspect.Parameter.empty + ] + + # Check required params + for rp in required_params: + if rp not in parameters: + error_msg = f"Missing required parameter: {rp}" + logger.error(error_msg) + raise ValueError(error_msg) + + # Check unexpected params + valid_param_names = signature.parameters.keys() + for p in parameters.keys(): + if p not in valid_param_names: + error_msg = f"Unexpected parameter: {p}" + logger.error(error_msg) + raise ValueError(error_msg) + + # 3) Convert types based on annotations (if any) + converted_params = {} + for param_name, param_value in parameters.items(): + param_obj = signature.parameters[param_name] + if param_obj.annotation != inspect.Parameter.empty: + try: + converted_params[param_name] = param_obj.annotation(param_value) + except (ValueError, TypeError) as e: + error_msg = f"Parameter '{param_name}' must be of type {param_obj.annotation.__name__}: {e}" + logger.error(error_msg) + raise ValueError(error_msg) + else: + converted_params[param_name] = param_value + + # 4) Invoke the function + try: + result = func(**converted_params) + logger.debug("Function '%s' returned: %s", function_name, result) + return result + except Exception as e: + logger.exception("Unexpected error calling '%s'", function_name) # logs stack trace + raise