21
21
from warnings import warn
22
22
23
23
from gevent import sleep , spawn , get_hub
24
- from gevent .select import POLLIN , POLLOUT
25
24
from ssh2 .error_codes import LIBSSH2_ERROR_EAGAIN
26
25
from ssh2 .exceptions import SFTPHandleError , SFTPProtocolError , \
27
26
Timeout as SSH2Timeout
@@ -163,11 +162,14 @@ def _connect_proxy(self, proxy_host, proxy_port, proxy_pkey,
163
162
return proxy_local_port
164
163
165
164
def disconnect (self ):
166
- """Disconnect session, close socket if needed."""
165
+ """Attempt to disconnect session.
166
+
167
+ Any errors on calling disconnect are suppressed by this function.
168
+ """
167
169
self ._keepalive_greenlet = None
168
170
if self .session is not None :
169
171
try :
170
- self ._eagain ( self . session . disconnect )
172
+ self ._disconnect_eagain ( )
171
173
except Exception :
172
174
pass
173
175
self .session = None
@@ -316,10 +318,13 @@ def close_channel(self, channel):
316
318
def _eagain (self , func , * args , ** kwargs ):
317
319
return self ._eagain_errcode (func , LIBSSH2_ERROR_EAGAIN , * args , ** kwargs )
318
320
321
+ def _make_sftp_eagain (self ):
322
+ return self ._eagain (self .session .sftp_init )
323
+
319
324
def _make_sftp (self ):
320
325
"""Make SFTP client from open transport"""
321
326
try :
322
- sftp = self ._eagain ( self . session . sftp_init )
327
+ sftp = self ._make_sftp_eagain ( )
323
328
except Exception as ex :
324
329
raise SFTPError (ex )
325
330
return sftp
@@ -486,6 +491,27 @@ def copy_remote_file(self, remote_file, local_file, recurse=False,
486
491
logger .info ("Copied local file %s from remote destination %s:%s" ,
487
492
local_file , self .host , remote_file )
488
493
494
+ def _scp_recv_recursive (self , remote_file , local_file , sftp , encoding = 'utf-8' ):
495
+ try :
496
+ self ._eagain (sftp .stat , remote_file )
497
+ except (SFTPHandleError , SFTPProtocolError ):
498
+ msg = "Remote file or directory %s does not exist"
499
+ logger .error (msg , remote_file )
500
+ raise SCPError (msg , remote_file )
501
+ try :
502
+ dir_h = self ._sftp_openfh (sftp .opendir , remote_file )
503
+ except SFTPError :
504
+ # remote_file is not a dir, scp file
505
+ return self .scp_recv (remote_file , local_file , encoding = encoding )
506
+ try :
507
+ os .makedirs (local_file )
508
+ except OSError :
509
+ pass
510
+ file_list = self ._sftp_readdir (dir_h )
511
+ return self ._scp_recv_dir (file_list , remote_file ,
512
+ local_file , sftp ,
513
+ encoding = encoding )
514
+
489
515
def scp_recv (self , remote_file , local_file , recurse = False , sftp = None ,
490
516
encoding = 'utf-8' ):
491
517
"""Copy remote file to local host via SCP.
@@ -505,33 +531,13 @@ def scp_recv(self, remote_file, local_file, recurse=False, sftp=None,
505
531
enabled.
506
532
:type encoding: str
507
533
508
- :raises: :py:class:`pssh.exceptions.SCPError` when a directory is
509
- supplied to ``local_file`` and ``recurse`` is not set.
510
534
:raises: :py:class:`pssh.exceptions.SCPError` on errors copying file.
511
535
:raises: :py:class:`IOError` on local file IO errors.
512
536
:raises: :py:class:`OSError` on local OS errors like permission denied.
513
537
"""
514
538
if recurse :
515
539
sftp = self ._make_sftp () if sftp is None else sftp
516
- try :
517
- self ._eagain (sftp .stat , remote_file )
518
- except (SFTPHandleError , SFTPProtocolError ):
519
- msg = "Remote file or directory %s does not exist"
520
- logger .error (msg , remote_file )
521
- raise SCPError (msg , remote_file )
522
- try :
523
- dir_h = self ._sftp_openfh (sftp .opendir , remote_file )
524
- except SFTPError :
525
- pass
526
- else :
527
- try :
528
- os .makedirs (local_file )
529
- except OSError :
530
- pass
531
- file_list = self ._sftp_readdir (dir_h )
532
- return self ._scp_recv_dir (file_list , remote_file ,
533
- local_file , sftp ,
534
- encoding = encoding )
540
+ return self ._scp_recv_recursive (remote_file , local_file , sftp , encoding = encoding )
535
541
elif local_file .endswith ('/' ):
536
542
remote_filename = remote_file .rsplit ('/' )[- 1 ]
537
543
local_file += remote_filename
@@ -561,11 +567,6 @@ def _scp_recv(self, remote_file, local_file):
561
567
continue
562
568
total += size
563
569
local_fh .write (data )
564
- if total != fileinfo .st_size :
565
- msg = "Error copying data from remote file %s on host %s. " \
566
- "Copied %s out of %s total bytes"
567
- raise SCPError (msg , remote_file , self .host , total ,
568
- fileinfo .st_size )
569
570
finally :
570
571
local_fh .close ()
571
572
file_chan .close ()
@@ -690,16 +691,12 @@ def poll(self, timeout=None):
690
691
Blocks current greenlet only if socket has pending read or write operations
691
692
in the appropriate direction.
692
693
"""
693
- timeout = self .timeout if timeout is None else timeout
694
- directions = self .session .block_directions ()
695
- if directions == 0 :
696
- return
697
- events = 0
698
- if directions & LIBSSH2_SESSION_BLOCK_INBOUND :
699
- events = POLLIN
700
- if directions & LIBSSH2_SESSION_BLOCK_OUTBOUND :
701
- events |= POLLOUT
702
- self ._poll_socket (events , timeout = timeout )
694
+ self ._poll_errcodes (
695
+ self .session .block_directions ,
696
+ LIBSSH2_SESSION_BLOCK_INBOUND ,
697
+ LIBSSH2_SESSION_BLOCK_OUTBOUND ,
698
+ timeout = timeout ,
699
+ )
703
700
704
701
def _eagain_write (self , write_func , data , timeout = None ):
705
702
"""Write data with given write_func for an ssh2-python session while
0 commit comments