From 3887067777d9ef8025230ca2a0f19fbb4b0ff318 Mon Sep 17 00:00:00 2001 From: Victor Moene Date: Mon, 6 Oct 2025 14:49:23 +0200 Subject: [PATCH] Added vagrant vm spawning provider Ticket: ENT-5725 Signed-off-by: Victor Moene --- MANIFEST.in | 1 + cf_remote/Vagrantfile | 64 +++++++++ cf_remote/aramid.py | 6 +- cf_remote/commands.py | 156 +++++++++++++--------- cf_remote/main.py | 72 ++++++++-- cf_remote/paths.py | 4 + cf_remote/spawn.py | 299 ++++++++++++++++++++++++++++++++++++------ cf_remote/ssh.py | 16 ++- 8 files changed, 499 insertions(+), 119 deletions(-) create mode 100644 cf_remote/Vagrantfile diff --git a/MANIFEST.in b/MANIFEST.in index f48a209..7b325ee 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1,2 @@ include cf_remote/nt-discovery.sh +include cf_remote/Vagrantfile diff --git a/cf_remote/Vagrantfile b/cf_remote/Vagrantfile new file mode 100644 index 0000000..a913bf0 --- /dev/null +++ b/cf_remote/Vagrantfile @@ -0,0 +1,64 @@ +# -*- mode: ruby -*- +# vi: set ft=ruby : +require 'json' + +config_path = File.expand_path("../config.json", __FILE__) +if !File.file?(config_path) + abort "Need a config file for this host at " + config_path +end +config = JSON.parse(File.read(config_path)) + +VM_BOX = config['box'] +VM_COUNT = config['count'] +VM_MEMORY = config['memory'] +VM_CPUS = config['cpus'] +VM_PROVISION = File.expand_path(config['provision'], __FILE__) +VM_NAME = config['name'] +SYNC_FOLDER = config['sync_folder'] + +Vagrant.configure("2") do |config| + + (1..VM_COUNT).each do |i| + + config.vm.define "#{VM_NAME}-#{i}" do |node| + node.vm.hostname = "#{VM_NAME}-#{i}" + node.vm.network "private_network", ip: "192.168.56.#{9 + i}" + node.vm.box = VM_BOX + + node.vm.provision "shell" do |s| + ssh_pub_key = File.readlines("#{Dir.home}/.ssh/id_rsa.pub").first.strip + s.inline = <<-SHELL + echo #{ssh_pub_key} >> /home/vagrant/.ssh/authorized_keys + echo #{ssh_pub_key} >> /root/.ssh/authorized_keys + SHELL + end + node.vm.provision "bootstrap", type: "shell", path: VM_PROVISION + node.vm.synced_folder ".", "/vagrant", + rsync__args: ["--verbose", "--archive", "--delete", "-z", "--links"] + node.vm.synced_folder "#{SYNC_FOLDER}", "/synched_folder", + rsync__args: ["--verbose", "--archive", "--delete", "-z", "--links"] + + # https://bugs.launchpad.net/cloud-images/+bug/1874453 + NOW = Time.now.strftime("%d.%m.%Y.%H:%M:%S") + FILENAME = "serial-debug-%s.log" % NOW + node.vm.provider "virtualbox" do |vb| + vb.memory = VM_MEMORY + vb.cpus = VM_CPUS + vb.customize [ "guestproperty", "set", :id, "/VirtualBox/GuestAdd/VBoxService/--timesync-set-threshold", 1000 ] + vb.customize [ "modifyvm", :id, "--uart1", "0x3F8", "4" ] + vb.customize [ "modifyvm", :id, "--uartmode1", "file", + File.join(Dir.pwd, FILENAME) ] + end + + node.vm.provider :libvirt do |v, override| + v.memory = VM_MEMORY + v.cpus = VM_CPUS + # Fedora 30+ uses QEMU sessions by default, breaking pretty much all + # previously working Vagrantfiles: + # https://fedoraproject.org/wiki/Changes/Vagrant_2.2_with_QEMU_Session#Upgrade.2Fcompatibility_impact + v.qemu_use_session = false + override.vm.synced_folder "#{SYNC_FOLDER}", "/synched_folder", type: :rsync + end + end + end +end diff --git a/cf_remote/aramid.py b/cf_remote/aramid.py index 565ea5f..c822f48 100644 --- a/cf_remote/aramid.py +++ b/cf_remote/aramid.py @@ -29,6 +29,7 @@ import time from urllib.parse import urlparse from cf_remote import log +from cf_remote.paths import SSH_CONFIG_FPATH DEFAULT_SSH_ARGS = [ "-oLogLevel=ERROR", @@ -37,6 +38,8 @@ "-oBatchMode=yes", "-oHostKeyAlgorithms=+ssh-rsa", "-oPubkeyAcceptedKeyTypes=+ssh-rsa", + "-F", + "{}".format(SSH_CONFIG_FPATH), ] """Default arguments to use with all SSH commands (incl. 'scp' and 'rsync')""" @@ -338,10 +341,11 @@ def execute( if host.port != _DEFAULT_SSH_PORT: port_args += ["-p", str(host.port)] proc = subprocess.Popen( - ["ssh", host.login] + ["ssh"] + DEFAULT_SSH_ARGS + port_args + host.extra_ssh_args + + [host.login] + [commands[i]], stdout=subprocess.PIPE, stderr=subprocess.PIPE, diff --git a/cf_remote/commands.py b/cf_remote/commands.py index e998a7e..924e944 100644 --- a/cf_remote/commands.py +++ b/cf_remote/commands.py @@ -18,9 +18,11 @@ from cf_remote.web import download_package from cf_remote.paths import ( cf_remote_dir, + cf_remote_packages_dir, CLOUD_CONFIG_FPATH, CLOUD_STATE_FPATH, - cf_remote_packages_dir, + SSH_CONFIG_FPATH, + SSH_CONFIGS_JSON_FPATH, ) from cf_remote.utils import ( copy_file, @@ -37,7 +39,14 @@ CFRChecksumError, CFRUserError, ) -from cf_remote.spawn import VM, VMRequest, Providers, AWSCredentials, GCPCredentials +from cf_remote.spawn import ( + CloudVM, + VMRequest, + Providers, + AWSCredentials, + GCPCredentials, + VagrantVM, +) from cf_remote.spawn import spawn_vms, destroy_vms, dump_vms_info, get_cloud_driver from cf_remote import log from cf_remote import cloud_data @@ -393,6 +402,9 @@ def spawn( network=None, public_ip=True, extend_group=False, + vagrant_cpus=None, + vagrant_sync=None, + vagrant_provision=None, ): creds_data = None if os.path.exists(CLOUD_CONFIG_FPATH): @@ -469,13 +481,20 @@ def spawn( network=network, role=role, spawned_cb=print_progress_dot, + vagrant_cpus=vagrant_cpus, + vagrant_sync=vagrant_sync, + vagrant_provision=vagrant_provision, ) except ValueError as e: print("\nError: Failed to spawn VMs - " + str(e)) return 1 print("DONE") - if public_ip and (not all(vm.public_ips for vm in vms)): + if ( + provider != Providers.VAGRANT + and public_ip + and (not all(vm.public_ips for vm in vms)) + ): print("Waiting for VMs to get IP addresses...", end="") sys.stdout.flush() # STDOUT is line-buffered while not all(vm.public_ips for vm in vms): @@ -488,6 +507,16 @@ def spawn( else: vms_info[group_key] = dump_vms_info(vms) + if provider == Providers.VAGRANT: + vmdir = vms[0].vmdir + ssh_config = read_json(SSH_CONFIGS_JSON_FPATH) + if not ssh_config: + ssh_config = {} + + with open(os.path.join(vmdir, "vagrant-ssh-config"), "r") as f: + ssh_config[group_key] = f.read() + write_json(SSH_CONFIGS_JSON_FPATH, ssh_config) + write_json(CLOUD_STATE_FPATH, vms_info) print("Details about the spawned VMs can be found in %s" % CLOUD_STATE_FPATH) @@ -510,6 +539,36 @@ def _delete_saved_group(vms_info, group_name): del vms_info[group_name] +def _get_cloud_vms(provider, creds, region, group): + if creds is None: + raise CFRExitError("Missing/incomplete {} credentials".format(provider.upper())) + driver = get_cloud_driver(provider, creds, region) + + assert driver is not None + + nodes = driver.list_nodes() + for name, vm_info in group.items(): + if name == "meta": + continue + vm_uuid = vm_info["uuid"] + vm = CloudVM.get_by_uuid(vm_uuid, nodes=nodes) + if vm is not None: + yield vm + else: + print("VM '%s' not found in the clouds" % vm_uuid) + + +def _get_vagrant_vms(group): + for name, vm_info in group.items(): + if name == "meta": + continue + vm = VagrantVM.get_by_info(name, vm_info) + if vm is not None: + yield vm + else: + print("VM '%s' not found locally" % name) + + def destroy(group_name=None): if os.path.exists(CLOUD_CONFIG_FPATH): creds_data = read_json(CLOUD_CONFIG_FPATH) @@ -549,6 +608,7 @@ def destroy(group_name=None): raise CFRUserError("No saved VMs found in '{}'".format(CLOUD_STATE_FPATH)) to_destroy = [] + group_names = None if group_name: if not group_name.startswith("@"): group_name = "@" + group_name @@ -556,85 +616,46 @@ def destroy(group_name=None): print("Group '%s' not found" % group_name) return 1 + group_names = [group_name] + else: + group_names = [key for key in vms_info.keys() if key.startswith("@")] + + ssh_config = read_json(SSH_CONFIGS_JSON_FPATH) + assert group_names is not None + for group_name in group_names: if _is_saved_group(vms_info, group_name): _delete_saved_group(vms_info, group_name) - write_json(CLOUD_STATE_FPATH, vms_info) - return 0 - - print("Destroying hosts in the '%s' group" % group_name) + continue region = vms_info[group_name]["meta"]["region"] provider = vms_info[group_name]["meta"]["provider"] - if provider not in ["aws", "gcp"]: + if provider not in ["aws", "gcp", "vagrant"]: raise CFRUserError( - "Unsupported provider '{}' encountered in '{}', only aws / gcp is supported".format( + "Unsupported provider '{}' encountered in '{}', only aws, gcp and vagrant are supported".format( provider, CLOUD_STATE_FPATH ) ) - driver = None + group = vms_info[group_name] + vms = [] if provider == "aws": - if aws_creds is None: - raise CFRExitError("Missing/incomplete AWS credentials") - driver = get_cloud_driver(Providers.AWS, aws_creds, region) + vms = _get_cloud_vms(Providers.AWS, aws_creds, region, group) if provider == "gcp": - if gcp_creds is None: - raise CFRExitError("Missing/incomplete GCP credentials") - driver = get_cloud_driver(Providers.GCP, gcp_creds, region) - assert driver is not None + vms = _get_cloud_vms(Providers.GCP, gcp_creds, region, group) + if provider == "vagrant": + vms = _get_vagrant_vms(group) - nodes = driver.list_nodes() - for name, vm_info in vms_info[group_name].items(): - if name == "meta": - continue - vm_uuid = vm_info["uuid"] - vm = VM.get_by_uuid(vm_uuid, nodes=nodes) - if vm is not None: - to_destroy.append(vm) - else: - print("VM '%s' not found in the clouds" % vm_uuid) - del vms_info[group_name] - else: - print("Destroying all hosts") - for group_name in [key for key in vms_info.keys() if key.startswith("@")]: - if _is_saved_group(vms_info, group_name): - _delete_saved_group(vms_info, group_name) - continue + for vm in vms: + to_destroy.append(vm) - region = vms_info[group_name]["meta"]["region"] - provider = vms_info[group_name]["meta"]["provider"] - if provider not in ["aws", "gcp"]: - raise CFRUserError( - "Unsupported provider '{}' encountered in '{}', only aws / gcp is supported".format( - provider, CLOUD_STATE_FPATH - ) - ) + del vms_info[group_name] - driver = None - if provider == "aws": - if aws_creds is None: - raise CFRExitError("Missing/incomplete AWS credentials") - driver = get_cloud_driver(Providers.AWS, aws_creds, region) - if provider == "gcp": - if gcp_creds is None: - raise CFRExitError("Missing/incomplete GCP credentials") - driver = get_cloud_driver(Providers.GCP, gcp_creds, region) - assert driver is not None - - nodes = driver.list_nodes() - for name, vm_info in vms_info[group_name].items(): - if name == "meta": - continue - vm_uuid = vm_info["uuid"] - vm = VM.get_by_uuid(vm_uuid, nodes=nodes) - if vm is not None: - to_destroy.append(vm) - else: - print("VM '%s' not found in the clouds" % vm_uuid) - del vms_info[group_name] + if ssh_config and group_name in ssh_config: + del ssh_config[group_name] destroy_vms(to_destroy) write_json(CLOUD_STATE_FPATH, vms_info) + write_json(SSH_CONFIGS_JSON_FPATH, ssh_config) return 0 @@ -655,6 +676,11 @@ def list_platforms(): return 0 +def list_boxes(): + result = subprocess.run(["vagrant", "box", "list"]) + return result.returncode + + def init_cloud_config(): if os.path.exists(CLOUD_CONFIG_FPATH): print("File %s already exists" % CLOUD_CONFIG_FPATH) @@ -1010,7 +1036,7 @@ def connect_cmd(hosts): raise CFRExitError("You can only connect to one host at a time") print("Opening a SSH command shell...") - r = subprocess.run(["ssh", hosts[0]]) + r = subprocess.run(["ssh", "-F", SSH_CONFIG_FPATH, hosts[0]]) if r.returncode == 0: return 0 if r.returncode < 0: diff --git a/cf_remote/main.py b/cf_remote/main.py index d208798..6684b67 100644 --- a/cf_remote/main.py +++ b/cf_remote/main.py @@ -224,12 +224,15 @@ def _get_arg_parser(): sp.add_argument( "--list-platforms", help="List supported platforms", action="store_true" ) + sp.add_argument( + "--list-boxes", help="List installed vagrant boxes", action="store_true" + ) sp.add_argument( "--init-config", help="Initialize configuration file for spawn functionality", action="store_true", ) - sp.add_argument("--platform", help="Platform to use", type=str) + sp.add_argument("--platform", help="Platform or vagrant box to use", type=str) sp.add_argument("--count", default=1, help="How many hosts to spawn", type=int) sp.add_argument( "--role", help="Role of the hosts", choices=["hub", "hubs", "client", "clients"] @@ -242,8 +245,24 @@ def _get_arg_parser(): help="Append the new VMs to a pre-existing group", action="store_true", ) - sp.add_argument("--aws", help="Spawn VMs in AWS (default)", action="store_true") - sp.add_argument("--gcp", help="Spawn VMs in GCP", action="store_true") + sp.add_argument( + "--provider", + help="VM provider", + type=str, + default="aws", + choices=["aws", "gcp", "vagrant"], + ) + sp.add_argument("--cpus", help="Number of CPUs of the vagrant instances", type=int) + sp.add_argument( + "--sync-root", + help="Root folder of synchronized folders of vagrant instance", + type=str, + ) + sp.add_argument( + "--provision", + help="full path to provision shell script for Vagrant VM", + type=str, + ) sp.add_argument("--size", help="Size/type of the instances", type=str) sp.add_argument( "--network", help="network/subnet to assign the VMs to (GCP only)", type=str @@ -360,24 +379,39 @@ def run_command_with_args(command, args) -> int: elif command == "spawn": if args.list_platforms: return commands.list_platforms() + if args.list_boxes: + return commands.list_boxes() if args.init_config: return commands.init_cloud_config() if args.name and "," in args.name: raise CFRExitError("Group --name may not contain commas") - if args.aws and args.gcp: - raise CFRExitError("--aws and --gcp cannot be used at the same time") if args.role.endswith("s"): # role should be singular args.role = args.role[:-1] - if args.gcp: + if args.provider == "gcp": provider = Providers.GCP - else: - # AWS is currently also the default + elif args.provider == "aws": provider = Providers.AWS if args.network: raise CFRExitError("--network not supported for AWS") if args.no_public_ip: raise CFRExitError("--no-public-ip not supported for AWS") + else: + assert args.provider == "vagrant" + provider = Providers.VAGRANT + + if provider != Providers.VAGRANT: + if args.cpus: + raise CFRExitError("--cpus not supported for {}".format(args.provider)) + if args.sync_root: + raise CFRExitError( + "--sync-root not supported for {}".format(args.provider) + ) + if args.provision: + raise CFRExitError( + "--provision not supported for {}".format(args.provider) + ) + if args.network and (args.network.count("/") != 1): raise CFRExitError( "Invalid network specified, needs to be in the network/subnet format" @@ -393,6 +427,9 @@ def run_command_with_args(command, args) -> int: network=args.network, public_ip=not args.no_public_ip, extend_group=args.append, + vagrant_cpus=args.cpus, + vagrant_sync=args.sync_root, + vagrant_provision=args.provision, ) elif command == "show": return commands.show(args.ansible_inventory) @@ -457,7 +494,12 @@ def validate_command(command, args): ) args.remote_command = args.remote_command[0] - if command == "spawn" and not args.list_platforms and not args.init_config: + if ( + command == "spawn" + and not args.list_platforms + and not args.init_config + and not args.list_boxes + ): # The above options don't require any other options/arguments (TODO: # --provider), but otherwise all have to be given if not args.platform: @@ -524,23 +566,27 @@ def get_cloud_hosts(name, bootstrap_ips=False): if name == "meta": continue log.debug("found name '{}' in state, info='{}'".format(name, info)) - hosts.append(info) + hosts.append((name, info)) else: if name in state: # host_name given and exists at the top level - hosts.append(state[name]) + hosts.append((name, state[name])) else: for group_name in [key for key in state.keys() if key.startswith("@")]: if name in state[group_name]: - hosts.append(state[group_name][name]) + hosts.append((name, state[group_name][name])) ret = [] - for host in hosts: + for name, host in hosts: if bootstrap_ips and "private_ips" in host: key = "private_ips" else: key = "public_ips" + if "vmdir" in host: + ret.append(name) + continue + ips = host.get(key, []) if len(ips) > 0: if host.get("user"): diff --git a/cf_remote/paths.py b/cf_remote/paths.py index 9886f24..39f96cb 100644 --- a/cf_remote/paths.py +++ b/cf_remote/paths.py @@ -42,3 +42,7 @@ def cf_remote_packages_dir(subdir=None): CLOUD_CONFIG_FPATH = cf_remote_file(CLOUD_CONFIG_FNAME) CLOUD_STATE_FNAME = "cloud_state.json" CLOUD_STATE_FPATH = cf_remote_file(CLOUD_STATE_FNAME) +SSH_CONFIG_FNAME = "cf_remote_ssh_config" +SSH_CONFIG_FPATH = cf_remote_file(SSH_CONFIG_FNAME) +SSH_CONFIGS_JSON_FNAME = "ssh_configs.json" +SSH_CONFIGS_JSON_FPATH = cf_remote_file(SSH_CONFIGS_JSON_FNAME) diff --git a/cf_remote/spawn.py b/cf_remote/spawn.py index 19d7cfb..f2806a5 100644 --- a/cf_remote/spawn.py +++ b/cf_remote/spawn.py @@ -1,9 +1,15 @@ from datetime import datetime +from posixpath import dirname, join import string import random +import os +import subprocess +import json +import shutil from collections import namedtuple from enum import Enum from multiprocessing.dummy import Pool +from pathlib import Path from libcloud.common.types import InvalidCredsError from libcloud.compute.types import Provider @@ -13,10 +19,12 @@ from libcloud.compute.drivers.gce import GCENodeDriver from cf_remote.cloud_data import aws_image_criteria, aws_defaults -from cf_remote.utils import whoami +from cf_remote.paths import cf_remote_dir, CLOUD_STATE_FPATH +from cf_remote.utils import whoami, copy_file, canonify, read_json from cf_remote import log from cf_remote import cloud_data +VAGRANT_VM_IP_START = "192.168.56.9" _NAME_RANDOM_PART_LENGTH = 4 AWSCredentials = namedtuple("AWSCredentials", ["key", "secret", "token"]) @@ -33,6 +41,7 @@ class Providers(Enum): AWS = 1 GCP = 2 + VAGRANT = 3 def __str__(self): return self.name.lower() @@ -43,6 +52,55 @@ class MissingInfoError(ValueError): class VM: + + def __init__(self, name, role, platform, size, user, provider): + self._name = name + self._platform = platform + self._size = size + self._user = user + self._provider = provider + self.role = role + + @property + def platform(self): + return self._platform + + @property + def name(self): + return self._name + + @property + def size(self): + return self._size + + @property + def user(self): + return self._user + + @property + def provider(self): + return self._provider + + @property + def info(self): + ret = { + "platform": self.platform, + "size": self.size, + } + if self.user: + ret["user"] = self.user + if self.role: + ret["role"] = self.role + if self.provider: + ret["provider"] = str(self.provider) + + return ret + + def __str__(self): + return "%s: %s" % (self.name, self.info) + + +class CloudVM(VM): def __init__( self, name, @@ -56,16 +114,11 @@ def __init__( user=None, provider=None, ): - self._name = name + super().__init__(name, role, platform, size, user, provider) self._driver = driver self._node = node - self._platform = platform - self._size = size self._key_pair = key_pair self._sec_groups = security_groups - self._user = user - self.role = role - self._provider = provider @classmethod def get_by_ip(cls, ip, driver=None, nodes=None): @@ -150,10 +203,6 @@ def get_by_info(cls, driver, vm_info, nodes=None): ) return None - @property - def name(self): - return self._name - @property def uuid(self): assert self._node is not None @@ -163,10 +212,6 @@ def uuid(self): def driver(self): return self._driver - @property - def platform(self): - return self._platform - @property def region(self): try: @@ -184,10 +229,6 @@ def region(self): else: return str(region) - @property - def size(self): - return self._size - @property def key_pair(self): return self._key_pair @@ -196,14 +237,6 @@ def key_pair(self): def security_groups(self): return self._sec_groups - @property - def user(self): - return self._user - - @property - def provider(self): - return self._provider - @property def _data(self): # We need to refresh this every time to get fresh data because @@ -244,24 +277,17 @@ def private_ips(self): @property def info(self): - ret = { - "platform": self.platform, + ret = super().info + ret |= { "region": self.region, - "size": self.size, "private_ips": self.private_ips, "public_ips": self.public_ips, "uuid": self.uuid, } - if self.user: - ret["user"] = self.user - if self.role: - ret["role"] = self.role if self.key_pair: ret["key_pair"] = self.key_pair if self.security_groups: ret["security_groups"] = self.security_groups - if self.provider: - ret["provider"] = str(self.provider) return ret def __str__(self): @@ -464,7 +490,7 @@ def spawn_vm_in_aws( % (platform, ami, size, e) ) - return VM( + return CloudVM( name, driver, node, @@ -517,7 +543,9 @@ def spawn_vm_in_gcp( assert isinstance(driver, GCENodeDriver) node = driver.create_node(name, size, platform, **kwargs) - return VM(name, driver, node, role, platform, size, None, None, None, Providers.GCP) + return CloudVM( + name, driver, node, role, platform, size, None, None, None, Providers.GCP + ) class GCPSpawnTask: @@ -556,8 +584,11 @@ def spawn_vms( network=None, role=None, spawned_cb=None, + vagrant_cpus=None, + vagrant_sync=None, + vagrant_provision=None, ): - if provider not in (Providers.AWS, Providers.GCP): + if provider not in (Providers.AWS, Providers.GCP, Providers.VAGRANT): raise ValueError("Unsupported provider %s" % provider) if (provider == Providers.AWS) and (key_pair is None): @@ -581,6 +612,17 @@ def spawn_vms( if spawned_cb is not None: spawned_cb(vm) ret.append(vm) + elif provider == Providers.VAGRANT: + ret = spawn_vm_in_vagrant( + vm_requests[0].name, + vm_requests[0].platform, + len(vm_requests), + role, + cpus=vagrant_cpus, + memory=size, + sync_folder=vagrant_sync, + provision_script=vagrant_provision, + ) else: tasks = [ GCPSpawnTask( @@ -611,9 +653,18 @@ def spawn_vms( def destroy_vms(vms): if not vms: return + + folders = set(vm.vmdir for vm in vms if getattr(vm, "vmdir", False)) + with Pool(len(vms)) as pool: pool.map(lambda vm: vm.destroy(), vms) + try: + for f in folders: + shutil.rmtree(f) + except: + pass + def dump_vms_info(vms): current_time = datetime.now().astimezone().replace(microsecond=0).isoformat() @@ -635,3 +686,173 @@ def dump_vms_info(vms): del info[key] ret[vm.name] = info return ret + + +class VagrantVM(VM): + + def __init__(self, name, ip, vmdir, platform, role, size, cpus, sync_folder): + super().__init__(name, role, platform, size, "vagrant", Providers.VAGRANT) + + self.public_ips = [ip] + self.region = None + self.vmdir = vmdir + self.cpus = cpus + self.sync_folder = sync_folder + + log.debug( + "Created VM with the following information: \n\t- {}\n\t- {}\n\t- {}\n\t- {}".format( + name, ip, vmdir, sync_folder + ) + ) + + @property + def info(self): + ret = super().info + ret |= { + "private_ips": [], + "public_ips": self.public_ips, + "vmdir": self.vmdir, + "cpus": self.cpus, + "region": self.region, + } + if self.sync_folder: + ret["sync_folder"] = self.sync_folder + elif os.environ.get("CF_REMOTE_SYNC_ROOT"): + ret["sync_folder"] = os.environ.get("CF_REMOTE_SYNC_ROOT") + else: + ret["sync_folder"] = "/northern.tech" + + return ret + + @classmethod + def get_by_info(cls, name, info): + return cls( + name, + info["public_ips"][0], + info["vmdir"], + info["platform"], + info["role"], + info["size"], + info["cpus"], + info["sync_folder"], + ) + + def destroy(self): + + vagrant_env = os.environ.copy() + vagrant_env["VAGRANT_CWD"] = self.vmdir + + return subprocess.run( + ["vagrant", "destroy", "-f", self.name], env=vagrant_env + ).returncode + + +def get_last_vagrant_ip_address(): + state = read_json(CLOUD_STATE_FPATH) + + if not state: + return VAGRANT_VM_IP_START + + ip = VAGRANT_VM_IP_START + + for group in state.values(): + if group["meta"]["provider"] != "vagrant": + continue + for host, info in group.items(): + if host == "meta": + continue + + ip = min(ip, info["public_ips"][0]) + + return ip + + +def spawn_vm_in_vagrant( + name, + box, + count, + role, + cpus=None, + memory=None, + sync_folder=None, + provision_script=None, +): + name = canonify(name).replace("_", "-") + vagrantdir = cf_remote_dir(os.path.join("vagrant", name)) + os.makedirs(vagrantdir, exist_ok=True) + + # Copy Vagrantfile to .cfengine/cf-remote/vagrant + vagrantfile = join(dirname(__file__), "Vagrantfile") + copy_file(vagrantfile, os.path.join(vagrantdir, "Vagrantfile")) + + if cpus is None: + cpus = 1 + if memory is None: + memory = 1024 + if sync_folder is None: + sync_folder = os.getenv("CF_REMOTE_SYNC_ROOT", "/tmp") + if sync_folder == "/tmp": + log.warning( + "The synched folder has not been specified. The default synched folder will be set to '/tmp'. \ + \nYou can override it with CF_REMOTE_SYNC_ROOT" + ) + + bootstrap = os.path.join(vagrantdir, "bootstrap.sh") + if provision_script is None: + Path(bootstrap).touch(exist_ok=True) + else: + copy_file(provision_script, bootstrap) + + config = { + "box": box, + "count": count, + "memory": memory, + "cpus": cpus, + "provision": bootstrap, + "name": name, + "sync_folder": sync_folder, + } + + log.debug("Saving the vagrant VM config") + log.debug("Config: {}".format(json.dumps(config, indent=2))) + with open(os.path.join(vagrantdir, "config.json"), "w") as f: + json.dump(config, f, indent=4) + + # Prepare command + command_args = ["vagrant", "up"] + vagrant_env = os.environ.copy() + vagrant_env["VAGRANT_CWD"] = vagrantdir + + if provision_script is None: + command_args.append("--no-provision") + + log.debug("Starting the VM(s)") + result = subprocess.run(command_args, env=vagrant_env, stderr=subprocess.PIPE) + + if result.returncode != 0: + raise Exception(result.stderr.decode()) + + log.debug("Copying vagrant ssh config") + + ssh_config = os.path.join(vagrantdir, "vagrant-ssh-config") + with open(ssh_config, "w") as f: + subprocess.run(["vagrant", "ssh-config"], env=vagrant_env, stdout=f) + + # Calculate IP addresses + base_ip = get_last_vagrant_ip_address() + start, end = base_ip.rsplit(".", maxsplit=1) + end = int(end) + 1 + + return [ + VagrantVM( + "{}-{}".format(name, i + 1), + "{}.{}".format(start, end % 255), + vagrantdir, + box, + role, + memory, + cpus, + sync_folder, + ) + for i in range(count) + ] diff --git a/cf_remote/ssh.py b/cf_remote/ssh.py index 01bc66d..a7ef5ab 100644 --- a/cf_remote/ssh.py +++ b/cf_remote/ssh.py @@ -9,8 +9,9 @@ from cf_remote import aramid from cf_remote import log from cf_remote import paths -from cf_remote.utils import whoami +from cf_remote.utils import whoami, read_json from cf_remote.aramid import ExecutionResult +from cf_remote.paths import SSH_CONFIG_FPATH, SSH_CONFIGS_JSON_FPATH class LocalConnection: @@ -118,9 +119,22 @@ def __exit__(self, *args, **kwargs): pass +def _build_ssh_config(): + configs = read_json(SSH_CONFIGS_JSON_FPATH) + + with open(SSH_CONFIG_FPATH, "w") as f: + if configs: + for config in configs.values(): + f.write(config) + + def connect(host, users=None): log.debug("Connecting to '%s'" % host) log.debug("users= '%s'" % users) + + log.debug("Building config file") + _build_ssh_config() + parts = urlparse("ssh://%s" % host) host = parts.hostname if not users and parts.username: