Skip to content

Commit b541b48

Browse files
committed
Implement automatic sudo password prompting (#609)
* Implement automatic sudo password prompting for the local connector. This works by executing the command, detecting the sudo failed password message output and then propmts the user for the password and re-executes the command using that. * Implement automatic sudo password prompting for the SSH connector. * Fix command debug log calls. * Add test for automatic sudo prompting in SSH connector.
1 parent d96a011 commit b541b48

File tree

4 files changed

+104
-40
lines changed

4 files changed

+104
-40
lines changed

pyinfra/api/connectors/local.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pyinfra.api.util import get_file_io
1414

1515
from .util import (
16+
execute_command_with_sudo_retry,
1617
get_sudo_password,
1718
make_unix_command,
1819
run_local_process,
@@ -69,20 +70,25 @@ def run_shell_command(
6970
put_file=put_file,
7071
)
7172

72-
command = make_unix_command(command, state=state, **command_kwargs)
73-
actual_command = command.get_raw_value()
73+
def execute_command():
74+
unix_command = make_unix_command(command, state=state, **command_kwargs)
75+
actual_command = unix_command.get_raw_value()
7476

75-
logger.debug('--> Running command on localhost: {0}'.format(command))
77+
logger.debug('--> Running command on localhost: {0}'.format(unix_command))
7678

77-
if print_input:
78-
click.echo('{0}>>> {1}'.format(host.print_prefix, command), err=True)
79+
if print_input:
80+
click.echo('{0}>>> {1}'.format(host.print_prefix, unix_command), err=True)
7981

80-
return_code, combined_output = run_local_process(
81-
actual_command,
82-
stdin=stdin,
83-
timeout=timeout,
84-
print_output=print_output,
85-
print_prefix=host.print_prefix,
82+
return run_local_process(
83+
actual_command,
84+
stdin=stdin,
85+
timeout=timeout,
86+
print_output=print_output,
87+
print_prefix=host.print_prefix,
88+
)
89+
90+
return_code, combined_output = execute_command_with_sudo_retry(
91+
host, command_kwargs, execute_command,
8692
)
8793

8894
if success_exit_codes:

pyinfra/api/connectors/ssh.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
from .sshuserclient import SSHClient
3434
from .util import (
35+
execute_command_with_sudo_retry,
3536
get_sudo_password,
3637
make_unix_command,
3738
read_buffers_into_queue,
@@ -273,41 +274,48 @@ def run_shell_command(
273274
put_file=put_file,
274275
)
275276

276-
command = make_unix_command(command, state=state, **command_kwargs)
277-
actual_command = command.get_raw_value()
277+
def execute_command():
278+
unix_command = make_unix_command(command, state=state, **command_kwargs)
279+
actual_command = unix_command.get_raw_value()
278280

279-
logger.debug('Running command on {0}: (pty={1}) {2}'.format(
280-
host.name, get_pty, command,
281-
))
281+
logger.debug('Running command on {0}: (pty={1}) {2}'.format(
282+
host.name, get_pty, unix_command,
283+
))
282284

283-
if print_input:
284-
click.echo('{0}>>> {1}'.format(host.print_prefix, command), err=True)
285+
if print_input:
286+
click.echo('{0}>>> {1}'.format(host.print_prefix, unix_command), err=True)
285287

286-
# Run it! Get stdout, stderr & the underlying channel
287-
stdin_buffer, stdout_buffer, stderr_buffer = host.connection.exec_command(
288-
actual_command,
289-
get_pty=get_pty,
290-
)
288+
# Run it! Get stdout, stderr & the underlying channel
289+
stdin_buffer, stdout_buffer, stderr_buffer = host.connection.exec_command(
290+
actual_command,
291+
get_pty=get_pty,
292+
)
291293

292-
if stdin:
293-
write_stdin(stdin, stdin_buffer)
294+
if stdin:
295+
write_stdin(stdin, stdin_buffer)
294296

295-
combined_output = read_buffers_into_queue(
296-
stdout_buffer,
297-
stderr_buffer,
298-
timeout=timeout,
299-
print_output=print_output,
300-
print_prefix=host.print_prefix,
301-
)
297+
combined_output = read_buffers_into_queue(
298+
stdout_buffer,
299+
stderr_buffer,
300+
timeout=timeout,
301+
print_output=print_output,
302+
print_prefix=host.print_prefix,
303+
)
304+
305+
logger.debug('Waiting for exit status...')
306+
exit_status = stdout_buffer.channel.recv_exit_status()
307+
logger.debug('Command exit status: {0}'.format(exit_status))
302308

303-
logger.debug('Waiting for exit status...')
304-
exit_status = stdout_buffer.channel.recv_exit_status()
305-
logger.debug('Command exit status: {0}'.format(exit_status))
309+
return exit_status, combined_output
310+
311+
return_code, combined_output = execute_command_with_sudo_retry(
312+
host, command_kwargs, execute_command,
313+
)
306314

307315
if success_exit_codes:
308-
status = exit_status in success_exit_codes
316+
status = return_code in success_exit_codes
309317
else:
310-
status = exit_status == 0
318+
status = return_code == 0
311319

312320
if return_combined_output:
313321
return status, combined_output

pyinfra/api/connectors/util.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,21 @@ def _print(line):
5151
_print(line)
5252

5353

54+
def execute_command_with_sudo_retry(host, command_kwargs, execute_command):
55+
return_code, combined_output = execute_command()
56+
57+
if return_code != 0 and combined_output:
58+
last_line = combined_output[-1][1]
59+
if last_line == 'sudo: a password is required':
60+
command_kwargs['use_sudo_password'] = get_sudo_password(
61+
host,
62+
use_sudo_password=True, # ask for the password
63+
)
64+
return_code, combined_output = execute_command()
65+
66+
return return_code, combined_output
67+
68+
5469
def run_local_process(
5570
command,
5671
stdin=None,
@@ -151,11 +166,11 @@ def write_stdin(stdin, buffer):
151166
buffer.close()
152167

153168

154-
def get_sudo_password(state, host, use_sudo_password, run_shell_command, put_file):
169+
def get_sudo_password(host, use_sudo_password):
155170
sudo_askpass_uploaded = host.connector_data.get('sudo_askpass_uploaded', False)
156171
if not sudo_askpass_uploaded:
157-
put_file(state, host, get_sudo_askpass_exe(), SUDO_ASKPASS_EXE_FILENAME)
158-
run_shell_command(state, host, 'chmod +x {0}'.format(SUDO_ASKPASS_EXE_FILENAME))
172+
host.put_file(get_sudo_askpass_exe(), SUDO_ASKPASS_EXE_FILENAME)
173+
host.run_shell_command('chmod +x {0}'.format(SUDO_ASKPASS_EXE_FILENAME))
159174
host.connector_data['sudo_askpass_uploaded'] = True
160175

161176
if use_sudo_password is True:

tests/test_connectors/test_ssh.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,41 @@ def test_run_shell_command_error(self, fake_ssh_client):
463463
assert len(out) == 3
464464
assert out[0] is False
465465

466+
@patch('pyinfra.api.connectors.ssh.SSHClient')
467+
@patch('pyinfra.api.connectors.util.get_sudo_password')
468+
def test_run_shell_command_retry_for_sudo_password(
469+
self,
470+
fake_get_sudo_password,
471+
fake_ssh_client,
472+
):
473+
fake_get_sudo_password.return_value = ['FILENAME', 'PASSWORD']
474+
475+
fake_ssh = MagicMock()
476+
fake_stdin = MagicMock()
477+
fake_stdout = MagicMock()
478+
fake_stderr = ['sudo: a password is required']
479+
fake_ssh.exec_command.return_value = fake_stdin, fake_stdout, fake_stderr
480+
481+
fake_ssh_client.return_value = fake_ssh
482+
483+
inventory = make_inventory(hosts=('somehost',))
484+
state = State(inventory, Config())
485+
host = inventory.get_host('somehost')
486+
host.connect(state)
487+
488+
command = 'echo hi'
489+
return_values = [1, 0] # return 0 on the second call
490+
fake_stdout.channel.recv_exit_status.side_effect = lambda: return_values.pop(0)
491+
492+
out = host.run_shell_command(command)
493+
assert len(out) == 3
494+
assert out[0] is True
495+
assert fake_get_sudo_password.called
496+
fake_ssh.exec_command.assert_called_with(
497+
"env SUDO_ASKPASS=FILENAME PYINFRA_SUDO_PASSWORD=PASSWORD sh -c 'echo hi'",
498+
get_pty=False,
499+
)
500+
466501
@patch('pyinfra.api.connectors.ssh.SSHClient')
467502
@patch('pyinfra.api.connectors.ssh.SFTPClient')
468503
def test_put_file(self, fake_sftp_client, fake_ssh_client):

0 commit comments

Comments
 (0)