Skip to content

Commit 1c7352f

Browse files
authored
Merge pull request #19 from yanbeipang/feat/env_database
Feat/env database
2 parents 9e3e463 + c76bdc2 commit 1c7352f

File tree

3 files changed

+49
-286
lines changed

3 files changed

+49
-286
lines changed

pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
[project]
22
name = "alibabacloud-dms-mcp-server"
3-
version = "0.1.11"
3+
version = "0.1.12"
44
description = "MCP Server for AlibabaCloud DMS"
55
readme = "README.md"
66
authors = [
77
{ name = "AlibabaCloud DMS" }
88
]
99
requires-python = ">=3.10"
1010
dependencies = [
11-
"alibabacloud-dms-enterprise20181101>=1.72.0",
12-
"alibabacloud-dts20200101>=5.8.1",
11+
"alibabacloud-dms-enterprise20181101>=1.75.0",
1312
"httpx>=0.28.1",
1413
"mcp[cli]>=1.8.1",
1514
]

src/alibabacloud_dms_mcp_server/server.py

Lines changed: 41 additions & 198 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,20 @@
11
import os
22
import logging
3-
import random
4-
import string
5-
import json
63
from contextlib import asynccontextmanager
74
from collections.abc import AsyncGenerator
85
from typing import Dict, Any, Optional, List, Union
9-
from urllib.parse import urlparse
106

117
from pydantic import Field, BaseModel, ConfigDict
8+
129
from mcp.server.fastmcp import FastMCP
1310

1411
from alibabacloud_dms_enterprise20181101.client import Client as dms_enterprise20181101Client
1512
from alibabacloud_tea_openapi import models as open_api_models
1613
from alibabacloud_dms_enterprise20181101 import models as dms_enterprise_20181101_models
17-
from alibabacloud_dts20200101 import models as dts_20200101_models
18-
from alibabacloud_dts20200101.client import Client as DtsClient
19-
from alibabacloud_tea_openapi.models import Config
20-
from alibabacloud_tea_util import models as util_models
2114

2215
# --- Global Logger ---
2316
logger = logging.getLogger(__name__)
2417

25-
g_reserved = '''{
26-
"targetTableMode": "0",
27-
"dbListCaseChangeMode": "default",
28-
"isAnalyzer": false,
29-
"eventMove": false,
30-
"tableAnalyze": false,
31-
"whitelist.dms.online.ddl.enable": false,
32-
"sqlparser.dms.original.ddl": true,
33-
"whitelist.ghost.online.ddl.enable": false,
34-
"sqlparser.ghost.original.ddl": false,
35-
"privilegeMigration": false,
36-
"definer": false,
37-
"privilegeDbList": "[]",
38-
"maxRetryTime": 43200,
39-
"retry.blind.seconds": 600,
40-
"srcSSL": "0",
41-
"srcMySQLType": "HighAvailability",
42-
"destSSL": "0",
43-
"a2aFlag": "2.0",
44-
"channelInfo": "mcp",
45-
"autoStartModulesAfterConfig": "none"
46-
}
47-
'''
48-
4918

5019
# --- Pydantic Models ---
5120
class MyBaseModel(BaseModel):
@@ -159,20 +128,6 @@ def create_client() -> dms_enterprise20181101Client:
159128
return dms_enterprise20181101Client(config)
160129

161130

