Skip to content

Commit 4cdfbd2

Browse files
authored
Templated steps (#171)
* New template type * Adds template table * tmp commit * Basic functionality * Progress messages * Update template * Fix progres after update * added prevent_sql
1 parent 142612e commit 4cdfbd2

File tree

5 files changed

+370
-8
lines changed

5 files changed

+370
-8
lines changed

cognition_objects/step_templates.py

Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
from typing import List, Dict, Any, Iterable, Optional
2+
3+
from sqlalchemy.orm.attributes import flag_modified
4+
from ..enums import StrategyStepType
5+
from ..business_objects import general
6+
from ..session import session
7+
from ..models import StepTemplates
8+
from ..util import prevent_sql_injection, sql_alchemy_to_dict
9+
10+
11+
def get(organization_id: str, template_id: str) -> StepTemplates:
12+
return (
13+
session.query(StepTemplates)
14+
.filter(
15+
StepTemplates.organization_id == organization_id,
16+
StepTemplates.id == template_id,
17+
)
18+
.first()
19+
)
20+
21+
22+
def get_all_by_org_id(organization_id: str) -> List[Dict[str, Any]]:
23+
values = [
24+
sql_alchemy_to_dict(st)
25+
for st in (
26+
session.query(StepTemplates)
27+
.filter(
28+
StepTemplates.organization_id == organization_id,
29+
)
30+
.order_by(StepTemplates.created_at.asc())
31+
.all()
32+
)
33+
]
34+
organization_id = prevent_sql_injection(organization_id, isinstance(organization_id, str))
35+
query = f"""
36+
SELECT jsonb_object_agg(id,C)
37+
FROM (
38+
SELECT ss.config->>'templateId' id, COUNT(*)c
39+
FROM cognition.strategy_step ss
40+
INNER JOIN cognition.project p
41+
ON ss.project_id = p.id
42+
WHERE p.organization_id = '{organization_id}'
43+
AND ss.step_type = '{StrategyStepType.TEMPLATED.value}'
44+
GROUP BY 1
45+
)X
46+
"""
47+
template_counts = general.execute_first(query)
48+
template_counts = (
49+
template_counts[0] if template_counts and template_counts[0] else {}
50+
)
51+
52+
values = [
53+
{**s, "usage_count": template_counts.get(str(s["id"]), 0)} for s in values
54+
]
55+
56+
return values
57+
58+
59+
def get_all_by_user(organization_id: str, user_id: str) -> List[StepTemplates]:
60+
return (
61+
session.query(StepTemplates)
62+
.filter(
63+
StepTemplates.organization_id == organization_id,
64+
StepTemplates.created_by == user_id,
65+
)
66+
.order_by(StepTemplates.created_at.asc())
67+
.all()
68+
)
69+
70+
71+
# result structure:
72+
# {<project_id>: {
73+
# "<strategy_id>": {
74+
# "strategy_name": <name>,
75+
# "order": <order>,
76+
# "steps": [
77+
# {
78+
# "step_name": <name>,
79+
# "step_description": <description>,
80+
# "step_type": <type>,
81+
# "progress_text": <progress_text>,
82+
# "execute_if_source_code": <execute_if_source_code>,
83+
# "config": <config>,
84+
# "position": <position>
85+
# },
86+
# ...
87+
# ]
88+
# },
89+
# ...
90+
# }}
91+
def get_all_existing_steps_for_template_creation(org_id: str) -> Dict[str, Any]:
92+
org_id = prevent_sql_injection(org_id, isinstance(org_id, str))
93+
query = f"""
94+
SELECT
95+
jsonb_object_agg(proj.project_id::text, proj.proj_json) AS all_projects
96+
FROM (
97+
SELECT
98+
p.id AS project_id,
99+
jsonb_build_object(
100+
'project_name', p.name,
101+
'created_at', p.created_at,
102+
'strategies',
103+
jsonb_object_agg(
104+
s.id::text,
105+
jsonb_build_object(
106+
'strategy_name', s.name,
107+
'order', s."order",
108+
'steps',
109+
coalesce(
110+
(
111+
SELECT jsonb_agg(
112+
jsonb_build_object(
113+
'src_step_id', ss.id,
114+
'step_name', ss.name,
115+
'step_description', ss.description,
116+
'step_type', ss.step_type,
117+
'progress_text', ss.progress_text,
118+
'execute_if_source_code', ss.execute_if_source_code,
119+
'config', ss.config,
120+
'position', ss.position
121+
)
122+
ORDER BY ss.position
123+
)
124+
FROM cognition.strategy_step ss
125+
WHERE ss.project_id = p.id
126+
AND ss.strategy_id = s.id
127+
AND ss.step_type != '{StrategyStepType.TEMPLATED.value}'
128+
),
129+
'[]'
130+
)
131+
)
132+
)
133+
) AS proj_json
134+
FROM cognition.project p
135+
INNER JOIN cognition.strategy s
136+
ON s.project_id = p.id
137+
WHERE p.organization_id = '{org_id}'
138+
GROUP BY p.id, p.name
139+
) AS proj;
140+
"""
141+
result = general.execute_first(query)
142+
if result and result[0]:
143+
return result[0]
144+
return {}
145+
146+
147+
def get_step_template_progress_text_lookup_for_strategy(
148+
project_id: str, strategy_id: str, step_id: Optional[str] = None
149+
) -> Dict[str, str]:
150+
project_id = prevent_sql_injection(project_id, isinstance(project_id, str))
151+
strategy_id = prevent_sql_injection(strategy_id, isinstance(strategy_id, str))
152+
step_id_filter = ""
153+
if step_id:
154+
step_id = prevent_sql_injection(step_id, isinstance(step_id, str))
155+
step_id_filter = f"AND ss.id = '{step_id}'"
156+
query = f"""
157+
WITH base AS (
158+
SELECT
159+
st.config AS config,
160+
ss.config->'variableValues' AS variableValues,
161+
ss.id step_id
162+
FROM cognition.strategy_step ss
163+
INNER JOIN cognition.project p
164+
ON ss.project_id = p.id
165+
INNER JOIN cognition.step_templates st
166+
ON st.id = (ss.config->>'templateId')::UUID
167+
AND st.organization_id = p.organization_id
168+
WHERE
169+
ss.project_id = '{project_id}'
170+
AND ss.strategy_id = '{strategy_id}'
171+
AND ss.step_type = '{StrategyStepType.TEMPLATED.value}'
172+
{step_id_filter}
173+
)
174+
175+
176+
SELECT jsonb_object_agg(step_id,progress_lookup)
177+
FROM (
178+
SELECT
179+
step_id,
180+
array_agg(jsonb_build_object('template_key',dict_key,'progress_text',resolved_progress_text)) progress_lookup
181+
FROM (
182+
SELECT
183+
-- Extract one step-object at a time
184+
step_id,
185+
(step_item ->> 'stepName') AS step_name,
186+
step_item->>'stepType'|| '@' || (step_item->>'srcStepId')::TEXT AS dict_key,
187+
-- Raw progressText from the JSON
188+
(step_item ->> 'progressText') AS raw_progress_text,
189+
-- If raw_progress_text matches @@var_<id>@@, replace via variableValues
190+
CASE
191+
WHEN (step_item ->> 'progressText') ~ '^@@var_[0-9a-fA-F\-]{{36}}@@$'
192+
THEN
193+
-- Extract the UUID between "var_" and "@@"
194+
(
195+
SELECT variableValues ->> inner_uuid
196+
FROM (
197+
SELECT
198+
regexp_replace(step_item ->> 'progressText',
199+
'^@@var_([0-9a-fA-F\-]{{36}})@@$',
200+
'\\1') AS inner_uuid
201+
) AS sub
202+
)
203+
ELSE
204+
(step_item ->> 'progressText')
205+
END AS resolved_progress_text
206+
FROM base
207+
-- Unnest the steps array
208+
CROSS JOIN LATERAL json_array_elements(base.config->'steps') AS step_item
209+
)x
210+
GROUP BY 1
211+
)y
212+
"""
213+
result = general.execute_first(query)
214+
if result and result[0]:
215+
return result[0]
216+
return {}
217+
218+
219+
def create(
220+
org_id: str,
221+
user_id: str,
222+
name: str,
223+
description: str,
224+
config: Dict[str, Any],
225+
with_commit: bool = True,
226+
) -> StepTemplates:
227+
template: StepTemplates = StepTemplates(
228+
organization_id=org_id,
229+
name=name,
230+
description=description,
231+
created_by=user_id,
232+
config=config,
233+
)
234+
general.add(template, with_commit)
235+
236+
return template
237+
238+
239+
def update(
240+
org_id: str,
241+
template_id: str,
242+
name: Optional[str] = None,
243+
description: Optional[str] = None,
244+
config: Optional[Dict[str, Any]] = None,
245+
with_commit: bool = True,
246+
) -> StepTemplates:
247+
template = get(org_id, template_id)
248+
if not template:
249+
raise ValueError(
250+
f"Template with ID {template_id} not found in organization {org_id}."
251+
)
252+
253+
if name is not None:
254+
template.name = name
255+
if description is not None:
256+
template.description = description
257+
if config is not None:
258+
template.config = config
259+
flag_modified(template, "config")
260+
261+
general.flush_or_commit(with_commit)
262+
263+
return template
264+
265+
266+
def delete(org_id: str, template_id: str, with_commit: bool = True) -> None:
267+
session.query(StepTemplates).filter(
268+
StepTemplates.organization_id == org_id,
269+
StepTemplates.id == template_id,
270+
).delete()
271+
general.flush_or_commit(with_commit)
272+
273+
274+
def delete_many(
275+
org_id: str, template_ids: Iterable[str], with_commit: bool = True
276+
) -> None:
277+
session.query(StepTemplates).filter(
278+
StepTemplates.organization_id == org_id,
279+
StepTemplates.id.in_(template_ids),
280+
).delete()
281+
general.flush_or_commit(with_commit)

cognition_objects/strategy_step.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
from typing import List, Optional, Dict, Any
1+
from typing import List, Optional, Dict, Any, Iterable, Tuple
22
from datetime import datetime
33
from sqlalchemy.orm.attributes import flag_modified
4+
from sqlalchemy import tuple_
5+
46

57
from ..business_objects import general
68
from ..session import session
@@ -19,6 +21,23 @@ def get(project_id: str, strategy_step_id: str) -> CognitionStrategyStep:
1921
)
2022

