diff --git a/django_mongodb_cli/repo.py b/django_mongodb_cli/repo.py index 50c0be3..2a1c2de 100644 --- a/django_mongodb_cli/repo.py +++ b/django_mongodb_cli/repo.py @@ -118,6 +118,45 @@ def clone(repo, context, repo_names, all_repos, install): click.echo(context.get_help()) +@repo.command() +@click.argument("repo_names", nargs=-1) +@click.option("-a", "--all-repos", is_flag=True, help="Install all repositories") +@click.pass_context +@pass_repo +def install(repo, context, repo_names, all_repos): + """Install repositories (like 'clone -i').""" + repos, url_pattern, _ = get_repos("pyproject.toml") + repo_name_map = get_repo_name_map(repos, url_pattern) + + if all_repos and repo_names: + click.echo("Cannot specify both repo names and --all-repos") + return + + # If -a/--all-repos is given + if all_repos: + click.echo(f"Updating {len(repo_name_map)} repositories...") + for repo_name, repo_url in repo_name_map.items(): + clone_path = os.path.join(context.obj.home, repo_name) + if os.path.exists(clone_path): + install_package(clone_path) + return + + # If specific repo names are given + if repo_names: + not_found = [] + for repo_name in repo_names: + clone_path = os.path.join(context.obj.home, repo_name) + if os.path.exists(clone_path): + install_package(clone_path) + else: + not_found.append(repo_name) + for name in not_found: + click.echo(f"Repository '{name}' not found.") + return + + click.echo(context.get_help()) + + @repo.command(context_settings={"ignore_unknown_options": True}) @click.argument("repo_name", required=False) @click.argument("args", nargs=-1) diff --git a/qe.py b/qe.py index 4be1a1a..7c50ace 100644 --- a/qe.py +++ b/qe.py @@ -1,5 +1,4 @@ import code -import os from bson.binary import STANDARD from bson.codec_options import CodecOptions @@ -8,25 +7,20 @@ from pymongo.errors import EncryptedCollectionError from django_mongodb_backend.encryption import ( get_auto_encryption_opts, - get_customer_master_key, + get_kms_providers, + get_key_vault_namespace, ) -HOME = os.environ.get("HOME") - -kms_providers = { - "local": { - "key": get_customer_master_key(), - }, -} +kms_providers = get_kms_providers() +key_vault_namespace = get_key_vault_namespace() client = MongoClient( auto_encryption_opts=get_auto_encryption_opts( - crypt_shared_lib_path=f"{HOME}/Downloads/mongo_crypt_shared_v1-macos-arm64-enterprise-8.0.10/lib/mongo_crypt_v1.dylib", + key_vault_namespace=key_vault_namespace, kms_providers=kms_providers, ) ) -key_vault_namespace = client.options.auto_encryption_opts._key_vault_namespace codec_options = CodecOptions(uuid_representation=STANDARD) client_encryption = ClientEncryption( kms_providers, key_vault_namespace, client, codec_options diff --git a/test/settings/django.py b/test/settings/django.py index 8dd684e..e1b3baa 100644 --- a/test/settings/django.py +++ b/test/settings/django.py @@ -2,22 +2,25 @@ from django_mongodb_backend import encryption, parse_uri -kms_providers = encryption.get_kms_providers() - -auto_encryption_opts = encryption.get_auto_encryption_opts( - kms_providers=kms_providers, +# Queryable Encryption settings +KEY_VAULT_NAMESPACE = encryption.get_key_vault_namespace() +KMS_PROVIDERS = encryption.get_kms_providers() +KMS_PROVIDER = encryption.KMS_PROVIDER +AUTO_ENCRYPTION_OPTS = encryption.get_auto_encryption_opts( + key_vault_namespace=KEY_VAULT_NAMESPACE, + kms_providers=KMS_PROVIDERS, ) DATABASE_URL = os.environ.get("MONGODB_URI", "mongodb://localhost:27017") DATABASES = { "default": parse_uri( DATABASE_URL, - db_name="djangotests", + db_name="test", ), "encrypted": parse_uri( DATABASE_URL, - options={"auto_encryption_opts": auto_encryption_opts}, - db_name="encrypted_djangotests", + options={"auto_encryption_opts": AUTO_ENCRYPTION_OPTS}, + db_name="encrypted", ), } @@ -25,3 +28,14 @@ PASSWORD_HASHERS = ("django.contrib.auth.hashers.MD5PasswordHasher",) SECRET_KEY = "django_tests_secret_key" USE_TZ = False + + +class TestRouter: + def allow_migrate(self, db, app_label, model_name=None, **hints): + if db == "encrypted": + if app_label != "encryption_": + return False + return None + + +DATABASE_ROUTERS = [TestRouter()]