162-
def get_dts_client(region_id: str):
163-
config = Config(
164-
access_key_id=os.getenv('ALIBABA_CLOUD_ACCESS_KEY_ID'),
165-
access_key_secret=os.getenv('ALIBABA_CLOUD_ACCESS_KEY_SECRET'),
166-
security_token=os.getenv('ALIBABA_CLOUD_SECURITY_TOKEN'),
167-
region_id=region_id,
168-
protocol="https",
169-
connect_timeout=10 * 1000,
170-
read_timeout=300 * 1000
171-
)
172-
client = DtsClient(config)
173-
return client
174-
175-
176131
async def add_instance(
177132
db_user: str = Field(description="The username used to connect to the database"),
178133
db_password: str = Field(description="The password used to connect to the database"),
@@ -195,6 +150,8 @@ async def add_instance(
195150
req.instance_id = instance_resource_id
196151
if region:
197152
req.region = region
153+
if mcp.state.real_login_uid:
154+
req.real_login_user_uid = mcp.state.real_login_uid
198155
try:
199156
resp = client.simply_add_instance(req)
200157
return InstanceInfo(**resp.body.to_map()) if resp and resp.body else InstanceInfo()
@@ -211,6 +168,8 @@ async def get_instance(
211168
client = create_client()
212169
req = dms_enterprise_20181101_models.GetInstanceRequest(host=host, port=port)
213170
if sid: req.sid = sid
171+
if mcp.state.real_login_uid:
172+
req.real_login_user_uid = mcp.state.real_login_uid
214173
try:
215174
resp = client.get_instance(req)
216175
instance_data = resp.body.to_map().get('Instance', {}) if resp and resp.body else {}
@@ -221,10 +180,13 @@ async def get_instance(
221180

222181

223182
async def list_instance(
224-
search_key: Optional[str] = Field(default=None, description="Optional search key (e.g., instance host, instance alias, etc.)"),
225-
db_type: Optional[str] = Field(default=None, description="Optional instanceType, or called dbType (e.g., mysql, polardb, oracle, "
226-
"postgresql, sqlserver, polardb-pg, etc.)"),
227-
env_type: Optional[str] = Field(default=None, description="Optional instance environment type (e.g., product, dev, test, etc. )")
183+
search_key: Optional[str] = Field(default=None,
184+
description="Optional search key (e.g., instance host, instance alias, etc.)"),
185+
db_type: Optional[str] = Field(default=None,
186+
description="Optional instanceType, or called dbType (e.g., mysql, polardb, oracle, "
187+
"postgresql, sqlserver, polardb-pg, etc.)"),
188+
env_type: Optional[str] = Field(default=None,
189+
description="Optional instance environment type (e.g., product, dev, test, etc. )")
228190
) -> List[InstanceDetail]:
229191
client = create_client()
230192
req = dms_enterprise_20181101_models.ListInstancesRequest()
@@ -234,6 +196,8 @@ async def list_instance(
234196
req.db_type = db_type
235197
if env_type:
236198
req.env_type = env_type
199+
if mcp.state.real_login_uid:
200+
req.real_login_user_uid = mcp.state.real_login_uid
237201
try:
238202
resp = client.list_instances(req)
239203

@@ -262,6 +226,8 @@ async def search_database(
262226
client = create_client()
263227
req = dms_enterprise_20181101_models.SearchDatabaseRequest(search_key=search_key, page_number=page_number,
264228
page_size=page_size)
229+
if mcp.state.real_login_uid:
230+
req.real_login_user_uid = mcp.state.real_login_uid
265231
try:
266232
resp = client.search_database(req)
267233
if not resp or not resp.body: return []
@@ -287,8 +253,12 @@ async def get_database(
287253
) -> DatabaseDetail:
288254
client = create_client()
289255
req = dms_enterprise_20181101_models.GetDatabaseRequest(host=host, port=port, schema_name=schema_name)
256+
290257
if sid:
291258
req.sid = sid
259+
if mcp.state.real_login_uid:
260+
req.real_login_user_uid = mcp.state.real_login_uid
261+
292262
try:
293263
resp = client.get_database(req)
294264
db_data = resp.body.to_map().get('Database', {}) if resp and resp.body else {}
@@ -310,6 +280,8 @@ async def list_tables( # Renamed from listTable to follow convention
310280
req = dms_enterprise_20181101_models.ListTablesRequest(database_id=database_id, search_name=search_name,
311281
page_number=page_number, page_size=page_size,
312282
return_guid=True)
283+
if mcp.state.real_login_uid:
284+
req.real_login_user_uid = mcp.state.real_login_uid
313285
try:
314286
resp = client.list_tables(req)
315287
return resp.body.to_map() if resp and resp.body else {}
@@ -324,6 +296,8 @@ async def get_meta_table_detail_info(
324296
) -> TableDetail:
325297
client = create_client()
326298
req = dms_enterprise_20181101_models.GetMetaTableDetailInfoRequest(table_guid=table_guid)
299+
if mcp.state.real_login_uid:
300+
req.real_login_user_uid = mcp.state.real_login_uid
327301
try:
328302
resp = client.get_meta_table_detail_info(req)
329303
detail_info = resp.body.to_map().get('DetailInfo', {}) if resp and resp.body else {}
@@ -351,6 +325,8 @@ async def execute_script(
351325
) -> ExecuteScriptResult: # Return the object, __str__ will be used by wrapper if needed
352326
client = create_client()
353327
req = dms_enterprise_20181101_models.ExecuteScriptRequest(db_id=database_id, script=script, logic=logic)
328+
if mcp.state.real_login_uid:
329+
req.real_login_user_uid = mcp.state.real_login_uid
354330
try:
355331
resp = client.execute_script(req)
356332
if not resp or not resp.body:
@@ -447,6 +423,8 @@ async def nl2sql(
447423
client = create_client()
448424
req = dms_enterprise_20181101_models.GenerateSqlFromNLRequest(db_id=database_id, question=question)
449425
if knowledge: req.knowledge = knowledge
426+
if mcp.state.real_login_uid:
427+
req.real_login_user_uid = mcp.state.real_login_uid
450428
try:
451429
resp = client.generate_sql_from_nl(req)
452430
if not resp or not resp.body: return SqlResult(sql=None)
@@ -458,136 +436,6 @@ async def nl2sql(
458436
raise
459437

460438

461-
async def configureDtsJob(
462-
region_id: str = Field(description="The region id of the dts job (e.g., 'cn-hangzhou')"),
463-
job_type: str = Field(
464-
description="The type of job (synchronization job: SYNC, migration job: MIGRATION, data check job: CHECK)"),
465-
source_endpoint_region: str = Field(description="The source endpoint region ID"),
466-
source_endpoint_instance_type: str = Field(
467-
description="The source endpoint instance type (RDS, ECS, EXPRESS, CEN, DG)"),
468-
source_endpoint_engine_name: str = Field(
469-
description="The source endpoint engine name (MySQL, PostgreSQL, SQLServer)"),
470-
source_endpoint_instance_id: str = Field(description="The source endpoint instance ID (e.g., 'rm-xxx')"),
471-
source_endpoint_user_name: str = Field(description="The source endpoint user name"),
472-
source_endpoint_password: str = Field(description="The source endpoint password"),
473-
destination_endpoint_region: str = Field(description="The destination endpoint region ID"),
474-
destination_endpoint_instance_type: str = Field(
475-
description="The destination endpoint instance type (RDS, ECS, EXPRESS, CEN, DG)"),
476-
destination_endpoint_engine_name: str = Field(
477-
description="The destination endpoint engine name (MySQL, PostgreSQL, SQLServer)"),
478-
destination_endpoint_instance_id: str = Field(
479-
description="The destination endpoint instance ID (e.g., 'rm-xxx')"),
480-
destination_endpoint_user_name: str = Field(description="The destination endpoint user name"),
481-
destination_endpoint_password: str = Field(description="The destination endpoint password"),
482-
db_list: Dict[str, Any] = Field(
483-
description='The database objects in JSON format, example 1: migration dtstest database, db_list should like {"dtstest":{"name":"dtstest","all":true}}; example 2: migration one table task01 in dtstest database, db_list should like {"dtstest":{"name":"dtstest","all":false,"Table":{"task01":{"name":"task01","all":true}}}}; example 3: migration two tables task01 and task02 in dtstest database, db_list should like {"dtstest":{"name":"dtstest","all":false,"Table":{"task01":{"name":"task01","all":true},"task02":{"name":"task02","all":true}}}}')
484-
) -> Dict[str, Any]:
485-
try:
486-
db_list_str = json.dumps(db_list, separators=(',', ':'))
487-
logger.info(f"Configure dts job with db_list: {db_list_str}")
488-
489-
# init dts client
490-
client = get_dts_client(region_id)
491-
runtime = util_models.RuntimeOptions()
492-
493-
# create dts instance
494-
create_dts_instance_request = dts_20200101_models.CreateDtsInstanceRequest(
495-
region_id=region_id,
496-
type=job_type,
497-
source_region=source_endpoint_region,
498-
destination_region=destination_endpoint_region,
499-
source_endpoint_engine_name=source_endpoint_engine_name,
500-
destination_endpoint_engine_name=destination_endpoint_engine_name,
501-
pay_type='PostPaid',
502-
quantity=1,
503-
min_du=1,
504-
max_du=4,
505-
instance_class='micro'
506-
)
507-
508-
create_dts_instance_response = client.create_dts_instance_with_options(create_dts_instance_request, runtime)
509-
logger.info(f"Create dts instance response: {create_dts_instance_response.body.to_map()}")
510-
dts_job_id = create_dts_instance_response.body.to_map()['JobId']
511-
512-
# configure dts job
513-
ran_job_name = 'dtsmcp-' + ''.join(random.sample(string.ascii_letters + string.digits, 6))
514-
custom_reserved = json.loads(g_reserved)
515-
dts_mcp_channel = os.getenv('DTS_MCP_CHANNEL')
516-
if dts_mcp_channel and len(dts_mcp_channel) > 0:
517-
logger.info(f"Configure dts job with custom dts mcp channel: {dts_mcp_channel}")
518-
custom_reserved['channelInfo'] = dts_mcp_channel
519-
custom_reserved_str = json.dumps(custom_reserved, separators=(',', ':'))
520-
logger.info(f"Configure dts job with reserved: {custom_reserved_str}")
521-
configure_dts_job_request = dts_20200101_models.ConfigureDtsJobRequest(
522-
region_id=region_id,
523-
dts_job_name=ran_job_name,
524-
source_endpoint_instance_type=source_endpoint_instance_type,
525-
source_endpoint_engine_name=source_endpoint_engine_name,
526-
source_endpoint_instance_id=source_endpoint_instance_id,
527-
source_endpoint_region=source_endpoint_region,
528-
source_endpoint_user_name=source_endpoint_user_name,
529-
source_endpoint_password=source_endpoint_password,
530-
destination_endpoint_instance_type=destination_endpoint_instance_type,
531-
destination_endpoint_instance_id=destination_endpoint_instance_id,
532-
destination_endpoint_engine_name=destination_endpoint_engine_name,
533-
destination_endpoint_region=destination_endpoint_region,
534-
destination_endpoint_user_name=destination_endpoint_user_name,
535-
destination_endpoint_password=destination_endpoint_password,
536-
structure_initialization=True,
537-
data_initialization=True,
538-
data_synchronization=False,
539-
job_type=job_type,
540-
db_list=db_list_str,
541-
reserve=custom_reserved_str
542-
)
543-
544-
if dts_job_id and len(dts_job_id) > 0:
545-
configure_dts_job_request.dts_job_id = dts_job_id
546-
547-
configure_dts_job_response = client.configure_dts_job_with_options(configure_dts_job_request, runtime)
548-
logger.info(f"Configure dts job response: {configure_dts_job_response.body.to_map()}")
549-
return configure_dts_job_response.body.to_map()
550-
except Exception as e:
551-
logger.error(f"Error occurred while configure dts job: {str(e)}")
552-
raise e
553-
554-
555-
async def startDtsJob(
556-
region_id: str = Field(description="The region id of the dts job (e.g., 'cn-hangzhou')"),
557-
dts_job_id: str = Field(description="The job id of the dts job")
558-
) -> Dict[str, Any]:
559-
try:
560-
client = get_dts_client(region_id)
561-
request = dts_20200101_models.StartDtsJobRequest(
562-
region_id=region_id,
563-
dts_job_id=dts_job_id
564-
)
565-
runtime = util_models.RuntimeOptions()
566-
response = client.start_dts_job_with_options(request, runtime)
567-
return response.body.to_map()
568-
except Exception as e:
569-
logger.error(f"Error occurred while start dts job: {str(e)}")
570-
raise e
571-
572-
573-
async def getDtsJob(
574-
region_id: str = Field(description="The region id of the dts job (e.g., 'cn-hangzhou')"),
575-
dts_job_id: str = Field(description="The job id of the dts job")
576-
) -> Dict[str, Any]:
577-
try:
578-
client = get_dts_client(region_id)
579-
request = dts_20200101_models.DescribeDtsJobDetailRequest(
580-
region_id=region_id,
581-
dts_job_id=dts_job_id
582-
)
583-
runtime = util_models.RuntimeOptions()
584-
response = client.describe_dts_job_detail_with_options(request, runtime)
585-
return response.body.to_map()
586-
except Exception as e:
587-
logger.error(f"Error occurred while describe dts job detail: {str(e)}")
588-
raise e
589-
590-
591439
# --- ToolRegistry Class ---
592440
class ToolRegistry:
593441
def __init__(self, mcp: FastMCP):
@@ -678,14 +526,6 @@ async def ask_database_configured(
678526
return AskDatabaseResult(executed_sql=generated_sql,
679527
execution_result=f"Error: An issue occurred while executing the query: {str(e)}")
680528

681-
self.mcp.tool(name="configureDtsJob", description="Configure a dts job.",
682-
annotations={"title": "配置DTS任务", "readOnlyHint": False, "destructiveHint": True})(
683-
configureDtsJob)
684-
self.mcp.tool(name="startDtsJob", description="Start a dts job.",
685-
annotations={"title": "启动DTS任务", "readOnlyHint": False, "destructiveHint": True})(startDtsJob)
686-
self.mcp.tool(name="getDtsJob", description="Get a dts job detail information.",
687-
annotations={"title": "查询DTS任务详细信息", "readOnlyHint": True})(getDtsJob)
688-
689529
def _register_full_toolset(self):
690530
self.mcp.tool(name="addInstance",
691531
description="Add an instance to DMS. The username and password are required. "
@@ -697,7 +537,8 @@ def _register_full_toolset(self):
697537
add_instance)
698538
self.mcp.tool(name="listInstances", description="Search for instances from DMS.",
699539
annotations={"title": "搜索DMS实例列表", "readOnlyHint": True})(list_instance)
700-
self.mcp.tool(name="getInstance", description="Retrieve detailed instance information from DMS using the host and port.",
540+
self.mcp.tool(name="getInstance",
541+
description="Retrieve detailed instance information from DMS using the host and port.",
701542
annotations={"title": "获取DMS实例详情", "readOnlyHint": True})(get_instance)
702543
self.mcp.tool(name="searchDatabase", description="Search databases in DMS by name.",
703544
annotations={"title": "搜索DMS数据库", "readOnlyHint": True})(search_database)
@@ -759,14 +600,6 @@ async def create_data_change_order_wrapper(
759600
self.mcp.tool(name="generateSql", description="Generate SELECT-type SQL queries from natural language input.",
760601
annotations={"title": "自然语言转SQL (DMS)", "readOnlyHint": True})(nl2sql)
761602

762-
self.mcp.tool(name="configureDtsJob", description="Configure a dts job.",
763-
annotations={"title": "配置DTS任务", "readOnlyHint": False, "destructiveHint": True})(
764-
configureDtsJob)
765-
self.mcp.tool(name="startDtsJob", description="Start a dts job.",
766-
annotations={"title": "启动DTS任务", "readOnlyHint": False, "destructiveHint": True})(startDtsJob)
767-
self.mcp.tool(name="getDtsJob", description="Get a dts job detail information.",
768-
annotations={"title": "查询DTS任务详细信息", "readOnlyHint": True})(getDtsJob)
769-
770603

771604
# --- Lifespan Function ---
772605
@asynccontextmanager
@@ -779,7 +612,15 @@ class AppState: pass
779612

780613
app.state = AppState()
781614

782-
app.state.default_database_id = None # Initialize default_database_id
615+
# Initialize realLoginUid
616+
app.state.real_login_uid = None
617+
uid = os.getenv("UID")
618+
if uid:
619+
app.state.real_login_uid = uid
620+
logger.info(f"RealLoginUid environment variable found: {uid}")
621+
622+
# Initialize default_database_id
623+
app.state.default_database_id = None
783624

784625
dms_connection_string = os.getenv("CONNECTION_STRING")
785626
if dms_connection_string:
@@ -902,6 +743,8 @@ class AppState: pass
902743
logger.info("Shutting down DMS MCP Server via lifespan")
903744
if hasattr(app.state, 'default_database_id'):
904745
delattr(app.state, 'default_database_id')
746+
if hasattr(app.state, 'real_login_uid'):
747+
delattr(app.state, 'real_login_uid')
905748

906749

907750
# --- FastMCP Instance Creation & Server Run ---

0 commit comments

Comments
 (0)