2123

24+
def get_all_by_project_and_ids(
25+
project_step_tuple: Iterable[Tuple[str, str]],
26+
) -> List[CognitionStrategyStep]:
27+
if not project_step_tuple:
28+
return []
29+
30+
return (
31+
session.query(CognitionStrategyStep)
32+
.filter(
33+
tuple_(CognitionStrategyStep.project_id, CognitionStrategyStep.id).in_(
34+
project_step_tuple
35+
)
36+
)
37+
.all()
38+
)
39+
40+
2241
def get_all_by_strategy_id(
2342
project_id: str, strategy_id: str
2443
) -> List[CognitionStrategyStep]:

enums.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ class Tablenames(Enum):
155155
EVALUATION_RUN = "evaluation_run"
156156
PLAYGROUND_QUESTION = "playground_question"
157157
FULL_ADMIN_ACCESS = "full_admin_access"
158+
STEP_TEMPLATES = "step_templates" # templates for strategy steps
158159

159160
def snake_case_to_pascal_case(self):
160161
# the type name (written in PascalCase) of a table is needed to create backrefs
@@ -543,6 +544,7 @@ class StrategyStepType(Enum):
543544
NEURAL_SEARCH = "NEURAL_SEARCH"
544545
WEBHOOK = "WEBHOOK"
545546
GRAPHRAG_SEARCH = "GRAPHRAG_SEARCH"
547+
TEMPLATED = "TEMPLATED"
546548

