Skip to content

Commit 1bdd4ef

Browse files
authored
Sftp tests (#273)
* Added sftp tests for errors opening local and remote files and directories and sftp init. * Added disconnect exception test. * Test cleanup * Moved polling code to base client * Added libssh client tests * Updated tests for faster test running (60-70% speedup) * Updated setup.cfg
1 parent 6ac89bb commit 1bdd4ef

File tree

11 files changed

+377
-232
lines changed

11 files changed

+377
-232
lines changed

pssh/clients/base/single.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727

2828
from gevent import sleep, socket, Timeout as GTimeout
2929
from gevent.hub import Hub
30-
from gevent.select import poll
30+
from gevent.select import poll, POLLIN, POLLOUT
31+
3132
from ssh2.utils import find_eol
3233
from ssh2.exceptions import AgentConnectionError, AgentListIdentitiesError, \
3334
AgentAuthenticationError, AgentGetIdentityError
@@ -244,9 +245,12 @@ def open_shell(self, encoding='utf-8', read_timeout=None):
244245
def _shell(self, channel):
245246
raise NotImplementedError
246247

248+
def _disconnect_eagain(self):
249+
self._eagain(self.session.disconnect)
250+
247251
def _connect_init_session_retry(self, retries):
248252
try:
249-
self.session.disconnect()
253+
self._disconnect_eagain()
250254
except Exception:
251255
pass
252256
self.session = None
@@ -664,3 +668,15 @@ def _poll_socket(self, events, timeout=None):
664668
poller = poll()
665669
poller.register(self.sock, eventmask=events)
666670
poller.poll(timeout=timeout)
671+
672+
def _poll_errcodes(self, directions_func, inbound, outbound, timeout=None):
673+
timeout = self.timeout if timeout is None else timeout
674+
directions = directions_func()
675+
if directions == 0:
676+
return
677+
events = 0
678+
if directions & inbound:
679+
events = POLLIN
680+
if directions & outbound:
681+
events |= POLLOUT
682+
self._poll_socket(events, timeout=timeout)

pssh/clients/native/single.py

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from warnings import warn
2222

2323
from gevent import sleep, spawn, get_hub
24-
from gevent.select import POLLIN, POLLOUT
2524
from ssh2.error_codes import LIBSSH2_ERROR_EAGAIN
2625
from ssh2.exceptions import SFTPHandleError, SFTPProtocolError, \
2726
Timeout as SSH2Timeout
@@ -163,11 +162,14 @@ def _connect_proxy(self, proxy_host, proxy_port, proxy_pkey,
163162
return proxy_local_port
164163

165164
def disconnect(self):
166-
"""Disconnect session, close socket if needed."""
165+
"""Attempt to disconnect session.
166+
167+
Any errors on calling disconnect are suppressed by this function.
168+
"""
167169
self._keepalive_greenlet = None
168170
if self.session is not None:
169171
try:
170-
self._eagain(self.session.disconnect)
172+
self._disconnect_eagain()
171173
except Exception:
172174
pass
173175
self.session = None
@@ -316,10 +318,13 @@ def close_channel(self, channel):
316318
def _eagain(self, func, *args, **kwargs):
317319
return self._eagain_errcode(func, LIBSSH2_ERROR_EAGAIN, *args, **kwargs)
318320

321+
def _make_sftp_eagain(self):
322+
return self._eagain(self.session.sftp_init)
323+
319324
def _make_sftp(self):
320325
"""Make SFTP client from open transport"""
321326
try:
322-
sftp = self._eagain(self.session.sftp_init)
327+
sftp = self._make_sftp_eagain()
323328
except Exception as ex:
324329
raise SFTPError(ex)
325330
return sftp
@@ -486,6 +491,27 @@ def copy_remote_file(self, remote_file, local_file, recurse=False,
486491
logger.info("Copied local file %s from remote destination %s:%s",
487492
local_file, self.host, remote_file)
488493

494+
def _scp_recv_recursive(self, remote_file, local_file, sftp, encoding='utf-8'):
495+
try:
496+
self._eagain(sftp.stat, remote_file)
497+
except (SFTPHandleError, SFTPProtocolError):
498+
msg = "Remote file or directory %s does not exist"
499+
logger.error(msg, remote_file)
500+
raise SCPError(msg, remote_file)
501+
try:
502+
dir_h = self._sftp_openfh(sftp.opendir, remote_file)
503+
except SFTPError:
504+
# remote_file is not a dir, scp file
505+
return self.scp_recv(remote_file, local_file, encoding=encoding)
506+
try:
507+
os.makedirs(local_file)
508+
except OSError:
509+
pass
510+
file_list = self._sftp_readdir(dir_h)
511+
return self._scp_recv_dir(file_list, remote_file,
512+
local_file, sftp,
513+
encoding=encoding)
514+
489515
def scp_recv(self, remote_file, local_file, recurse=False, sftp=None,
490516
encoding='utf-8'):
491517
"""Copy remote file to local host via SCP.
@@ -505,33 +531,13 @@ def scp_recv(self, remote_file, local_file, recurse=False, sftp=None,
505531
enabled.
506532
:type encoding: str
507533
508-
:raises: :py:class:`pssh.exceptions.SCPError` when a directory is
509-
supplied to ``local_file`` and ``recurse`` is not set.
510534
:raises: :py:class:`pssh.exceptions.SCPError` on errors copying file.
511535
:raises: :py:class:`IOError` on local file IO errors.
512536
:raises: :py:class:`OSError` on local OS errors like permission denied.
513537
"""
514538
if recurse:
515539
sftp = self._make_sftp() if sftp is None else sftp
516-
try:
517-
self._eagain(sftp.stat, remote_file)
518-
except (SFTPHandleError, SFTPProtocolError):
519-
msg = "Remote file or directory %s does not exist"
520-
logger.error(msg, remote_file)
521-
raise SCPError(msg, remote_file)
522-
try:
523-
dir_h = self._sftp_openfh(sftp.opendir, remote_file)
524-
except SFTPError:
525-
pass
526-
else:
527-
try:
528-
os.makedirs(local_file)
529-
except OSError:
530-
pass
531-
file_list = self._sftp_readdir(dir_h)
532-
return self._scp_recv_dir(file_list, remote_file,
533-
local_file, sftp,
534-
encoding=encoding)
540+
return self._scp_recv_recursive(remote_file, local_file, sftp, encoding=encoding)
535541
elif local_file.endswith('/'):
536542
remote_filename = remote_file.rsplit('/')[-1]
537543
local_file += remote_filename
@@ -561,11 +567,6 @@ def _scp_recv(self, remote_file, local_file):
561567
continue
562568
total += size
563569
local_fh.write(data)
564-
if total != fileinfo.st_size:
565-
msg = "Error copying data from remote file %s on host %s. " \
566-
"Copied %s out of %s total bytes"
567-
raise SCPError(msg, remote_file, self.host, total,
568-
fileinfo.st_size)
569570
finally:
570571
local_fh.close()
571572
file_chan.close()
@@ -690,16 +691,12 @@ def poll(self, timeout=None):
690691
Blocks current greenlet only if socket has pending read or write operations
691692
in the appropriate direction.
692693
"""
693-
timeout = self.timeout if timeout is None else timeout
694-
directions = self.session.block_directions()
695-
if directions == 0:
696-
return
697-
events = 0
698-
if directions & LIBSSH2_SESSION_BLOCK_INBOUND:
699-
events = POLLIN
700-
if directions & LIBSSH2_SESSION_BLOCK_OUTBOUND:
701-
events |= POLLOUT
702-
self._poll_socket(events, timeout=timeout)
694+
self._poll_errcodes(
695+
self.session.block_directions,
696+
LIBSSH2_SESSION_BLOCK_INBOUND,
697+
LIBSSH2_SESSION_BLOCK_OUTBOUND,
698+
timeout=timeout,
699+
)
703700

704701
def _eagain_write(self, write_func, data, timeout=None):
705702
"""Write data with given write_func for an ssh2-python session while

pssh/clients/ssh/single.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import logging
1919

2020
from gevent import sleep, spawn, Timeout as GTimeout, joinall
21-
from gevent.select import POLLIN, POLLOUT
2221
from ssh import options
2322
from ssh.session import Session, SSH_READ_PENDING, SSH_WRITE_PENDING
2423
from ssh.key import import_privkey_file, import_cert_file, copy_cert_to_privkey
@@ -145,7 +144,7 @@ def _init_session(self, retries=1):
145144
self.session.set_socket(self.sock)
146145
logger.debug("Session started, connecting with existing socket")
147146
try:
148-
self.session.connect()
147+
self._session_connect()
149148
except Exception as ex:
150149
if retries < self.num_retries:
151150
return self._connect_init_session_retry(retries=retries+1)
@@ -155,6 +154,9 @@ def _init_session(self, retries=1):
155154
ex.port = self.port
156155
raise ex
157156

157+
def _session_connect(self):
158+
self.session.connect()
159+
158160
def auth(self):
159161
if self.gssapi_auth or (self.gssapi_server_identity or self.gssapi_client_identity):
160162
try:
@@ -166,8 +168,6 @@ def auth(self):
166168
return super(SSHClient, self).auth()
167169

168170
def _password_auth(self):
169-
if not self.password:
170-
raise AuthenticationError("All authentication methods failed")
171171
try:
172172
self.session.userauth_password(self.password)
173173
except Exception as ex:
@@ -270,7 +270,6 @@ def wait_finished(self, host_output, timeout=None):
270270
with GTimeout(seconds=timeout, exception=Timeout):
271271
joinall((host_output.buffers.stdout.reader, host_output.buffers.stderr.reader))
272272
logger.debug("Readers finished, closing channel")
273-
# Close channel
274273
self.close_channel(channel)
275274

276275
def finished(self, channel):
@@ -305,16 +304,12 @@ def close_channel(self, channel):
305304

306305
def poll(self, timeout=None):
307306
"""ssh-python based co-operative gevent poll on session socket."""
308-
timeout = self.timeout if timeout is None else timeout
309-
directions = self.session.get_poll_flags()
310-
if directions == 0:
311-
return
312-
events = 0
313-
if directions & SSH_READ_PENDING:
314-
events = POLLIN
315-
if directions & SSH_WRITE_PENDING:
316-
events |= POLLOUT
317-
self._poll_socket(events, timeout=timeout)
307+
self._poll_errcodes(
308+
self.session.get_poll_flags,
309+
SSH_READ_PENDING,
310+
SSH_WRITE_PENDING,
311+
timeout=timeout,
312+
)
318313

319314
def _eagain(self, func, *args, **kwargs):
320315
"""Run function given and handle EAGAIN for an ssh-python session"""

setup.cfg

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,8 @@ universal = 1
99

1010
[flake8]
1111
max-line-length = 100
12+
13+
[tool:pytest]
14+
# addopts=--cov=pssh --cov-append --cov-report=term --cov-report=term-missing
15+
testpaths =
16+
tests

tests/native/base_ssh2_case.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def setUpClass(cls):
6060
pkey=PKEY_FILENAME,
6161
num_retries=1,
6262
identity_auth=False,
63+
retry_delay=.1,
6364
)
6465

6566
@classmethod

0 commit comments

Comments
 (0)