diff --git a/cognition_objects/step_templates.py b/cognition_objects/step_templates.py new file mode 100644 index 00000000..73604217 --- /dev/null +++ b/cognition_objects/step_templates.py @@ -0,0 +1,281 @@ +from typing import List, Dict, Any, Iterable, Optional + +from sqlalchemy.orm.attributes import flag_modified +from ..enums import StrategyStepType +from ..business_objects import general +from ..session import session +from ..models import StepTemplates +from ..util import prevent_sql_injection, sql_alchemy_to_dict + + +def get(organization_id: str, template_id: str) -> StepTemplates: + return ( + session.query(StepTemplates) + .filter( + StepTemplates.organization_id == organization_id, + StepTemplates.id == template_id, + ) + .first() + ) + + +def get_all_by_org_id(organization_id: str) -> List[Dict[str, Any]]: + values = [ + sql_alchemy_to_dict(st) + for st in ( + session.query(StepTemplates) + .filter( + StepTemplates.organization_id == organization_id, + ) + .order_by(StepTemplates.created_at.asc()) + .all() + ) + ] + organization_id = prevent_sql_injection(organization_id, isinstance(organization_id, str)) + query = f""" + SELECT jsonb_object_agg(id,C) + FROM ( + SELECT ss.config->>'templateId' id, COUNT(*)c + FROM cognition.strategy_step ss + INNER JOIN cognition.project p + ON ss.project_id = p.id + WHERE p.organization_id = '{organization_id}' + AND ss.step_type = '{StrategyStepType.TEMPLATED.value}' + GROUP BY 1 + )X + """ + template_counts = general.execute_first(query) + template_counts = ( + template_counts[0] if template_counts and template_counts[0] else {} + ) + + values = [ + {**s, "usage_count": template_counts.get(str(s["id"]), 0)} for s in values + ] + + return values + + +def get_all_by_user(organization_id: str, user_id: str) -> List[StepTemplates]: + return ( + session.query(StepTemplates) + .filter( + StepTemplates.organization_id == organization_id, + StepTemplates.created_by == user_id, + ) + .order_by(StepTemplates.created_at.asc()) + .all() + ) + + +# result structure: +# {: { +# "": { +# "strategy_name": , +# "order": , +# "steps": [ +# { +# "step_name": , +# "step_description": , +# "step_type": , +# "progress_text": , +# "execute_if_source_code": , +# "config": , +# "position": +# }, +# ... +# ] +# }, +# ... +# }} +def get_all_existing_steps_for_template_creation(org_id: str) -> Dict[str, Any]: + org_id = prevent_sql_injection(org_id, isinstance(org_id, str)) + query = f""" + SELECT + jsonb_object_agg(proj.project_id::text, proj.proj_json) AS all_projects + FROM ( + SELECT + p.id AS project_id, + jsonb_build_object( + 'project_name', p.name, + 'created_at', p.created_at, + 'strategies', + jsonb_object_agg( + s.id::text, + jsonb_build_object( + 'strategy_name', s.name, + 'order', s."order", + 'steps', + coalesce( + ( + SELECT jsonb_agg( + jsonb_build_object( + 'src_step_id', ss.id, + 'step_name', ss.name, + 'step_description', ss.description, + 'step_type', ss.step_type, + 'progress_text', ss.progress_text, + 'execute_if_source_code', ss.execute_if_source_code, + 'config', ss.config, + 'position', ss.position + ) + ORDER BY ss.position + ) + FROM cognition.strategy_step ss + WHERE ss.project_id = p.id + AND ss.strategy_id = s.id + AND ss.step_type != '{StrategyStepType.TEMPLATED.value}' + ), + '[]' + ) + ) + ) + ) AS proj_json + FROM cognition.project p + INNER JOIN cognition.strategy s + ON s.project_id = p.id + WHERE p.organization_id = '{org_id}' + GROUP BY p.id, p.name + ) AS proj; + """ + result = general.execute_first(query) + if result and result[0]: + return result[0] + return {} + + +def get_step_template_progress_text_lookup_for_strategy( + project_id: str, strategy_id: str, step_id: Optional[str] = None +) -> Dict[str, str]: + project_id = prevent_sql_injection(project_id, isinstance(project_id, str)) + strategy_id = prevent_sql_injection(strategy_id, isinstance(strategy_id, str)) + step_id_filter = "" + if step_id: + step_id = prevent_sql_injection(step_id, isinstance(step_id, str)) + step_id_filter = f"AND ss.id = '{step_id}'" + query = f""" + WITH base AS ( + SELECT + st.config AS config, + ss.config->'variableValues' AS variableValues, + ss.id step_id + FROM cognition.strategy_step ss + INNER JOIN cognition.project p + ON ss.project_id = p.id + INNER JOIN cognition.step_templates st + ON st.id = (ss.config->>'templateId')::UUID + AND st.organization_id = p.organization_id + WHERE + ss.project_id = '{project_id}' + AND ss.strategy_id = '{strategy_id}' + AND ss.step_type = '{StrategyStepType.TEMPLATED.value}' + {step_id_filter} + ) + + + SELECT jsonb_object_agg(step_id,progress_lookup) + FROM ( + SELECT + step_id, + array_agg(jsonb_build_object('template_key',dict_key,'progress_text',resolved_progress_text)) progress_lookup + FROM ( + SELECT + -- Extract one step-object at a time + step_id, + (step_item ->> 'stepName') AS step_name, + step_item->>'stepType'|| '@' || (step_item->>'srcStepId')::TEXT AS dict_key, + -- Raw progressText from the JSON + (step_item ->> 'progressText') AS raw_progress_text, + -- If raw_progress_text matches @@var_@@, replace via variableValues + CASE + WHEN (step_item ->> 'progressText') ~ '^@@var_[0-9a-fA-F\-]{{36}}@@$' + THEN + -- Extract the UUID between "var_" and "@@" + ( + SELECT variableValues ->> inner_uuid + FROM ( + SELECT + regexp_replace(step_item ->> 'progressText', + '^@@var_([0-9a-fA-F\-]{{36}})@@$', + '\\1') AS inner_uuid + ) AS sub + ) + ELSE + (step_item ->> 'progressText') + END AS resolved_progress_text + FROM base + -- Unnest the steps array + CROSS JOIN LATERAL json_array_elements(base.config->'steps') AS step_item + )x + GROUP BY 1 + )y + """ + result = general.execute_first(query) + if result and result[0]: + return result[0] + return {} + + +def create( + org_id: str, + user_id: str, + name: str, + description: str, + config: Dict[str, Any], + with_commit: bool = True, +) -> StepTemplates: + template: StepTemplates = StepTemplates( + organization_id=org_id, + name=name, + description=description, + created_by=user_id, + config=config, + ) + general.add(template, with_commit) + + return template + + +def update( + org_id: str, + template_id: str, + name: Optional[str] = None, + description: Optional[str] = None, + config: Optional[Dict[str, Any]] = None, + with_commit: bool = True, +) -> StepTemplates: + template = get(org_id, template_id) + if not template: + raise ValueError( + f"Template with ID {template_id} not found in organization {org_id}." + ) + + if name is not None: + template.name = name + if description is not None: + template.description = description + if config is not None: + template.config = config + flag_modified(template, "config") + + general.flush_or_commit(with_commit) + + return template + + +def delete(org_id: str, template_id: str, with_commit: bool = True) -> None: + session.query(StepTemplates).filter( + StepTemplates.organization_id == org_id, + StepTemplates.id == template_id, + ).delete() + general.flush_or_commit(with_commit) + + +def delete_many( + org_id: str, template_ids: Iterable[str], with_commit: bool = True +) -> None: + session.query(StepTemplates).filter( + StepTemplates.organization_id == org_id, + StepTemplates.id.in_(template_ids), + ).delete() + general.flush_or_commit(with_commit) diff --git a/cognition_objects/strategy_step.py b/cognition_objects/strategy_step.py index 8aa976a2..bf3b2d92 100644 --- a/cognition_objects/strategy_step.py +++ b/cognition_objects/strategy_step.py @@ -1,6 +1,8 @@ -from typing import List, Optional, Dict, Any +from typing import List, Optional, Dict, Any, Iterable, Tuple from datetime import datetime from sqlalchemy.orm.attributes import flag_modified +from sqlalchemy import tuple_ + from ..business_objects import general from ..session import session @@ -19,6 +21,23 @@ def get(project_id: str, strategy_step_id: str) -> CognitionStrategyStep: ) +def get_all_by_project_and_ids( + project_step_tuple: Iterable[Tuple[str, str]], +) -> List[CognitionStrategyStep]: + if not project_step_tuple: + return [] + + return ( + session.query(CognitionStrategyStep) + .filter( + tuple_(CognitionStrategyStep.project_id, CognitionStrategyStep.id).in_( + project_step_tuple + ) + ) + .all() + ) + + def get_all_by_strategy_id( project_id: str, strategy_id: str ) -> List[CognitionStrategyStep]: diff --git a/enums.py b/enums.py index fc90e646..49a740de 100644 --- a/enums.py +++ b/enums.py @@ -155,6 +155,7 @@ class Tablenames(Enum): EVALUATION_RUN = "evaluation_run" PLAYGROUND_QUESTION = "playground_question" FULL_ADMIN_ACCESS = "full_admin_access" + STEP_TEMPLATES = "step_templates" # templates for strategy steps def snake_case_to_pascal_case(self): # the type name (written in PascalCase) of a table is needed to create backrefs @@ -543,6 +544,7 @@ class StrategyStepType(Enum): NEURAL_SEARCH = "NEURAL_SEARCH" WEBHOOK = "WEBHOOK" GRAPHRAG_SEARCH = "GRAPHRAG_SEARCH" + TEMPLATED = "TEMPLATED" def get_description(self): return STEP_DESCRIPTIONS.get(self, "No description available") @@ -570,6 +572,7 @@ def get_progress_text(self): StrategyStepType.CALL_OTHER_AGENT: "Retrieve results from other agents", StrategyStepType.WEBHOOK: "Webhook", StrategyStepType.GRAPHRAG_SEARCH: "Query GraphRAG index", + StrategyStepType.TEMPLATED: "Templated step", } STEP_WHEN_TO_USE = { @@ -587,6 +590,7 @@ def get_progress_text(self): StrategyStepType.CALL_OTHER_AGENT: "When you want to call another agent", StrategyStepType.WEBHOOK: "When you want to run a webhook", StrategyStepType.GRAPHRAG_SEARCH: "When you want to query a knowledge graph", + StrategyStepType.TEMPLATED: "When you want to reuse existing templates", } STEP_PROGRESS_TEXTS = { @@ -605,6 +609,7 @@ def get_progress_text(self): StrategyStepType.CALL_OTHER_AGENT: "Calling another agent", StrategyStepType.WEBHOOK: "Running webhook", StrategyStepType.GRAPHRAG_SEARCH: "Querying knowledge graph", + StrategyStepType.TEMPLATED: "Running templated step", } STEP_ERRORS = { diff --git a/models.py b/models.py index ab2ef0e9..3aea5150 100644 --- a/models.py +++ b/models.py @@ -1933,6 +1933,35 @@ class GraphRAGIndex(Base): root_dir = Column(String) +class StepTemplates(Base): + __tablename__ = Tablenames.STEP_TEMPLATES.value + __table_args__ = {"schema": "cognition"} + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + organization_id = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.ORGANIZATION.value}.id", ondelete="CASCADE"), + index=True, + ) + name = Column(String) + description = Column(String) + created_at = Column(DateTime, default=sql.func.now()) + created_by = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.USER.value}.id", ondelete="SET NULL"), + index=True, + ) + config = Column(JSON) # JSON schema for the step template + # config contains all step configurations in an array & variable fields to be changed on useage + # e.g. + # { + # "variables": [ + # {"name": "Env var", "path": "[0].config.llmConfig.environmentVariable", "hasDefault": True, "defaultValue": "OpenAI Leo"}, + # {"name": "System Prompt", "path": "[0].config.templatePrompt", "hasDefault": False}, + # ], + # "steps": [{...},{...}] + # } + + # =========================== Global tables =========================== class GlobalWebsocketAccess(Base): # table to store prepared websocket configuration. diff --git a/util.py b/util.py index 9ab06ca5..dca8c616 100644 --- a/util.py +++ b/util.py @@ -2,7 +2,7 @@ from typing import Tuple, Any, Union, List, Dict, Optional, Iterable from pydantic import BaseModel from collections.abc import Iterable as collections_abc_Iterable -from re import sub, match, compile +from re import sub, match, compile, IGNORECASE import sqlalchemy import decimal from uuid import UUID @@ -15,6 +15,10 @@ from .business_objects import general CAMEL_CASE_PATTERN = compile(r"^([a-z]+[A-Z]?)*$") +UUID_REGEX_PATTERN = compile( + r"^[0-9a-f]{8}-[0-9a-f]{4}-[1-5][0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$", + IGNORECASE, +) STRING_TRUE_VALUES = {"true", "x", "1", "y"} @@ -105,12 +109,13 @@ def sql_alchemy_to_dict( column_whitelist: Optional[Iterable[str]] = None, column_blacklist: Optional[Iterable[str]] = None, column_rename_map: Optional[Dict[str, str]] = None, + dont_wrap_uuids: bool = True, ): result = __sql_alchemy_to_dict( sql_alchemy_object, column_whitelist, column_blacklist, column_rename_map ) if for_frontend: - return to_frontend_obj(result) + return to_frontend_obj(result, dont_wrap_uuids=dont_wrap_uuids) return result @@ -171,18 +176,29 @@ def rename_columns(data: Any) -> Any: return sql_alchemy_object -def to_frontend_obj(value: Union[List, Dict], blacklist_keys: List[str] = []): +def to_frontend_obj( + value: Union[List, Dict], + blacklist_keys: List[str] = [], + dont_wrap_uuids: bool = True, +): if isinstance(value, dict): return { - to_camel_case(k): ( - to_frontend_obj(v, blacklist_keys=blacklist_keys) + to_camel_case(k, dont_wrap_uuids=dont_wrap_uuids): ( + to_frontend_obj( + v, blacklist_keys=blacklist_keys, dont_wrap_uuids=dont_wrap_uuids + ) if k not in blacklist_keys else v ) for k, v in value.items() } elif is_list_like(value): - return [to_frontend_obj(x, blacklist_keys=blacklist_keys) for x in value] + return [ + to_frontend_obj( + x, blacklist_keys=blacklist_keys, dont_wrap_uuids=dont_wrap_uuids + ) + for x in value + ] else: return to_json_serializable(value) @@ -209,9 +225,11 @@ def to_json_serializable(x: Any): return x -def to_camel_case(name: str): +def to_camel_case(name: str, dont_wrap_uuids: bool = True): if is_camel_case(name): return name + if dont_wrap_uuids and is_uuid(name): + return name name = sub(r"(_|-)+", " ", name).title().replace(" ", "") return "".join([name[0].lower(), name[1:]]) @@ -233,6 +251,16 @@ def is_camel_case(text: str) -> bool: return False +def is_uuid(value: Union[str, UUID]) -> bool: + if isinstance(value, UUID): + return True + elif isinstance(value, str): + if UUID_REGEX_PATTERN.fullmatch(value): + return True + return False + return False + + # str is expected but depending on the attack vector e.g. the type hints don't mean anything so an int could still receive a string # the idea is that every directly inserted variable (e.g. project_id) is run through this function before being used in a plain text query # orm model is sufficient for most cases but for raw queries we mask all directly included variables