Added Float type to the function parameter values (#77)

This commit is contained in:
Sampreeth Sarma 2024-09-25 13:29:20 -07:00 committed by GitHub
parent 7505a0fc1f
commit 7f0fcb372b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 1505 additions and 45 deletions

View file

@ -0,0 +1,29 @@
from typing import List, Optional
# Function for top_employees
def top_employees(grouping: str, ranking_criteria: str, top_n: int):
pass
# Function for aggregate_stats
def aggregate_stats(grouping: str, aggregate_criteria: str, aggregate_type: str):
pass
# Function for employees_projects
def employees_projects(min_performance_score: float, min_years_experience: int, department: str, min_project_count: int = None, months_range: int = None):
pass
# Function for salary_growth
def salary_growth(min_salary_increase_percentage: float, department: str = None):
pass
# Function for promotions_increases
def promotions_increases(year: int, min_salary_increase_percentage: float = None, department: str = None):
pass
# Function for avg_project_performance
def avg_project_performance(min_project_count: int, min_performance_score: float, department: str = None):
pass
# Function for certifications_experience
def certifications_experience(certifications: list, min_years_experience: int, department: str = None):
pass

View file

@ -0,0 +1,78 @@
import inspect
import yaml
import functions # This is your module containing the function definitions
import os
def generate_config_from_function(func):
func_name = func.__name__
func_doc = func.__doc__
# Get function signature
sig = inspect.signature(func)
params = []
# Extract parameter info
for name, param in sig.parameters.items():
param_info = {
'name': name,
'description': f"Provide the {name.replace('_', ' ')}", # Customize as needed
'required': param.default == inspect.Parameter.empty, # True if no default value
'type': param.annotation.__name__ if param.annotation != inspect.Parameter.empty else 'str' # Get type if available
}
params.append(param_info)
# Define the config for this function
config = {
'name': func_name,
'description': func_doc or "",
'parameters': params,
'endpoint': {
'cluster': 'api_server',
'path': f"/{func_name}"
},
'system_prompt': f"You are responsible for handling {func_name} requests."
}
return config
def generate_full_config(module):
config = {'prompt_targets': []}
# Automatically get all functions from the module
functions_list = inspect.getmembers(module, inspect.isfunction)
for func_name, func_obj in functions_list:
func_config = generate_config_from_function(func_obj)
config['prompt_targets'].append(func_config)
return config
def replace_prompt_targets_in_config(file_path, new_prompt_targets):
# Load the existing bolt_config.yaml
with open(file_path, 'r') as file:
config_data = yaml.safe_load(file)
# Replace the prompt_targets section with the new one
config_data['prompt_targets'] = new_prompt_targets
# Write the updated config back to the YAML file
with open("bolt_config.yaml", 'w+') as file:
yaml.dump(config_data, file, sort_keys=False)
print(f"Updated prompt_targets in bolt_config.yaml")
# Main execution
if __name__ == "__main__":
# Path to the existing bolt_config.yaml two directories up
bolt_config_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../bolt_config.yaml'))
# Generate new prompt_targets from the functions module
new_config = generate_full_config(functions)
new_prompt_targets = new_config['prompt_targets']
# Replace the prompt_targets in the existing bolt_config.yaml
replace_prompt_targets_in_config(bolt_config_path, new_prompt_targets)

View file

@ -1,7 +1,6 @@
import random
from typing import List
from fastapi import FastAPI, HTTPException, Response
from datetime import datetime, date, timedelta, timezone
import logging
from pydantic import BaseModel
from utils import load_sql
@ -118,7 +117,7 @@ class TopEmployeesProjects(BaseModel):
months_range: int = None # Optional (for filtering recent projects)
@app.post("/top_employees_projects")
@app.post("/employees_projects")
async def employees_projects(req: TopEmployeesProjects, res: Response):
params, filters = {}, []
@ -225,8 +224,8 @@ class AvgProjPerformanceRequest(BaseModel):
department: str = None # Optional
@app.post("/avg_project_performance")
async def avg_project_performance(req: AvgProjPerformanceRequest, res: Response):
@app.post("/project_performance")
async def project_performance(req: AvgProjPerformanceRequest, res: Response):
params, filters = {}, []
if req.department:
@ -257,7 +256,7 @@ class CertificationsExperienceRequest(BaseModel):
min_years_experience: int
department: str = None # Optional
@app.post("/employees_certifications_experience")
@app.post("/certifications_experience")
async def certifications_experience(req: CertificationsExperienceRequest, res: Response):
# Convert the list of certifications into a format for SQL query
certs_filter = ', '.join([f"'{cert}'" for cert in req.certifications])