diff --git a/ddpui/api/transform_api.py b/ddpui/api/transform_api.py index a1e2f1d37..22f718b60 100644 --- a/ddpui/api/transform_api.py +++ b/ddpui/api/transform_api.py @@ -25,9 +25,15 @@ EditDbtOperationPayload, LockCanvasRequestSchema, LockCanvasResponseSchema, + GenerateGraphSchema, ) from ddpui.utils.taskprogress import TaskProgress -from ddpui.core.transformfunctions import validate_operation_config, check_canvas_locked +from ddpui.core.transformfunctions import ( + validate_operation_config, + check_canvas_locked, + chat_to_graph, + create_operation_node, +) from ddpui.api.warehouse_api import get_warehouse_data from ddpui.models.tasks import TaskProgressHashPrefix @@ -153,70 +159,71 @@ def post_construct_dbt_model_operation(request, payload: CreateDbtModelPayload): if not org_warehouse: raise HttpError(404, "please setup your warehouse first") - # make sure the orgdbt here is the one we create locally - orgdbt = OrgDbt.objects.filter(org=org, gitrepo_url=None).first() - if not orgdbt: - raise HttpError(404, "dbt workspace not setup") + # # make sure the orgdbt here is the one we create locally + # orgdbt = OrgDbt.objects.filter(org=org, gitrepo_url=None).first() + # if not orgdbt: + # raise HttpError(404, "dbt workspace not setup") - check_canvas_locked(orguser, payload.canvas_lock_id) + # check_canvas_locked(orguser, payload.canvas_lock_id) - if payload.op_type not in dbtautomation_service.OPERATIONS_DICT.keys(): - raise HttpError(422, "Operation not supported") + # if payload.op_type not in dbtautomation_service.OPERATIONS_DICT.keys(): + # raise HttpError(422, "Operation not supported") - is_multi_input_op = payload.op_type in ["join", "unionall"] + # is_multi_input_op = payload.op_type in ["join", "unionall"] - target_model = None - if payload.target_model_uuid: - target_model = OrgDbtModel.objects.filter(uuid=payload.target_model_uuid).first() + # target_model = None + # if payload.target_model_uuid: + # target_model = OrgDbtModel.objects.filter(uuid=payload.target_model_uuid).first() - if not target_model: - target_model = OrgDbtModel.objects.create( - uuid=uuid.uuid4(), - orgdbt=orgdbt, - under_construction=True, - ) + # if not target_model: + # target_model = OrgDbtModel.objects.create( + # uuid=uuid.uuid4(), + # orgdbt=orgdbt, + # under_construction=True, + # ) - # only under construction models can be modified - if not target_model.under_construction: - raise HttpError(422, "model is locked") + # # only under construction models can be modified + # if not target_model.under_construction: + # raise HttpError(422, "model is locked") - current_operations_chained = OrgDbtOperation.objects.filter(dbtmodel=target_model).count() + # current_operations_chained = OrgDbtOperation.objects.filter(dbtmodel=target_model).count() - final_config, all_input_models = validate_operation_config( - payload, target_model, is_multi_input_op, current_operations_chained - ) + # final_config, all_input_models = validate_operation_config( + # payload, target_model, is_multi_input_op, current_operations_chained + # ) - # we create edges only with tables/models - for source in all_input_models: - edge = DbtEdge.objects.filter(from_node=source, to_node=target_model).first() - if not edge: - DbtEdge.objects.create( - from_node=source, - to_node=target_model, - ) + # # we create edges only with tables/models + # for source in all_input_models: + # edge = DbtEdge.objects.filter(from_node=source, to_node=target_model).first() + # if not edge: + # DbtEdge.objects.create( + # from_node=source, + # to_node=target_model, + # ) - output_cols = dbtautomation_service.get_output_cols_for_operation( - org_warehouse, payload.op_type, final_config["config"].copy() - ) - logger.info("creating operation") - - dbt_op = OrgDbtOperation.objects.create( - dbtmodel=target_model, - uuid=uuid.uuid4(), - seq=current_operations_chained + 1, - config=final_config, - output_cols=output_cols, - ) + # output_cols = dbtautomation_service.get_output_cols_for_operation( + # org_warehouse, payload.op_type, final_config["config"].copy() + # ) + # logger.info("creating operation") - logger.info("created operation") + # dbt_op = OrgDbtOperation.objects.create( + # dbtmodel=target_model, + # uuid=uuid.uuid4(), + # seq=current_operations_chained + 1, + # config=final_config, + # output_cols=output_cols, + # ) - # save the output cols of the latest operation to the dbt model - target_model.output_cols = dbt_op.output_cols - target_model.save() + # logger.info("created operation") - logger.info("updated output cols for the model") + # # save the output cols of the latest operation to the dbt model + # target_model.output_cols = dbt_op.output_cols + # target_model.save() - return from_orgdbtoperation(dbt_op, chain_length=dbt_op.seq) + # logger.info("updated output cols for the model") + return create_operation_node(org, orguser, payload) + + # return from_orgdbtoperation(dbt_op, chain_length=dbt_op.seq) @transform_router.put( @@ -731,3 +738,24 @@ def post_unlock_canvas(request, payload: LockCanvasRequestSchema): canvas_lock.delete() return {"success": 1} + + +@transform_router.post( + "/agent/chat/", + auth=auth.CustomAuthMiddleware(), +) +@has_permission(["can_edit_dbt_model"]) +def post_generate_graph(request, payload: GenerateGraphSchema): + """ + Unlock the canvas for the org + """ + orguser: OrgUser = request.orguser + org = orguser.org + + orgdbt = OrgDbt.objects.filter(org=org, gitrepo_url=None).first() + if not orgdbt: + raise HttpError(404, "dbt workspace not setup") + + reply = chat_to_graph(org, orguser, payload) + + return reply diff --git a/ddpui/core/transformfunctions.py b/ddpui/core/transformfunctions.py index fd6664834..6974fd00e 100644 --- a/ddpui/core/transformfunctions.py +++ b/ddpui/core/transformfunctions.py @@ -1,16 +1,22 @@ import os - +import uuid +import json from typing import Union from ninja.errors import HttpError from ddpui.models.dbt_workflow import OrgDbtModel, DbtEdge, OrgDbtOperation from ddpui.models.org_user import OrgUser -from ddpui.models.org import Org +from ddpui.models.org import Org, OrgDbt, OrgWarehouse from ddpui.models.canvaslock import CanvasLock from ddpui.schemas.dbt_workflow_schema import ( CreateDbtModelPayload, EditDbtOperationPayload, + GenerateGraphSchema, ) - +from ddpui.utils.transform_workflow_helpers import ( + from_orgdbtoperation, +) +from ddpui.core import dbtautomation_service +from autogen import ConversableAgent, initiate_chats from ddpui.utils.custom_logger import CustomLogger logger = CustomLogger("ddpui") @@ -128,3 +134,257 @@ def check_canvas_locked(requestor_orguser: OrgUser, lock_id: str): ) else: raise HttpError(403, "acquire a canvas lock first") + + +def create_operation_node( + org, + orguser: OrgUser, + payload: CreateDbtModelPayload, +): + org_warehouse = OrgWarehouse.objects.filter(org=org).first() + if not org_warehouse: + raise HttpError(404, "please setup your warehouse first") + + # make sure the orgdbt here is the one we create locally + orgdbt = OrgDbt.objects.filter(org=org, gitrepo_url=None).first() + if not orgdbt: + raise HttpError(404, "dbt workspace not setup") + + check_canvas_locked(orguser, payload.canvas_lock_id) + + if payload.op_type not in dbtautomation_service.OPERATIONS_DICT.keys(): + raise HttpError(422, "Operation not supported") + + is_multi_input_op = payload.op_type in ["join", "unionall"] + + target_model = None + if payload.target_model_uuid: + target_model = OrgDbtModel.objects.filter(uuid=payload.target_model_uuid).first() + + if not target_model: + target_model = OrgDbtModel.objects.create( + uuid=uuid.uuid4(), + orgdbt=orgdbt, + under_construction=True, + ) + + # only under construction models can be modified + if not target_model.under_construction: + raise HttpError(422, "model is locked") + + current_operations_chained = OrgDbtOperation.objects.filter(dbtmodel=target_model).count() + + final_config, all_input_models = validate_operation_config( + payload, target_model, is_multi_input_op, current_operations_chained + ) + + # we create edges only with tables/models + for source in all_input_models: + edge = DbtEdge.objects.filter(from_node=source, to_node=target_model).first() + if not edge: + DbtEdge.objects.create( + from_node=source, + to_node=target_model, + ) + + output_cols = dbtautomation_service.get_output_cols_for_operation( + org_warehouse, payload.op_type, final_config["config"].copy() + ) + logger.info("creating operation") + + dbt_op = OrgDbtOperation.objects.create( + dbtmodel=target_model, + uuid=uuid.uuid4(), + seq=current_operations_chained + 1, + config=final_config, + output_cols=output_cols, + ) + + logger.info("created operation") + + # save the output cols of the latest operation to the dbt model + target_model.output_cols = dbt_op.output_cols + target_model.save() + + logger.info("updated output cols for the model") + + return from_orgdbtoperation(dbt_op, chain_length=dbt_op.seq) + + +def chat_to_graph(org, orguser: OrgUser, payload: GenerateGraphSchema): + """ + 1. Converts the query to dbt sql statement + 2. Creates the model file on disk + 3. Validates the model + 4. Generates nodes and edges + """ + + config = payload.config + + prompt = config["query"] + + op_types = dbtautomation_service.OPERATIONS_DICT.keys() + optypes_str = "\n".join(op_types) + + # custom prompts for each operation + # op_prompts = {"dropcolumns": ""} + + llm_config = {"model": "gpt-4", "cache_seed": None, "api_key": os.getenv("OPENAI_API_KEY")} + + classify_optype_agent = ConversableAgent( + name="classify_optype_agent", + system_message=f"You will be given a user prompt that you need to classify into one of the sql operations present in {optypes_str}", + llm_config=llm_config, + human_input_mode="NEVER", + ) + + user_proxy_agent = ConversableAgent( + name="user_proxy_agent", + llm_config=False, + human_input_mode="NEVER", + code_execution_config=False, + ) + + chats = [ + { + "sender": user_proxy_agent, + "recipient": classify_optype_agent, + "message": prompt, + "summary_method": "reflection_with_llm", + "summary_args": { + "summary_prompt": f"Return the operation type as single word from the following list: {optypes_str}, Just give the response as a single type among the follwing and no other sentence", + }, + "max_turns": 1, + "clear_history": True, + }, + ] + result = initiate_chats(chats) + + # chat_result = # Extract the first (and only) ChatResult object + + # Access the summary directly from the ChatResult + operation_type = result[0].summary + + print(operation_type) + + payload.op_type = operation_type + + if operation_type == "dropcolumns": + + drop_columns = [ + { + "sender": user_proxy_agent, + "recipient": classify_optype_agent, + "message": prompt, + "summary_method": "reflection_with_llm", + "summary_args": { + "summary_prompt": "Return the column names to perform operations in the following array format: ['column name 1', 'column name 2', 'column name 3']", + }, + "max_turns": 1, + "clear_history": True, + } + ] + + result1 = initiate_chats(drop_columns) + print(result1) + columns = result1[0].summary + payload.config = {"columns": columns} + + elif operation_type == "aggregate": + aggregate_columns = [ + { + "sender": user_proxy_agent, + "recipient": classify_optype_agent, + "message": prompt, + "summary_method": "reflection_with_llm", + "summary_args": { + "summary_prompt": "From the following list: [avg, sum, min, max, count], identify the aggregation operation mentioned in the input. Respond with only the operation name.", + }, + "max_turns": 1, + "clear_history": True, + }, + { + "sender": user_proxy_agent, + "recipient": classify_optype_agent, + "message": prompt, + "summary_method": "reflection_with_llm", + "summary_args": { + "summary_prompt": "Extract the column name mentioned in the input for aggregation. If no column is mentioned, respond with 'none'.", + }, + "max_turns": 1, + "clear_history": True, + }, + # Step 3: Generate the output column name + { + "sender": user_proxy_agent, + "recipient": classify_optype_agent, + "message": prompt, + "summary_method": "reflection_with_llm", + "summary_args": { + "summary_prompt": "Generate the output column name by combining the target column with the aggregation operation. Example: 'salary' + 'avg' -> 'salary_avg'. If no column is given, return 'invalid'.", + }, + "max_turns": 1, + "clear_history": True, + }, + ] + + result = initiate_chats(aggregate_columns) + + # Extract the operation, column, and output column name + operation = result[0].summary # Aggregation operation + column = result[1].summary # Target column + output_column_name = result[2].summary + + payload.config = config = { + "aggregate_on": [ + { + "operation": operation, + "column": column, + "output_column_name": output_column_name, + } + ] + } + + elif operation_type == "where": + clauses_chat = [ + { + "sender": user_proxy_agent, + "recipient": classify_optype_agent, + "message": prompt, + "summary_method": "reflection_with_llm", + "summary_args": { + "summary_prompt": "Extract the following details from the sentence:\n" + "1. The column name (e.g., age, salary).\n" + "2. The logical operator out of these ('between','=', '!=', '>=', '<=', '>', '<').\n" + "3. The operand value (mention whether it is a column value or a constant).\n" + "Respond in the following JSON format:\n" + "{ \"column\": \"\", \"operator\": \"\", \"operand_value\": \"\", \"is_col\": }", + }, + "max_turns": 1, + "clear_history": True, + } + ] + + clauses_result = initiate_chats(clauses_chat) + clauses_data = clauses_result[0].summary # Extracted clauses as JSON-like string + print(f"Clauses: {clauses_data}") + clauses = json.loads(clauses_data) + + + payload.config = { + "where_type": "and", + "clauses": [ + { + "column": clauses["column"], + "operator": clauses["operator"], + "operand": { + "value": clauses["operand_value"], + "is_col": clauses["is_col"], + }, + } + ], + } + + # print(columns, operation_type) + + return create_operation_node(org, orguser, payload) diff --git a/ddpui/schemas/dbt_workflow_schema.py b/ddpui/schemas/dbt_workflow_schema.py index 0267cbea7..8c665a819 100644 --- a/ddpui/schemas/dbt_workflow_schema.py +++ b/ddpui/schemas/dbt_workflow_schema.py @@ -70,3 +70,15 @@ class LockCanvasResponseSchema(Schema): lock_id: str = None locked_by: str locked_at: str + + +class GenerateGraphSchema(Schema): + """schema to acquire a lock on the ui4t canvas""" + + config: dict + op_type: str + target_model_uuid: str = "" + input_uuid: str = "" + source_columns: list[str] = [] + other_inputs: list[InputModelPayload] = [] + canvas_lock_id: str = None diff --git a/requirements.txt b/requirements.txt index d835e5629..f94cd2155 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,7 @@ asttokens==2.2.1 async-timeout==4.0.3 asyncpg==0.27.0 attrs==22.2.0 +autogen==0.3.1 Babel==2.12.1 backcall==0.2.0 bcrypt==4.1.3 @@ -45,7 +46,9 @@ dateparser==1.1.7 dbt_automation @ git+https://github.com/DalgoT4D/dbt-automation.git@140a1f2abdd726d4471e2b6f0c89700b598c77d8 decorator==5.1.1 dill==0.3.6 +diskcache==5.6.3 distlib==0.3.8 +distro==1.9.0 Django==4.2 django-cors-headers==3.14.0 django-extensions==3.2.3 @@ -53,10 +56,12 @@ django-flags==5.0.13 django-ninja==0.21.0 django-prometheus==2.3.1 djangorestframework==3.14.0 +docker==7.1.0 exceptiongroup==1.1.1 executing==1.2.0 Faker==17.6.0 filelock==3.16.1 +FLAML==2.3.1 fsspec==2023.1.0 future==0.18.3 google-api-core==2.19.0 @@ -90,6 +95,7 @@ isodate==0.6.1 isort==5.12.0 jedi==0.19.0 Jinja2==3.1.2 +jiter==0.6.1 jmespath==1.0.1 jsonpatch==1.32 jsonpointer==2.3 @@ -117,6 +123,7 @@ networkx==2.8.8 nodeenv==1.9.1 numpy==1.25.2 oauthlib==3.2.2 +openai==1.52.1 orjson==3.8.8 packaging==23.0 paramiko==3.4.0 @@ -181,14 +188,17 @@ sshtunnel==0.4.0 stack-data==0.6.2 starkbank-ecdsa==2.2.0 stringcase==1.2.0 +termcolor==2.5.0 text-unidecode==1.3 +tiktoken==0.8.0 toml==0.10.2 tomli==2.0.1 tomlkit==0.11.7 tornado==6.3.2 +tqdm==4.66.5 traitlets==5.9.0 typer==0.7.0 -typing_extensions==4.5.0 +typing_extensions==4.12.2 tzdata==2022.7 tzlocal==4.3 urllib3==1.26.15