1
1
import os
2
2
import logging
3
- import random
4
- import string
5
- import json
6
3
from contextlib import asynccontextmanager
7
4
from collections .abc import AsyncGenerator
8
5
from typing import Dict , Any , Optional , List , Union
9
- from urllib .parse import urlparse
10
6
11
7
from pydantic import Field , BaseModel , ConfigDict
8
+
12
9
from mcp .server .fastmcp import FastMCP
13
10
14
11
from alibabacloud_dms_enterprise20181101 .client import Client as dms_enterprise20181101Client
15
12
from alibabacloud_tea_openapi import models as open_api_models
16
13
from alibabacloud_dms_enterprise20181101 import models as dms_enterprise_20181101_models
17
- from alibabacloud_dts20200101 import models as dts_20200101_models
18
- from alibabacloud_dts20200101 .client import Client as DtsClient
19
- from alibabacloud_tea_openapi .models import Config
20
- from alibabacloud_tea_util import models as util_models
21
14
22
15
# --- Global Logger ---
23
16
logger = logging .getLogger (__name__ )
24
17
25
- g_reserved = '''{
26
- "targetTableMode": "0",
27
- "dbListCaseChangeMode": "default",
28
- "isAnalyzer": false,
29
- "eventMove": false,
30
- "tableAnalyze": false,
31
- "whitelist.dms.online.ddl.enable": false,
32
- "sqlparser.dms.original.ddl": true,
33
- "whitelist.ghost.online.ddl.enable": false,
34
- "sqlparser.ghost.original.ddl": false,
35
- "privilegeMigration": false,
36
- "definer": false,
37
- "privilegeDbList": "[]",
38
- "maxRetryTime": 43200,
39
- "retry.blind.seconds": 600,
40
- "srcSSL": "0",
41
- "srcMySQLType": "HighAvailability",
42
- "destSSL": "0",
43
- "a2aFlag": "2.0",
44
- "channelInfo": "mcp",
45
- "autoStartModulesAfterConfig": "none"
46
- }
47
- '''
48
-
49
18
50
19
# --- Pydantic Models ---
51
20
class MyBaseModel (BaseModel ):
@@ -159,20 +128,6 @@ def create_client() -> dms_enterprise20181101Client:
159
128
return dms_enterprise20181101Client (config )
160
129
161
130
162
- def get_dts_client (region_id : str ):
163
- config = Config (
164
- access_key_id = os .getenv ('ALIBABA_CLOUD_ACCESS_KEY_ID' ),
165
- access_key_secret = os .getenv ('ALIBABA_CLOUD_ACCESS_KEY_SECRET' ),
166
- security_token = os .getenv ('ALIBABA_CLOUD_SECURITY_TOKEN' ),
167
- region_id = region_id ,
168
- protocol = "https" ,
169
- connect_timeout = 10 * 1000 ,
170
- read_timeout = 300 * 1000
171
- )
172
- client = DtsClient (config )
173
- return client
174
-
175
-
176
131
async def add_instance (
177
132
db_user : str = Field (description = "The username used to connect to the database" ),
178
133
db_password : str = Field (description = "The password used to connect to the database" ),
@@ -195,6 +150,8 @@ async def add_instance(
195
150
req .instance_id = instance_resource_id
196
151
if region :
197
152
req .region = region
153
+ if mcp .state .real_login_uid :
154
+ req .real_login_user_uid = mcp .state .real_login_uid
198
155
try :
199
156
resp = client .simply_add_instance (req )
200
157
return InstanceInfo (** resp .body .to_map ()) if resp and resp .body else InstanceInfo ()
@@ -211,6 +168,8 @@ async def get_instance(
211
168
client = create_client ()
212
169
req = dms_enterprise_20181101_models .GetInstanceRequest (host = host , port = port )
213
170
if sid : req .sid = sid
171
+ if mcp .state .real_login_uid :
172
+ req .real_login_user_uid = mcp .state .real_login_uid
214
173
try :
215
174
resp = client .get_instance (req )
216
175
instance_data = resp .body .to_map ().get ('Instance' , {}) if resp and resp .body else {}
@@ -221,10 +180,13 @@ async def get_instance(
221
180
222
181
223
182
async def list_instance (
224
- search_key : Optional [str ] = Field (default = None , description = "Optional search key (e.g., instance host, instance alias, etc.)" ),
225
- db_type : Optional [str ] = Field (default = None , description = "Optional instanceType, or called dbType (e.g., mysql, polardb, oracle, "
226
- "postgresql, sqlserver, polardb-pg, etc.)" ),
227
- env_type : Optional [str ] = Field (default = None , description = "Optional instance environment type (e.g., product, dev, test, etc. )" )
183
+ search_key : Optional [str ] = Field (default = None ,
184
+ description = "Optional search key (e.g., instance host, instance alias, etc.)" ),
185
+ db_type : Optional [str ] = Field (default = None ,
186
+ description = "Optional instanceType, or called dbType (e.g., mysql, polardb, oracle, "
187
+ "postgresql, sqlserver, polardb-pg, etc.)" ),
188
+ env_type : Optional [str ] = Field (default = None ,
189
+ description = "Optional instance environment type (e.g., product, dev, test, etc. )" )
228
190
) -> List [InstanceDetail ]:
229
191
client = create_client ()
230
192
req = dms_enterprise_20181101_models .ListInstancesRequest ()
@@ -234,6 +196,8 @@ async def list_instance(
234
196
req .db_type = db_type
235
197
if env_type :
236
198
req .env_type = env_type
199
+ if mcp .state .real_login_uid :
200
+ req .real_login_user_uid = mcp .state .real_login_uid
237
201
try :
238
202
resp = client .list_instances (req )
239
203
@@ -262,6 +226,8 @@ async def search_database(
262
226
client = create_client ()
263
227
req = dms_enterprise_20181101_models .SearchDatabaseRequest (search_key = search_key , page_number = page_number ,
264
228
page_size = page_size )
229
+ if mcp .state .real_login_uid :
230
+ req .real_login_user_uid = mcp .state .real_login_uid
265
231
try :
266
232
resp = client .search_database (req )
267
233
if not resp or not resp .body : return []
@@ -287,8 +253,12 @@ async def get_database(
287
253
) -> DatabaseDetail :
288
254
client = create_client ()
289
255
req = dms_enterprise_20181101_models .GetDatabaseRequest (host = host , port = port , schema_name = schema_name )
256
+
290
257
if sid :
291
258
req .sid = sid
259
+ if mcp .state .real_login_uid :
260
+ req .real_login_user_uid = mcp .state .real_login_uid
261
+
292
262
try :
293
263
resp = client .get_database (req )
294
264
db_data = resp .body .to_map ().get ('Database' , {}) if resp and resp .body else {}
@@ -310,6 +280,8 @@ async def list_tables( # Renamed from listTable to follow convention
310
280
req = dms_enterprise_20181101_models .ListTablesRequest (database_id = database_id , search_name = search_name ,
311
281
page_number = page_number , page_size = page_size ,
312
282
return_guid = True )
283
+ if mcp .state .real_login_uid :
284
+ req .real_login_user_uid = mcp .state .real_login_uid
313
285
try :
314
286
resp = client .list_tables (req )
315
287
return resp .body .to_map () if resp and resp .body else {}
@@ -324,6 +296,8 @@ async def get_meta_table_detail_info(
324
296
) -> TableDetail :
325
297
client = create_client ()
326
298
req = dms_enterprise_20181101_models .GetMetaTableDetailInfoRequest (table_guid = table_guid )
299
+ if mcp .state .real_login_uid :
300
+ req .real_login_user_uid = mcp .state .real_login_uid
327
301
try :
328
302
resp = client .get_meta_table_detail_info (req )
329
303
detail_info = resp .body .to_map ().get ('DetailInfo' , {}) if resp and resp .body else {}
@@ -351,6 +325,8 @@ async def execute_script(
351
325
) -> ExecuteScriptResult : # Return the object, __str__ will be used by wrapper if needed
352
326
client = create_client ()
353
327
req = dms_enterprise_20181101_models .ExecuteScriptRequest (db_id = database_id , script = script , logic = logic )
328
+ if mcp .state .real_login_uid :
329
+ req .real_login_user_uid = mcp .state .real_login_uid
354
330
try :
355
331
resp = client .execute_script (req )
356
332
if not resp or not resp .body :
@@ -447,6 +423,8 @@ async def nl2sql(
447
423
client = create_client ()
448
424
req = dms_enterprise_20181101_models .GenerateSqlFromNLRequest (db_id = database_id , question = question )
449
425
if knowledge : req .knowledge = knowledge
426
+ if mcp .state .real_login_uid :
427
+ req .real_login_user_uid = mcp .state .real_login_uid
450
428
try :
451
429
resp = client .generate_sql_from_nl (req )
452
430
if not resp or not resp .body : return SqlResult (sql = None )
@@ -458,136 +436,6 @@ async def nl2sql(
458
436
raise
459
437
460
438
461
- async def configureDtsJob (
462
- region_id : str = Field (description = "The region id of the dts job (e.g., 'cn-hangzhou')" ),
463
- job_type : str = Field (
464
- description = "The type of job (synchronization job: SYNC, migration job: MIGRATION, data check job: CHECK)" ),
465
- source_endpoint_region : str = Field (description = "The source endpoint region ID" ),
466
- source_endpoint_instance_type : str = Field (
467
- description = "The source endpoint instance type (RDS, ECS, EXPRESS, CEN, DG)" ),
468
- source_endpoint_engine_name : str = Field (
469
- description = "The source endpoint engine name (MySQL, PostgreSQL, SQLServer)" ),
470
- source_endpoint_instance_id : str = Field (description = "The source endpoint instance ID (e.g., 'rm-xxx')" ),
471
- source_endpoint_user_name : str = Field (description = "The source endpoint user name" ),
472
- source_endpoint_password : str = Field (description = "The source endpoint password" ),
473
- destination_endpoint_region : str = Field (description = "The destination endpoint region ID" ),
474
- destination_endpoint_instance_type : str = Field (
475
- description = "The destination endpoint instance type (RDS, ECS, EXPRESS, CEN, DG)" ),
476
- destination_endpoint_engine_name : str = Field (
477
- description = "The destination endpoint engine name (MySQL, PostgreSQL, SQLServer)" ),
478
- destination_endpoint_instance_id : str = Field (
479
- description = "The destination endpoint instance ID (e.g., 'rm-xxx')" ),
480
- destination_endpoint_user_name : str = Field (description = "The destination endpoint user name" ),
481
- destination_endpoint_password : str = Field (description = "The destination endpoint password" ),
482
- db_list : Dict [str , Any ] = Field (
483
- description = 'The database objects in JSON format, example 1: migration dtstest database, db_list should like {"dtstest":{"name":"dtstest","all":true}}; example 2: migration one table task01 in dtstest database, db_list should like {"dtstest":{"name":"dtstest","all":false,"Table":{"task01":{"name":"task01","all":true}}}}; example 3: migration two tables task01 and task02 in dtstest database, db_list should like {"dtstest":{"name":"dtstest","all":false,"Table":{"task01":{"name":"task01","all":true},"task02":{"name":"task02","all":true}}}}' )
484
- ) -> Dict [str , Any ]:
485
- try :
486
- db_list_str = json .dumps (db_list , separators = (',' , ':' ))
487
- logger .info (f"Configure dts job with db_list: { db_list_str } " )
488
-
489
- # init dts client
490
- client = get_dts_client (region_id )
491
- runtime = util_models .RuntimeOptions ()
492
-
493
- # create dts instance
494
- create_dts_instance_request = dts_20200101_models .CreateDtsInstanceRequest (
495
- region_id = region_id ,
496
- type = job_type ,
497
- source_region = source_endpoint_region ,
498
- destination_region = destination_endpoint_region ,
499
- source_endpoint_engine_name = source_endpoint_engine_name ,
500
- destination_endpoint_engine_name = destination_endpoint_engine_name ,
501
- pay_type = 'PostPaid' ,
502
- quantity = 1 ,
503
- min_du = 1 ,
504
- max_du = 4 ,
505
- instance_class = 'micro'
506
- )
507
-
508
- create_dts_instance_response = client .create_dts_instance_with_options (create_dts_instance_request , runtime )
509
- logger .info (f"Create dts instance response: { create_dts_instance_response .body .to_map ()} " )
510
- dts_job_id = create_dts_instance_response .body .to_map ()['JobId' ]
511
-
512
- # configure dts job
513
- ran_job_name = 'dtsmcp-' + '' .join (random .sample (string .ascii_letters + string .digits , 6 ))
514
- custom_reserved = json .loads (g_reserved )
515
- dts_mcp_channel = os .getenv ('DTS_MCP_CHANNEL' )
516
- if dts_mcp_channel and len (dts_mcp_channel ) > 0 :
517
- logger .info (f"Configure dts job with custom dts mcp channel: { dts_mcp_channel } " )
518
- custom_reserved ['channelInfo' ] = dts_mcp_channel
519
- custom_reserved_str = json .dumps (custom_reserved , separators = (',' , ':' ))
520
- logger .info (f"Configure dts job with reserved: { custom_reserved_str } " )
521
- configure_dts_job_request = dts_20200101_models .ConfigureDtsJobRequest (
522
- region_id = region_id ,
523
- dts_job_name = ran_job_name ,
524
- source_endpoint_instance_type = source_endpoint_instance_type ,
525
- source_endpoint_engine_name = source_endpoint_engine_name ,
526
- source_endpoint_instance_id = source_endpoint_instance_id ,
527
- source_endpoint_region = source_endpoint_region ,
528
- source_endpoint_user_name = source_endpoint_user_name ,
529
- source_endpoint_password = source_endpoint_password ,
530
- destination_endpoint_instance_type = destination_endpoint_instance_type ,
531
- destination_endpoint_instance_id = destination_endpoint_instance_id ,
532
- destination_endpoint_engine_name = destination_endpoint_engine_name ,
533
- destination_endpoint_region = destination_endpoint_region ,
534
- destination_endpoint_user_name = destination_endpoint_user_name ,
535
- destination_endpoint_password = destination_endpoint_password ,
536
- structure_initialization = True ,
537
- data_initialization = True ,
538
- data_synchronization = False ,
539
- job_type = job_type ,
540
- db_list = db_list_str ,
541
- reserve = custom_reserved_str
542
- )
543
-
544
- if dts_job_id and len (dts_job_id ) > 0 :
545
- configure_dts_job_request .dts_job_id = dts_job_id
546
-
547
- configure_dts_job_response = client .configure_dts_job_with_options (configure_dts_job_request , runtime )
548
- logger .info (f"Configure dts job response: { configure_dts_job_response .body .to_map ()} " )
549
- return configure_dts_job_response .body .to_map ()
550
- except Exception as e :
551
- logger .error (f"Error occurred while configure dts job: { str (e )} " )
552
- raise e
553
-
554
-
555
- async def startDtsJob (
556
- region_id : str = Field (description = "The region id of the dts job (e.g., 'cn-hangzhou')" ),
557
- dts_job_id : str = Field (description = "The job id of the dts job" )
558
- ) -> Dict [str , Any ]:
559
- try :
560
- client = get_dts_client (region_id )
561
- request = dts_20200101_models .StartDtsJobRequest (
562
- region_id = region_id ,
563
- dts_job_id = dts_job_id
564
- )
565
- runtime = util_models .RuntimeOptions ()
566
- response = client .start_dts_job_with_options (request , runtime )
567
- return response .body .to_map ()
568
- except Exception as e :
569
- logger .error (f"Error occurred while start dts job: { str (e )} " )
570
- raise e
571
-
572
-
573
- async def getDtsJob (
574
- region_id : str = Field (description = "The region id of the dts job (e.g., 'cn-hangzhou')" ),
575
- dts_job_id : str = Field (description = "The job id of the dts job" )
576
- ) -> Dict [str , Any ]:
577
- try :
578
- client = get_dts_client (region_id )
579
- request = dts_20200101_models .DescribeDtsJobDetailRequest (
580
- region_id = region_id ,
581
- dts_job_id = dts_job_id
582
- )
583
- runtime = util_models .RuntimeOptions ()
584
- response = client .describe_dts_job_detail_with_options (request , runtime )
585
- return response .body .to_map ()
586
- except Exception as e :
587
- logger .error (f"Error occurred while describe dts job detail: { str (e )} " )
588
- raise e
589
-
590
-
591
439
# --- ToolRegistry Class ---
592
440
class ToolRegistry :
593
441
def __init__ (self , mcp : FastMCP ):
@@ -678,14 +526,6 @@ async def ask_database_configured(
678
526
return AskDatabaseResult (executed_sql = generated_sql ,
679
527
execution_result = f"Error: An issue occurred while executing the query: { str (e )} " )
680
528
681
- self .mcp .tool (name = "configureDtsJob" , description = "Configure a dts job." ,
682
- annotations = {"title" : "配置DTS任务" , "readOnlyHint" : False , "destructiveHint" : True })(
683
- configureDtsJob )
684
- self .mcp .tool (name = "startDtsJob" , description = "Start a dts job." ,
685
- annotations = {"title" : "启动DTS任务" , "readOnlyHint" : False , "destructiveHint" : True })(startDtsJob )
686
- self .mcp .tool (name = "getDtsJob" , description = "Get a dts job detail information." ,
687
- annotations = {"title" : "查询DTS任务详细信息" , "readOnlyHint" : True })(getDtsJob )
688
-
689
529
def _register_full_toolset (self ):
690
530
self .mcp .tool (name = "addInstance" ,
691
531
description = "Add an instance to DMS. The username and password are required. "
@@ -697,7 +537,8 @@ def _register_full_toolset(self):
697
537
add_instance )
698
538
self .mcp .tool (name = "listInstances" , description = "Search for instances from DMS." ,
699
539
annotations = {"title" : "搜索DMS实例列表" , "readOnlyHint" : True })(list_instance )
700
- self .mcp .tool (name = "getInstance" , description = "Retrieve detailed instance information from DMS using the host and port." ,
540
+ self .mcp .tool (name = "getInstance" ,
541
+ description = "Retrieve detailed instance information from DMS using the host and port." ,
701
542
annotations = {"title" : "获取DMS实例详情" , "readOnlyHint" : True })(get_instance )
702
543
self .mcp .tool (name = "searchDatabase" , description = "Search databases in DMS by name." ,
703
544
annotations = {"title" : "搜索DMS数据库" , "readOnlyHint" : True })(search_database )
@@ -759,14 +600,6 @@ async def create_data_change_order_wrapper(
759
600
self .mcp .tool (name = "generateSql" , description = "Generate SELECT-type SQL queries from natural language input." ,
760
601
annotations = {"title" : "自然语言转SQL (DMS)" , "readOnlyHint" : True })(nl2sql )
761
602
762
- self .mcp .tool (name = "configureDtsJob" , description = "Configure a dts job." ,
763
- annotations = {"title" : "配置DTS任务" , "readOnlyHint" : False , "destructiveHint" : True })(
764
- configureDtsJob )
765
- self .mcp .tool (name = "startDtsJob" , description = "Start a dts job." ,
766
- annotations = {"title" : "启动DTS任务" , "readOnlyHint" : False , "destructiveHint" : True })(startDtsJob )
767
- self .mcp .tool (name = "getDtsJob" , description = "Get a dts job detail information." ,
768
- annotations = {"title" : "查询DTS任务详细信息" , "readOnlyHint" : True })(getDtsJob )
769
-
770
603
771
604
# --- Lifespan Function ---
772
605
@asynccontextmanager
@@ -779,7 +612,15 @@ class AppState: pass
779
612
780
613
app .state = AppState ()
781
614
782
- app .state .default_database_id = None # Initialize default_database_id
615
+ # Initialize realLoginUid
616
+ app .state .real_login_uid = None
617
+ uid = os .getenv ("UID" )
618
+ if uid :
619
+ app .state .real_login_uid = uid
620
+ logger .info (f"RealLoginUid environment variable found: { uid } " )
621
+
622
+ # Initialize default_database_id
623
+ app .state .default_database_id = None
783
624
784
625
dms_connection_string = os .getenv ("CONNECTION_STRING" )
785
626
if dms_connection_string :
@@ -902,6 +743,8 @@ class AppState: pass
902
743
logger .info ("Shutting down DMS MCP Server via lifespan" )
903
744
if hasattr (app .state , 'default_database_id' ):
904
745
delattr (app .state , 'default_database_id' )
746
+ if hasattr (app .state , 'real_login_uid' ):
747
+ delattr (app .state , 'real_login_uid' )
905
748
906
749
907
750
# --- FastMCP Instance Creation & Server Run ---
0 commit comments