diff --git a/business_objects/embedding.py b/business_objects/embedding.py index 1e02d222..480b5006 100644 --- a/business_objects/embedding.py +++ b/business_objects/embedding.py @@ -321,8 +321,11 @@ def __build_payload_selector( if ( data_type != enums.DataTypes.TEXT.value and data_type != enums.DataTypes.LLM_RESPONSE.value + and data_type != enums.DataTypes.PERMISSION.value ): payload_selector += f"'{attr}', (r.\"data\"->>'{attr}')::{data_type}" + if data_type == enums.DataTypes.PERMISSION.value: + payload_selector += f"'{attr}', r.\"data\"->'{attr}'" else: payload_selector += f"'{attr}', r.\"data\"->>'{attr}'" payload_selector = f"json_build_object({payload_selector}) payload" @@ -391,7 +394,8 @@ def get_tensors_and_attributes_for_qdrant( WHERE et.project_id = '{project_id}' AND et.embedding_id = '{embedding_id}' """ if record_ids: - query += f" AND r.id IN ('{','.join(record_ids)}')" + _record_ids = "','".join(record_ids) + query += f" AND r.id IN ('{_record_ids}')" return general.execute_all(query) diff --git a/business_objects/monitor.py b/business_objects/monitor.py index 6baaf5f4..540741a8 100644 --- a/business_objects/monitor.py +++ b/business_objects/monitor.py @@ -1,4 +1,5 @@ from typing import Any, List, Optional +import datetime from . import general from .. import enums from ..models import TaskQueue, Organization @@ -9,6 +10,7 @@ markdown_file as markdown_file_db_bo, file_extraction as file_extraction_db_bo, file_transformation as file_transformation_db_bo, + integration as integration_db_bo, ) FILE_CACHING_IN_PROGRESS_STATES = [ @@ -197,6 +199,27 @@ def set_parse_cognition_file_task_to_failed( general.commit() +def set_integration_task_to_failed( + integration_id: str, + is_synced: bool = False, + error_message: Optional[str] = None, + state: Optional[ + enums.CognitionMarkdownFileState + ] = enums.CognitionMarkdownFileState.FAILED, + with_commit: bool = True, +) -> None: + # argument `state` is a workaround for cognition-gateway/api/routes/integrations.delete_many + integration_db_bo.update( + id=integration_id, + state=state, + finished_at=datetime.datetime.now(datetime.timezone.utc), + is_synced=is_synced, + error_message=error_message, + last_synced_at=datetime.datetime.now(datetime.timezone.utc), + with_commit=with_commit, + ) + + def __select_running_information_source_payloads( project_id: Optional[str] = None, only_running: bool = False, diff --git a/business_objects/project.py b/business_objects/project.py index d3d192b0..26cc8d7a 100644 --- a/business_objects/project.py +++ b/business_objects/project.py @@ -1,16 +1,14 @@ from typing import List, Optional, Any, Dict, Union, Set -from sqlalchemy.sql import func -from sqlalchemy import cast, Integer +from sqlalchemy.sql import func, cast from sqlalchemy.sql.functions import coalesce - - +from sqlalchemy import Integer from . import general, attribute - from .. import enums from ..session import session -from ..models import ( - Project, - Record, +from ..models import Project, Record, Attribute +from ..integration_objects.helper import ( + REFINERY_ATTRIBUTE_ACCESS_GROUPS, + REFINERY_ATTRIBUTE_ACCESS_USERS, ) from ..util import prevent_sql_injection @@ -118,12 +116,18 @@ def __build_sql_data_slices_by_project(project_id: str) -> str: project.id = '{project_id}'::UUID; """ -def get_dropdown_list_project_list(org_id: str) -> List[Dict[str, str]]: +def get_dropdown_list_project_list( + org_id: str, project_id: Optional[str] = None +) -> List[Dict[str, str]]: org_id = prevent_sql_injection(org_id, isinstance(org_id, str)) + prj_filter = "" + if project_id: + project_id = prevent_sql_injection(project_id, isinstance(project_id, str)) + prj_filter = f"AND p.id = '{project_id}'" query = f""" SELECT array_agg(jsonb_build_object('value', p.id,'name',p.NAME)) FROM public.project p - WHERE p.organization_id = '{org_id}' AND p.status != '{enums.ProjectStatus.HIDDEN.value}' + WHERE p.organization_id = '{org_id}' AND p.status != '{enums.ProjectStatus.HIDDEN.value}' {prj_filter} """ values = general.execute_first(query) @@ -155,6 +159,56 @@ def get_all(organization_id: str) -> List[Project]: ) +def get_all_with_access_management(org_id: str) -> List[Dict[str, Any]]: + org_id_safe = prevent_sql_injection(org_id, isinstance(org_id, str)) + + hidden_status = enums.ProjectStatus.HIDDEN.value + permission_data_type = enums.DataTypes.PERMISSION.value + automatically_created_state = enums.AttributeState.AUTOMATICALLY_CREATED.value + access_groups_attr = REFINERY_ATTRIBUTE_ACCESS_GROUPS + access_users_attr = REFINERY_ATTRIBUTE_ACCESS_USERS + + query = f""" + SELECT DISTINCT + p.*, + COALESCE((ci.config -> 'extract_kwargs' ->> 'sync_sharepoint_permissions')::BOOLEAN,FALSE) AS is_sharepoint_sync_active + FROM + public.project p + JOIN + public.attribute a ON p.id = a.project_id + LEFT JOIN + cognition.integration ci ON p.id = ci.project_id + WHERE + p.organization_id = '{org_id_safe}' + AND p.status != '{hidden_status}' + AND a.name IN ('{access_groups_attr}', '{access_users_attr}') + AND a.user_created = FALSE + AND a.data_type = '{permission_data_type}' + AND a.state = '{automatically_created_state}'; + """ + + values = general.execute_all(query) + return values + + +def check_access_management_active(project_id: str) -> bool: + return ( + session.query(Project) + .join(Attribute, Project.id == Attribute.project_id) + .filter( + Project.id == project_id, + Attribute.name.in_( + [REFINERY_ATTRIBUTE_ACCESS_GROUPS, REFINERY_ATTRIBUTE_ACCESS_USERS] + ), + Attribute.user_created == False, + Attribute.data_type == enums.DataTypes.PERMISSION.value, + Attribute.state == enums.AttributeState.AUTOMATICALLY_CREATED.value, + ) + .count() + > 0 + ) + + def get_all_by_user_organization_id(organization_id: str) -> List[Project]: projects = ( session.query(Project).filter(Project.organization_id == organization_id).all() diff --git a/business_objects/record.py b/business_objects/record.py index 7d51a5d5..36572292 100644 --- a/business_objects/record.py +++ b/business_objects/record.py @@ -1,6 +1,6 @@ from __future__ import with_statement from typing import List, Dict, Any, Optional, Tuple, Iterable -from sqlalchemy import cast, Text +from sqlalchemy import cast, Text, String from sqlalchemy.orm.attributes import flag_modified from sqlalchemy.sql.expression import bindparam from sqlalchemy import update @@ -15,6 +15,10 @@ Attribute, RecordTokenized, ) +from ..integration_objects.helper import ( + REFINERY_ATTRIBUTE_ACCESS_GROUPS, + REFINERY_ATTRIBUTE_ACCESS_USERS, +) from ..session import session from ..util import prevent_sql_injection @@ -609,7 +613,7 @@ def count_missing_tokenized_records(project_id: str) -> int: query = f""" SELECT COUNT(*) FROM ( - {get_records_without_tokenization(project_id, None, query_only = True)} + {get_records_without_tokenization(project_id, None, query_only=True)} ) record_query """ return general.execute_first(query)[0] @@ -807,6 +811,29 @@ def delete_user_created_attribute( general.flush_or_commit(with_commit) +def delete_access_management_attributes( + project_id: str, with_commit: bool = True +) -> None: + access_groups_attribute_item = attribute.get_by_name( + project_id, REFINERY_ATTRIBUTE_ACCESS_GROUPS + ) + access_users_attribute_item = attribute.get_by_name( + project_id, REFINERY_ATTRIBUTE_ACCESS_USERS + ) + + if access_users_attribute_item and access_groups_attribute_item: + record_items = get_all(project_id=project_id) + for i, record_item in enumerate(record_items): + if record_item.data.get(access_groups_attribute_item.name): + del record_item.data[access_groups_attribute_item.name] + if record_item.data.get(access_users_attribute_item.name): + del record_item.data[access_users_attribute_item.name] + flag_modified(record_item, "data") + if (i + 1) % 1000 == 0: + general.flush_or_commit(with_commit) + general.flush_or_commit(with_commit) + + def delete_duplicated_rats(with_commit: bool = False) -> None: # no project so run for all to prevent expensive join with record table query = """ @@ -925,3 +952,19 @@ def get_first_no_text_column(project_id: str, record_id: str) -> str: WHERE r.project_id = '{project_id}' AND r.id = '{record_id}' """ return general.execute_first(query)[0] + + +def get_record_ids_by_running_ids(project_id: str, running_ids: List[int]) -> List[str]: + return [ + row[0] + for row in ( + session.query(cast(Record.id, String)) + .filter( + Record.project_id == project_id, + Record.data[attribute.get_running_id_name(project_id)] + .as_integer() + .in_(running_ids), + ) + .all() + ) + ] diff --git a/cognition_objects/group.py b/cognition_objects/group.py new file mode 100644 index 00000000..75dac238 --- /dev/null +++ b/cognition_objects/group.py @@ -0,0 +1,124 @@ +from datetime import datetime +from typing import List, Optional +from ..business_objects import general +from ..session import session +from ..models import CognitionGroup + + +def get(group_id: str) -> CognitionGroup: + return session.query(CognitionGroup).filter(CognitionGroup.id == group_id).first() + + +def get_with_organization_id(organization_id: str, group_id: str) -> CognitionGroup: + return ( + session.query(CognitionGroup) + .filter( + CognitionGroup.organization_id == organization_id, + CognitionGroup.id == group_id, + ) + .first() + ) + + +def get_all(organization_id: str) -> List[CognitionGroup]: + return ( + session.query(CognitionGroup) + .filter(CognitionGroup.organization_id == organization_id) + .order_by(CognitionGroup.name.asc()) + .all() + ) + + +def get_all_by_integration_id( + organization_id: str, integration_id: str +) -> List[CognitionGroup]: + integration_id_json = CognitionGroup.meta_data.op("->>")("integration_id") + + return ( + session.query(CognitionGroup) + .filter( + CognitionGroup.organization_id == organization_id, + integration_id_json == integration_id, + ) + .order_by(CognitionGroup.name.asc()) + .all() + ) + + +def get_all_by_integration_id_permission_grouped( + organization_id: str, integration_id: str +) -> List[CognitionGroup]: + integration_id_json = CognitionGroup.meta_data.op("->>")("integration_id") + + integration_groups = ( + session.query(CognitionGroup) + .filter( + CognitionGroup.organization_id == organization_id, + integration_id_json == integration_id, + ) + .all() + ) + + return {group.meta_data.get("permission_id"): group for group in integration_groups} + + +def get_by_name_and_integration( + organization_id: str, integration_id: str, name: str +) -> CognitionGroup: + integration_id_json = CognitionGroup.meta_data.op("->>")("integration_id") + return ( + session.query(CognitionGroup) + .filter( + CognitionGroup.organization_id == organization_id, + CognitionGroup.name == name, + integration_id_json == integration_id, + ) + .first() + ) + + +def create_group( + organization_id: str, + name: str, + description: str, + created_by: str, + created_at: Optional[datetime] = None, + with_commit: bool = True, + meta_data: Optional[dict] = None, +) -> CognitionGroup: + group = CognitionGroup( + organization_id=organization_id, + name=name, + description=description, + created_by=created_by, + created_at=created_at, + meta_data=meta_data, + ) + general.add(group, with_commit) + return group + + +def update_group( + group_id: str, + name: Optional[str] = None, + description: Optional[str] = None, + with_commit: bool = True, + meta_data: Optional[dict] = None, +) -> CognitionGroup: + group = get(group_id) + + if name is not None: + group.name = name + if description is not None: + group.description = description + if meta_data is not None: + group.meta_data = meta_data + general.flush_or_commit(with_commit) + + return group + + +def delete(organization_id: str, group_id: str, with_commit: bool = True) -> None: + group = get_with_organization_id(organization_id, group_id) + if group: + general.delete(group, with_commit) diff --git a/cognition_objects/group_member.py b/cognition_objects/group_member.py new file mode 100644 index 00000000..4f5661b4 --- /dev/null +++ b/cognition_objects/group_member.py @@ -0,0 +1,103 @@ +from datetime import datetime +from typing import Optional, List +from ..business_objects import general, user +from . import group +from ..session import session +from ..models import CognitionGroupMember + + +def get(group_id: str, id: str) -> CognitionGroupMember: + return ( + session.query(CognitionGroupMember) + .filter( + CognitionGroupMember.group_id == group_id, CognitionGroupMember.id == id + ) + .first() + ) + + +def get_by_group_and_user(group_id: str, user_id: str) -> CognitionGroupMember: + return ( + session.query(CognitionGroupMember) + .filter( + CognitionGroupMember.group_id == group_id, + CognitionGroupMember.user_id == user_id, + ) + .first() + ) + + +def get_by_user_id(user_id: str) -> List[CognitionGroupMember]: + return ( + session.query(CognitionGroupMember) + .filter(CognitionGroupMember.user_id == user_id) + .all() + ) + + +def get_all_by_group(group_id: str) -> List[CognitionGroupMember]: + return ( + session.query(CognitionGroupMember) + .filter(CognitionGroupMember.group_id == group_id) + .all() + ) + + +def get_all_by_group_count(group_id: str) -> int: + return ( + session.query(CognitionGroupMember) + .filter(CognitionGroupMember.group_id == group_id) + .count() + ) + + +def create( + group_id: str, + user_id: str, + created_at: Optional[datetime] = None, + with_commit: bool = True, +) -> CognitionGroupMember: + already_exist = get_by_group_and_user(group_id=group_id, user_id=user_id) + if already_exist: + return already_exist + + group_item = group.get(group_id) + user_item = user.get(user_id) + if not group_item or not user_item: + raise Exception("Group or user not found") + if group_item.organization_id != user_item.organization_id: + raise Exception("User not in the same organization as the group") + + group_member = CognitionGroupMember( + group_id=group_id, + user_id=user_id, + created_at=created_at, + ) + general.add(group_member, with_commit) + return group_member + + +def delete_by_group_and_user_id( + group_id: str, user_id: str, with_commit: bool = True +) -> None: + group_member = get_by_group_and_user(group_id, user_id) + if group_member: + general.delete(group_member, with_commit) + + +def delete_by_user_id(user_id: str, with_commit: bool = True) -> None: + group_memberships = ( + session.query(CognitionGroupMember) + .filter(CognitionGroupMember.user_id == user_id) + .all() + ) + for membership in group_memberships: + general.delete(membership, with_commit=False) + general.flush_or_commit(with_commit) + + +def clear_by_group_id(group_id: str, with_commit: bool = True) -> None: + group_memberships = get_all_by_group(group_id) + for membership in group_memberships: + general.delete(membership, with_commit=False) + general.flush_or_commit(with_commit) diff --git a/cognition_objects/integration.py b/cognition_objects/integration.py new file mode 100644 index 00000000..7da1e0b6 --- /dev/null +++ b/cognition_objects/integration.py @@ -0,0 +1,255 @@ +from typing import List, Optional, Dict, Union +import datetime +from sqlalchemy import func +from sqlalchemy.orm.attributes import flag_modified + +from ..business_objects import general +from ..session import session +from ..models import CognitionIntegration, CognitionGroup +from ..enums import ( + CognitionMarkdownFileState, + CognitionIntegrationType, +) + +FINISHED_STATES = [ + CognitionMarkdownFileState.FINISHED.value, + CognitionMarkdownFileState.FAILED.value, +] + + +def get_by_ids(ids: List[str]) -> List[CognitionIntegration]: + return ( + session.query(CognitionIntegration) + .filter(CognitionIntegration.id.in_(ids)) + .all() + ) + + +def get_by_id(id: str) -> CognitionIntegration: + return ( + session.query(CognitionIntegration) + .filter(CognitionIntegration.id == id) + .first() + ) + + +def get_all( + integration_type: Optional[str] = None, + exclude_failed: bool = False, + only_synced: bool = False, +) -> List[CognitionIntegration]: + query = session.query(CognitionIntegration) + if integration_type: + query = query.filter(CognitionIntegration.type == integration_type) + if exclude_failed: + query = query.filter( + CognitionIntegration.state != CognitionMarkdownFileState.FAILED.value + ) + if only_synced: + query = query.filter(CognitionIntegration.is_synced == True) + return query.order_by(CognitionIntegration.created_at.desc()).all() + + +def get_all_in_org( + org_id: str, + integration_type: Optional[str] = None, + only_synced: bool = False, +) -> List[CognitionIntegration]: + query = session.query(CognitionIntegration).filter( + CognitionIntegration.organization_id == org_id + ) + if integration_type: + query = query.filter(CognitionIntegration.type == integration_type) + if only_synced: + query = query.filter(CognitionIntegration.is_synced == True) + return query.order_by(CognitionIntegration.created_at.desc()).all() + + +def get_all_in_org_paginated( + org_id: str, + integration_type: Optional[str] = None, + page: int = 1, + page_size: int = 10, +) -> List[CognitionIntegration]: + query = session.query(CognitionIntegration).filter( + CognitionIntegration.organization_id == org_id, + ) + + if integration_type: + query = query.filter(CognitionIntegration.type == integration_type) + + return ( + query.order_by(CognitionIntegration.created_at.desc()) + .limit(page_size) + .offset(max(0, (page - 1) * page_size)) + .all() + ) + + +def get_all_by_project_id(project_id: str) -> List[CognitionIntegration]: + return ( + session.query(CognitionIntegration) + .filter( + CognitionIntegration.project_id == project_id, + ) + .order_by(CognitionIntegration.created_at.desc()) + .all() + ) + + +def get_last_synced_at( + org_id: str, integration_type: Optional[str] = None +) -> datetime.datetime: + query = session.query(func.max(CognitionIntegration.last_synced_at)).filter( + CognitionIntegration.organization_id == org_id + ) + if integration_type: + query = query.filter(CognitionIntegration.type == integration_type) + result = query.first() + return result[0] if result else None + + +def count_org_integrations(org_id: str) -> Dict[str, int]: + counts = ( + session.query(CognitionIntegration.type, func.count(CognitionIntegration.id)) + .filter( + CognitionIntegration.organization_id == org_id, + ) + .group_by(CognitionIntegration.type) + .all() + ) + return {cognition_type: count for cognition_type, count in counts} + + +def create( + org_id: str, + user_id: str, + name: str, + description: str, + tokenizer: str, + state: str, + integration_type: CognitionIntegrationType, + integration_config: Dict, + llm_config: Dict, + started_at: Optional[datetime.datetime] = None, + created_at: Optional[datetime.datetime] = None, + finished_at: Optional[datetime.datetime] = None, + id: Optional[str] = None, + project_id: Optional[str] = None, + with_commit: bool = True, +) -> CognitionIntegration: + integration: CognitionIntegration = CognitionIntegration( + id=id, + organization_id=org_id, + project_id=project_id, + created_by=user_id, + updated_by=user_id, + created_at=created_at, + started_at=started_at, + finished_at=finished_at, + name=name, + description=description, + tokenizer=tokenizer, + state=state, + type=integration_type.value, + config=integration_config, + llm_config=llm_config, + delta_criteria={"delta_url": None}, + ) + general.add(integration, with_commit) + + return integration + + +def update( + id: str, + updated_by: Optional[str] = None, + name: Optional[str] = None, + description: Optional[str] = None, + tokenizer: Optional[str] = None, + state: Optional[CognitionMarkdownFileState] = None, + integration_config: Optional[int] = None, + llm_config: Optional[Dict] = None, + error_message: Optional[str] = None, + started_at: Optional[datetime.datetime] = None, + finished_at: Optional[Union[str, datetime.datetime]] = None, + last_synced_at: Optional[datetime.datetime] = None, + is_synced: Optional[Union[str, bool]] = None, + delta_criteria: Optional[Dict[str, str]] = None, + with_commit: bool = True, +) -> Optional[CognitionIntegration]: + integration: CognitionIntegration = get_by_id(id) + if not integration: + return None + + if updated_by is not None: + integration.updated_by = updated_by + if name is not None: + integration.name = name + if description is not None: + integration.description = description + if tokenizer is not None: + integration.tokenizer = tokenizer + if state is not None: + integration.state = state.value + if integration_config is not None: + integration.config = integration_config + flag_modified(integration, "config") + if llm_config is not None: + integration.llm_config = llm_config + flag_modified(integration, "llm_config") + if started_at is not None: + integration.started_at = started_at + if last_synced_at is not None: + integration.last_synced_at = last_synced_at + if delta_criteria is not None: + integration.delta_criteria = delta_criteria + flag_modified(integration, "delta_criteria") + if error_message is not None: + if error_message == "NULL": + integration.error_message = None + else: + integration.error_message = error_message + if is_synced is not None: + if is_synced == "NULL": + integration.is_synced = None + else: + integration.is_synced = is_synced + if finished_at is not None: + if finished_at == "NULL": + integration.finished_at = None + else: + integration.finished_at = finished_at + + general.add(integration, with_commit) + return integration + + +def execution_finished(id: str) -> bool: + if not get_by_id(id): + return True + return bool( + session.query(CognitionIntegration) + .filter( + CognitionIntegration.id == id, + CognitionIntegration.state.in_(FINISHED_STATES), + ) + .first() + ) + + +def delete_many( + ids: List[str], delete_cognition_groups: bool = True, with_commit: bool = True +) -> None: + ( + session.query(CognitionIntegration) + .filter(CognitionIntegration.id.in_(ids)) + .delete(synchronize_session=False) + ) + if delete_cognition_groups: + ( + session.query(CognitionGroup) + .filter(CognitionGroup.meta_data.op("->>")("integration_id").in_(ids)) + .delete(synchronize_session=False) + ) + general.flush_or_commit(with_commit) diff --git a/cognition_objects/integration_access.py b/cognition_objects/integration_access.py new file mode 100644 index 00000000..548e3e71 --- /dev/null +++ b/cognition_objects/integration_access.py @@ -0,0 +1,88 @@ +from typing import List, Optional +from datetime import datetime + +from ..business_objects import general +from ..session import session +from ..models import CognitionIntegrationAccess +from ..enums import CognitionIntegrationType + + +def get_by_id(id: str) -> CognitionIntegrationAccess: + return ( + session.query(CognitionIntegrationAccess) + .filter(CognitionIntegrationAccess.id == id) + .first() + ) + + +def get_by_org_id(org_id: str) -> List[CognitionIntegrationAccess]: + return ( + session.query(CognitionIntegrationAccess) + .filter(CognitionIntegrationAccess.organization_id == org_id) + .all() + ) + + +def get( + org_id: str, integration_type: Optional[CognitionIntegrationType] = None +) -> List[CognitionIntegrationAccess]: + query = session.query(CognitionIntegrationAccess).filter( + CognitionIntegrationAccess.organization_id == org_id, + ) + if integration_type: + query = query.filter( + CognitionIntegrationAccess.integration_type == integration_type.value + ) + return query.order_by(CognitionIntegrationAccess.created_at.asc()).all() + + +def get_all() -> List[CognitionIntegrationAccess]: + return ( + session.query(CognitionIntegrationAccess) + .order_by(CognitionIntegrationAccess.created_at.desc()) + .all() + ) + + +def create( + org_id: str, + user_id: str, + integration_types: List[CognitionIntegrationType], + with_commit: bool = True, + created_at: Optional[datetime] = None, +) -> CognitionIntegrationAccess: + integration_access: CognitionIntegrationAccess = CognitionIntegrationAccess( + organization_id=org_id, + created_by=user_id, + created_at=created_at, + integration_types=[ + integration_type.value for integration_type in integration_types + ], + ) + general.add(integration_access, with_commit) + + return integration_access + + +def update( + id: str, + org_id: Optional[str] = None, + integration_types: Optional[List[CognitionIntegrationType]] = None, + with_commit: bool = True, +) -> CognitionIntegrationAccess: + integration_access = get_by_id(id) + if org_id: + integration_access.organization_id = org_id + if integration_types: + integration_access.integration_types = [ + integration_type.value for integration_type in integration_types + ] + general.add(integration_access, with_commit) + return integration_access + + +def delete(id: str, with_commit: bool = True) -> None: + session.query(CognitionIntegrationAccess).filter( + CognitionIntegrationAccess.id == id + ).delete() + general.flush_or_commit(with_commit) diff --git a/enums.py b/enums.py index 49a740de..b9a9461d 100644 --- a/enums.py +++ b/enums.py @@ -10,6 +10,7 @@ class DataTypes(Enum): TEXT = "TEXT" LLM_RESPONSE = "LLM_RESPONSE" EMBEDDING_LIST = "EMBEDDING_LIST" # only for embeddings & default hidden + PERMISSION = "PERMISSION" # used for access control UNKNOWN = "UNKNOWN" @@ -155,6 +156,17 @@ class Tablenames(Enum): EVALUATION_RUN = "evaluation_run" PLAYGROUND_QUESTION = "playground_question" FULL_ADMIN_ACCESS = "full_admin_access" + GROUP = "group" # used for group based access control + GROUP_MEMBER = "group_member" # used for group based access control + PERMISSION = "permission" # used for access control + INTEGRATION = "integration" + INTEGRATION_ACCESS = "integration_access" + + # Individial integrations + INTEGRATION_GITHUB_FILE = "github_file" + INTEGRATION_GITHUB_ISSUE = "github_issue" + INTEGRATION_PDF = "pdf" + INTEGRATION_SHAREPOINT = "sharepoint" STEP_TEMPLATES = "step_templates" # templates for strategy steps def snake_case_to_pascal_case(self): @@ -494,6 +506,7 @@ class TaskType(Enum): TASK_QUEUE_ACTION = "task_queue_action" RUN_COGNITION_MACRO = "RUN_COGNITION_MACRO" PARSE_COGNITION_FILE = "PARSE_COGNITION_FILE" + EXECUTE_INTEGRATION = "EXECUTE_INTEGRATION" class TaskQueueAction(Enum): @@ -501,6 +514,7 @@ class TaskQueueAction(Enum): SEND_WEBSOCKET = "SEND_WEBSOCKET" FINISH_COGNITION_SETUP = "FINISH_COGNITION_SETUP" RUN_WEAK_SUPERVISION = "RUN_WEAK_SUPERVISION" + POSTPROCESS_INTEGRATION = "POSTPROCESS_INTEGRATION" class AgreementType(Enum): @@ -682,6 +696,10 @@ class CognitionMarkdownFileState(Enum): FINISHED = "FINISHED" FAILED = "FAILED" + @classmethod + def all(cls): + return [e.value for e in cls] + class CognitionInterfaceType(Enum): CHAT = "CHAT" @@ -890,3 +908,20 @@ class AdminQueries(Enum): "AVG_MESSAGES_PER_CONVERSATION_GLOBAL" # parameter options: organization_id ) AVG_MESSAGES_PER_CONVERSATION = "AVG_MESSAGES_PER_CONVERSATION" # parameter options: period (days, weeks or months), slices, organization_id + + +class CognitionIntegrationType(Enum): + SHAREPOINT = "SHAREPOINT" + GITHUB_FILE = "GITHUB_FILE" + GITHUB_ISSUE = "GITHUB_ISSUE" + PDF = "PDF" + + @staticmethod + def from_string(value: str): + changed_value = value.upper().replace(" ", "_").replace("-", "_") + try: + return CognitionIntegrationType[changed_value] + except KeyError: + raise KeyError( + f"Could not parse CognitionIntegrationType from string '{changed_value}'" + ) diff --git a/integration_objects/__init__.py b/integration_objects/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/integration_objects/helper.py b/integration_objects/helper.py new file mode 100644 index 00000000..5ac51b73 --- /dev/null +++ b/integration_objects/helper.py @@ -0,0 +1,58 @@ +from typing import Set + +from ..enums import Tablenames + + +REFINERY_ATTRIBUTE_ACCESS_GROUPS = "" +REFINERY_ATTRIBUTE_ACCESS_USERS = "" + +DEFAULT_METADATA = {"source", "minio_file_name", "running_id"} +TABLE_METADATA = { + Tablenames.INTEGRATION_PDF.value: {"file_path", "page", "total_pages", "title"}, + Tablenames.INTEGRATION_GITHUB_FILE.value: {"path", "sha", "code_language"}, + Tablenames.INTEGRATION_GITHUB_ISSUE.value: { + "url", + "state", + "number", + "assignee", + "milestone", + }, + Tablenames.INTEGRATION_SHAREPOINT.value: { + "extension", + "object_id", + "parent_path", + "name", + "web_url", + f"{Tablenames.INTEGRATION_SHAREPOINT.value}_created_by", + "modified_by", + "created", + "modified", + "description", + "size", + "mime_type", + "hashes", + "permissions", + "file_properties", + }, +} + + +def get_supported_metadata_keys(table_name: str) -> Set[str]: + """ + Function for controlling and documenting the dynamic metadata fields associated with different integration types. + + The `TABLE_METADATA` dictionary defines which metadata keys are expected and allowed for each integration table + (e.g., `integration.sharepoint`, `integration.github_file`). Each value contains a set of keys specific to that integration, while + the `DEFAULT_METADATA` (`source`, `running_id`, `minio_file_name`) are always included. + + During extraction, metadata is dynamically attached to each document according to the rules defined here. + + This function is used by the integration object logic (see `src.util.integration #delta_load`) to + validate and filter metadata before persisting it, ensuring consistency and preventing unwanted fields from being + stored in the database. + + Example: + get_supported_metadata("pdf") + # returns: {"source", "minio_file_name", "running_id", "file_path", "page", "total_pages", "title"} + """ + return DEFAULT_METADATA.union(TABLE_METADATA.get(table_name, set())) diff --git a/integration_objects/manager.py b/integration_objects/manager.py new file mode 100644 index 00000000..7979678f --- /dev/null +++ b/integration_objects/manager.py @@ -0,0 +1,235 @@ +from typing import List, Optional, Dict, Union, Type, Any +from datetime import datetime +from sqlalchemy import func +from sqlalchemy.orm.attributes import flag_modified + +from ..business_objects import general +from ..cognition_objects import integration as integration_db_bo +from ..session import session +from .helper import get_supported_metadata_keys + + +def get( + IntegrationModel: Type, + integration_id: str, + id: Optional[str] = None, +) -> Union[List[object], object]: + query = session.query(IntegrationModel).filter( + IntegrationModel.integration_id == integration_id, + ) + if id is not None: + query = query.filter(IntegrationModel.id == id) + return query.first() + return query.order_by(IntegrationModel.created_at.desc()).all() + + +def get_by_id( + IntegrationModel: Type, + id: str, +) -> object: + return session.query(IntegrationModel).filter(IntegrationModel.id == id).first() + + +def get_by_running_id( + IntegrationModel: Type, + integration_id: str, + running_id: int, +) -> object: + return ( + session.query(IntegrationModel) + .filter( + IntegrationModel.integration_id == integration_id, + IntegrationModel.running_id == running_id, + ) + .first() + ) + + +def get_by_source( + IntegrationModel: Type, + integration_id: str, + source: str, +) -> object: + return ( + session.query(IntegrationModel) + .filter( + IntegrationModel.integration_id == integration_id, + IntegrationModel.source == source, + ) + .first() + ) + + +def get_all_by_integration_id( + IntegrationModel: Type, + integration_id: str, +) -> List[object]: + return ( + session.query(IntegrationModel) + .filter(IntegrationModel.integration_id == integration_id) + .order_by(IntegrationModel.created_at) + .all() + ) + + +def get_all_by_project_id( + IntegrationModel: Type, + project_id: str, +) -> List[object]: + integrations = integration_db_bo.get_all_by_project_id(project_id) + return ( + session.query(IntegrationModel) + .filter( + IntegrationModel.integration_id.in_([i.id for i in integrations]), + ) + .order_by(IntegrationModel.created_at.asc()) + .all() + ) + + +def get_existing_integration_records( + IntegrationModel: Type, + integration_id: str, + by: str = "source", +) -> Dict[str, object]: + return { + getattr(record, by, record.source): record + for record in get_all_by_integration_id(IntegrationModel, integration_id) + } + + +def get_running_ids( + IntegrationModel: Type, + integration_id: str, + by: str = "source", +) -> Dict[str, int]: + return dict( + session.query( + getattr(IntegrationModel, by, IntegrationModel.source), + func.coalesce(func.max(IntegrationModel.running_id), 0), + ) + .filter(IntegrationModel.integration_id == integration_id) + .group_by(getattr(IntegrationModel, by, IntegrationModel.source)) + .all() + ) + + +def create( + IntegrationModel: Type, + created_by: str, + integration_id: str, + running_id: int, + created_at: Optional[datetime] = None, + error_message: Optional[str] = None, + id: Optional[str] = None, + with_commit: bool = True, + **metadata, +) -> Optional[object]: + if not integration_db_bo.get_by_id(integration_id): + # If the integration does not exist, + # it was likely deleted during runtime + print(f"Integration with id '{integration_id}' not found", flush=True) + return + integration_record = IntegrationModel( + created_by=created_by, + integration_id=integration_id, + running_id=running_id, + created_at=created_at, + error_message=error_message, + id=id, + **metadata, + ) + + general.add(integration_record, with_commit) + + return integration_record + + +def update( + IntegrationModel: Type, + id: str, + integration_id: str, + updated_by: str, + running_id: Optional[int] = None, + updated_at: Optional[datetime] = None, + error_message: Optional[str] = None, + with_commit: bool = True, + **metadata, +) -> Optional[object]: + if not integration_db_bo.get_by_id(integration_id): + # If the integration does not exist, + # it was likely deleted during runtime + print(f"Integration with id '{integration_id}' not found", flush=True) + return + integration_record = get(IntegrationModel, integration_id, id) + integration_record.updated_by = updated_by + + if running_id is not None: + integration_record.running_id = running_id + if updated_at is not None: + integration_record.updated_at = updated_at + if error_message is not None: + integration_record.error_message = error_message + + record_updated = False + for key, value in metadata.items(): + if not hasattr(integration_record, key): + raise ValueError( + f"Invalid field '{key}' for {IntegrationModel.__tablename__}" + ) + existing_value = getattr(integration_record, key, None) + if value is not None and value != existing_value: + setattr(integration_record, key, value) + flag_modified(integration_record, key) + record_updated = True + + if record_updated: + general.flush_or_commit(with_commit) + + return integration_record + + +def delete_many( + IntegrationModel: Type, + ids: List[str], + with_commit: bool = False, +) -> None: + integration_records = session.query(IntegrationModel).filter( + IntegrationModel.id.in_(ids) + ) + integration_records.delete(synchronize_session=False) + general.flush_or_commit(with_commit) + + +def clear_history( + IntegrationModel: Type, + id: str, + with_commit: bool = False, +) -> None: + integration_record = get_by_id(IntegrationModel, id) + integration_record.delta_criteria = None + flag_modified(integration_record, "delta_criteria") + general.flush_or_commit(with_commit) + + +def get_supported_metadata( + table_name: str, metadata: Dict[str, Union[str, int, float, bool]] +) -> Dict[str, Any]: + supported_keys = get_supported_metadata_keys(table_name) + supported_metadata = { + key: metadata[key] for key in supported_keys.intersection(metadata.keys()) + } + return __rename_metadata(table_name, supported_metadata) + + +def __rename_metadata( + table_name: str, metadata: Dict[str, Union[str, int, float, bool]] +) -> Dict[str, Any]: + rename_keys = { + "id": f"{table_name}_id", + "created_by": f"{table_name}_created_by", + "created_at": f"{table_name}_created_at", + "updated_by": f"{table_name}_updated_by", + "updated_at": f"{table_name}_updated_at", + } + return {rename_keys.get(key, key): value for key, value in metadata.items()} diff --git a/models.py b/models.py index 3aea5150..57c891f5 100644 --- a/models.py +++ b/models.py @@ -226,6 +226,7 @@ class User(Base): created_at = Column(DateTime, default=sql.func.now()) metadata_public = Column(JSON) sso_provider = Column(String) + oidc_identifier = Column(String) use_new_cognition_ui = Column(Boolean, default=True) @@ -1962,6 +1963,43 @@ class StepTemplates(Base): # } +class CognitionGroup(Base): + __tablename__ = Tablenames.GROUP.value + __table_args__ = {"schema": "cognition"} + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + organization_id = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.ORGANIZATION.value}.id", ondelete="CASCADE"), + index=True, + ) + name = Column(String) + description = Column(String) + created_at = Column(DateTime, default=sql.func.now()) + created_by = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.USER.value}.id", ondelete="SET NULL"), + index=True, + ) + meta_data = Column(JSON) + + +class CognitionGroupMember(Base): + __tablename__ = Tablenames.GROUP_MEMBER.value + __table_args__ = {"schema": "cognition"} + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + group_id = Column( + UUID(as_uuid=True), + ForeignKey(f"cognition.{Tablenames.GROUP.value}.id", ondelete="CASCADE"), + index=True, + ) + user_id = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.USER.value}.id", ondelete="CASCADE"), + index=True, + ) + created_at = Column(DateTime, default=sql.func.now()) + + # =========================== Global tables =========================== class GlobalWebsocketAccess(Base): # table to store prepared websocket configuration. @@ -2109,3 +2147,244 @@ class FullAdminAccess(Base): id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) email = Column(String, unique=True) meta_info = Column(JSON) + + +class CognitionIntegration(Base): + __tablename__ = Tablenames.INTEGRATION.value + __table_args__ = {"schema": "cognition"} + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + organization_id = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.ORGANIZATION.value}.id", ondelete="CASCADE"), + index=True, + ) + project_id = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.PROJECT.value}.id", ondelete="SET NULL"), + index=True, + ) + created_by = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.USER.value}.id", ondelete="SET NULL"), + index=True, + ) + created_at = Column(DateTime, default=sql.func.now()) + updated_by = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.USER.value}.id", ondelete="SET NULL"), + index=True, + ) + updated_at = Column(DateTime, onupdate=sql.func.now()) + started_at = Column(DateTime) + finished_at = Column(DateTime) + name = Column(String) + description = Column(String) + tokenizer = Column(String) + state = Column(String) # of type enums.CognitionMarkdownFileState.*.value + type = Column(String) # of type enums.CognitionIntegrationType.*.value + config = Column(JSON) + """JSON object that contains the configuration for the integration type. + Examples: + - For a webhook integration, it might contain the URL and headers. + - For an API integration, it might contain the API key and endpoint. + - For a database integration, it might contain the connection string and credentials. + + """ + + llm_config = Column(JSON) + error_message = Column(String) + is_synced = Column(Boolean, nullable=True) + last_synced_at = Column(DateTime) + delta_criteria = Column(JSON) + + +class CognitionIntegrationAccess(Base): + __tablename__ = Tablenames.INTEGRATION_ACCESS.value + __table_args__ = {"schema": "cognition"} + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + created_by = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.USER.value}.id", ondelete="SET NULL"), + index=True, + ) + created_at = Column(DateTime, default=sql.func.now()) + organization_id = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.ORGANIZATION.value}.id", ondelete="CASCADE"), + index=True, + ) + integration_types = Column( + ARRAY(String) + ) # of type enums.CognitionIntegrationType.*.value + + +class IntegrationGithubFile(Base): + __tablename__ = Tablenames.INTEGRATION_GITHUB_FILE.value + __table_args__ = ( + UniqueConstraint( + "integration_id", + "running_id", + "source", + name=f"unique_{__tablename__}_source", + ), + {"schema": "integration"}, + ) + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + created_by = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.USER.value}.id", ondelete="SET NULL"), + index=True, + ) + updated_by = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.USER.value}.id", ondelete="SET NULL"), + index=True, + ) + created_at = Column(DateTime, default=sql.func.now()) + updated_at = Column(DateTime, onupdate=sql.func.now()) + integration_id = Column( + UUID(as_uuid=True), + ForeignKey(f"cognition.{Tablenames.INTEGRATION.value}.id", ondelete="CASCADE"), + index=True, + ) + running_id = Column(Integer, index=True) + source = Column(String, index=True) + minio_file_name = Column(String) + error_message = Column(String) + + path = Column(String) + sha = Column(String) + code_language = Column(String) + + +class IntegrationGithubIssue(Base): + __tablename__ = Tablenames.INTEGRATION_GITHUB_ISSUE.value + __table_args__ = ( + UniqueConstraint( + "integration_id", + "running_id", + "source", + name=f"unique_{__tablename__}_source", + ), + {"schema": "integration"}, + ) + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + created_by = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.USER.value}.id", ondelete="SET NULL"), + index=True, + ) + updated_by = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.USER.value}.id", ondelete="SET NULL"), + index=True, + ) + created_at = Column(DateTime, default=sql.func.now()) + updated_at = Column(DateTime, onupdate=sql.func.now()) + integration_id = Column( + UUID(as_uuid=True), + ForeignKey(f"cognition.{Tablenames.INTEGRATION.value}.id", ondelete="CASCADE"), + index=True, + ) + running_id = Column(Integer, index=True) + source = Column(String, index=True) + minio_file_name = Column(String) + error_message = Column(String) + + url = Column(String) + state = Column(String) + assignee = Column(String) + milestone = Column(String) + number = Column(Integer) + + +class IntegrationPdf(Base): + __tablename__ = Tablenames.INTEGRATION_PDF.value + __table_args__ = ( + UniqueConstraint( + "integration_id", + "running_id", + "source", + name=f"unique_{__tablename__}_source", + ), + {"schema": "integration"}, + ) + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + created_by = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.USER.value}.id", ondelete="SET NULL"), + index=True, + ) + updated_by = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.USER.value}.id", ondelete="SET NULL"), + index=True, + ) + created_at = Column(DateTime, default=sql.func.now()) + updated_at = Column(DateTime, onupdate=sql.func.now()) + integration_id = Column( + UUID(as_uuid=True), + ForeignKey(f"cognition.{Tablenames.INTEGRATION.value}.id", ondelete="CASCADE"), + index=True, + ) + running_id = Column(Integer, index=True) + source = Column(String, index=True) + minio_file_name = Column(String) + error_message = Column(String) + + file_path = Column(String) + page = Column(Integer) + total_pages = Column(Integer) + title = Column(String) + + +class IntegrationSharepoint(Base): + __tablename__ = Tablenames.INTEGRATION_SHAREPOINT.value + __table_args__ = ( + UniqueConstraint( + "integration_id", + "running_id", + "source", + name=f"unique_{__tablename__}_source", + ), + {"schema": "integration"}, + ) + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + created_by = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.USER.value}.id", ondelete="SET NULL"), + index=True, + ) + updated_by = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.USER.value}.id", ondelete="SET NULL"), + index=True, + ) + created_at = Column(DateTime, default=sql.func.now()) + updated_at = Column(DateTime, onupdate=sql.func.now()) + integration_id = Column( + UUID(as_uuid=True), + ForeignKey(f"cognition.{Tablenames.INTEGRATION.value}.id", ondelete="CASCADE"), + index=True, + ) + running_id = Column(Integer, index=True) + source = Column(String, index=True) + minio_file_name = Column(String) + error_message = Column(String) + + extension = Column(String) + object_id = Column(String) + parent_path = Column(String) + name = Column(String) + web_url = Column(String) + sharepoint_created_by = Column(String) + modified_by = Column(String) + created = Column(DateTime, default=None) + modified = Column(DateTime, default=None) + description = Column(String) + size = Column(Integer) + mime_type = Column(String) + hashes = Column(JSON) + permissions = Column(JSON) + file_properties = Column(JSON) diff --git a/util.py b/util.py index dca8c616..d2ef304f 100644 --- a/util.py +++ b/util.py @@ -15,6 +15,10 @@ from .business_objects import general CAMEL_CASE_PATTERN = compile(r"^([a-z]+[A-Z]?)*$") +SNAKE_CASE_PATTERNS = [ + compile(r"(.)([A-Z][a-z]+)"), + compile(r"([a-z0-9])([A-Z])"), +] UUID_REGEX_PATTERN = compile( r"^[0-9a-f]{8}-[0-9a-f]{4}-[1-5][0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$", IGNORECASE, @@ -234,6 +238,15 @@ def to_camel_case(name: str, dont_wrap_uuids: bool = True): return "".join([name[0].lower(), name[1:]]) +def to_snake_case(name: str) -> str: + # ref: https://stackoverflow.com/questions/1175208/elegant-python-function-to-convert-camelcase-to-snake-case + if not is_camel_case(name): + return name + for phase in SNAKE_CASE_PATTERNS: + name = phase.sub(r"\1_\2", name) + return name.lower() + + def is_list_like(value: Any) -> bool: return ( isinstance(value, collections_abc_Iterable)