547549
def get_description(self):
548550
return STEP_DESCRIPTIONS.get(self, "No description available")
@@ -570,6 +572,7 @@ def get_progress_text(self):
570572
StrategyStepType.CALL_OTHER_AGENT: "Retrieve results from other agents",
571573
StrategyStepType.WEBHOOK: "Webhook",
572574
StrategyStepType.GRAPHRAG_SEARCH: "Query GraphRAG index",
575+
StrategyStepType.TEMPLATED: "Templated step",
573576
}
574577

575578
STEP_WHEN_TO_USE = {
@@ -587,6 +590,7 @@ def get_progress_text(self):
587590
StrategyStepType.CALL_OTHER_AGENT: "When you want to call another agent",
588591
StrategyStepType.WEBHOOK: "When you want to run a webhook",
589592
StrategyStepType.GRAPHRAG_SEARCH: "When you want to query a knowledge graph",
593+
StrategyStepType.TEMPLATED: "When you want to reuse existing templates",
590594
}
591595

592596
STEP_PROGRESS_TEXTS = {
@@ -605,6 +609,7 @@ def get_progress_text(self):
605609
StrategyStepType.CALL_OTHER_AGENT: "Calling another agent",
606610
StrategyStepType.WEBHOOK: "Running webhook",
607611
StrategyStepType.GRAPHRAG_SEARCH: "Querying knowledge graph",
612+
StrategyStepType.TEMPLATED: "Running templated step",
608613
}
609614

610615
STEP_ERRORS = {

models.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1933,6 +1933,35 @@ class GraphRAGIndex(Base):
19331933
root_dir = Column(String)
19341934

19351935

1936+
class StepTemplates(Base):
1937+
__tablename__ = Tablenames.STEP_TEMPLATES.value
1938+
__table_args__ = {"schema": "cognition"}
1939+
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
1940+
organization_id = Column(
1941+
UUID(as_uuid=True),
1942+
ForeignKey(f"{Tablenames.ORGANIZATION.value}.id", ondelete="CASCADE"),
1943+
index=True,
1944+
)
1945+
name = Column(String)
1946+
description = Column(String)
1947+
created_at = Column(DateTime, default=sql.func.now())
1948+
created_by = Column(
1949+
UUID(as_uuid=True),
1950+
ForeignKey(f"{Tablenames.USER.value}.id", ondelete="SET NULL"),
1951+
index=True,
1952+
)
1953+
config = Column(JSON) # JSON schema for the step template
1954+
# config contains all step configurations in an array & variable fields to be changed on useage
1955+
# e.g.
1956+
# {
1957+
# "variables": [
1958+
# {"name": "Env var", "path": "[0].config.llmConfig.environmentVariable", "hasDefault": True, "defaultValue": "OpenAI Leo"},
1959+
# {"name": "System Prompt", "path": "[0].config.templatePrompt", "hasDefault": False},
1960+
# ],
1961+
# "steps": [{...},{...}]
1962+
# }
1963+
1964+
19361965
# =========================== Global tables ===========================
19371966
class GlobalWebsocketAccess(Base):
19381967
# table to store prepared websocket configuration.

0 commit comments

Comments
 (0)