Skip to content

Instantly share code, notes, and snippets.

@offbyone
Created May 16, 2018 15:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save offbyone/2f1181293104466af02ad3209f559741 to your computer and use it in GitHub Desktop.
Save offbyone/2f1181293104466af02ad3209f559741 to your computer and use it in GitHub Desktop.
diff --git a/paramiko/__init__.py b/paramiko/__init__.py
index 4d80bf2..5429cfd 100644
--- a/paramiko/__init__.py
+++ b/paramiko/__init__.py
@@ -21,15 +21,22 @@ import sys
from paramiko._version import __version__, __version_info__
from paramiko.transport import SecurityOptions, Transport
from paramiko.client import (
- SSHClient, MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy,
+ SSHClient,
+ MissingHostKeyPolicy,
+ AutoAddPolicy,
+ RejectPolicy,
WarningPolicy,
)
from paramiko.auth_handler import AuthHandler
from paramiko.ssh_gss import GSSAuth, GSS_AUTH_AVAILABLE, GSS_EXCEPTIONS
from paramiko.channel import Channel, ChannelFile
from paramiko.ssh_exception import (
- SSHException, PasswordRequiredException, BadAuthenticationType,
- ChannelException, BadHostKeyException, AuthenticationException,
+ SSHException,
+ PasswordRequiredException,
+ BadAuthenticationType,
+ ChannelException,
+ BadHostKeyException,
+ AuthenticationException,
ProxyCommandFailure,
)
from paramiko.server import ServerInterface, SubsystemHandler, InteractiveQuery
@@ -54,14 +61,25 @@ from paramiko.config import SSHConfig
from paramiko.proxy import ProxyCommand
from paramiko.common import (
- AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED, OPEN_SUCCEEDED,
- OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, OPEN_FAILED_CONNECT_FAILED,
- OPEN_FAILED_UNKNOWN_CHANNEL_TYPE, OPEN_FAILED_RESOURCE_SHORTAGE,
+ AUTH_SUCCESSFUL,
+ AUTH_PARTIALLY_SUCCESSFUL,
+ AUTH_FAILED,
+ OPEN_SUCCEEDED,
+ OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED,
+ OPEN_FAILED_CONNECT_FAILED,
+ OPEN_FAILED_UNKNOWN_CHANNEL_TYPE,
+ OPEN_FAILED_RESOURCE_SHORTAGE,
)
from paramiko.sftp import (
- SFTP_OK, SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED, SFTP_FAILURE,
- SFTP_BAD_MESSAGE, SFTP_NO_CONNECTION, SFTP_CONNECTION_LOST,
+ SFTP_OK,
+ SFTP_EOF,
+ SFTP_NO_SUCH_FILE,
+ SFTP_PERMISSION_DENIED,
+ SFTP_FAILURE,
+ SFTP_BAD_MESSAGE,
+ SFTP_NO_CONNECTION,
+ SFTP_CONNECTION_LOST,
SFTP_OP_UNSUPPORTED,
)
@@ -72,43 +90,43 @@ __author__ = "Jeff Forcier <jeff@bitprophet.org>"
__license__ = "GNU Lesser General Public License (LGPL)"
__all__ = [
- 'Transport',
- 'SSHClient',
- 'MissingHostKeyPolicy',
- 'AutoAddPolicy',
- 'RejectPolicy',
- 'WarningPolicy',
- 'SecurityOptions',
- 'SubsystemHandler',
- 'Channel',
- 'PKey',
- 'RSAKey',
- 'DSSKey',
- 'ECDSAKey',
- 'Ed25519Key',
- 'Message',
- 'SSHException',
- 'AuthenticationException',
- 'PasswordRequiredException',
- 'BadAuthenticationType',
- 'ChannelException',
- 'BadHostKeyException',
- 'ProxyCommand',
- 'ProxyCommandFailure',
- 'SFTP',
- 'SFTPFile',
- 'SFTPHandle',
- 'SFTPClient',
- 'SFTPServer',
- 'SFTPError',
- 'SFTPAttributes',
- 'SFTPServerInterface',
- 'ServerInterface',
- 'BufferedFile',
- 'Agent',
- 'AgentKey',
- 'HostKeys',
- 'SSHConfig',
- 'util',
- 'io_sleep',
+ "Transport",
+ "SSHClient",
+ "MissingHostKeyPolicy",
+ "AutoAddPolicy",
+ "RejectPolicy",
+ "WarningPolicy",
+ "SecurityOptions",
+ "SubsystemHandler",
+ "Channel",
+ "PKey",
+ "RSAKey",
+ "DSSKey",
+ "ECDSAKey",
+ "Ed25519Key",
+ "Message",
+ "SSHException",
+ "AuthenticationException",
+ "PasswordRequiredException",
+ "BadAuthenticationType",
+ "ChannelException",
+ "BadHostKeyException",
+ "ProxyCommand",
+ "ProxyCommandFailure",
+ "SFTP",
+ "SFTPFile",
+ "SFTPHandle",
+ "SFTPClient",
+ "SFTPServer",
+ "SFTPError",
+ "SFTPAttributes",
+ "SFTPServerInterface",
+ "ServerInterface",
+ "BufferedFile",
+ "Agent",
+ "AgentKey",
+ "HostKeys",
+ "SSHConfig",
+ "util",
+ "io_sleep",
]
diff --git a/paramiko/_version.py b/paramiko/_version.py
index a5db89b..95e86b6 100644
--- a/paramiko/_version.py
+++ b/paramiko/_version.py
@@ -1,2 +1,2 @@
__version_info__ = (2, 4, 1)
-__version__ = '.'.join(map(str, __version_info__))
+__version__ = ".".join(map(str, __version_info__))
diff --git a/paramiko/_winapi.py b/paramiko/_winapi.py
index a13d7e8..c996ec4 100644
--- a/paramiko/_winapi.py
+++ b/paramiko/_winapi.py
@@ -15,6 +15,7 @@ from paramiko.py3compat import u, builtins
######################
# jaraco.windows.error
+
def format_system_message(errno):
"""
Call FormatMessage with a system error number to retrieve
@@ -77,7 +78,7 @@ class WindowsError(builtins.WindowsError):
return self.message
def __repr__(self):
- return '{self.__class__.__name__}({self.winerror})'.format(**vars())
+ return "{self.__class__.__name__}({self.winerror})".format(**vars())
def handle_nonzero_success(result):
@@ -124,11 +125,7 @@ UnmapViewOfFile = ctypes.windll.kernel32.UnmapViewOfFile
UnmapViewOfFile.argtypes = ctypes.wintypes.HANDLE,
RtlMoveMemory = ctypes.windll.kernel32.RtlMoveMemory
-RtlMoveMemory.argtypes = (
- ctypes.c_void_p,
- ctypes.c_void_p,
- ctypes.c_size_t,
-)
+RtlMoveMemory.argtypes = (ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t)
ctypes.windll.kernel32.LocalFree.argtypes = ctypes.wintypes.HLOCAL,
@@ -140,6 +137,7 @@ class MemoryMap(object):
"""
A memory map object which can have security attributes overridden.
"""
+
def __init__(self, name, length, security_attributes=None):
self.name = name
self.length = length
@@ -149,14 +147,20 @@ class MemoryMap(object):
def __enter__(self):
p_SA = (
ctypes.byref(self.security_attributes)
- if self.security_attributes else None
+ if self.security_attributes
+ else None
)
INVALID_HANDLE_VALUE = -1
PAGE_READWRITE = 0x4
FILE_MAP_WRITE = 0x2
filemap = ctypes.windll.kernel32.CreateFileMappingW(
- INVALID_HANDLE_VALUE, p_SA, PAGE_READWRITE, 0, self.length,
- u(self.name))
+ INVALID_HANDLE_VALUE,
+ p_SA,
+ PAGE_READWRITE,
+ 0,
+ self.length,
+ u(self.name),
+ )
handle_nonzero_success(filemap)
if filemap == INVALID_HANDLE_VALUE:
raise Exception("Failed to create file mapping")
@@ -220,41 +224,45 @@ POLICY_LOOKUP_NAMES = 0x00000800
POLICY_NOTIFICATION = 0x00001000
POLICY_ALL_ACCESS = (
- STANDARD_RIGHTS_REQUIRED |
- POLICY_VIEW_LOCAL_INFORMATION |
- POLICY_VIEW_AUDIT_INFORMATION |
- POLICY_GET_PRIVATE_INFORMATION |
- POLICY_TRUST_ADMIN |
- POLICY_CREATE_ACCOUNT |
- POLICY_CREATE_SECRET |
- POLICY_CREATE_PRIVILEGE |
- POLICY_SET_DEFAULT_QUOTA_LIMITS |
- POLICY_SET_AUDIT_REQUIREMENTS |
- POLICY_AUDIT_LOG_ADMIN |
- POLICY_SERVER_ADMIN |
- POLICY_LOOKUP_NAMES)
+ STANDARD_RIGHTS_REQUIRED
+ | POLICY_VIEW_LOCAL_INFORMATION
+ | POLICY_VIEW_AUDIT_INFORMATION
+ | POLICY_GET_PRIVATE_INFORMATION
+ | POLICY_TRUST_ADMIN
+ | POLICY_CREATE_ACCOUNT
+ | POLICY_CREATE_SECRET
+ | POLICY_CREATE_PRIVILEGE
+ | POLICY_SET_DEFAULT_QUOTA_LIMITS
+ | POLICY_SET_AUDIT_REQUIREMENTS
+ | POLICY_AUDIT_LOG_ADMIN
+ | POLICY_SERVER_ADMIN
+ | POLICY_LOOKUP_NAMES
+)
POLICY_READ = (
- STANDARD_RIGHTS_READ |
- POLICY_VIEW_AUDIT_INFORMATION |
- POLICY_GET_PRIVATE_INFORMATION)
+ STANDARD_RIGHTS_READ
+ | POLICY_VIEW_AUDIT_INFORMATION
+ | POLICY_GET_PRIVATE_INFORMATION
+)
POLICY_WRITE = (
- STANDARD_RIGHTS_WRITE |
- POLICY_TRUST_ADMIN |
- POLICY_CREATE_ACCOUNT |
- POLICY_CREATE_SECRET |
- POLICY_CREATE_PRIVILEGE |
- POLICY_SET_DEFAULT_QUOTA_LIMITS |
- POLICY_SET_AUDIT_REQUIREMENTS |
- POLICY_AUDIT_LOG_ADMIN |
- POLICY_SERVER_ADMIN)
+ STANDARD_RIGHTS_WRITE
+ | POLICY_TRUST_ADMIN
+ | POLICY_CREATE_ACCOUNT
+ | POLICY_CREATE_SECRET
+ | POLICY_CREATE_PRIVILEGE
+ | POLICY_SET_DEFAULT_QUOTA_LIMITS
+ | POLICY_SET_AUDIT_REQUIREMENTS
+ | POLICY_AUDIT_LOG_ADMIN
+ | POLICY_SERVER_ADMIN
+)
POLICY_EXECUTE = (
- STANDARD_RIGHTS_EXECUTE |
- POLICY_VIEW_LOCAL_INFORMATION |
- POLICY_LOOKUP_NAMES)
+ STANDARD_RIGHTS_EXECUTE
+ | POLICY_VIEW_LOCAL_INFORMATION
+ | POLICY_LOOKUP_NAMES
+)
class TokenAccess:
@@ -268,8 +276,7 @@ class TokenInformationClass:
class TOKEN_USER(ctypes.Structure):
num = 1
_fields_ = [
- ('SID', ctypes.c_void_p),
- ('ATTRIBUTES', ctypes.wintypes.DWORD),
+ ("SID", ctypes.c_void_p), ("ATTRIBUTES", ctypes.wintypes.DWORD)
]
@@ -290,13 +297,13 @@ class SECURITY_DESCRIPTOR(ctypes.Structure):
REVISION = 1
_fields_ = [
- ('Revision', ctypes.c_ubyte),
- ('Sbz1', ctypes.c_ubyte),
- ('Control', SECURITY_DESCRIPTOR_CONTROL),
- ('Owner', ctypes.c_void_p),
- ('Group', ctypes.c_void_p),
- ('Sacl', ctypes.c_void_p),
- ('Dacl', ctypes.c_void_p),
+ ("Revision", ctypes.c_ubyte),
+ ("Sbz1", ctypes.c_ubyte),
+ ("Control", SECURITY_DESCRIPTOR_CONTROL),
+ ("Owner", ctypes.c_void_p),
+ ("Group", ctypes.c_void_p),
+ ("Sacl", ctypes.c_void_p),
+ ("Dacl", ctypes.c_void_p),
]
@@ -309,9 +316,9 @@ class SECURITY_ATTRIBUTES(ctypes.Structure):
} SECURITY_ATTRIBUTES;
"""
_fields_ = [
- ('nLength', ctypes.wintypes.DWORD),
- ('lpSecurityDescriptor', ctypes.c_void_p),
- ('bInheritHandle', ctypes.wintypes.BOOL),
+ ("nLength", ctypes.wintypes.DWORD),
+ ("lpSecurityDescriptor", ctypes.c_void_p),
+ ("bInheritHandle", ctypes.wintypes.BOOL),
]
def __init__(self, *args, **kwargs):
@@ -329,9 +336,7 @@ class SECURITY_ATTRIBUTES(ctypes.Structure):
ctypes.windll.advapi32.SetSecurityDescriptorOwner.argtypes = (
- ctypes.POINTER(SECURITY_DESCRIPTOR),
- ctypes.c_void_p,
- ctypes.wintypes.BOOL,
+ ctypes.POINTER(SECURITY_DESCRIPTOR), ctypes.c_void_p, ctypes.wintypes.BOOL
)
#########################
@@ -343,21 +348,30 @@ def GetTokenInformation(token, information_class):
Given a token, get the token information for it.
"""
data_size = ctypes.wintypes.DWORD()
- ctypes.windll.advapi32.GetTokenInformation(token, information_class.num,
- 0, 0, ctypes.byref(data_size))
+ ctypes.windll.advapi32.GetTokenInformation(
+ token, information_class.num, 0, 0, ctypes.byref(data_size)
+ )
data = ctypes.create_string_buffer(data_size.value)
- handle_nonzero_success(ctypes.windll.advapi32.GetTokenInformation(token,
- information_class.num,
- ctypes.byref(data), ctypes.sizeof(data),
- ctypes.byref(data_size)))
+ handle_nonzero_success(
+ ctypes.windll.advapi32.GetTokenInformation(
+ token,
+ information_class.num,
+ ctypes.byref(data),
+ ctypes.sizeof(data),
+ ctypes.byref(data_size),
+ )
+ )
return ctypes.cast(data, ctypes.POINTER(TOKEN_USER)).contents
def OpenProcessToken(proc_handle, access):
result = ctypes.wintypes.HANDLE()
proc_handle = ctypes.wintypes.HANDLE(proc_handle)
- handle_nonzero_success(ctypes.windll.advapi32.OpenProcessToken(
- proc_handle, access, ctypes.byref(result)))
+ handle_nonzero_success(
+ ctypes.windll.advapi32.OpenProcessToken(
+ proc_handle, access, ctypes.byref(result)
+ )
+ )
return result
@@ -366,8 +380,7 @@ def get_current_user():
Return a TOKEN_USER for the owner of this process.
"""
process = OpenProcessToken(
- ctypes.windll.kernel32.GetCurrentProcess(),
- TokenAccess.TOKEN_QUERY,
+ ctypes.windll.kernel32.GetCurrentProcess(), TokenAccess.TOKEN_QUERY
)
return GetTokenInformation(process, TOKEN_USER)
@@ -389,8 +402,10 @@ def get_security_attributes_for_user(user=None):
SA.descriptor = SD
SA.bInheritHandle = 1
- ctypes.windll.advapi32.InitializeSecurityDescriptor(ctypes.byref(SD),
- SECURITY_DESCRIPTOR.REVISION)
- ctypes.windll.advapi32.SetSecurityDescriptorOwner(ctypes.byref(SD),
- user.SID, 0)
+ ctypes.windll.advapi32.InitializeSecurityDescriptor(
+ ctypes.byref(SD), SECURITY_DESCRIPTOR.REVISION
+ )
+ ctypes.windll.advapi32.SetSecurityDescriptorOwner(
+ ctypes.byref(SD), user.SID, 0
+ )
return SA
diff --git a/paramiko/agent.py b/paramiko/agent.py
index 7a4dde2..00baf85 100644
--- a/paramiko/agent.py
+++ b/paramiko/agent.py
@@ -43,8 +43,8 @@ cSSH2_AGENTC_SIGN_REQUEST = byte_chr(13)
SSH2_AGENT_SIGN_RESPONSE = 14
-
class AgentSSH(object):
+
def __init__(self):
self._conn = None
self._keys = ()
@@ -65,7 +65,7 @@ class AgentSSH(object):
self._conn = conn
ptype, result = self._send_message(cSSH2_AGENTC_REQUEST_IDENTITIES)
if ptype != SSH2_AGENT_IDENTITIES_ANSWER:
- raise SSHException('could not get keys from ssh-agent')
+ raise SSHException("could not get keys from ssh-agent")
keys = []
for i in range(result.get_int()):
keys.append(AgentKey(self, result.get_binary()))
@@ -80,19 +80,19 @@ class AgentSSH(object):
def _send_message(self, msg):
msg = asbytes(msg)
- self._conn.send(struct.pack('>I', len(msg)) + msg)
+ self._conn.send(struct.pack(">I", len(msg)) + msg)
l = self._read_all(4)
- msg = Message(self._read_all(struct.unpack('>I', l)[0]))
+ msg = Message(self._read_all(struct.unpack(">I", l)[0]))
return ord(msg.get_byte()), msg
def _read_all(self, wanted):
result = self._conn.recv(wanted)
while len(result) < wanted:
if len(result) == 0:
- raise SSHException('lost ssh-agent')
+ raise SSHException("lost ssh-agent")
extra = self._conn.recv(wanted - len(result))
if len(extra) == 0:
- raise SSHException('lost ssh-agent')
+ raise SSHException("lost ssh-agent")
result += extra
return result
@@ -101,6 +101,7 @@ class AgentProxyThread(threading.Thread):
"""
Class in charge of communication between two channels.
"""
+
def __init__(self, agent):
threading.Thread.__init__(self, target=self.run)
self._agent = agent
@@ -116,10 +117,10 @@ class AgentProxyThread(threading.Thread):
self.__addr = addr
self._agent.connect()
if (
- not isinstance(self._agent, int) and
- (
- self._agent._conn is None or
- not hasattr(self._agent._conn, 'fileno')
+ not isinstance(self._agent, int)
+ and (
+ self._agent._conn is None
+ or not hasattr(self._agent._conn, "fileno")
)
):
raise AuthenticationException("Unable to connect to SSH agent")
@@ -130,6 +131,7 @@ class AgentProxyThread(threading.Thread):
def _communicate(self):
import fcntl
+
oldflags = fcntl.fcntl(self.__inr, fcntl.F_GETFL)
fcntl.fcntl(self.__inr, fcntl.F_SETFL, oldflags | os.O_NONBLOCK)
while not self._exit:
@@ -162,6 +164,7 @@ class AgentLocalProxy(AgentProxyThread):
Class to be used when wanting to ask a local SSH Agent being
asked from a remote fake agent (so use a unix socket for ex.)
"""
+
def __init__(self, agent):
AgentProxyThread.__init__(self, agent)
@@ -185,6 +188,7 @@ class AgentRemoteProxy(AgentProxyThread):
"""
Class to be used when wanting to ask a remote SSH Agent
"""
+
def __init__(self, agent, chan):
AgentProxyThread.__init__(self, agent)
self.__chan = chan
@@ -205,6 +209,7 @@ class AgentClientProxy(object):
the remote fake agent and the local agent
#. Communication occurs ...
"""
+
def __init__(self, chanRemote):
self._conn = None
self.__chanR = chanRemote
@@ -218,16 +223,18 @@ class AgentClientProxy(object):
"""
Method automatically called by ``AgentProxyThread.run``.
"""
- if ('SSH_AUTH_SOCK' in os.environ) and (sys.platform != 'win32'):
+ if ("SSH_AUTH_SOCK" in os.environ) and (sys.platform != "win32"):
conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
retry_on_signal(
- lambda: conn.connect(os.environ['SSH_AUTH_SOCK']))
+ lambda: conn.connect(os.environ["SSH_AUTH_SOCK"])
+ )
except:
# probably a dangling env var: the ssh agent is gone
return
- elif sys.platform == 'win32':
+ elif sys.platform == "win32":
import paramiko.win_pageant as win_pageant
+
if win_pageant.can_talk_to_agent():
conn = win_pageant.PageantConnection()
else:
@@ -255,12 +262,13 @@ class AgentServerProxy(AgentSSH):
:raises: `.SSHException` -- mostly if we lost the agent
"""
+
def __init__(self, t):
AgentSSH.__init__(self)
self.__t = t
- self._dir = tempfile.mkdtemp('sshproxy')
+ self._dir = tempfile.mkdtemp("sshproxy")
os.chmod(self._dir, stat.S_IRWXU)
- self._file = self._dir + '/sshproxy.ssh'
+ self._file = self._dir + "/sshproxy.ssh"
self.thread = AgentLocalProxy(self)
self.thread.start()
@@ -270,8 +278,8 @@ class AgentServerProxy(AgentSSH):
def connect(self):
conn_sock = self.__t.open_forward_agent_channel()
if conn_sock is None:
- raise SSHException('lost ssh-agent')
- conn_sock.set_name('auth-agent')
+ raise SSHException("lost ssh-agent")
+ conn_sock.set_name("auth-agent")
self._connect(conn_sock)
def close(self):
@@ -292,7 +300,7 @@ class AgentServerProxy(AgentSSH):
:return:
a dict containing the ``SSH_AUTH_SOCK`` environnement variables
"""
- return {'SSH_AUTH_SOCK': self._get_filename()}
+ return {"SSH_AUTH_SOCK": self._get_filename()}
def _get_filename(self):
return self._file
@@ -319,6 +327,7 @@ class AgentRequestHandler(object):
# the remote end.
session.exec_command("git clone https://my.git.repository/")
"""
+
def __init__(self, chanClient):
self._conn = None
self.__chanC = chanClient
@@ -350,18 +359,20 @@ class Agent(AgentSSH):
:raises: `.SSHException` --
if an SSH agent is found, but speaks an incompatible protocol
"""
+
def __init__(self):
AgentSSH.__init__(self)
- if ('SSH_AUTH_SOCK' in os.environ) and (sys.platform != 'win32'):
+ if ("SSH_AUTH_SOCK" in os.environ) and (sys.platform != "win32"):
conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
- conn.connect(os.environ['SSH_AUTH_SOCK'])
+ conn.connect(os.environ["SSH_AUTH_SOCK"])
except:
# probably a dangling env var: the ssh agent is gone
return
- elif sys.platform == 'win32':
+ elif sys.platform == "win32":
from . import win_pageant
+
if win_pageant.can_talk_to_agent():
conn = win_pageant.PageantConnection()
else:
@@ -384,6 +395,7 @@ class AgentKey(PKey):
authenticating to a remote server (signing). Most other key operations
work as expected.
"""
+
def __init__(self, agent, blob):
self.agent = agent
self.blob = blob
@@ -407,5 +419,5 @@ class AgentKey(PKey):
msg.add_int(0)
ptype, result = self.agent._send_message(msg)
if ptype != SSH2_AGENT_SIGN_RESPONSE:
- raise SSHException('key cannot be used for signing')
+ raise SSHException("key cannot be used for signing")
return result.get_binary()
diff --git a/paramiko/auth_handler.py b/paramiko/auth_handler.py
index 3b894de..1bea21b 100644
--- a/paramiko/auth_handler.py
+++ b/paramiko/auth_handler.py
@@ -24,31 +24,55 @@ import weakref
import time
from paramiko.common import (
- cMSG_SERVICE_REQUEST, cMSG_DISCONNECT, DISCONNECT_SERVICE_NOT_AVAILABLE,
- DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE, cMSG_USERAUTH_REQUEST,
- cMSG_SERVICE_ACCEPT, DEBUG, AUTH_SUCCESSFUL, INFO, cMSG_USERAUTH_SUCCESS,
- cMSG_USERAUTH_FAILURE, AUTH_PARTIALLY_SUCCESSFUL,
- cMSG_USERAUTH_INFO_REQUEST, WARNING, AUTH_FAILED, cMSG_USERAUTH_PK_OK,
- cMSG_USERAUTH_INFO_RESPONSE, MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT,
- MSG_USERAUTH_REQUEST, MSG_USERAUTH_SUCCESS, MSG_USERAUTH_FAILURE,
- MSG_USERAUTH_BANNER, MSG_USERAUTH_INFO_REQUEST, MSG_USERAUTH_INFO_RESPONSE,
- cMSG_USERAUTH_GSSAPI_RESPONSE, cMSG_USERAUTH_GSSAPI_TOKEN,
- cMSG_USERAUTH_GSSAPI_MIC, MSG_USERAUTH_GSSAPI_RESPONSE,
- MSG_USERAUTH_GSSAPI_TOKEN, MSG_USERAUTH_GSSAPI_ERROR,
- MSG_USERAUTH_GSSAPI_ERRTOK, MSG_USERAUTH_GSSAPI_MIC, MSG_NAMES,
- cMSG_USERAUTH_BANNER
+ cMSG_SERVICE_REQUEST,
+ cMSG_DISCONNECT,
+ DISCONNECT_SERVICE_NOT_AVAILABLE,
+ DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE,
+ cMSG_USERAUTH_REQUEST,
+ cMSG_SERVICE_ACCEPT,
+ DEBUG,
+ AUTH_SUCCESSFUL,
+ INFO,
+ cMSG_USERAUTH_SUCCESS,
+ cMSG_USERAUTH_FAILURE,
+ AUTH_PARTIALLY_SUCCESSFUL,
+ cMSG_USERAUTH_INFO_REQUEST,
+ WARNING,
+ AUTH_FAILED,
+ cMSG_USERAUTH_PK_OK,
+ cMSG_USERAUTH_INFO_RESPONSE,
+ MSG_SERVICE_REQUEST,
+ MSG_SERVICE_ACCEPT,
+ MSG_USERAUTH_REQUEST,
+ MSG_USERAUTH_SUCCESS,
+ MSG_USERAUTH_FAILURE,
+ MSG_USERAUTH_BANNER,
+ MSG_USERAUTH_INFO_REQUEST,
+ MSG_USERAUTH_INFO_RESPONSE,
+ cMSG_USERAUTH_GSSAPI_RESPONSE,
+ cMSG_USERAUTH_GSSAPI_TOKEN,
+ cMSG_USERAUTH_GSSAPI_MIC,
+ MSG_USERAUTH_GSSAPI_RESPONSE,
+ MSG_USERAUTH_GSSAPI_TOKEN,
+ MSG_USERAUTH_GSSAPI_ERROR,
+ MSG_USERAUTH_GSSAPI_ERRTOK,
+ MSG_USERAUTH_GSSAPI_MIC,
+ MSG_NAMES,
+ cMSG_USERAUTH_BANNER,
)
from paramiko.message import Message
from paramiko.py3compat import b
from paramiko.ssh_exception import (
- SSHException, AuthenticationException, BadAuthenticationType,
+ SSHException,
+ AuthenticationException,
+ BadAuthenticationType,
PartialAuthentication,
)
from paramiko.server import InteractiveQuery
from paramiko.ssh_gss import GSSAuth, GSS_EXCEPTIONS
-class AuthHandler (object):
+class AuthHandler(object):
"""
Internal class to handle the mechanics of authentication.
"""
@@ -58,7 +82,7 @@ class AuthHandler (object):
self.username = None
self.authenticated = False
self.auth_event = None
- self.auth_method = ''
+ self.auth_method = ""
self.banner = None
self.password = None
self.private_key = None
@@ -87,7 +111,7 @@ class AuthHandler (object):
self.transport.lock.acquire()
try:
self.auth_event = event
- self.auth_method = 'none'
+ self.auth_method = "none"
self.username = username
self._request_auth()
finally:
@@ -97,7 +121,7 @@ class AuthHandler (object):
self.transport.lock.acquire()
try:
self.auth_event = event
- self.auth_method = 'publickey'
+ self.auth_method = "publickey"
self.username = username
self.private_key = key
self._request_auth()
@@ -108,21 +132,21 @@ class AuthHandler (object):
self.transport.lock.acquire()
try:
self.auth_event = event
- self.auth_method = 'password'
+ self.auth_method = "password"
self.username = username
self.password = password
self._request_auth()
finally:
self.transport.lock.release()
- def auth_interactive(self, username, handler, event, submethods=''):
+ def auth_interactive(self, username, handler, event, submethods=""):
"""
response_list = handler(title, instructions, prompt_list)
"""
self.transport.lock.acquire()
try:
self.auth_event = event
- self.auth_method = 'keyboard-interactive'
+ self.auth_method = "keyboard-interactive"
self.username = username
self.interactive_handler = handler
self.submethods = submethods
@@ -134,7 +158,7 @@ class AuthHandler (object):
self.transport.lock.acquire()
try:
self.auth_event = event
- self.auth_method = 'gssapi-with-mic'
+ self.auth_method = "gssapi-with-mic"
self.username = username
self.gss_host = gss_host
self.gss_deleg_creds = gss_deleg_creds
@@ -146,7 +170,7 @@ class AuthHandler (object):
self.transport.lock.acquire()
try:
self.auth_event = event
- self.auth_method = 'gssapi-keyex'
+ self.auth_method = "gssapi-keyex"
self.username = username
self._request_auth()
finally:
@@ -161,15 +185,15 @@ class AuthHandler (object):
def _request_auth(self):
m = Message()
m.add_byte(cMSG_SERVICE_REQUEST)
- m.add_string('ssh-userauth')
+ m.add_string("ssh-userauth")
self.transport._send_message(m)
def _disconnect_service_not_available(self):
m = Message()
m.add_byte(cMSG_DISCONNECT)
m.add_int(DISCONNECT_SERVICE_NOT_AVAILABLE)
- m.add_string('Service not available')
- m.add_string('en')
+ m.add_string("Service not available")
+ m.add_string("en")
self.transport._send_message(m)
self.transport.close()
@@ -177,8 +201,8 @@ class AuthHandler (object):
m = Message()
m.add_byte(cMSG_DISCONNECT)
m.add_int(DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE)
- m.add_string('No more auth methods available')
- m.add_string('en')
+ m.add_string("No more auth methods available")
+ m.add_string("en")
self.transport._send_message(m)
self.transport.close()
@@ -188,7 +212,7 @@ class AuthHandler (object):
m.add_byte(cMSG_USERAUTH_REQUEST)
m.add_string(username)
m.add_string(service)
- m.add_string('publickey')
+ m.add_string("publickey")
m.add_boolean(True)
# Use certificate contents, if available, plain pubkey otherwise
if key.public_blob:
@@ -208,17 +232,17 @@ class AuthHandler (object):
if not self.transport.is_active():
e = self.transport.get_exception()
if (e is None) or issubclass(e.__class__, EOFError):
- e = AuthenticationException('Authentication failed.')
+ e = AuthenticationException("Authentication failed.")
raise e
if event.is_set():
break
if max_ts is not None and max_ts <= time.time():
- raise AuthenticationException('Authentication timeout.')
+ raise AuthenticationException("Authentication timeout.")
if not self.is_authenticated():
e = self.transport.get_exception()
if e is None:
- e = AuthenticationException('Authentication failed.')
+ e = AuthenticationException("Authentication failed.")
# this is horrible. Python Exception isn't yet descended from
# object, so type(e) won't work. :(
if issubclass(e.__class__, PartialAuthentication):
@@ -228,7 +252,7 @@ class AuthHandler (object):
def _parse_service_request(self, m):
service = m.get_text()
- if self.transport.server_mode and (service == 'ssh-userauth'):
+ if self.transport.server_mode and (service == "ssh-userauth"):
# accepted
m = Message()
m.add_byte(cMSG_SERVICE_ACCEPT)
@@ -247,18 +271,18 @@ class AuthHandler (object):
def _parse_service_accept(self, m):
service = m.get_text()
- if service == 'ssh-userauth':
- self._log(DEBUG, 'userauth is OK')
+ if service == "ssh-userauth":
+ self._log(DEBUG, "userauth is OK")
m = Message()
m.add_byte(cMSG_USERAUTH_REQUEST)
m.add_string(self.username)
- m.add_string('ssh-connection')
+ m.add_string("ssh-connection")
m.add_string(self.auth_method)
- if self.auth_method == 'password':
+ if self.auth_method == "password":
m.add_boolean(False)
password = b(self.password)
m.add_string(password)
- elif self.auth_method == 'publickey':
+ elif self.auth_method == "publickey":
m.add_boolean(True)
# Use certificate contents, if available, plain pubkey
# otherwise
@@ -269,11 +293,12 @@ class AuthHandler (object):
m.add_string(self.private_key.get_name())
m.add_string(self.private_key)
blob = self._get_session_blob(
- self.private_key, 'ssh-connection', self.username)
+ self.private_key, "ssh-connection", self.username
+ )
sig = self.private_key.sign_ssh_data(blob)
m.add_string(sig)
- elif self.auth_method == 'keyboard-interactive':
- m.add_string('')
+ elif self.auth_method == "keyboard-interactive":
+ m.add_string("")
m.add_string(self.submethods)
elif self.auth_method == "gssapi-with-mic":
sshgss = GSSAuth(self.auth_method, self.gss_deleg_creds)
@@ -292,10 +317,11 @@ class AuthHandler (object):
m = Message()
m.add_byte(cMSG_USERAUTH_GSSAPI_TOKEN)
try:
- m.add_string(sshgss.ssh_init_sec_context(
- self.gss_host,
- mech,
- self.username,))
+ m.add_string(
+ sshgss.ssh_init_sec_context(
+ self.gss_host, mech, self.username
+ )
+ )
except GSS_EXCEPTIONS as e:
return self._handle_local_gss_failure(e)
self.transport._send_message(m)
@@ -308,7 +334,8 @@ class AuthHandler (object):
self.gss_host,
mech,
self.username,
- srv_token)
+ srv_token,
+ )
except GSS_EXCEPTIONS as e:
return self._handle_local_gss_failure(e)
# After this step the GSSAPI should not return any
@@ -340,48 +367,55 @@ class AuthHandler (object):
min_status = m.get_int()
err_msg = m.get_string()
m.get_string() # Lang tag - discarded
- raise SSHException("""GSS-API Error:
+ raise SSHException(
+ """GSS-API Error:
Major Status: {}
Minor Status: {}
Error Message: {}
-""".format(maj_status, min_status, err_msg))
+""".format(
+ maj_status, min_status, err_msg
+ )
+ )
elif ptype == MSG_USERAUTH_FAILURE:
self._parse_userauth_failure(m)
return
else:
raise SSHException(
- "Received Package: {}".format(MSG_NAMES[ptype]))
+ "Received Package: {}".format(MSG_NAMES[ptype])
+ )
elif (
- self.auth_method == 'gssapi-keyex' and
- self.transport.gss_kex_used
+ self.auth_method == "gssapi-keyex"
+ and self.transport.gss_kex_used
):
kexgss = self.transport.kexgss_ctxt
kexgss.set_username(self.username)
mic_token = kexgss.ssh_get_mic(self.transport.session_id)
m.add_string(mic_token)
- elif self.auth_method == 'none':
+ elif self.auth_method == "none":
pass
else:
raise SSHException(
- 'Unknown auth method "{}"'.format(self.auth_method))
+ 'Unknown auth method "{}"'.format(self.auth_method)
+ )
self.transport._send_message(m)
else:
self._log(
- DEBUG,
- 'Service request "{}" accepted (?)'.format(service))
+ DEBUG, 'Service request "{}" accepted (?)'.format(service)
+ )
def _send_auth_result(self, username, method, result):
# okay, send result
m = Message()
if result == AUTH_SUCCESSFUL:
- self._log(INFO, 'Auth granted ({}).'.format(method))
+ self._log(INFO, "Auth granted ({}).".format(method))
m.add_byte(cMSG_USERAUTH_SUCCESS)
self.authenticated = True
else:
- self._log(INFO, 'Auth rejected ({}).'.format(method))
+ self._log(INFO, "Auth rejected ({}).".format(method))
m.add_byte(cMSG_USERAUTH_FAILURE)
m.add_string(
- self.transport.server_object.get_allowed_auths(username))
+ self.transport.server_object.get_allowed_auths(username)
+ )
if result == AUTH_PARTIALLY_SUCCESSFUL:
m.add_boolean(True)
else:
@@ -411,7 +445,7 @@ Error Message: {}
# er, uh... what?
m = Message()
m.add_byte(cMSG_USERAUTH_FAILURE)
- m.add_string('none')
+ m.add_string("none")
m.add_boolean(False)
self.transport._send_message(m)
return
@@ -423,18 +457,20 @@ Error Message: {}
method = m.get_text()
self._log(
DEBUG,
- 'Auth request (type={}) service={}, username={}'.format(
+ "Auth request (type={}) service={}, username={}".format(
method, service, username
- )
+ ),
)
- if service != 'ssh-connection':
+ if service != "ssh-connection":
self._disconnect_service_not_available()
return
- if ((self.auth_username is not None) and
- (self.auth_username != username)):
+ if (
+ (self.auth_username is not None)
+ and (self.auth_username != username)
+ ):
self._log(
WARNING,
- 'Auth rejected because the client attempted to change username in mid-flight' # noqa
+ "Auth rejected because the client attempted to change username in mid-flight", # noqa
)
self._disconnect_no_more_auth()
return
@@ -442,13 +478,13 @@ Error Message: {}
# check if GSS-API authentication is enabled
gss_auth = self.transport.server_object.enable_auth_gssapi()
- if method == 'none':
+ if method == "none":
result = self.transport.server_object.check_auth_none(username)
- elif method == 'password':
+ elif method == "password":
changereq = m.get_boolean()
password = m.get_binary()
try:
- password = password.decode('UTF-8')
+ password = password.decode("UTF-8")
except UnicodeError:
# some clients/servers expect non-utf-8 passwords!
# in this case, just return the raw byte string.
@@ -457,31 +493,28 @@ Error Message: {}
# always treated as failure, since we don't support changing
# passwords, but collect the list of valid auth types from
# the callback anyway
- self._log(
- DEBUG,
- 'Auth request to change passwords (rejected)')
+ self._log(DEBUG, "Auth request to change passwords (rejected)")
newpassword = m.get_binary()
try:
- newpassword = newpassword.decode('UTF-8', 'replace')
+ newpassword = newpassword.decode("UTF-8", "replace")
except UnicodeError:
pass
result = AUTH_FAILED
else:
result = self.transport.server_object.check_auth_password(
- username, password)
- elif method == 'publickey':
+ username, password
+ )
+ elif method == "publickey":
sig_attached = m.get_boolean()
keytype = m.get_text()
keyblob = m.get_binary()
try:
key = self.transport._key_info[keytype](Message(keyblob))
except SSHException as e:
- self._log(
- INFO,
- 'Auth rejected: public key: {}'.format(str(e)))
+ self._log(INFO, "Auth rejected: public key: {}".format(str(e)))
key = None
except Exception as e:
- msg = 'Auth rejected: unsupported or mangled public key ({}: {})' # noqa
+ msg = "Auth rejected: unsupported or mangled public key ({}: {})" # noqa
self._log(INFO, msg.format(e.__class__.__name__, e))
key = None
if key is None:
@@ -489,7 +522,8 @@ Error Message: {}
return
# first check if this key is okay... if not, we can skip the verify
result = self.transport.server_object.check_auth_publickey(
- username, key)
+ username, key
+ )
if result != AUTH_FAILED:
# key is okay, verify it
if not sig_attached:
@@ -504,14 +538,13 @@ Error Message: {}
sig = Message(m.get_binary())
blob = self._get_session_blob(key, service, username)
if not key.verify_ssh_sig(blob, sig):
- self._log(
- INFO,
- 'Auth rejected: invalid signature')
+ self._log(INFO, "Auth rejected: invalid signature")
result = AUTH_FAILED
- elif method == 'keyboard-interactive':
+ elif method == "keyboard-interactive":
submethods = m.get_string()
result = self.transport.server_object.check_auth_interactive(
- username, submethods)
+ username, submethods
+ )
if isinstance(result, InteractiveQuery):
# make interactive query instead of response
self._interactive_query(result)
@@ -527,7 +560,8 @@ Error Message: {}
if mechs > 1:
self._log(
INFO,
- 'Disconnect: Received more than one GSS-API OID mechanism')
+ "Disconnect: Received more than one GSS-API OID mechanism",
+ )
self._disconnect_no_more_auth()
desired_mech = m.get_string()
mech_ok = sshgss.ssh_check_mech(desired_mech)
@@ -535,7 +569,8 @@ Error Message: {}
if not mech_ok:
self._log(
INFO,
- 'Disconnect: Received an invalid GSS-API OID mechanism')
+ "Disconnect: Received an invalid GSS-API OID mechanism",
+ )
self._disconnect_no_more_auth()
# send the Kerberos V5 GSSAPI OID to the client
supported_mech = sshgss.ssh_gss_oids("server")
@@ -544,11 +579,14 @@ Error Message: {}
m = Message()
m.add_byte(cMSG_USERAUTH_GSSAPI_RESPONSE)
m.add_bytes(supported_mech)
- self.transport.auth_handler = GssapiWithMicAuthHandler(self,
- sshgss)
- self.transport._expected_packet = (MSG_USERAUTH_GSSAPI_TOKEN,
- MSG_USERAUTH_REQUEST,
- MSG_SERVICE_REQUEST)
+ self.transport.auth_handler = GssapiWithMicAuthHandler(
+ self, sshgss
+ )
+ self.transport._expected_packet = (
+ MSG_USERAUTH_GSSAPI_TOKEN,
+ MSG_USERAUTH_REQUEST,
+ MSG_SERVICE_REQUEST,
+ )
self.transport._send_message(m)
return
elif method == "gssapi-keyex" and gss_auth:
@@ -559,16 +597,17 @@ Error Message: {}
result = AUTH_FAILED
self._send_auth_result(username, method, result)
try:
- sshgss.ssh_check_mic(mic_token,
- self.transport.session_id,
- self.auth_username)
+ sshgss.ssh_check_mic(
+ mic_token, self.transport.session_id, self.auth_username
+ )
except Exception:
result = AUTH_FAILED
self._send_auth_result(username, method, result)
raise
result = AUTH_SUCCESSFUL
self.transport.server_object.check_auth_gssapi_keyex(
- username, result)
+ username, result
+ )
else:
result = self.transport.server_object.check_auth_none(username)
# okay, send result
@@ -576,8 +615,8 @@ Error Message: {}
def _parse_userauth_success(self, m):
self._log(
- INFO,
- 'Authentication ({}) successful!'.format(self.auth_method))
+ INFO, "Authentication ({}) successful!".format(self.auth_method)
+ )
self.authenticated = True
self.transport._auth_trigger()
if self.auth_event is not None:
@@ -587,24 +626,23 @@ Error Message: {}
authlist = m.get_list()
partial = m.get_boolean()
if partial:
- self._log(INFO, 'Authentication continues...')
- self._log(DEBUG, 'Methods: ' + str(authlist))
+ self._log(INFO, "Authentication continues...")
+ self._log(DEBUG, "Methods: " + str(authlist))
self.transport.saved_exception = PartialAuthentication(authlist)
elif self.auth_method not in authlist:
for msg in (
- 'Authentication type ({}) not permitted.'.format(
+ "Authentication type ({}) not permitted.".format(
self.auth_method
),
- 'Allowed methods: {}'.format(authlist),
+ "Allowed methods: {}".format(authlist),
):
self._log(DEBUG, msg)
self.transport.saved_exception = BadAuthenticationType(
- 'Bad authentication type', authlist
+ "Bad authentication type", authlist
)
else:
self._log(
- INFO,
- 'Authentication ({}) failed.'.format(self.auth_method)
+ INFO, "Authentication ({}) failed.".format(self.auth_method)
)
self.authenticated = False
self.username = None
@@ -614,12 +652,12 @@ Error Message: {}
def _parse_userauth_banner(self, m):
banner = m.get_string()
self.banner = banner
- self._log(INFO, 'Auth banner: {}'.format(banner))
+ self._log(INFO, "Auth banner: {}".format(banner))
# who cares.
def _parse_userauth_info_request(self, m):
- if self.auth_method != 'keyboard-interactive':
- raise SSHException('Illegal info request from server')
+ if self.auth_method != "keyboard-interactive":
+ raise SSHException("Illegal info request from server")
title = m.get_text()
instructions = m.get_text()
m.get_binary() # lang
@@ -628,7 +666,8 @@ Error Message: {}
for i in range(prompts):
prompt_list.append((m.get_text(), m.get_boolean()))
response_list = self.interactive_handler(
- title, instructions, prompt_list)
+ title, instructions, prompt_list
+ )
m = Message()
m.add_byte(cMSG_USERAUTH_INFO_RESPONSE)
@@ -639,25 +678,26 @@ Error Message: {}
def _parse_userauth_info_response(self, m):
if not self.transport.server_mode:
- raise SSHException('Illegal info response from server')
+ raise SSHException("Illegal info response from server")
n = m.get_int()
responses = []
for i in range(n):
responses.append(m.get_text())
result = self.transport.server_object.check_auth_interactive_response(
- responses)
+ responses
+ )
if isinstance(result, InteractiveQuery):
# make interactive query instead of response
self._interactive_query(result)
return
self._send_auth_result(
- self.auth_username, 'keyboard-interactive', result)
+ self.auth_username, "keyboard-interactive", result
+ )
def _handle_local_gss_failure(self, e):
self.transport.saved_exception = e
self._log(DEBUG, "GSSAPI failure: {}".format(e))
- self._log(INFO, 'Authentication ({}) failed.'.format(
- self.auth_method))
+ self._log(INFO, "Authentication ({}) failed.".format(self.auth_method))
self.authenticated = False
self.username = None
if self.auth_event is not None:
@@ -718,9 +758,9 @@ class GssapiWithMicAuthHandler(object):
# context.
sshgss = self.sshgss
try:
- token = sshgss.ssh_accept_sec_context(self.gss_host,
- client_token,
- self.auth_username)
+ token = sshgss.ssh_accept_sec_context(
+ self.gss_host, client_token, self.auth_username
+ )
except Exception as e:
self.transport.saved_exception = e
result = AUTH_FAILED
@@ -731,9 +771,11 @@ class GssapiWithMicAuthHandler(object):
m = Message()
m.add_byte(cMSG_USERAUTH_GSSAPI_TOKEN)
m.add_string(token)
- self.transport._expected_packet = (MSG_USERAUTH_GSSAPI_TOKEN,
- MSG_USERAUTH_GSSAPI_MIC,
- MSG_USERAUTH_REQUEST)
+ self.transport._expected_packet = (
+ MSG_USERAUTH_GSSAPI_TOKEN,
+ MSG_USERAUTH_GSSAPI_MIC,
+ MSG_USERAUTH_REQUEST,
+ )
self.transport._send_message(m)
def _parse_userauth_gssapi_mic(self, m):
@@ -742,9 +784,9 @@ class GssapiWithMicAuthHandler(object):
username = self.auth_username
self._restore_delegate_auth_handler()
try:
- sshgss.ssh_check_mic(mic_token,
- self.transport.session_id,
- username)
+ sshgss.ssh_check_mic(
+ mic_token, self.transport.session_id, username
+ )
except Exception as e:
self.transport.saved_exception = e
result = AUTH_FAILED
@@ -754,8 +796,9 @@ class GssapiWithMicAuthHandler(object):
# The OpenSSH server is able to create a TGT with the delegated
# client credentials, but this is not supported by GSS-API.
result = AUTH_SUCCESSFUL
- self.transport.server_object.check_auth_gssapi_with_mic(username,
- result)
+ self.transport.server_object.check_auth_gssapi_with_mic(
+ username, result
+ )
# okay, send result
self._send_auth_result(username, self.method, result)
diff --git a/paramiko/ber.py b/paramiko/ber.py
index 876347e..fb6ee71 100644
--- a/paramiko/ber.py
+++ b/paramiko/ber.py
@@ -21,7 +21,7 @@ from paramiko.py3compat import b, byte_ord, byte_chr, long
import paramiko.util as util
-class BERException (Exception):
+class BERException(Exception):
pass
@@ -41,7 +41,7 @@ class BER(object):
return self.asbytes()
def __repr__(self):
- return 'BER(\'' + repr(self.content) + '\')'
+ return "BER('" + repr(self.content) + "')"
def decode(self):
return self.decode_next()
@@ -71,13 +71,12 @@ class BER(object):
t = size & 0x7f
if self.idx + t > len(self.content):
return None
- size = util.inflate_long(
- self.content[self.idx: self.idx + t], True)
+ size = util.inflate_long(self.content[self.idx:self.idx + t], True)
self.idx += t
if self.idx + size > len(self.content):
# can't fit
return None
- data = self.content[self.idx: self.idx + size]
+ data = self.content[self.idx:self.idx + size]
self.idx += size
# now switch on id
if ident == 0x30:
@@ -88,7 +87,7 @@ class BER(object):
return util.inflate_long(data)
else:
# 1: boolean (00 false, otherwise true)
- msg = 'Unknown ber encoding type {:d} (robey is lazy)'
+ msg = "Unknown ber encoding type {:d} (robey is lazy)"
raise BERException(msg.format(ident))
@staticmethod
@@ -126,7 +125,7 @@ class BER(object):
self.encode_tlv(0x30, self.encode_sequence(x))
else:
raise BERException(
- 'Unknown type for encoding: {!r}'.format(type(x))
+ "Unknown type for encoding: {!r}".format(type(x))
)
@staticmethod
diff --git a/paramiko/buffered_pipe.py b/paramiko/buffered_pipe.py
index d9f5149..48f5aae 100644
--- a/paramiko/buffered_pipe.py
+++ b/paramiko/buffered_pipe.py
@@ -28,14 +28,14 @@ import time
from paramiko.py3compat import PY2, b
-class PipeTimeout (IOError):
+class PipeTimeout(IOError):
"""
Indicates that a timeout was reached on a read from a `.BufferedPipe`.
"""
pass
-class BufferedPipe (object):
+class BufferedPipe(object):
"""
A buffer that obeys normal read (with timeout) & close semantics for a
file or socket, but is fed data from another thread. This is used by
@@ -46,16 +46,19 @@ class BufferedPipe (object):
self._lock = threading.Lock()
self._cv = threading.Condition(self._lock)
self._event = None
- self._buffer = array.array('B')
+ self._buffer = array.array("B")
self._closed = False
if PY2:
+
def _buffer_frombytes(self, data):
self._buffer.fromstring(data)
def _buffer_tobytes(self, limit=None):
return self._buffer[:limit].tostring()
+
else:
+
def _buffer_frombytes(self, data):
self._buffer.frombytes(data)
diff --git a/paramiko/channel.py b/paramiko/channel.py
index 91a8f0d..00f86d6 100644
--- a/paramiko/channel.py
+++ b/paramiko/channel.py
@@ -25,14 +25,22 @@ import os
import socket
import time
import threading
+
# TODO: switch as much of py3compat.py to 'six' as possible, then use six.wraps
from functools import wraps
from paramiko import util
from paramiko.common import (
- cMSG_CHANNEL_REQUEST, cMSG_CHANNEL_WINDOW_ADJUST, cMSG_CHANNEL_DATA,
- cMSG_CHANNEL_EXTENDED_DATA, DEBUG, ERROR, cMSG_CHANNEL_SUCCESS,
- cMSG_CHANNEL_FAILURE, cMSG_CHANNEL_EOF, cMSG_CHANNEL_CLOSE,
+ cMSG_CHANNEL_REQUEST,
+ cMSG_CHANNEL_WINDOW_ADJUST,
+ cMSG_CHANNEL_DATA,
+ cMSG_CHANNEL_EXTENDED_DATA,
+ DEBUG,
+ ERROR,
+ cMSG_CHANNEL_SUCCESS,
+ cMSG_CHANNEL_FAILURE,
+ cMSG_CHANNEL_EOF,
+ cMSG_CHANNEL_CLOSE,
)
from paramiko.message import Message
from paramiko.py3compat import bytes_types
@@ -51,20 +59,22 @@ def open_only(func):
`.SSHException` -- If the wrapped method is called on an unopened
`.Channel`.
"""
+
@wraps(func)
def _check(self, *args, **kwds):
if (
- self.closed or
- self.eof_received or
- self.eof_sent or
- not self.active
+ self.closed
+ or self.eof_received
+ or self.eof_sent
+ or not self.active
):
- raise SSHException('Channel is not open')
+ raise SSHException("Channel is not open")
return func(self, *args, **kwds)
+
return _check
-class Channel (ClosingContextManager):
+class Channel(ClosingContextManager):
"""
A secure tunnel across an SSH `.Transport`. A Channel is meant to behave
like a socket, and has an API that should be indistinguishable from the
@@ -117,7 +127,7 @@ class Channel (ClosingContextManager):
self.in_window_sofar = 0
self.status_event = threading.Event()
self._name = str(chanid)
- self.logger = util.get_logger('paramiko.transport')
+ self.logger = util.get_logger("paramiko.transport")
self._pipe = None
self.event = threading.Event()
self.event_ready = False
@@ -135,24 +145,30 @@ class Channel (ClosingContextManager):
"""
Return a string representation of this object, for debugging.
"""
- out = '<paramiko.Channel {}'.format(self.chanid)
+ out = "<paramiko.Channel {}".format(self.chanid)
if self.closed:
- out += ' (closed)'
+ out += " (closed)"
elif self.active:
if self.eof_received:
- out += ' (EOF received)'
+ out += " (EOF received)"
if self.eof_sent:
- out += ' (EOF sent)'
- out += ' (open) window={}'.format(self.out_window_size)
+ out += " (EOF sent)"
+ out += " (open) window={}".format(self.out_window_size)
if len(self.in_buffer) > 0:
- out += ' in-buffer={}'.format(len(self.in_buffer))
- out += ' -> ' + repr(self.transport)
- out += '>'
+ out += " in-buffer={}".format(len(self.in_buffer))
+ out += " -> " + repr(self.transport)
+ out += ">"
return out
@open_only
- def get_pty(self, term='vt100', width=80, height=24, width_pixels=0,
- height_pixels=0):
+ def get_pty(
+ self,
+ term="vt100",
+ width=80,
+ height=24,
+ width_pixels=0,
+ height_pixels=0,
+ ):
"""
Request a pseudo-terminal from the server. This is usually used right
after creating a client channel, to ask the server to provide some
@@ -174,7 +190,7 @@ class Channel (ClosingContextManager):
m = Message()
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
- m.add_string('pty-req')
+ m.add_string("pty-req")
m.add_boolean(True)
m.add_string(term)
m.add_int(width)
@@ -207,7 +223,7 @@ class Channel (ClosingContextManager):
m = Message()
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
- m.add_string('shell')
+ m.add_string("shell")
m.add_boolean(True)
self._event_pending()
self.transport._send_user_message(m)
@@ -233,7 +249,7 @@ class Channel (ClosingContextManager):
m = Message()
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
- m.add_string('exec')
+ m.add_string("exec")
m.add_boolean(True)
m.add_string(command)
self._event_pending()
@@ -259,7 +275,7 @@ class Channel (ClosingContextManager):
m = Message()
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
- m.add_string('subsystem')
+ m.add_string("subsystem")
m.add_boolean(True)
m.add_string(subsystem)
self._event_pending()
@@ -284,7 +300,7 @@ class Channel (ClosingContextManager):
m = Message()
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
- m.add_string('window-change')
+ m.add_string("window-change")
m.add_boolean(False)
m.add_int(width)
m.add_int(height)
@@ -315,7 +331,7 @@ class Channel (ClosingContextManager):
try:
self.set_environment_variable(name, value)
except SSHException as e:
- err = "Failed to set environment variable \"{}\"."
+ err = 'Failed to set environment variable "{}".'
raise SSHException(err.format(name), e)
@open_only
@@ -339,7 +355,7 @@ class Channel (ClosingContextManager):
m = Message()
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
- m.add_string('env')
+ m.add_string("env")
m.add_boolean(False)
m.add_string(name)
m.add_string(value)
@@ -403,19 +419,19 @@ class Channel (ClosingContextManager):
m = Message()
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
- m.add_string('exit-status')
+ m.add_string("exit-status")
m.add_boolean(False)
m.add_int(status)
self.transport._send_user_message(m)
@open_only
def request_x11(
- self,
- screen_number=0,
- auth_protocol=None,
- auth_cookie=None,
- single_connection=False,
- handler=None
+ self,
+ screen_number=0,
+ auth_protocol=None,
+ auth_cookie=None,
+ single_connection=False,
+ handler=None,
):
"""
Request an x11 session on this channel. If the server allows it,
@@ -456,14 +472,14 @@ class Channel (ClosingContextManager):
:return: the auth_cookie used
"""
if auth_protocol is None:
- auth_protocol = 'MIT-MAGIC-COOKIE-1'
+ auth_protocol = "MIT-MAGIC-COOKIE-1"
if auth_cookie is None:
auth_cookie = binascii.hexlify(os.urandom(16))
m = Message()
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
- m.add_string('x11-req')
+ m.add_string("x11-req")
m.add_boolean(True)
m.add_boolean(single_connection)
m.add_string(auth_protocol)
@@ -493,7 +509,7 @@ class Channel (ClosingContextManager):
m = Message()
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
- m.add_string('auth-agent-req@openssh.com')
+ m.add_string("auth-agent-req@openssh.com")
m.add_boolean(False)
self.transport._send_user_message(m)
self.transport._set_forward_agent_handler(handler)
@@ -808,7 +824,6 @@ class Channel (ClosingContextManager):
m.add_int(1)
return self._send(s, m)
-
def sendall(self, s):
"""
Send data to the channel, without allowing partial results. Unlike
@@ -976,7 +991,7 @@ class Channel (ClosingContextManager):
# a window update
self.in_window_threshold = window_size // 10
self.in_window_sofar = 0
- self._log(DEBUG, 'Max packet in: {} bytes'.format(max_packet_size))
+ self._log(DEBUG, "Max packet in: {} bytes".format(max_packet_size))
def _set_remote_channel(self, chanid, window_size, max_packet_size):
self.remote_chanid = chanid
@@ -985,11 +1000,12 @@ class Channel (ClosingContextManager):
max_packet_size
)
self.active = 1
- self._log(DEBUG,
- 'Max packet out: {} bytes'.format(self.out_max_packet_size))
+ self._log(
+ DEBUG, "Max packet out: {} bytes".format(self.out_max_packet_size)
+ )
def _request_success(self, m):
- self._log(DEBUG, 'Sesch channel {} request ok'.format(self.chanid))
+ self._log(DEBUG, "Sesch channel {} request ok".format(self.chanid))
self.event_ready = True
self.event.set()
return
@@ -1017,8 +1033,7 @@ class Channel (ClosingContextManager):
s = m.get_binary()
if code != 1:
self._log(
- ERROR,
- 'unknown extended_data type {}; discarding'.format(code)
+ ERROR, "unknown extended_data type {}; discarding".format(code)
)
return
if self.combine_stderr:
@@ -1031,7 +1046,7 @@ class Channel (ClosingContextManager):
self.lock.acquire()
try:
if self.ultra_debug:
- self._log(DEBUG, 'window up {}'.format(nbytes))
+ self._log(DEBUG, "window up {}".format(nbytes))
self.out_window_size += nbytes
self.out_buffer_cv.notifyAll()
finally:
@@ -1042,14 +1057,14 @@ class Channel (ClosingContextManager):
want_reply = m.get_boolean()
server = self.transport.server_object
ok = False
- if key == 'exit-status':
+ if key == "exit-status":
self.exit_status = m.get_int()
self.status_event.set()
ok = True
- elif key == 'xon-xoff':
+ elif key == "xon-xoff":
# ignore
ok = True
- elif key == 'pty-req':
+ elif key == "pty-req":
term = m.get_string()
width = m.get_int()
height = m.get_int()
@@ -1060,39 +1075,33 @@ class Channel (ClosingContextManager):
ok = False
else:
ok = server.check_channel_pty_request(
- self,
- term,
- width,
- height,
- pixelwidth,
- pixelheight,
- modes
+ self, term, width, height, pixelwidth, pixelheight, modes
)
- elif key == 'shell':
+ elif key == "shell":
if server is None:
ok = False
else:
ok = server.check_channel_shell_request(self)
- elif key == 'env':
+ elif key == "env":
name = m.get_string()
value = m.get_string()
if server is None:
ok = False
else:
ok = server.check_channel_env_request(self, name, value)
- elif key == 'exec':
+ elif key == "exec":
cmd = m.get_string()
if server is None:
ok = False
else:
ok = server.check_channel_exec_request(self, cmd)
- elif key == 'subsystem':
+ elif key == "subsystem":
name = m.get_text()
if server is None:
ok = False
else:
ok = server.check_channel_subsystem_request(self, name)
- elif key == 'window-change':
+ elif key == "window-change":
width = m.get_int()
height = m.get_int()
pixelwidth = m.get_int()
@@ -1101,8 +1110,9 @@ class Channel (ClosingContextManager):
ok = False
else:
ok = server.check_channel_window_change_request(
- self, width, height, pixelwidth, pixelheight)
- elif key == 'x11-req':
+ self, width, height, pixelwidth, pixelheight
+ )
+ elif key == "x11-req":
single_connection = m.get_boolean()
auth_proto = m.get_text()
auth_cookie = m.get_binary()
@@ -1115,9 +1125,9 @@ class Channel (ClosingContextManager):
single_connection,
auth_proto,
auth_cookie,
- screen_number
+ screen_number,
)
- elif key == 'auth-agent-req@openssh.com':
+ elif key == "auth-agent-req@openssh.com":
if server is None:
ok = False
else:
@@ -1145,7 +1155,7 @@ class Channel (ClosingContextManager):
self._pipe.set_forever()
finally:
self.lock.release()
- self._log(DEBUG, 'EOF received ({})'.format(self._name))
+ self._log(DEBUG, "EOF received ({})".format(self._name))
def _handle_close(self, m):
self.lock.acquire()
@@ -1167,7 +1177,7 @@ class Channel (ClosingContextManager):
if self.closed:
# this doesn't seem useful, but it is the documented behavior
# of Socket
- raise socket.error('Socket is closed')
+ raise socket.error("Socket is closed")
size = self._wait_for_send_window(size)
if size == 0:
# eof or similar
@@ -1194,7 +1204,7 @@ class Channel (ClosingContextManager):
return
e = self.transport.get_exception()
if e is None:
- e = SSHException('Channel closed.')
+ e = SSHException("Channel closed.")
raise e
def _set_closed(self):
@@ -1217,7 +1227,7 @@ class Channel (ClosingContextManager):
m.add_byte(cMSG_CHANNEL_EOF)
m.add_int(self.remote_chanid)
self.eof_sent = True
- self._log(DEBUG, 'EOF sent ({})'.format(self._name))
+ self._log(DEBUG, "EOF sent ({})".format(self._name))
return m
def _close_internal(self):
@@ -1251,13 +1261,14 @@ class Channel (ClosingContextManager):
if self.closed or self.eof_received or not self.active:
return 0
if self.ultra_debug:
- self._log(DEBUG, 'addwindow {}'.format(n))
+ self._log(DEBUG, "addwindow {}".format(n))
self.in_window_sofar += n
if self.in_window_sofar <= self.in_window_threshold:
return 0
if self.ultra_debug:
- self._log(DEBUG,
- 'addwindow send {}'.format(self.in_window_sofar))
+ self._log(
+ DEBUG, "addwindow send {}".format(self.in_window_sofar)
+ )
out = self.in_window_sofar
self.in_window_sofar = 0
return out
@@ -1300,11 +1311,11 @@ class Channel (ClosingContextManager):
size = self.out_max_packet_size - 64
self.out_window_size -= size
if self.ultra_debug:
- self._log(DEBUG, 'window down to {}'.format(self.out_window_size))
+ self._log(DEBUG, "window down to {}".format(self.out_window_size))
return size
-class ChannelFile (BufferedFile):
+class ChannelFile(BufferedFile):
"""
A file-like wrapper around `.Channel`. A ChannelFile is created by calling
`Channel.makefile`.
@@ -1317,7 +1328,7 @@ class ChannelFile (BufferedFile):
flush the buffer.
"""
- def __init__(self, channel, mode='r', bufsize=-1):
+ def __init__(self, channel, mode="r", bufsize=-1):
self.channel = channel
BufferedFile.__init__(self)
self._set_mode(mode, bufsize)
@@ -1326,7 +1337,7 @@ class ChannelFile (BufferedFile):
"""
Returns a string representation of this object, for debugging.
"""
- return '<paramiko.ChannelFile from ' + repr(self.channel) + '>'
+ return "<paramiko.ChannelFile from " + repr(self.channel) + ">"
def _read(self, size):
return self.channel.recv(size)
@@ -1336,8 +1347,9 @@ class ChannelFile (BufferedFile):
return len(data)
-class ChannelStderrFile (ChannelFile):
- def __init__(self, channel, mode='r', bufsize=-1):
+class ChannelStderrFile(ChannelFile):
+
+ def __init__(self, channel, mode="r", bufsize=-1):
ChannelFile.__init__(self, channel, mode, bufsize)
def _read(self, size):
diff --git a/paramiko/client.py b/paramiko/client.py
index de0a495..2fb6e62 100644
--- a/paramiko/client.py
+++ b/paramiko/client.py
@@ -38,13 +38,15 @@ from paramiko.hostkeys import HostKeys
from paramiko.py3compat import string_types
from paramiko.rsakey import RSAKey
from paramiko.ssh_exception import (
- SSHException, BadHostKeyException, NoValidConnectionsError
+ SSHException,
+ BadHostKeyException,
+ NoValidConnectionsError,
)
from paramiko.transport import Transport
from paramiko.util import retry_on_signal, ClosingContextManager
-class SSHClient (ClosingContextManager):
+class SSHClient(ClosingContextManager):
"""
A high-level representation of a session with an SSH server. This class
wraps `.Transport`, `.Channel`, and `.SFTPClient` to take care of most
@@ -97,7 +99,7 @@ class SSHClient (ClosingContextManager):
"""
if filename is None:
# try the user's .ssh key file, and mask exceptions
- filename = os.path.expanduser('~/.ssh/known_hosts')
+ filename = os.path.expanduser("~/.ssh/known_hosts")
try:
self._system_host_keys.load(filename)
except IOError:
@@ -140,12 +142,14 @@ class SSHClient (ClosingContextManager):
if self._host_keys_filename is not None:
self.load_host_keys(self._host_keys_filename)
- with open(filename, 'w') as f:
+ with open(filename, "w") as f:
for hostname, keys in self._host_keys.items():
for keytype, key in keys.items():
- f.write('{} {} {}\n'.format(
- hostname, keytype, key.get_base64()
- ))
+ f.write(
+ "{} {} {}\n".format(
+ hostname, keytype, key.get_base64()
+ )
+ )
def get_host_keys(self):
"""
@@ -197,7 +201,8 @@ class SSHClient (ClosingContextManager):
"""
guess = True
addrinfos = socket.getaddrinfo(
- hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM)
+ hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM
+ )
for (family, socktype, proto, canonname, sockaddr) in addrinfos:
if socktype == socket.SOCK_STREAM:
yield family, sockaddr
@@ -419,8 +424,16 @@ class SSHClient (ClosingContextManager):
key_filenames = key_filename
self._auth(
- username, password, pkey, key_filenames, allow_agent,
- look_for_keys, gss_auth, gss_kex, gss_deleg_creds, t.gss_host,
+ username,
+ password,
+ pkey,
+ key_filenames,
+ allow_agent,
+ look_for_keys,
+ gss_auth,
+ gss_kex,
+ gss_deleg_creds,
+ t.gss_host,
passphrase,
)
@@ -490,13 +503,20 @@ class SSHClient (ClosingContextManager):
if environment:
chan.update_environment(environment)
chan.exec_command(command)
- stdin = chan.makefile('wb', bufsize)
- stdout = chan.makefile('r', bufsize)
- stderr = chan.makefile_stderr('r', bufsize)
+ stdin = chan.makefile("wb", bufsize)
+ stdout = chan.makefile("r", bufsize)
+ stderr = chan.makefile_stderr("r", bufsize)
return stdin, stdout, stderr
- def invoke_shell(self, term='vt100', width=80, height=24, width_pixels=0,
- height_pixels=0, environment=None):
+ def invoke_shell(
+ self,
+ term="vt100",
+ width=80,
+ height=24,
+ width_pixels=0,
+ height_pixels=0,
+ environment=None,
+ ):
"""
Start an interactive shell session on the SSH server. A new `.Channel`
is opened and connected to a pseudo-terminal using the requested
@@ -545,7 +565,7 @@ class SSHClient (ClosingContextManager):
- Otherwise, the filename is assumed to be a private key, and the
matching public cert will be loaded if it exists.
"""
- cert_suffix = '-cert.pub'
+ cert_suffix = "-cert.pub"
# Assume privkey, not cert, by default
if filename.endswith(cert_suffix):
key_path = filename[:-len(cert_suffix)]
@@ -559,7 +579,7 @@ class SSHClient (ClosingContextManager):
# when #387 is released, since this is a critical log message users are
# likely testing/filtering for (bah.)
msg = "Trying discovered key {} in {}".format(
- hexlify(key.get_fingerprint()), key_path,
+ hexlify(key.get_fingerprint()), key_path
)
self._log(DEBUG, msg)
# Attempt to load cert if it exists.
@@ -569,8 +589,17 @@ class SSHClient (ClosingContextManager):
return key
def _auth(
- self, username, password, pkey, key_filenames, allow_agent,
- look_for_keys, gss_auth, gss_kex, gss_deleg_creds, gss_host,
+ self,
+ username,
+ password,
+ pkey,
+ key_filenames,
+ allow_agent,
+ look_for_keys,
+ gss_auth,
+ gss_kex,
+ gss_deleg_creds,
+ gss_host,
passphrase,
):
"""
@@ -589,7 +618,7 @@ class SSHClient (ClosingContextManager):
saved_exception = None
two_factor = False
allowed_types = set()
- two_factor_types = {'keyboard-interactive', 'password'}
+ two_factor_types = {"keyboard-interactive", "password"}
if passphrase is None and password is not None:
passphrase = password
@@ -609,7 +638,7 @@ class SSHClient (ClosingContextManager):
if gss_auth:
try:
return self._transport.auth_gssapi_with_mic(
- username, gss_host, gss_deleg_creds,
+ username, gss_host, gss_deleg_creds
)
except Exception as e:
saved_exception = e
@@ -618,10 +647,13 @@ class SSHClient (ClosingContextManager):
try:
self._log(
DEBUG,
- 'Trying SSH key {}'.format(hexlify(pkey.get_fingerprint()))
+ "Trying SSH key {}".format(
+ hexlify(pkey.get_fingerprint())
+ ),
)
allowed_types = set(
- self._transport.auth_publickey(username, pkey))
+ self._transport.auth_publickey(username, pkey)
+ )
two_factor = (allowed_types & two_factor_types)
if not two_factor:
return
@@ -633,10 +665,11 @@ class SSHClient (ClosingContextManager):
for pkey_class in (RSAKey, DSSKey, ECDSAKey, Ed25519Key):
try:
key = self._key_from_filepath(
- key_filename, pkey_class, passphrase,
+ key_filename, pkey_class, passphrase
)
allowed_types = set(
- self._transport.auth_publickey(username, key))
+ self._transport.auth_publickey(username, key)
+ )
two_factor = (allowed_types & two_factor_types)
if not two_factor:
return
@@ -651,11 +684,12 @@ class SSHClient (ClosingContextManager):
for key in self._agent.get_keys():
try:
id_ = hexlify(key.get_fingerprint())
- self._log(DEBUG, 'Trying SSH agent key {}'.format(id_))
+ self._log(DEBUG, "Trying SSH agent key {}".format(id_))
# for 2-factor auth a successfully auth'd key password
# will return an allowed 2fac auth method
allowed_types = set(
- self._transport.auth_publickey(username, key))
+ self._transport.auth_publickey(username, key)
+ )
two_factor = (allowed_types & two_factor_types)
if not two_factor:
return
@@ -680,8 +714,8 @@ class SSHClient (ClosingContextManager):
if os.path.isfile(full_path):
# TODO: only do this append if below did not run
keyfiles.append((keytype, full_path))
- if os.path.isfile(full_path + '-cert.pub'):
- keyfiles.append((keytype, full_path + '-cert.pub'))
+ if os.path.isfile(full_path + "-cert.pub"):
+ keyfiles.append((keytype, full_path + "-cert.pub"))
if not look_for_keys:
keyfiles = []
@@ -689,12 +723,13 @@ class SSHClient (ClosingContextManager):
for pkey_class, filename in keyfiles:
try:
key = self._key_from_filepath(
- filename, pkey_class, passphrase,
+ filename, pkey_class, passphrase
)
# for 2-factor auth a successfully auth'd key will result
# in ['password']
allowed_types = set(
- self._transport.auth_publickey(username, key))
+ self._transport.auth_publickey(username, key)
+ )
two_factor = (allowed_types & two_factor_types)
if not two_factor:
return
@@ -718,13 +753,13 @@ class SSHClient (ClosingContextManager):
# if we got an auth-failed exception earlier, re-raise it
if saved_exception is not None:
raise saved_exception
- raise SSHException('No authentication methods available')
+ raise SSHException("No authentication methods available")
def _log(self, level, msg):
self._transport._log(level, msg)
-class MissingHostKeyPolicy (object):
+class MissingHostKeyPolicy(object):
"""
Interface for defining the policy that `.SSHClient` should use when the
SSH server's hostname is not in either the system host keys or the
@@ -745,7 +780,7 @@ class MissingHostKeyPolicy (object):
pass
-class AutoAddPolicy (MissingHostKeyPolicy):
+class AutoAddPolicy(MissingHostKeyPolicy):
"""
Policy for automatically adding the hostname and new host key to the
local `.HostKeys` object, and saving it. This is used by `.SSHClient`.
@@ -755,32 +790,41 @@ class AutoAddPolicy (MissingHostKeyPolicy):
client._host_keys.add(hostname, key.get_name(), key)
if client._host_keys_filename is not None:
client.save_host_keys(client._host_keys_filename)
- client._log(DEBUG, 'Adding {} host key for {}: {}'.format(
- key.get_name(), hostname, hexlify(key.get_fingerprint()),
- ))
+ client._log(
+ DEBUG,
+ "Adding {} host key for {}: {}".format(
+ key.get_name(), hostname, hexlify(key.get_fingerprint())
+ ),
+ )
-class RejectPolicy (MissingHostKeyPolicy):
+class RejectPolicy(MissingHostKeyPolicy):
"""
Policy for automatically rejecting the unknown hostname & key. This is
used by `.SSHClient`.
"""
def missing_host_key(self, client, hostname, key):
- client._log(DEBUG, 'Rejecting {} host key for {}: {}'.format(
- key.get_name(), hostname, hexlify(key.get_fingerprint()),
- ))
+ client._log(
+ DEBUG,
+ "Rejecting {} host key for {}: {}".format(
+ key.get_name(), hostname, hexlify(key.get_fingerprint())
+ ),
+ )
raise SSHException(
- 'Server {!r} not found in known_hosts'.format(hostname)
+ "Server {!r} not found in known_hosts".format(hostname)
)
-class WarningPolicy (MissingHostKeyPolicy):
+class WarningPolicy(MissingHostKeyPolicy):
"""
Policy for logging a Python-style warning for an unknown host key, but
accepting it. This is used by `.SSHClient`.
"""
+
def missing_host_key(self, client, hostname, key):
- warnings.warn('Unknown {} host key for {}: {}'.format(
- key.get_name(), hostname, hexlify(key.get_fingerprint()),
- ))
+ warnings.warn(
+ "Unknown {} host key for {}: {}".format(
+ key.get_name(), hostname, hexlify(key.get_fingerprint())
+ )
+ )
diff --git a/paramiko/common.py b/paramiko/common.py
index eab6647..843c0ce 100644
--- a/paramiko/common.py
+++ b/paramiko/common.py
@@ -22,22 +22,24 @@ Common constants and global variables.
import logging
from paramiko.py3compat import byte_chr, PY2, long, b
-MSG_DISCONNECT, MSG_IGNORE, MSG_UNIMPLEMENTED, MSG_DEBUG, \
- MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT = range(1, 7)
+MSG_DISCONNECT, MSG_IGNORE, MSG_UNIMPLEMENTED, MSG_DEBUG, MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT = range(
+ 1, 7
+)
MSG_KEXINIT, MSG_NEWKEYS = range(20, 22)
-MSG_USERAUTH_REQUEST, MSG_USERAUTH_FAILURE, MSG_USERAUTH_SUCCESS, \
- MSG_USERAUTH_BANNER = range(50, 54)
+MSG_USERAUTH_REQUEST, MSG_USERAUTH_FAILURE, MSG_USERAUTH_SUCCESS, MSG_USERAUTH_BANNER = range(
+ 50, 54
+)
MSG_USERAUTH_PK_OK = 60
MSG_USERAUTH_INFO_REQUEST, MSG_USERAUTH_INFO_RESPONSE = range(60, 62)
MSG_USERAUTH_GSSAPI_RESPONSE, MSG_USERAUTH_GSSAPI_TOKEN = range(60, 62)
-MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE, MSG_USERAUTH_GSSAPI_ERROR,\
- MSG_USERAUTH_GSSAPI_ERRTOK, MSG_USERAUTH_GSSAPI_MIC = range(63, 67)
+MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE, MSG_USERAUTH_GSSAPI_ERROR, MSG_USERAUTH_GSSAPI_ERRTOK, MSG_USERAUTH_GSSAPI_MIC = range(
+ 63, 67
+)
HIGHEST_USERAUTH_MESSAGE_ID = 79
MSG_GLOBAL_REQUEST, MSG_REQUEST_SUCCESS, MSG_REQUEST_FAILURE = range(80, 83)
-MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_SUCCESS, MSG_CHANNEL_OPEN_FAILURE, \
- MSG_CHANNEL_WINDOW_ADJUST, MSG_CHANNEL_DATA, MSG_CHANNEL_EXTENDED_DATA, \
- MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST, \
- MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(90, 101)
+MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_SUCCESS, MSG_CHANNEL_OPEN_FAILURE, MSG_CHANNEL_WINDOW_ADJUST, MSG_CHANNEL_DATA, MSG_CHANNEL_EXTENDED_DATA, MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST, MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(
+ 90, 101
+)
cMSG_DISCONNECT = byte_chr(MSG_DISCONNECT)
cMSG_IGNORE = byte_chr(MSG_IGNORE)
@@ -56,8 +58,9 @@ cMSG_USERAUTH_INFO_REQUEST = byte_chr(MSG_USERAUTH_INFO_REQUEST)
cMSG_USERAUTH_INFO_RESPONSE = byte_chr(MSG_USERAUTH_INFO_RESPONSE)
cMSG_USERAUTH_GSSAPI_RESPONSE = byte_chr(MSG_USERAUTH_GSSAPI_RESPONSE)
cMSG_USERAUTH_GSSAPI_TOKEN = byte_chr(MSG_USERAUTH_GSSAPI_TOKEN)
-cMSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE = \
- byte_chr(MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE)
+cMSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE = byte_chr(
+ MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE
+)
cMSG_USERAUTH_GSSAPI_ERROR = byte_chr(MSG_USERAUTH_GSSAPI_ERROR)
cMSG_USERAUTH_GSSAPI_ERRTOK = byte_chr(MSG_USERAUTH_GSSAPI_ERRTOK)
cMSG_USERAUTH_GSSAPI_MIC = byte_chr(MSG_USERAUTH_GSSAPI_MIC)
@@ -78,47 +81,47 @@ cMSG_CHANNEL_FAILURE = byte_chr(MSG_CHANNEL_FAILURE)
# for debugging:
MSG_NAMES = {
- MSG_DISCONNECT: 'disconnect',
- MSG_IGNORE: 'ignore',
- MSG_UNIMPLEMENTED: 'unimplemented',
- MSG_DEBUG: 'debug',
- MSG_SERVICE_REQUEST: 'service-request',
- MSG_SERVICE_ACCEPT: 'service-accept',
- MSG_KEXINIT: 'kexinit',
- MSG_NEWKEYS: 'newkeys',
- 30: 'kex30',
- 31: 'kex31',
- 32: 'kex32',
- 33: 'kex33',
- 34: 'kex34',
- 40: 'kex40',
- 41: 'kex41',
- MSG_USERAUTH_REQUEST: 'userauth-request',
- MSG_USERAUTH_FAILURE: 'userauth-failure',
- MSG_USERAUTH_SUCCESS: 'userauth-success',
- MSG_USERAUTH_BANNER: 'userauth--banner',
- MSG_USERAUTH_PK_OK: 'userauth-60(pk-ok/info-request)',
- MSG_USERAUTH_INFO_RESPONSE: 'userauth-info-response',
- MSG_GLOBAL_REQUEST: 'global-request',
- MSG_REQUEST_SUCCESS: 'request-success',
- MSG_REQUEST_FAILURE: 'request-failure',
- MSG_CHANNEL_OPEN: 'channel-open',
- MSG_CHANNEL_OPEN_SUCCESS: 'channel-open-success',
- MSG_CHANNEL_OPEN_FAILURE: 'channel-open-failure',
- MSG_CHANNEL_WINDOW_ADJUST: 'channel-window-adjust',
- MSG_CHANNEL_DATA: 'channel-data',
- MSG_CHANNEL_EXTENDED_DATA: 'channel-extended-data',
- MSG_CHANNEL_EOF: 'channel-eof',
- MSG_CHANNEL_CLOSE: 'channel-close',
- MSG_CHANNEL_REQUEST: 'channel-request',
- MSG_CHANNEL_SUCCESS: 'channel-success',
- MSG_CHANNEL_FAILURE: 'channel-failure',
- MSG_USERAUTH_GSSAPI_RESPONSE: 'userauth-gssapi-response',
- MSG_USERAUTH_GSSAPI_TOKEN: 'userauth-gssapi-token',
- MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE: 'userauth-gssapi-exchange-complete',
- MSG_USERAUTH_GSSAPI_ERROR: 'userauth-gssapi-error',
- MSG_USERAUTH_GSSAPI_ERRTOK: 'userauth-gssapi-error-token',
- MSG_USERAUTH_GSSAPI_MIC: 'userauth-gssapi-mic'
+ MSG_DISCONNECT: "disconnect",
+ MSG_IGNORE: "ignore",
+ MSG_UNIMPLEMENTED: "unimplemented",
+ MSG_DEBUG: "debug",
+ MSG_SERVICE_REQUEST: "service-request",
+ MSG_SERVICE_ACCEPT: "service-accept",
+ MSG_KEXINIT: "kexinit",
+ MSG_NEWKEYS: "newkeys",
+ 30: "kex30",
+ 31: "kex31",
+ 32: "kex32",
+ 33: "kex33",
+ 34: "kex34",
+ 40: "kex40",
+ 41: "kex41",
+ MSG_USERAUTH_REQUEST: "userauth-request",
+ MSG_USERAUTH_FAILURE: "userauth-failure",
+ MSG_USERAUTH_SUCCESS: "userauth-success",
+ MSG_USERAUTH_BANNER: "userauth--banner",
+ MSG_USERAUTH_PK_OK: "userauth-60(pk-ok/info-request)",
+ MSG_USERAUTH_INFO_RESPONSE: "userauth-info-response",
+ MSG_GLOBAL_REQUEST: "global-request",
+ MSG_REQUEST_SUCCESS: "request-success",
+ MSG_REQUEST_FAILURE: "request-failure",
+ MSG_CHANNEL_OPEN: "channel-open",
+ MSG_CHANNEL_OPEN_SUCCESS: "channel-open-success",
+ MSG_CHANNEL_OPEN_FAILURE: "channel-open-failure",
+ MSG_CHANNEL_WINDOW_ADJUST: "channel-window-adjust",
+ MSG_CHANNEL_DATA: "channel-data",
+ MSG_CHANNEL_EXTENDED_DATA: "channel-extended-data",
+ MSG_CHANNEL_EOF: "channel-eof",
+ MSG_CHANNEL_CLOSE: "channel-close",
+ MSG_CHANNEL_REQUEST: "channel-request",
+ MSG_CHANNEL_SUCCESS: "channel-success",
+ MSG_CHANNEL_FAILURE: "channel-failure",
+ MSG_USERAUTH_GSSAPI_RESPONSE: "userauth-gssapi-response",
+ MSG_USERAUTH_GSSAPI_TOKEN: "userauth-gssapi-token",
+ MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE: "userauth-gssapi-exchange-complete",
+ MSG_USERAUTH_GSSAPI_ERROR: "userauth-gssapi-error",
+ MSG_USERAUTH_GSSAPI_ERRTOK: "userauth-gssapi-error-token",
+ MSG_USERAUTH_GSSAPI_MIC: "userauth-gssapi-mic",
}
@@ -127,23 +130,26 @@ AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED = range(3)
# channel request failed reasons:
-(OPEN_SUCCEEDED,
- OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED,
- OPEN_FAILED_CONNECT_FAILED,
- OPEN_FAILED_UNKNOWN_CHANNEL_TYPE,
- OPEN_FAILED_RESOURCE_SHORTAGE) = range(0, 5)
+(
+ OPEN_SUCCEEDED,
+ OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED,
+ OPEN_FAILED_CONNECT_FAILED,
+ OPEN_FAILED_UNKNOWN_CHANNEL_TYPE,
+ OPEN_FAILED_RESOURCE_SHORTAGE,
+) = range(
+ 0, 5
+)
CONNECTION_FAILED_CODE = {
- 1: 'Administratively prohibited',
- 2: 'Connect failed',
- 3: 'Unknown channel type',
- 4: 'Resource shortage'
+ 1: "Administratively prohibited",
+ 2: "Connect failed",
+ 3: "Unknown channel type",
+ 4: "Resource shortage",
}
-DISCONNECT_SERVICE_NOT_AVAILABLE, DISCONNECT_AUTH_CANCELLED_BY_USER, \
- DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE = 7, 13, 14
+DISCONNECT_SERVICE_NOT_AVAILABLE, DISCONNECT_AUTH_CANCELLED_BY_USER, DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE = 7, 13, 14
zero_byte = byte_chr(0)
one_byte = byte_chr(1)
diff --git a/paramiko/compress.py b/paramiko/compress.py
index 5073109..cea2b30 100644
--- a/paramiko/compress.py
+++ b/paramiko/compress.py
@@ -23,7 +23,8 @@ Compression implementations for a Transport.
import zlib
-class ZlibCompressor (object):
+class ZlibCompressor(object):
+
def __init__(self):
# Use the default level of zlib compression
self.z = zlib.compressobj()
@@ -32,7 +33,8 @@ class ZlibCompressor (object):
return self.z.compress(data) + self.z.flush(zlib.Z_FULL_FLUSH)
-class ZlibDecompressor (object):
+class ZlibDecompressor(object):
+
def __init__(self):
self.z = zlib.decompressobj()
diff --git a/paramiko/config.py b/paramiko/config.py
index 35a5017..d0f4899 100644
--- a/paramiko/config.py
+++ b/paramiko/config.py
@@ -30,7 +30,7 @@ import socket
SSH_PORT = 22
-class SSHConfig (object):
+class SSHConfig(object):
"""
Representation of config information as stored in the format used by
OpenSSH. Queries can be made via `lookup`. The format is described in
@@ -41,7 +41,7 @@ class SSHConfig (object):
.. versionadded:: 1.6
"""
- SETTINGS_REGEX = re.compile(r'(\w+)(?:\s*=\s*|\s+)(.+)')
+ SETTINGS_REGEX = re.compile(r"(\w+)(?:\s*=\s*|\s+)(.+)")
def __init__(self):
"""
@@ -55,12 +55,12 @@ class SSHConfig (object):
:param file_obj: a file-like object to read the config file from
"""
- host = {"host": ['*'], "config": {}}
+ host = {"host": ["*"], "config": {}}
for line in file_obj:
# Strip any leading or trailing whitespace from the line.
# Refer to https://github.com/paramiko/paramiko/issues/499
line = line.strip()
- if not line or line.startswith('#'):
+ if not line or line.startswith("#"):
continue
match = re.match(self.SETTINGS_REGEX, line)
@@ -69,17 +69,14 @@ class SSHConfig (object):
key = match.group(1).lower()
value = match.group(2)
- if key == 'host':
+ if key == "host":
self._config.append(host)
- host = {
- 'host': self._get_hosts(value),
- 'config': {}
- }
- elif key == 'proxycommand' and value.lower() == 'none':
+ host = {"host": self._get_hosts(value), "config": {}}
+ elif key == "proxycommand" and value.lower() == "none":
# Store 'none' as None; prior to 3.x, it will get stripped out
# at the end (for compatibility with issue #415). After 3.x, it
# will simply not get stripped, leaving a nice explicit marker.
- host['config'][key] = None
+ host["config"][key] = None
else:
if value.startswith('"') and value.endswith('"'):
value = value[1:-1]
@@ -87,13 +84,13 @@ class SSHConfig (object):
# identityfile, localforward, remoteforward keys are special
# cases, since they are allowed to be specified multiple times
# and they should be tried in order of specification.
- if key in ['identityfile', 'localforward', 'remoteforward']:
- if key in host['config']:
- host['config'][key].append(value)
+ if key in ["identityfile", "localforward", "remoteforward"]:
+ if key in host["config"]:
+ host["config"][key].append(value)
else:
- host['config'][key] = [value]
- elif key not in host['config']:
- host['config'][key] = value
+ host["config"][key] = [value]
+ elif key not in host["config"]:
+ host["config"][key] = value
self._config.append(host)
def lookup(self, hostname):
@@ -117,25 +114,26 @@ class SSHConfig (object):
:param str hostname: the hostname to lookup
"""
matches = [
- config for config in self._config
- if self._allowed(config['host'], hostname)
+ config
+ for config in self._config
+ if self._allowed(config["host"], hostname)
]
ret = SSHConfigDict()
for match in matches:
- for key, value in match['config'].items():
+ for key, value in match["config"].items():
if key not in ret:
# Create a copy of the original value,
# else it will reference the original list
# in self._config and update that value too
# when the extend() is being called.
ret[key] = value[:] if value is not None else value
- elif key == 'identityfile':
+ elif key == "identityfile":
ret[key].extend(value)
ret = self._expand_variables(ret, hostname)
# TODO: remove in 3.x re #670
- if 'proxycommand' in ret and ret['proxycommand'] is None:
- del ret['proxycommand']
+ if "proxycommand" in ret and ret["proxycommand"] is None:
+ del ret["proxycommand"]
return ret
def get_hostnames(self):
@@ -145,13 +143,13 @@ class SSHConfig (object):
"""
hosts = set()
for entry in self._config:
- hosts.update(entry['host'])
+ hosts.update(entry["host"])
return hosts
def _allowed(self, hosts, hostname):
match = False
for host in hosts:
- if host.startswith('!') and fnmatch.fnmatch(hostname, host[1:]):
+ if host.startswith("!") and fnmatch.fnmatch(hostname, host[1:]):
return False
elif fnmatch.fnmatch(hostname, host):
match = True
@@ -169,52 +167,50 @@ class SSHConfig (object):
:param str hostname: the hostname that the config belongs to
"""
- if 'hostname' in config:
- config['hostname'] = config['hostname'].replace('%h', hostname)
+ if "hostname" in config:
+ config["hostname"] = config["hostname"].replace("%h", hostname)
else:
- config['hostname'] = hostname
+ config["hostname"] = hostname
- if 'port' in config:
- port = config['port']
+ if "port" in config:
+ port = config["port"]
else:
port = SSH_PORT
- user = os.getenv('USER')
- if 'user' in config:
- remoteuser = config['user']
+ user = os.getenv("USER")
+ if "user" in config:
+ remoteuser = config["user"]
else:
remoteuser = user
- host = socket.gethostname().split('.')[0]
+ host = socket.gethostname().split(".")[0]
fqdn = LazyFqdn(config, host)
- homedir = os.path.expanduser('~')
- replacements = {'controlpath':
- [
- ('%h', config['hostname']),
- ('%l', fqdn),
- ('%L', host),
- ('%n', hostname),
- ('%p', port),
- ('%r', remoteuser),
- ('%u', user)
- ],
- 'identityfile':
- [
- ('~', homedir),
- ('%d', homedir),
- ('%h', config['hostname']),
- ('%l', fqdn),
- ('%u', user),
- ('%r', remoteuser)
- ],
- 'proxycommand':
- [
- ('~', homedir),
- ('%h', config['hostname']),
- ('%p', port),
- ('%r', remoteuser)
- ]
- }
+ homedir = os.path.expanduser("~")
+ replacements = {
+ "controlpath": [
+ ("%h", config["hostname"]),
+ ("%l", fqdn),
+ ("%L", host),
+ ("%n", hostname),
+ ("%p", port),
+ ("%r", remoteuser),
+ ("%u", user),
+ ],
+ "identityfile": [
+ ("~", homedir),
+ ("%d", homedir),
+ ("%h", config["hostname"]),
+ ("%l", fqdn),
+ ("%u", user),
+ ("%r", remoteuser),
+ ],
+ "proxycommand": [
+ ("~", homedir),
+ ("%h", config["hostname"]),
+ ("%p", port),
+ ("%r", remoteuser),
+ ],
+ }
for k in config:
if config[k] is None:
@@ -265,11 +261,11 @@ class LazyFqdn(object):
# Handle specific option
fqdn = None
- address_family = self.config.get('addressfamily', 'any').lower()
- if address_family != 'any':
+ address_family = self.config.get("addressfamily", "any").lower()
+ if address_family != "any":
try:
family = socket.AF_INET6
- if address_family == 'inet':
+ if address_family == "inet":
socket.AF_INET
results = socket.getaddrinfo(
self.host,
@@ -277,11 +273,11 @@ class LazyFqdn(object):
family,
socket.SOCK_DGRAM,
socket.IPPROTO_IP,
- socket.AI_CANONNAME
+ socket.AI_CANONNAME,
)
for res in results:
af, socktype, proto, canonname, sa = res
- if canonname and '.' in canonname:
+ if canonname and "." in canonname:
fqdn = canonname
break
# giaerror -> socket.getaddrinfo() can't resolve self.host
@@ -315,7 +311,7 @@ class SSHConfigDict(dict):
val = self[key]
if isinstance(val, bool):
return val
- return val.lower() == 'yes'
+ return val.lower() == "yes"
def as_int(self, key):
"""Express the key as a true integer, if possible. Raises an Error otherwise
diff --git a/paramiko/dsskey.py b/paramiko/dsskey.py
index ac1d4c2..fc9e7b5 100644
--- a/paramiko/dsskey.py
+++ b/paramiko/dsskey.py
@@ -25,7 +25,8 @@ from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import dsa
from cryptography.hazmat.primitives.asymmetric.utils import (
- decode_dss_signature, encode_dss_signature
+ decode_dss_signature,
+ encode_dss_signature,
)
from paramiko import util
@@ -42,8 +43,15 @@ class DSSKey(PKey):
data.
"""
- def __init__(self, msg=None, data=None, filename=None, password=None,
- vals=None, file_obj=None):
+ def __init__(
+ self,
+ msg=None,
+ data=None,
+ filename=None,
+ password=None,
+ vals=None,
+ file_obj=None,
+ ):
self.p = None
self.q = None
self.g = None
@@ -63,8 +71,8 @@ class DSSKey(PKey):
else:
self._check_type_and_load_cert(
msg=msg,
- key_type='ssh-dss',
- cert_type='ssh-dss-cert-v01@openssh.com',
+ key_type="ssh-dss",
+ cert_type="ssh-dss-cert-v01@openssh.com",
)
self.p = msg.get_mpint()
self.q = msg.get_mpint()
@@ -74,7 +82,7 @@ class DSSKey(PKey):
def asbytes(self):
m = Message()
- m.add_string('ssh-dss')
+ m.add_string("ssh-dss")
m.add_mpint(self.p)
m.add_mpint(self.q)
m.add_mpint(self.g)
@@ -88,7 +96,7 @@ class DSSKey(PKey):
return hash((self.get_name(), self.p, self.q, self.g, self.y))
def get_name(self):
- return 'ssh-dss'
+ return "ssh-dss"
def get_bits(self):
return self.size
@@ -102,17 +110,17 @@ class DSSKey(PKey):
public_numbers=dsa.DSAPublicNumbers(
y=self.y,
parameter_numbers=dsa.DSAParameterNumbers(
- p=self.p,
- q=self.q,
- g=self.g
- )
- )
- ).private_key(backend=default_backend())
+ p=self.p, q=self.q, g=self.g
+ ),
+ ),
+ ).private_key(
+ backend=default_backend()
+ )
sig = key.sign(data, hashes.SHA1())
r, s = decode_dss_signature(sig)
m = Message()
- m.add_string('ssh-dss')
+ m.add_string("ssh-dss")
# apparently, in rare cases, r or s may be shorter than 20 bytes!
rstr = util.deflate_long(r, 0)
sstr = util.deflate_long(s, 0)
@@ -129,7 +137,7 @@ class DSSKey(PKey):
sig = msg.asbytes()
else:
kind = msg.get_text()
- if kind != 'ssh-dss':
+ if kind != "ssh-dss":
return 0
sig = msg.get_binary()
@@ -142,11 +150,11 @@ class DSSKey(PKey):
key = dsa.DSAPublicNumbers(
y=self.y,
parameter_numbers=dsa.DSAParameterNumbers(
- p=self.p,
- q=self.q,
- g=self.g
- )
- ).public_key(backend=default_backend())
+ p=self.p, q=self.q, g=self.g
+ ),
+ ).public_key(
+ backend=default_backend()
+ )
try:
key.verify(signature, data, hashes.SHA1())
except InvalidSignature:
@@ -160,18 +168,18 @@ class DSSKey(PKey):
public_numbers=dsa.DSAPublicNumbers(
y=self.y,
parameter_numbers=dsa.DSAParameterNumbers(
- p=self.p,
- q=self.q,
- g=self.g
- )
- )
- ).private_key(backend=default_backend())
+ p=self.p, q=self.q, g=self.g
+ ),
+ ),
+ ).private_key(
+ backend=default_backend()
+ )
self._write_private_key_file(
filename,
key,
serialization.PrivateFormat.TraditionalOpenSSL,
- password=password
+ password=password,
)
def write_private_key(self, file_obj, password=None):
@@ -180,18 +188,18 @@ class DSSKey(PKey):
public_numbers=dsa.DSAPublicNumbers(
y=self.y,
parameter_numbers=dsa.DSAParameterNumbers(
- p=self.p,
- q=self.q,
- g=self.g
- )
- )
- ).private_key(backend=default_backend())
+ p=self.p, q=self.q, g=self.g
+ ),
+ ),
+ ).private_key(
+ backend=default_backend()
+ )
self._write_private_key(
file_obj,
key,
serialization.PrivateFormat.TraditionalOpenSSL,
- password=password
+ password=password,
)
@staticmethod
@@ -207,23 +215,25 @@ class DSSKey(PKey):
numbers = dsa.generate_private_key(
bits, backend=default_backend()
).private_numbers()
- key = DSSKey(vals=(
- numbers.public_numbers.parameter_numbers.p,
- numbers.public_numbers.parameter_numbers.q,
- numbers.public_numbers.parameter_numbers.g,
- numbers.public_numbers.y
- ))
+ key = DSSKey(
+ vals=(
+ numbers.public_numbers.parameter_numbers.p,
+ numbers.public_numbers.parameter_numbers.q,
+ numbers.public_numbers.parameter_numbers.g,
+ numbers.public_numbers.y,
+ )
+ )
key.x = numbers.x
return key
# ...internals...
def _from_private_key_file(self, filename, password):
- data = self._read_private_key_file('DSA', filename, password)
+ data = self._read_private_key_file("DSA", filename, password)
self._decode_key(data)
def _from_private_key(self, file_obj, password):
- data = self._read_private_key('DSA', file_obj, password)
+ data = self._read_private_key("DSA", file_obj, password)
self._decode_key(data)
def _decode_key(self, data):
@@ -232,14 +242,11 @@ class DSSKey(PKey):
try:
keylist = BER(data).decode()
except BERException as e:
- raise SSHException('Unable to parse key file: ' + str(e))
- if (
- type(keylist) is not list or
- len(keylist) < 6 or
- keylist[0] != 0
- ):
+ raise SSHException("Unable to parse key file: " + str(e))
+ if type(keylist) is not list or len(keylist) < 6 or keylist[0] != 0:
raise SSHException(
- 'not a valid DSA private key file (bad ber encoding)')
+ "not a valid DSA private key file (bad ber encoding)"
+ )
self.p = keylist[1]
self.q = keylist[2]
self.g = keylist[3]
diff --git a/paramiko/ecdsakey.py b/paramiko/ecdsakey.py
index 92e01a7..4b7984c 100644
--- a/paramiko/ecdsakey.py
+++ b/paramiko/ecdsakey.py
@@ -25,7 +25,8 @@ from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.asymmetric.utils import (
- decode_dss_signature, encode_dss_signature
+ decode_dss_signature,
+ encode_dss_signature,
)
from paramiko.common import four_byte
@@ -43,6 +44,7 @@ class _ECDSACurve(object):
the proper hash function. Also grabs the proper curve from the 'ecdsa'
package.
"""
+
def __init__(self, curve_class, nist_name):
self.nist_name = nist_name
self.key_length = curve_class.key_size
@@ -67,6 +69,7 @@ class _ECDSACurveSet(object):
format identifier. The two ways in which ECDSAKey needs to be able to look
up curves.
"""
+
def __init__(self, ecdsa_curves):
self.ecdsa_curves = ecdsa_curves
@@ -95,14 +98,24 @@ class ECDSAKey(PKey):
data.
"""
- _ECDSA_CURVES = _ECDSACurveSet([
- _ECDSACurve(ec.SECP256R1, 'nistp256'),
- _ECDSACurve(ec.SECP384R1, 'nistp384'),
- _ECDSACurve(ec.SECP521R1, 'nistp521'),
- ])
-
- def __init__(self, msg=None, data=None, filename=None, password=None,
- vals=None, file_obj=None, validate_point=True):
+ _ECDSA_CURVES = _ECDSACurveSet(
+ [
+ _ECDSACurve(ec.SECP256R1, "nistp256"),
+ _ECDSACurve(ec.SECP384R1, "nistp384"),
+ _ECDSACurve(ec.SECP521R1, "nistp521"),
+ ]
+ )
+
+ def __init__(
+ self,
+ msg=None,
+ data=None,
+ filename=None,
+ password=None,
+ vals=None,
+ file_obj=None,
+ validate_point=True,
+ ):
self.verifying_key = None
self.signing_key = None
self.public_blob = None
@@ -126,7 +139,7 @@ class ECDSAKey(PKey):
# identifier, so strip out any cert business. (NOTE: could push
# that into _ECDSACurveSet.get_by_key_format_identifier(), but it
# feels more correct to do it here?)
- suffix = '-cert-v01@openssh.com'
+ suffix = "-cert-v01@openssh.com"
if key_type.endswith(suffix):
key_type = key_type[:-len(suffix)]
self.ecdsa_curve = self._ECDSA_CURVES.get_by_key_format_identifier(
@@ -134,13 +147,10 @@ class ECDSAKey(PKey):
)
key_types = self._ECDSA_CURVES.get_key_format_identifier_list()
cert_types = [
- '{}-cert-v01@openssh.com'.format(x)
- for x in key_types
+ "{}-cert-v01@openssh.com".format(x) for x in key_types
]
self._check_type_and_load_cert(
- msg=msg,
- key_type=key_types,
- cert_type=cert_types,
+ msg=msg, key_type=key_types, cert_type=cert_types
)
curvename = msg.get_text()
if curvename != self.ecdsa_curve.nist_name:
@@ -172,10 +182,10 @@ class ECDSAKey(PKey):
key_size_bytes = (key.curve.key_size + 7) // 8
x_bytes = deflate_long(numbers.x, add_sign_padding=False)
- x_bytes = b'\x00' * (key_size_bytes - len(x_bytes)) + x_bytes
+ x_bytes = b"\x00" * (key_size_bytes - len(x_bytes)) + x_bytes
y_bytes = deflate_long(numbers.y, add_sign_padding=False)
- y_bytes = b'\x00' * (key_size_bytes - len(y_bytes)) + y_bytes
+ y_bytes = b"\x00" * (key_size_bytes - len(y_bytes)) + y_bytes
point_str = four_byte + x_bytes + y_bytes
m.add_string(point_str)
@@ -185,8 +195,13 @@ class ECDSAKey(PKey):
return self.asbytes()
def __hash__(self):
- return hash((self.get_name(), self.verifying_key.public_numbers().x,
- self.verifying_key.public_numbers().y))
+ return hash(
+ (
+ self.get_name(),
+ self.verifying_key.public_numbers().x,
+ self.verifying_key.public_numbers().y,
+ )
+ )
def get_name(self):
return self.ecdsa_curve.key_format_identifier
@@ -228,7 +243,7 @@ class ECDSAKey(PKey):
filename,
self.signing_key,
serialization.PrivateFormat.TraditionalOpenSSL,
- password=password
+ password=password,
)
def write_private_key(self, file_obj, password=None):
@@ -236,7 +251,7 @@ class ECDSAKey(PKey):
file_obj,
self.signing_key,
serialization.PrivateFormat.TraditionalOpenSSL,
- password=password
+ password=password,
)
@classmethod
@@ -260,11 +275,11 @@ class ECDSAKey(PKey):
# ...internals...
def _from_private_key_file(self, filename, password):
- data = self._read_private_key_file('EC', filename, password)
+ data = self._read_private_key_file("EC", filename, password)
self._decode_key(data)
def _from_private_key(self, file_obj, password):
- data = self._read_private_key('EC', file_obj, password)
+ data = self._read_private_key("EC", file_obj, password)
self._decode_key(data)
def _decode_key(self, data):
diff --git a/paramiko/ed25519key.py b/paramiko/ed25519key.py
index 8ad71d0..c8f6dd3 100644
--- a/paramiko/ed25519key.py
+++ b/paramiko/ed25519key.py
@@ -56,8 +56,10 @@ class Ed25519Key(PKey):
.. versionchanged:: 2.3
Added a ``file_obj`` parameter to match other key classes.
"""
- def __init__(self, msg=None, data=None, filename=None, password=None,
- file_obj=None):
+
+ def __init__(
+ self, msg=None, data=None, filename=None, password=None, file_obj=None
+ ):
self.public_blob = None
verifying_key = signing_key = None
if msg is None and data is not None:
@@ -86,6 +88,7 @@ class Ed25519Key(PKey):
def _parse_signing_key_data(self, data, password):
from paramiko.transport import Transport
+
# We may eventually want this to be usable for other key types, as
# OpenSSH moves to it, but for now this is just for Ed25519 keys.
# This format is described here:
@@ -144,7 +147,7 @@ class Ed25519Key(PKey):
decryptor = Cipher(
cipher["class"](key[:cipher["key-size"]]),
cipher["mode"](key[cipher["key-size"]:]),
- backend=default_backend()
+ backend=default_backend(),
).decryptor()
private_data = (
decryptor.update(private_ciphertext) + decryptor.finalize()
@@ -166,8 +169,10 @@ class Ed25519Key(PKey):
signing_key = nacl.signing.SigningKey(key_data[:32])
# Verify that all the public keys are the same...
assert (
- signing_key.verify_key.encode() == public == public_keys[i] ==
- key_data[32:]
+ signing_key.verify_key.encode()
+ == public
+ == public_keys[i]
+ == key_data[32:]
)
signing_keys.append(signing_key)
# Comment, ignore.
diff --git a/paramiko/file.py b/paramiko/file.py
index df9cdac..62686b5 100644
--- a/paramiko/file.py
+++ b/paramiko/file.py
@@ -16,14 +16,18 @@
# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
from paramiko.common import (
- linefeed_byte_value, crlf, cr_byte, linefeed_byte, cr_byte_value,
+ linefeed_byte_value,
+ crlf,
+ cr_byte,
+ linefeed_byte,
+ cr_byte_value,
)
from paramiko.py3compat import BytesIO, PY2, u, bytes_types, text_type
from paramiko.util import ClosingContextManager
-class BufferedFile (ClosingContextManager):
+class BufferedFile(ClosingContextManager):
"""
Reusable base class to implement Python-style file buffering around a
simpler stream.
@@ -70,7 +74,7 @@ class BufferedFile (ClosingContextManager):
:raises: ``ValueError`` -- if the file is closed.
"""
if self._closed:
- raise ValueError('I/O operation on closed file')
+ raise ValueError("I/O operation on closed file")
return self
def close(self):
@@ -90,6 +94,7 @@ class BufferedFile (ClosingContextManager):
return
if PY2:
+
def next(self):
"""
Returns the next line from the input, or raises
@@ -104,7 +109,9 @@ class BufferedFile (ClosingContextManager):
if not line:
raise StopIteration
return line
+
else:
+
def __next__(self):
"""
Returns the next line from the input, or raises ``StopIteration``
@@ -180,9 +187,9 @@ class BufferedFile (ClosingContextManager):
encountered immediately
"""
if self._closed:
- raise IOError('File is closed')
+ raise IOError("File is closed")
if not (self._flags & self.FLAG_READ):
- raise IOError('File is not open for reading')
+ raise IOError("File is not open for reading")
if (size is None) or (size < 0):
# go for broke
result = self._rbuffer
@@ -245,16 +252,16 @@ class BufferedFile (ClosingContextManager):
"""
# it's almost silly how complex this function is.
if self._closed:
- raise IOError('File is closed')
+ raise IOError("File is closed")
if not (self._flags & self.FLAG_READ):
- raise IOError('File not open for reading')
+ raise IOError("File not open for reading")
line = self._rbuffer
truncated = False
while True:
if (
- self._at_trailing_cr and
- self._flags & self.FLAG_UNIVERSAL_NEWLINE and
- len(line) > 0
+ self._at_trailing_cr
+ and self._flags & self.FLAG_UNIVERSAL_NEWLINE
+ and len(line) > 0
):
# edge case: the newline may be '\r\n' and we may have read
# only the first '\r' last time.
@@ -277,10 +284,10 @@ class BufferedFile (ClosingContextManager):
else:
n = self._bufsize
if (
- linefeed_byte in line or
- (
- self._flags & self.FLAG_UNIVERSAL_NEWLINE and
- cr_byte in line
+ linefeed_byte in line
+ or (
+ self._flags & self.FLAG_UNIVERSAL_NEWLINE
+ and cr_byte in line
)
):
break
@@ -306,9 +313,9 @@ class BufferedFile (ClosingContextManager):
return line if self._flags & self.FLAG_BINARY else u(line)
xpos = pos + 1
if (
- line[pos] == cr_byte_value and
- xpos < len(line) and
- line[xpos] == linefeed_byte_value
+ line[pos] == cr_byte_value
+ and xpos < len(line)
+ and line[xpos] == linefeed_byte_value
):
xpos += 1
# if the string was truncated, _rbuffer needs to have the string after
@@ -370,7 +377,7 @@ class BufferedFile (ClosingContextManager):
:raises: ``IOError`` -- if the file doesn't support random access.
"""
- raise IOError('File does not support seeking.')
+ raise IOError("File does not support seeking.")
def tell(self):
"""
@@ -393,11 +400,11 @@ class BufferedFile (ClosingContextManager):
"""
if isinstance(data, text_type):
# Accept text and encode as utf-8 for compatibility only.
- data = data.encode('utf-8')
+ data = data.encode("utf-8")
if self._closed:
- raise IOError('File is closed')
+ raise IOError("File is closed")
if not (self._flags & self.FLAG_WRITE):
- raise IOError('File not open for writing')
+ raise IOError("File not open for writing")
if not (self._flags & self.FLAG_BUFFERED):
self._write_all(data)
return
@@ -457,7 +464,7 @@ class BufferedFile (ClosingContextManager):
(subclass override)
Write data into the stream.
"""
- raise IOError('write not implemented')
+ raise IOError("write not implemented")
def _get_size(self):
"""
@@ -472,7 +479,7 @@ class BufferedFile (ClosingContextManager):
# ...internals...
- def _set_mode(self, mode='r', bufsize=-1):
+ def _set_mode(self, mode="r", bufsize=-1):
"""
Subclasses call this method to initialize the BufferedFile.
"""
@@ -495,17 +502,17 @@ class BufferedFile (ClosingContextManager):
# unbuffered
self._flags &= ~(self.FLAG_BUFFERED | self.FLAG_LINE_BUFFERED)
- if ('r' in mode) or ('+' in mode):
+ if ("r" in mode) or ("+" in mode):
self._flags |= self.FLAG_READ
- if ('w' in mode) or ('+' in mode):
+ if ("w" in mode) or ("+" in mode):
self._flags |= self.FLAG_WRITE
- if 'a' in mode:
+ if "a" in mode:
self._flags |= self.FLAG_WRITE | self.FLAG_APPEND
self._size = self._get_size()
self._pos = self._realpos = self._size
- if 'b' in mode:
+ if "b" in mode:
self._flags |= self.FLAG_BINARY
- if 'U' in mode:
+ if "U" in mode:
self._flags |= self.FLAG_UNIVERSAL_NEWLINE
# built-in file objects have this attribute to store which kinds of
# line terminations they've seen:
@@ -535,8 +542,7 @@ class BufferedFile (ClosingContextManager):
if self.newlines is None:
self.newlines = newline
elif (
- self.newlines != newline and
- isinstance(self.newlines, bytes_types)
+ self.newlines != newline and isinstance(self.newlines, bytes_types)
):
self.newlines = (self.newlines, newline)
elif newline not in self.newlines:
diff --git a/paramiko/hostkeys.py b/paramiko/hostkeys.py
index ca18527..c4d611d 100644
--- a/paramiko/hostkeys.py
+++ b/paramiko/hostkeys.py
@@ -34,7 +34,7 @@ from paramiko.ed25519key import Ed25519Key
from paramiko.ssh_exception import SSHException
-class HostKeys (MutableMapping):
+class HostKeys(MutableMapping):
"""
Representation of an OpenSSH-style "known hosts" file. Host keys can be
read from one or more files, and then individual hosts can be looked up to
@@ -88,10 +88,10 @@ class HostKeys (MutableMapping):
:raises: ``IOError`` -- if there was an error reading the file
"""
- with open(filename, 'r') as f:
+ with open(filename, "r") as f:
for lineno, line in enumerate(f, 1):
line = line.strip()
- if (len(line) == 0) or (line[0] == '#'):
+ if (len(line) == 0) or (line[0] == "#"):
continue
try:
e = HostKeyEntry.from_line(line, lineno)
@@ -118,7 +118,7 @@ class HostKeys (MutableMapping):
.. versionadded:: 1.6.1
"""
- with open(filename, 'w') as f:
+ with open(filename, "w") as f:
for e in self._entries:
line = e.to_line()
if line:
@@ -134,7 +134,9 @@ class HostKeys (MutableMapping):
:return: dict of `str` -> `.PKey` keys associated with this host
(or ``None``)
"""
- class SubDict (MutableMapping):
+
+ class SubDict(MutableMapping):
+
def __init__(self, hostname, entries, hostkeys):
self._hostname = hostname
self._entries = entries
@@ -176,7 +178,8 @@ class HostKeys (MutableMapping):
def keys(self):
return [
- e.key.get_name() for e in self._entries
+ e.key.get_name()
+ for e in self._entries
if e.key is not None
]
@@ -196,10 +199,10 @@ class HostKeys (MutableMapping):
"""
for h in entry.hostnames:
if (
- h == hostname or
- h.startswith('|1|') and
- not hostname.startswith('|1|') and
- constant_time_bytes_eq(self.hash_host(hostname, h), h)
+ h == hostname
+ or h.startswith("|1|")
+ and not hostname.startswith("|1|")
+ and constant_time_bytes_eq(self.hash_host(hostname, h), h)
):
return True
return False
@@ -295,16 +298,17 @@ class HostKeys (MutableMapping):
if salt is None:
salt = os.urandom(sha1().digest_size)
else:
- if salt.startswith('|1|'):
- salt = salt.split('|')[2]
+ if salt.startswith("|1|"):
+ salt = salt.split("|")[2]
salt = decodebytes(b(salt))
assert len(salt) == sha1().digest_size
hmac = HMAC(salt, b(hostname), sha1).digest()
- hostkey = '|1|{}|{}'.format(u(encodebytes(salt)), u(encodebytes(hmac)))
- return hostkey.replace('\n', '')
+ hostkey = "|1|{}|{}".format(u(encodebytes(salt)), u(encodebytes(hmac)))
+ return hostkey.replace("\n", "")
class InvalidHostKey(Exception):
+
def __init__(self, line, exc):
self.line = line
self.exc = exc
@@ -334,8 +338,8 @@ class HostKeyEntry:
:param str line: a line from an OpenSSH known_hosts file
"""
- log = get_logger('paramiko.hostkeys')
- fields = line.split(' ')
+ log = get_logger("paramiko.hostkeys")
+ fields = line.split(" ")
if len(fields) < 3:
# Bad number of fields
msg = "Not enough fields found in known_hosts in line {} ({!r})"
@@ -344,19 +348,19 @@ class HostKeyEntry:
fields = fields[:3]
names, keytype, key = fields
- names = names.split(',')
+ names = names.split(",")
# Decide what kind of key we're looking at and create an object
# to hold it accordingly.
try:
key = b(key)
- if keytype == 'ssh-rsa':
+ if keytype == "ssh-rsa":
key = RSAKey(data=decodebytes(key))
- elif keytype == 'ssh-dss':
+ elif keytype == "ssh-dss":
key = DSSKey(data=decodebytes(key))
elif keytype in ECDSAKey.supported_key_format_identifiers():
key = ECDSAKey(data=decodebytes(key), validate_point=False)
- elif keytype == 'ssh-ed25519':
+ elif keytype == "ssh-ed25519":
key = Ed25519Key(data=decodebytes(key))
else:
log.info("Unable to handle key of type {}".format(keytype))
@@ -374,12 +378,12 @@ class HostKeyEntry:
included.
"""
if self.valid:
- return '{} {} {}\n'.format(
- ','.join(self.hostnames),
+ return "{} {} {}\n".format(
+ ",".join(self.hostnames),
self.key.get_name(),
self.key.get_base64(),
)
return None
def __repr__(self):
- return '<HostKeyEntry {!r}: {!r}>'.format(self.hostnames, self.key)
+ return "<HostKeyEntry {!r}: {!r}>".format(self.hostnames, self.key)
diff --git a/paramiko/kex_ecdh_nist.py b/paramiko/kex_ecdh_nist.py
index 4e8ff35..ca32404 100644
--- a/paramiko/kex_ecdh_nist.py
+++ b/paramiko/kex_ecdh_nist.py
@@ -46,7 +46,7 @@ class KexNistp256():
elif not self.transport.server_mode and (ptype == _MSG_KEXECDH_REPLY):
return self._parse_kexecdh_reply(m)
raise SSHException(
- 'KexECDH asked to handle packet type {:d}'.format(ptype)
+ "KexECDH asked to handle packet type {:d}".format(ptype)
)
def _generate_key_pair(self):
@@ -66,8 +66,12 @@ class KexNistp256():
K = long(hexlify(K), 16)
# compute exchange hash
hm = Message()
- hm.add(self.transport.remote_version, self.transport.local_version,
- self.transport.remote_kex_init, self.transport.local_kex_init)
+ hm.add(
+ self.transport.remote_version,
+ self.transport.local_version,
+ self.transport.remote_kex_init,
+ self.transport.local_kex_init,
+ )
hm.add_string(K_S)
hm.add_string(Q_C_bytes)
# SEC1: V2.0 2.3.3 Elliptic-Curve-Point-to-Octet-String Conversion
@@ -96,8 +100,12 @@ class KexNistp256():
K = long(hexlify(K), 16)
# compute exchange hash and verify signature
hm = Message()
- hm.add(self.transport.local_version, self.transport.remote_version,
- self.transport.local_kex_init, self.transport.remote_kex_init)
+ hm.add(
+ self.transport.local_version,
+ self.transport.remote_version,
+ self.transport.local_kex_init,
+ self.transport.remote_kex_init,
+ )
hm.add_string(K_S)
# SEC1: V2.0 2.3.3 Elliptic-Curve-Point-to-Octet-String Conversion
hm.add_string(self.Q_C.public_numbers().encode_point())
diff --git a/paramiko/kex_gex.py b/paramiko/kex_gex.py
index 4403056..6d10912 100644
--- a/paramiko/kex_gex.py
+++ b/paramiko/kex_gex.py
@@ -32,17 +32,18 @@ from paramiko.py3compat import byte_chr, byte_ord, byte_mask
from paramiko.ssh_exception import SSHException
-_MSG_KEXDH_GEX_REQUEST_OLD, _MSG_KEXDH_GEX_GROUP, _MSG_KEXDH_GEX_INIT, \
- _MSG_KEXDH_GEX_REPLY, _MSG_KEXDH_GEX_REQUEST = range(30, 35)
+_MSG_KEXDH_GEX_REQUEST_OLD, _MSG_KEXDH_GEX_GROUP, _MSG_KEXDH_GEX_INIT, _MSG_KEXDH_GEX_REPLY, _MSG_KEXDH_GEX_REQUEST = range(
+ 30, 35
+)
-c_MSG_KEXDH_GEX_REQUEST_OLD, c_MSG_KEXDH_GEX_GROUP, c_MSG_KEXDH_GEX_INIT, \
- c_MSG_KEXDH_GEX_REPLY, c_MSG_KEXDH_GEX_REQUEST = \
- [byte_chr(c) for c in range(30, 35)]
+c_MSG_KEXDH_GEX_REQUEST_OLD, c_MSG_KEXDH_GEX_GROUP, c_MSG_KEXDH_GEX_INIT, c_MSG_KEXDH_GEX_REPLY, c_MSG_KEXDH_GEX_REQUEST = [
+ byte_chr(c) for c in range(30, 35)
+]
-class KexGex (object):
+class KexGex(object):
- name = 'diffie-hellman-group-exchange-sha1'
+ name = "diffie-hellman-group-exchange-sha1"
min_bits = 1024
max_bits = 8192
preferred_bits = 2048
@@ -61,7 +62,8 @@ class KexGex (object):
def start_kex(self, _test_old_style=False):
if self.transport.server_mode:
self.transport._expect_packet(
- _MSG_KEXDH_GEX_REQUEST, _MSG_KEXDH_GEX_REQUEST_OLD)
+ _MSG_KEXDH_GEX_REQUEST, _MSG_KEXDH_GEX_REQUEST_OLD
+ )
return
# request a bit range: we accept (min_bits) to (max_bits), but prefer
# (preferred_bits). according to the spec, we shouldn't pull the
@@ -137,13 +139,12 @@ class KexGex (object):
# generate prime
pack = self.transport._get_modulus_pack()
if pack is None:
- raise SSHException(
- 'Can\'t do server-side gex with no modulus pack')
+ raise SSHException("Can't do server-side gex with no modulus pack")
self.transport._log(
DEBUG,
- 'Picking p ({} <= {} <= {} bits)'.format(
- minbits, preferredbits, maxbits,
- )
+ "Picking p ({} <= {} <= {} bits)".format(
+ minbits, preferredbits, maxbits
+ ),
)
self.g, self.p = pack.get_modulus(minbits, preferredbits, maxbits)
m = Message()
@@ -165,13 +166,13 @@ class KexGex (object):
# generate prime
pack = self.transport._get_modulus_pack()
if pack is None:
- raise SSHException(
- 'Can\'t do server-side gex with no modulus pack')
+ raise SSHException("Can't do server-side gex with no modulus pack")
self.transport._log(
- DEBUG, 'Picking p (~ {} bits)'.format(self.preferred_bits)
+ DEBUG, "Picking p (~ {} bits)".format(self.preferred_bits)
)
self.g, self.p = pack.get_modulus(
- self.min_bits, self.preferred_bits, self.max_bits)
+ self.min_bits, self.preferred_bits, self.max_bits
+ )
m = Message()
m.add_byte(c_MSG_KEXDH_GEX_GROUP)
m.add_mpint(self.p)
@@ -187,9 +188,10 @@ class KexGex (object):
bitlen = util.bit_length(self.p)
if (bitlen < 1024) or (bitlen > 8192):
raise SSHException(
- 'Server-generated gex p (don\'t ask) is out of range '
- '({} bits)'.format(bitlen))
- self.transport._log(DEBUG, 'Got server p ({} bits)'.format(bitlen))
+ "Server-generated gex p (don't ask) is out of range "
+ "({} bits)".format(bitlen)
+ )
+ self.transport._log(DEBUG, "Got server p ({} bits)".format(bitlen))
self._generate_x()
# now compute e = g^x mod p
self.e = pow(self.g, self.x, self.p)
@@ -210,9 +212,13 @@ class KexGex (object):
# okay, build up the hash H of
# (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) # noqa
hm = Message()
- hm.add(self.transport.remote_version, self.transport.local_version,
- self.transport.remote_kex_init, self.transport.local_kex_init,
- key)
+ hm.add(
+ self.transport.remote_version,
+ self.transport.local_version,
+ self.transport.remote_kex_init,
+ self.transport.local_kex_init,
+ key,
+ )
if not self.old_style:
hm.add_int(self.min_bits)
hm.add_int(self.preferred_bits)
@@ -246,9 +252,13 @@ class KexGex (object):
# okay, build up the hash H of
# (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) # noqa
hm = Message()
- hm.add(self.transport.local_version, self.transport.remote_version,
- self.transport.local_kex_init, self.transport.remote_kex_init,
- host_key)
+ hm.add(
+ self.transport.local_version,
+ self.transport.remote_version,
+ self.transport.local_kex_init,
+ self.transport.remote_kex_init,
+ host_key,
+ )
if not self.old_style:
hm.add_int(self.min_bits)
hm.add_int(self.preferred_bits)
@@ -265,5 +275,5 @@ class KexGex (object):
class KexGexSHA256(KexGex):
- name = 'diffie-hellman-group-exchange-sha256'
+ name = "diffie-hellman-group-exchange-sha256"
hash_algo = sha256
diff --git a/paramiko/kex_group1.py b/paramiko/kex_group1.py
index 1bebd37..904835d 100644
--- a/paramiko/kex_group1.py
+++ b/paramiko/kex_group1.py
@@ -44,7 +44,7 @@ class KexGroup1(object):
P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF # noqa
G = 2
- name = 'diffie-hellman-group1-sha1'
+ name = "diffie-hellman-group1-sha1"
hash_algo = sha1
def __init__(self, transport):
@@ -88,8 +88,10 @@ class KexGroup1(object):
while 1:
x_bytes = os.urandom(128)
x_bytes = byte_mask(x_bytes[0], 0x7f) + x_bytes[1:]
- if (x_bytes[:8] != b7fffffffffffffff and
- x_bytes[:8] != b0000000000000000):
+ if (
+ x_bytes[:8] != b7fffffffffffffff
+ and x_bytes[:8] != b0000000000000000
+ ):
break
self.x = util.inflate_long(x_bytes)
@@ -104,8 +106,12 @@ class KexGroup1(object):
# okay, build up the hash H of
# (V_C || V_S || I_C || I_S || K_S || e || f || K)
hm = Message()
- hm.add(self.transport.local_version, self.transport.remote_version,
- self.transport.local_kex_init, self.transport.remote_kex_init)
+ hm.add(
+ self.transport.local_version,
+ self.transport.remote_version,
+ self.transport.local_kex_init,
+ self.transport.remote_kex_init,
+ )
hm.add_string(host_key)
hm.add_mpint(self.e)
hm.add_mpint(self.f)
@@ -124,8 +130,12 @@ class KexGroup1(object):
# okay, build up the hash H of
# (V_C || V_S || I_C || I_S || K_S || e || f || K)
hm = Message()
- hm.add(self.transport.remote_version, self.transport.local_version,
- self.transport.remote_kex_init, self.transport.local_kex_init)
+ hm.add(
+ self.transport.remote_version,
+ self.transport.local_version,
+ self.transport.remote_kex_init,
+ self.transport.local_kex_init,
+ )
hm.add_string(key)
hm.add_mpint(self.e)
hm.add_mpint(self.f)
diff --git a/paramiko/kex_group14.py b/paramiko/kex_group14.py
index 22955e3..0df302e 100644
--- a/paramiko/kex_group14.py
+++ b/paramiko/kex_group14.py
@@ -31,5 +31,5 @@ class KexGroup14(KexGroup1):
P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF # noqa
G = 2
- name = 'diffie-hellman-group14-sha1'
+ name = "diffie-hellman-group14-sha1"
hash_algo = sha1
diff --git a/paramiko/kex_gss.py b/paramiko/kex_gss.py
index e21620f..e9f6480 100644
--- a/paramiko/kex_gss.py
+++ b/paramiko/kex_gss.py
@@ -47,13 +47,13 @@ from paramiko.py3compat import byte_chr, byte_mask, byte_ord
from paramiko.ssh_exception import SSHException
-MSG_KEXGSS_INIT, MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_HOSTKEY,\
- MSG_KEXGSS_ERROR = range(30, 35)
+MSG_KEXGSS_INIT, MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_HOSTKEY, MSG_KEXGSS_ERROR = range(
+ 30, 35
+)
MSG_KEXGSS_GROUPREQ, MSG_KEXGSS_GROUP = range(40, 42)
-c_MSG_KEXGSS_INIT, c_MSG_KEXGSS_CONTINUE, c_MSG_KEXGSS_COMPLETE,\
- c_MSG_KEXGSS_HOSTKEY, c_MSG_KEXGSS_ERROR = [
- byte_chr(c) for c in range(30, 35)
- ]
+c_MSG_KEXGSS_INIT, c_MSG_KEXGSS_CONTINUE, c_MSG_KEXGSS_COMPLETE, c_MSG_KEXGSS_HOSTKEY, c_MSG_KEXGSS_ERROR = [
+ byte_chr(c) for c in range(30, 35)
+]
c_MSG_KEXGSS_GROUPREQ, c_MSG_KEXGSS_GROUP = [
byte_chr(c) for c in range(40, 42)
]
@@ -98,10 +98,12 @@ class KexGSSGroup1(object):
m.add_string(self.kexgss.ssh_init_sec_context(target=self.gss_host))
m.add_mpint(self.e)
self.transport._send_message(m)
- self.transport._expect_packet(MSG_KEXGSS_HOSTKEY,
- MSG_KEXGSS_CONTINUE,
- MSG_KEXGSS_COMPLETE,
- MSG_KEXGSS_ERROR)
+ self.transport._expect_packet(
+ MSG_KEXGSS_HOSTKEY,
+ MSG_KEXGSS_CONTINUE,
+ MSG_KEXGSS_COMPLETE,
+ MSG_KEXGSS_ERROR,
+ )
def parse_next(self, ptype, m):
"""
@@ -120,7 +122,7 @@ class KexGSSGroup1(object):
return self._parse_kexgss_complete(m)
elif ptype == MSG_KEXGSS_ERROR:
return self._parse_kexgss_error(m)
- msg = 'GSS KexGroup1 asked to handle packet type {:d}'
+ msg = "GSS KexGroup1 asked to handle packet type {:d}"
raise SSHException(msg.format(ptype))
# ## internals...
@@ -152,8 +154,7 @@ class KexGSSGroup1(object):
self.transport.host_key = host_key
sig = m.get_string()
self.transport._verify_key(host_key, sig)
- self.transport._expect_packet(MSG_KEXGSS_CONTINUE,
- MSG_KEXGSS_COMPLETE)
+ self.transport._expect_packet(MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE)
def _parse_kexgss_continue(self, m):
"""
@@ -166,13 +167,14 @@ class KexGSSGroup1(object):
srv_token = m.get_string()
m = Message()
m.add_byte(c_MSG_KEXGSS_CONTINUE)
- m.add_string(self.kexgss.ssh_init_sec_context(
- target=self.gss_host, recv_token=srv_token))
+ m.add_string(
+ self.kexgss.ssh_init_sec_context(
+ target=self.gss_host, recv_token=srv_token
+ )
+ )
self.transport.send_message(m)
self.transport._expect_packet(
- MSG_KEXGSS_CONTINUE,
- MSG_KEXGSS_COMPLETE,
- MSG_KEXGSS_ERROR
+ MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_ERROR
)
else:
pass
@@ -200,8 +202,12 @@ class KexGSSGroup1(object):
# okay, build up the hash H of
# (V_C || V_S || I_C || I_S || K_S || e || f || K)
hm = Message()
- hm.add(self.transport.local_version, self.transport.remote_version,
- self.transport.local_kex_init, self.transport.remote_kex_init)
+ hm.add(
+ self.transport.local_version,
+ self.transport.remote_version,
+ self.transport.local_kex_init,
+ self.transport.remote_kex_init,
+ )
hm.add_string(self.transport.host_key.__str__())
hm.add_mpint(self.e)
hm.add_mpint(self.f)
@@ -209,8 +215,9 @@ class KexGSSGroup1(object):
H = sha1(str(hm)).digest()
self.transport._set_K_H(K, H)
if srv_token is not None:
- self.kexgss.ssh_init_sec_context(target=self.gss_host,
- recv_token=srv_token)
+ self.kexgss.ssh_init_sec_context(
+ target=self.gss_host, recv_token=srv_token
+ )
self.kexgss.ssh_check_mic(mic_token, H)
else:
self.kexgss.ssh_check_mic(mic_token, H)
@@ -234,20 +241,26 @@ class KexGSSGroup1(object):
# okay, build up the hash H of
# (V_C || V_S || I_C || I_S || K_S || e || f || K)
hm = Message()
- hm.add(self.transport.remote_version, self.transport.local_version,
- self.transport.remote_kex_init, self.transport.local_kex_init)
+ hm.add(
+ self.transport.remote_version,
+ self.transport.local_version,
+ self.transport.remote_kex_init,
+ self.transport.local_kex_init,
+ )
hm.add_string(key)
hm.add_mpint(self.e)
hm.add_mpint(self.f)
hm.add_mpint(K)
H = sha1(hm.asbytes()).digest()
self.transport._set_K_H(K, H)
- srv_token = self.kexgss.ssh_accept_sec_context(self.gss_host,
- client_token)
+ srv_token = self.kexgss.ssh_accept_sec_context(
+ self.gss_host, client_token
+ )
m = Message()
if self.kexgss._gss_srv_ctxt_status:
- mic_token = self.kexgss.ssh_get_mic(self.transport.session_id,
- gss_kex=True)
+ mic_token = self.kexgss.ssh_get_mic(
+ self.transport.session_id, gss_kex=True
+ )
m.add_byte(c_MSG_KEXGSS_COMPLETE)
m.add_mpint(self.f)
m.add_string(mic_token)
@@ -263,9 +276,9 @@ class KexGSSGroup1(object):
m.add_byte(c_MSG_KEXGSS_CONTINUE)
m.add_string(srv_token)
self.transport._send_message(m)
- self.transport._expect_packet(MSG_KEXGSS_CONTINUE,
- MSG_KEXGSS_COMPLETE,
- MSG_KEXGSS_ERROR)
+ self.transport._expect_packet(
+ MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_ERROR
+ )
def _parse_kexgss_error(self, m):
"""
@@ -281,12 +294,16 @@ class KexGSSGroup1(object):
maj_status = m.get_int()
min_status = m.get_int()
err_msg = m.get_string()
- m.get_string() # we don't care about the language!
- raise SSHException("""GSS-API Error:
+ m.get_string() # we don't care about the language!
+ raise SSHException(
+ """GSS-API Error:
Major Status: {}
Minor Status: {}
Error Message: {}
-""".format(maj_status, min_status, err_msg))
+""".format(
+ maj_status, min_status, err_msg
+ )
+ )
class KexGSSGroup14(KexGSSGroup1):
@@ -362,7 +379,7 @@ class KexGSSGex(object):
return self._parse_kexgss_complete(m)
elif ptype == MSG_KEXGSS_ERROR:
return self._parse_kexgss_error(m)
- msg = 'KexGex asked to handle packet type {:d}'
+ msg = "KexGex asked to handle packet type {:d}"
raise SSHException(msg.format(ptype))
# ## internals...
@@ -414,13 +431,12 @@ class KexGSSGex(object):
# generate prime
pack = self.transport._get_modulus_pack()
if pack is None:
- raise SSHException(
- 'Can\'t do server-side gex with no modulus pack')
+ raise SSHException("Can't do server-side gex with no modulus pack")
self.transport._log(
DEBUG, # noqa
- 'Picking p ({} <= {} <= {} bits)'.format(
- minbits, preferredbits, maxbits,
- )
+ "Picking p ({} <= {} <= {} bits)".format(
+ minbits, preferredbits, maxbits
+ ),
)
self.g, self.p = pack.get_modulus(minbits, preferredbits, maxbits)
m = Message()
@@ -442,9 +458,12 @@ class KexGSSGex(object):
bitlen = util.bit_length(self.p)
if (bitlen < 1024) or (bitlen > 8192):
raise SSHException(
- 'Server-generated gex p (don\'t ask) is out of range '
- '({} bits)'.format(bitlen))
- self.transport._log(DEBUG, 'Got server p ({} bits)'.format(bitlen)) # noqa
+ "Server-generated gex p (don't ask) is out of range "
+ "({} bits)".format(bitlen)
+ )
+ self.transport._log(
+ DEBUG, "Got server p ({} bits)".format(bitlen)
+ ) # noqa
self._generate_x()
# now compute e = g^x mod p
self.e = pow(self.g, self.x, self.p)
@@ -453,10 +472,12 @@ class KexGSSGex(object):
m.add_string(self.kexgss.ssh_init_sec_context(target=self.gss_host))
m.add_mpint(self.e)
self.transport._send_message(m)
- self.transport._expect_packet(MSG_KEXGSS_HOSTKEY,
- MSG_KEXGSS_CONTINUE,
- MSG_KEXGSS_COMPLETE,
- MSG_KEXGSS_ERROR)
+ self.transport._expect_packet(
+ MSG_KEXGSS_HOSTKEY,
+ MSG_KEXGSS_CONTINUE,
+ MSG_KEXGSS_COMPLETE,
+ MSG_KEXGSS_ERROR,
+ )
def _parse_kexgss_gex_init(self, m):
"""
@@ -476,9 +497,13 @@ class KexGSSGex(object):
# okay, build up the hash H of
# (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) # noqa
hm = Message()
- hm.add(self.transport.remote_version, self.transport.local_version,
- self.transport.remote_kex_init, self.transport.local_kex_init,
- key)
+ hm.add(
+ self.transport.remote_version,
+ self.transport.local_version,
+ self.transport.remote_kex_init,
+ self.transport.local_kex_init,
+ key,
+ )
hm.add_int(self.min_bits)
hm.add_int(self.preferred_bits)
hm.add_int(self.max_bits)
@@ -489,12 +514,14 @@ class KexGSSGex(object):
hm.add_mpint(K)
H = sha1(hm.asbytes()).digest()
self.transport._set_K_H(K, H)
- srv_token = self.kexgss.ssh_accept_sec_context(self.gss_host,
- client_token)
+ srv_token = self.kexgss.ssh_accept_sec_context(
+ self.gss_host, client_token
+ )
m = Message()
if self.kexgss._gss_srv_ctxt_status:
- mic_token = self.kexgss.ssh_get_mic(self.transport.session_id,
- gss_kex=True)
+ mic_token = self.kexgss.ssh_get_mic(
+ self.transport.session_id, gss_kex=True
+ )
m.add_byte(c_MSG_KEXGSS_COMPLETE)
m.add_mpint(self.f)
m.add_string(mic_token)
@@ -510,9 +537,9 @@ class KexGSSGex(object):
m.add_byte(c_MSG_KEXGSS_CONTINUE)
m.add_string(srv_token)
self.transport._send_message(m)
- self.transport._expect_packet(MSG_KEXGSS_CONTINUE,
- MSG_KEXGSS_COMPLETE,
- MSG_KEXGSS_ERROR)
+ self.transport._expect_packet(
+ MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_ERROR
+ )
def _parse_kexgss_hostkey(self, m):
"""
@@ -525,8 +552,7 @@ class KexGSSGex(object):
self.transport.host_key = host_key
sig = m.get_string()
self.transport._verify_key(host_key, sig)
- self.transport._expect_packet(MSG_KEXGSS_CONTINUE,
- MSG_KEXGSS_COMPLETE)
+ self.transport._expect_packet(MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE)
def _parse_kexgss_continue(self, m):
"""
@@ -538,12 +564,15 @@ class KexGSSGex(object):
srv_token = m.get_string()
m = Message()
m.add_byte(c_MSG_KEXGSS_CONTINUE)
- m.add_string(self.kexgss.ssh_init_sec_context(target=self.gss_host,
- recv_token=srv_token))
+ m.add_string(
+ self.kexgss.ssh_init_sec_context(
+ target=self.gss_host, recv_token=srv_token
+ )
+ )
self.transport.send_message(m)
- self.transport._expect_packet(MSG_KEXGSS_CONTINUE,
- MSG_KEXGSS_COMPLETE,
- MSG_KEXGSS_ERROR)
+ self.transport._expect_packet(
+ MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_ERROR
+ )
else:
pass
@@ -568,9 +597,13 @@ class KexGSSGex(object):
# okay, build up the hash H of
# (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) # noqa
hm = Message()
- hm.add(self.transport.local_version, self.transport.remote_version,
- self.transport.local_kex_init, self.transport.remote_kex_init,
- self.transport.host_key.__str__())
+ hm.add(
+ self.transport.local_version,
+ self.transport.remote_version,
+ self.transport.local_kex_init,
+ self.transport.remote_kex_init,
+ self.transport.host_key.__str__(),
+ )
if not self.old_style:
hm.add_int(self.min_bits)
hm.add_int(self.preferred_bits)
@@ -584,8 +617,9 @@ class KexGSSGex(object):
H = sha1(hm.asbytes()).digest()
self.transport._set_K_H(K, H)
if srv_token is not None:
- self.kexgss.ssh_init_sec_context(target=self.gss_host,
- recv_token=srv_token)
+ self.kexgss.ssh_init_sec_context(
+ target=self.gss_host, recv_token=srv_token
+ )
self.kexgss.ssh_check_mic(mic_token, H)
else:
self.kexgss.ssh_check_mic(mic_token, H)
@@ -606,12 +640,16 @@ class KexGSSGex(object):
maj_status = m.get_int()
min_status = m.get_int()
err_msg = m.get_string()
- m.get_string() # we don't care about the language (lang_tag)!
- raise SSHException("""GSS-API Error:
+ m.get_string() # we don't care about the language (lang_tag)!
+ raise SSHException(
+ """GSS-API Error:
Major Status: {}
Minor Status: {}
Error Message: {}
-""".format(maj_status, min_status, err_msg))
+""".format(
+ maj_status, min_status, err_msg
+ )
+ )
class NullHostKey(object):
@@ -620,6 +658,7 @@ class NullHostKey(object):
in `RFC 4462 Section 5
<https://tools.ietf.org/html/rfc4462.html#section-5>`_
"""
+
def __init__(self):
self.key = ""
diff --git a/paramiko/message.py b/paramiko/message.py
index 9af841d..dead350 100644
--- a/paramiko/message.py
+++ b/paramiko/message.py
@@ -27,7 +27,7 @@ from paramiko.common import zero_byte, max_byte, one_byte, asbytes
from paramiko.py3compat import long, BytesIO, u, integer_types
-class Message (object):
+class Message(object):
"""
An SSH2 message is a stream of bytes that encodes some combination of
strings, integers, bools, and infinite-precision integers (known in Python
@@ -63,7 +63,7 @@ class Message (object):
"""
Returns a string representation of this object, for debugging.
"""
- return 'paramiko.Message(' + repr(self.packet.getvalue()) + ')'
+ return "paramiko.Message(" + repr(self.packet.getvalue()) + ")"
def asbytes(self):
"""
@@ -139,13 +139,13 @@ class Message (object):
if byte == max_byte:
return util.inflate_long(self.get_binary())
byte += self.get_bytes(3)
- return struct.unpack('>I', byte)[0]
+ return struct.unpack(">I", byte)[0]
def get_int(self):
"""
Fetch an int from the stream.
"""
- return struct.unpack('>I', self.get_bytes(4))[0]
+ return struct.unpack(">I", self.get_bytes(4))[0]
def get_int64(self):
"""
@@ -153,7 +153,7 @@ class Message (object):
:return: a 64-bit unsigned integer (`long`).
"""
- return struct.unpack('>Q', self.get_bytes(8))[0]
+ return struct.unpack(">Q", self.get_bytes(8))[0]
def get_mpint(self):
"""
@@ -191,7 +191,7 @@ class Message (object):
These are trivially encoded as comma-separated values in a string.
"""
- return self.get_text().split(',')
+ return self.get_text().split(",")
def add_bytes(self, b):
"""
@@ -229,7 +229,7 @@ class Message (object):
:param int n: integer to add
"""
- self.packet.write(struct.pack('>I', n))
+ self.packet.write(struct.pack(">I", n))
return self
def add_adaptive_int(self, n):
@@ -242,7 +242,7 @@ class Message (object):
self.packet.write(max_byte)
self.add_string(util.deflate_long(n))
else:
- self.packet.write(struct.pack('>I', n))
+ self.packet.write(struct.pack(">I", n))
return self
def add_int64(self, n):
@@ -251,7 +251,7 @@ class Message (object):
:param long n: long int to add
"""
- self.packet.write(struct.pack('>Q', n))
+ self.packet.write(struct.pack(">Q", n))
return self
def add_mpint(self, z):
@@ -283,7 +283,7 @@ class Message (object):
:param l: list of strings to add
"""
- self.add_string(','.join(l))
+ self.add_string(",".join(l))
return self
def _add(self, i):
diff --git a/paramiko/packet.py b/paramiko/packet.py
index 2a1e91e..20b37ac 100644
--- a/paramiko/packet.py
+++ b/paramiko/packet.py
@@ -30,7 +30,12 @@ from hmac import HMAC
from paramiko import util
from paramiko.common import (
- linefeed_byte, cr_byte_value, asbytes, MSG_NAMES, DEBUG, xffffffff,
+ linefeed_byte,
+ cr_byte_value,
+ asbytes,
+ MSG_NAMES,
+ DEBUG,
+ xffffffff,
zero_byte,
)
from paramiko.py3compat import u, byte_ord
@@ -42,7 +47,7 @@ def compute_hmac(key, message, digest_class):
return HMAC(key, message, digest_class).digest()
-class NeedRekeyException (Exception):
+class NeedRekeyException(Exception):
"""
Exception indicating a rekey is needed.
"""
@@ -56,7 +61,7 @@ def first_arg(e):
return arg
-class Packetizer (object):
+class Packetizer(object):
"""
Implementation of the base SSH packet protocol.
"""
@@ -128,8 +133,15 @@ class Packetizer (object):
"""
self.__logger = log
- def set_outbound_cipher(self, block_engine, block_size, mac_engine,
- mac_size, mac_key, sdctr=False):
+ def set_outbound_cipher(
+ self,
+ block_engine,
+ block_size,
+ mac_engine,
+ mac_size,
+ mac_key,
+ sdctr=False,
+ ):
"""
Switch outbound data cipher.
"""
@@ -149,7 +161,8 @@ class Packetizer (object):
self.__need_rekey = False
def set_inbound_cipher(
- self, block_engine, block_size, mac_engine, mac_size, mac_key):
+ self, block_engine, block_size, mac_engine, mac_size, mac_key
+ ):
"""
Switch inbound data cipher.
"""
@@ -368,7 +381,7 @@ class Packetizer (object):
if cmd in MSG_NAMES:
cmd_name = MSG_NAMES[cmd]
else:
- cmd_name = '${:x}'.format(cmd)
+ cmd_name = "${:x}".format(cmd)
orig_len = len(data)
self.__write_lock.acquire()
try:
@@ -378,9 +391,9 @@ class Packetizer (object):
if self.__dump_packets:
self._log(
DEBUG,
- 'Write packet <{}>, length {}'.format(cmd_name, orig_len)
+ "Write packet <{}>, length {}".format(cmd_name, orig_len),
)
- self._log(DEBUG, util.format_binary(packet, 'OUT: '))
+ self._log(DEBUG, util.format_binary(packet, "OUT: "))
if self.__block_engine_out is not None:
out = self.__block_engine_out.update(packet)
else:
@@ -388,27 +401,30 @@ class Packetizer (object):
# + mac
if self.__block_engine_out is not None:
payload = struct.pack(
- '>I', self.__sequence_number_out) + packet
+ ">I", self.__sequence_number_out
+ ) + packet
out += compute_hmac(
- self.__mac_key_out,
- payload,
- self.__mac_engine_out)[:self.__mac_size_out]
- self.__sequence_number_out = \
- (self.__sequence_number_out + 1) & xffffffff
+ self.__mac_key_out, payload, self.__mac_engine_out
+ )[
+ :self.__mac_size_out
+ ]
+ self.__sequence_number_out = (
+ self.__sequence_number_out + 1
+ ) & xffffffff
self.write_all(out)
self.__sent_bytes += len(out)
self.__sent_packets += 1
sent_too_much = (
- self.__sent_packets >= self.REKEY_PACKETS or
- self.__sent_bytes >= self.REKEY_BYTES
+ self.__sent_packets >= self.REKEY_PACKETS
+ or self.__sent_bytes >= self.REKEY_BYTES
)
if sent_too_much and not self.__need_rekey:
# only ask once for rekeying
msg = "Rekeying (hit {} packets, {} bytes sent)"
- self._log(DEBUG, msg.format(
- self.__sent_packets, self.__sent_bytes,
- ))
+ self._log(
+ DEBUG, msg.format(self.__sent_packets, self.__sent_bytes)
+ )
self.__received_bytes_overflow = 0
self.__received_packets_overflow = 0
self._trigger_rekey()
@@ -427,41 +443,43 @@ class Packetizer (object):
if self.__block_engine_in is not None:
header = self.__block_engine_in.update(header)
if self.__dump_packets:
- self._log(DEBUG, util.format_binary(header, 'IN: '))
- packet_size = struct.unpack('>I', header[:4])[0]
+ self._log(DEBUG, util.format_binary(header, "IN: "))
+ packet_size = struct.unpack(">I", header[:4])[0]
# leftover contains decrypted bytes from the first block (after the
# length field)
leftover = header[4:]
if (packet_size - len(leftover)) % self.__block_size_in != 0:
- raise SSHException('Invalid packet blocking')
+ raise SSHException("Invalid packet blocking")
buf = self.read_all(packet_size + self.__mac_size_in - len(leftover))
packet = buf[:packet_size - len(leftover)]
post_packet = buf[packet_size - len(leftover):]
if self.__block_engine_in is not None:
packet = self.__block_engine_in.update(packet)
if self.__dump_packets:
- self._log(DEBUG, util.format_binary(packet, 'IN: '))
+ self._log(DEBUG, util.format_binary(packet, "IN: "))
packet = leftover + packet
if self.__mac_size_in > 0:
mac = post_packet[:self.__mac_size_in]
mac_payload = struct.pack(
- '>II', self.__sequence_number_in, packet_size) + packet
+ ">II", self.__sequence_number_in, packet_size
+ ) + packet
my_mac = compute_hmac(
- self.__mac_key_in,
- mac_payload,
- self.__mac_engine_in)[:self.__mac_size_in]
+ self.__mac_key_in, mac_payload, self.__mac_engine_in
+ )[
+ :self.__mac_size_in
+ ]
if not util.constant_time_bytes_eq(my_mac, mac):
- raise SSHException('Mismatched MAC')
+ raise SSHException("Mismatched MAC")
padding = byte_ord(packet[0])
payload = packet[1:packet_size - padding]
if self.__dump_packets:
self._log(
DEBUG,
- 'Got payload ({} bytes, {} padding)'.format(
+ "Got payload ({} bytes, {} padding)".format(
packet_size, padding
- )
+ ),
)
if self.__compress_engine_in is not None:
@@ -480,19 +498,28 @@ class Packetizer (object):
# dropping the connection
self.__received_bytes_overflow += raw_packet_size
self.__received_packets_overflow += 1
- if (self.__received_packets_overflow >=
- self.REKEY_PACKETS_OVERFLOW_MAX) or \
- (self.__received_bytes_overflow >=
- self.REKEY_BYTES_OVERFLOW_MAX):
+ if (
+ (
+ self.__received_packets_overflow
+ >= self.REKEY_PACKETS_OVERFLOW_MAX
+ )
+ or (
+ self.__received_bytes_overflow
+ >= self.REKEY_BYTES_OVERFLOW_MAX
+ )
+ ):
raise SSHException(
- 'Remote transport is ignoring rekey requests')
- elif (self.__received_packets >= self.REKEY_PACKETS) or \
- (self.__received_bytes >= self.REKEY_BYTES):
+ "Remote transport is ignoring rekey requests"
+ )
+ elif (self.__received_packets >= self.REKEY_PACKETS) or (
+ self.__received_bytes >= self.REKEY_BYTES
+ ):
# only ask once for rekeying
err = "Rekeying (hit {} packets, {} bytes received)"
- self._log(DEBUG, err.format(
- self.__received_packets, self.__received_bytes,
- ))
+ self._log(
+ DEBUG,
+ err.format(self.__received_packets, self.__received_bytes),
+ )
self.__received_bytes_overflow = 0
self.__received_packets_overflow = 0
self._trigger_rekey()
@@ -501,11 +528,11 @@ class Packetizer (object):
if cmd in MSG_NAMES:
cmd_name = MSG_NAMES[cmd]
else:
- cmd_name = '${:x}'.format(cmd)
+ cmd_name = "${:x}".format(cmd)
if self.__dump_packets:
self._log(
DEBUG,
- 'Read packet <{}>, length {}'.format(cmd_name, len(payload))
+ "Read packet <{}>, length {}".format(cmd_name, len(payload)),
)
return cmd, msg
@@ -522,9 +549,9 @@ class Packetizer (object):
def _check_keepalive(self):
if (
- not self.__keepalive_interval or
- not self.__block_engine_out or
- self.__need_rekey
+ not self.__keepalive_interval
+ or not self.__block_engine_out
+ or self.__need_rekey
):
# wait till we're encrypting, and not in the middle of rekeying
return
@@ -559,7 +586,7 @@ class Packetizer (object):
# pad up at least 4 bytes, to nearest block-size (usually 8)
bsize = self.__block_size_out
padding = 3 + bsize - ((len(payload) + 8) % bsize)
- packet = struct.pack('>IB', len(payload) + padding + 1, padding)
+ packet = struct.pack(">IB", len(payload) + padding + 1, padding)
packet += payload
if self.__sdctr_out or self.__block_engine_out is None:
# cute trick i caught openssh doing: if we're not encrypting or
diff --git a/paramiko/pipe.py b/paramiko/pipe.py
index 6ca3770..e88a5e4 100644
--- a/paramiko/pipe.py
+++ b/paramiko/pipe.py
@@ -31,14 +31,15 @@ import socket
def make_pipe():
- if sys.platform[:3] != 'win':
+ if sys.platform[:3] != "win":
p = PosixPipe()
else:
p = WindowsPipe()
return p
-class PosixPipe (object):
+class PosixPipe(object):
+
def __init__(self):
self._rfd, self._wfd = os.pipe()
self._set = False
@@ -64,26 +65,27 @@ class PosixPipe (object):
if self._set or self._closed:
return
self._set = True
- os.write(self._wfd, b'*')
+ os.write(self._wfd, b"*")
def set_forever(self):
self._forever = True
self.set()
-class WindowsPipe (object):
+class WindowsPipe(object):
"""
On Windows, only an OS-level "WinSock" may be used in select(), but reads
and writes must be to the actual socket object.
"""
+
def __init__(self):
serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- serv.bind(('127.0.0.1', 0))
+ serv.bind(("127.0.0.1", 0))
serv.listen(1)
# need to save sockets in _rsock/_wsock so they don't get closed
self._rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self._rsock.connect(('127.0.0.1', serv.getsockname()[1]))
+ self._rsock.connect(("127.0.0.1", serv.getsockname()[1]))
self._wsock, addr = serv.accept()
serv.close()
@@ -110,14 +112,15 @@ class WindowsPipe (object):
if self._set or self._closed:
return
self._set = True
- self._wsock.send(b'*')
+ self._wsock.send(b"*")
def set_forever(self):
self._forever = True
self.set()
-class OrPipe (object):
+class OrPipe(object):
+
def __init__(self, pipe):
self._set = False
self._partner = None
diff --git a/paramiko/pkey.py b/paramiko/pkey.py
index 808215f..a01d4fd 100644
--- a/paramiko/pkey.py
+++ b/paramiko/pkey.py
@@ -43,23 +43,23 @@ class PKey(object):
# known encryption types for private key files:
_CIPHER_TABLE = {
- 'AES-128-CBC': {
- 'cipher': algorithms.AES,
- 'keysize': 16,
- 'blocksize': 16,
- 'mode': modes.CBC
+ "AES-128-CBC": {
+ "cipher": algorithms.AES,
+ "keysize": 16,
+ "blocksize": 16,
+ "mode": modes.CBC,
},
- 'AES-256-CBC': {
- 'cipher': algorithms.AES,
- 'keysize': 32,
- 'blocksize': 16,
- 'mode': modes.CBC
+ "AES-256-CBC": {
+ "cipher": algorithms.AES,
+ "keysize": 32,
+ "blocksize": 16,
+ "mode": modes.CBC,
},
- 'DES-EDE3-CBC': {
- 'cipher': algorithms.TripleDES,
- 'keysize': 24,
- 'blocksize': 8,
- 'mode': modes.CBC
+ "DES-EDE3-CBC": {
+ "cipher": algorithms.TripleDES,
+ "keysize": 24,
+ "blocksize": 8,
+ "mode": modes.CBC,
},
}
@@ -107,7 +107,7 @@ class PKey(object):
hs = hash(self)
ho = hash(other)
if hs != ho:
- return cmp(hs, ho) # noqa
+ return cmp(hs, ho) # noqa
return cmp(self.asbytes(), other.asbytes()) # noqa
def __eq__(self, other):
@@ -121,7 +121,7 @@ class PKey(object):
name of this private key type, in SSH terminology, as a `str` (for
example, ``"ssh-rsa"``).
"""
- return ''
+ return ""
def get_bits(self):
"""
@@ -158,7 +158,7 @@ class PKey(object):
:return: a base64 `string <str>` containing the public part of the key.
"""
- return u(encodebytes(self.asbytes())).replace('\n', '')
+ return u(encodebytes(self.asbytes())).replace("\n", "")
def sign_ssh_data(self, data):
"""
@@ -239,7 +239,7 @@ class PKey(object):
:raises: ``IOError`` -- if there was an error writing the file
:raises: `.SSHException` -- if the key is invalid
"""
- raise Exception('Not implemented in PKey')
+ raise Exception("Not implemented in PKey")
def write_private_key(self, file_obj, password=None):
"""
@@ -252,7 +252,7 @@ class PKey(object):
:raises: ``IOError`` -- if there was an error writing to the file
:raises: `.SSHException` -- if the key is invalid
"""
- raise Exception('Not implemented in PKey')
+ raise Exception("Not implemented in PKey")
def _read_private_key_file(self, tag, filename, password=None):
"""
@@ -275,60 +275,61 @@ class PKey(object):
encrypted, and ``password`` is ``None``.
:raises: `.SSHException` -- if the key file is invalid.
"""
- with open(filename, 'r') as f:
+ with open(filename, "r") as f:
data = self._read_private_key(tag, f, password)
return data
def _read_private_key(self, tag, f, password=None):
lines = f.readlines()
start = 0
- beginning_of_key = '-----BEGIN ' + tag + ' PRIVATE KEY-----'
+ beginning_of_key = "-----BEGIN " + tag + " PRIVATE KEY-----"
while start < len(lines) and lines[start].strip() != beginning_of_key:
start += 1
if start >= len(lines):
- raise SSHException('not a valid ' + tag + ' private key file')
+ raise SSHException("not a valid " + tag + " private key file")
# parse any headers first
headers = {}
start += 1
while start < len(lines):
- l = lines[start].split(': ')
+ l = lines[start].split(": ")
if len(l) == 1:
break
headers[l[0].lower()] = l[1].strip()
start += 1
# find end
end = start
- ending_of_key = '-----END ' + tag + ' PRIVATE KEY-----'
+ ending_of_key = "-----END " + tag + " PRIVATE KEY-----"
while end < len(lines) and lines[end].strip() != ending_of_key:
end += 1
# if we trudged to the end of the file, just try to cope.
try:
- data = decodebytes(b(''.join(lines[start:end])))
+ data = decodebytes(b("".join(lines[start:end])))
except base64.binascii.Error as e:
- raise SSHException('base64 decoding error: ' + str(e))
- if 'proc-type' not in headers:
+ raise SSHException("base64 decoding error: " + str(e))
+ if "proc-type" not in headers:
# unencryped: done
return data
# encrypted keyfile: will need a password
- proc_type = headers['proc-type']
- if proc_type != '4,ENCRYPTED':
+ proc_type = headers["proc-type"]
+ if proc_type != "4,ENCRYPTED":
raise SSHException(
'Unknown private key structure "{}"'.format(proc_type)
)
try:
- encryption_type, saltstr = headers['dek-info'].split(',')
+ encryption_type, saltstr = headers["dek-info"].split(",")
except:
raise SSHException("Can't parse DEK-info in private key file")
if encryption_type not in self._CIPHER_TABLE:
raise SSHException(
- 'Unknown private key cipher "{}"'.format(encryption_type))
+ 'Unknown private key cipher "{}"'.format(encryption_type)
+ )
# if no password was passed in,
# raise an exception pointing out that we need one
if password is None:
- raise PasswordRequiredException('Private key file is encrypted')
- cipher = self._CIPHER_TABLE[encryption_type]['cipher']
- keysize = self._CIPHER_TABLE[encryption_type]['keysize']
- mode = self._CIPHER_TABLE[encryption_type]['mode']
+ raise PasswordRequiredException("Private key file is encrypted")
+ cipher = self._CIPHER_TABLE[encryption_type]["cipher"]
+ keysize = self._CIPHER_TABLE[encryption_type]["keysize"]
+ mode = self._CIPHER_TABLE[encryption_type]["mode"]
salt = unhexlify(b(saltstr))
key = util.generate_key_bytes(md5, salt, password, keysize)
decryptor = Cipher(
@@ -351,7 +352,7 @@ class PKey(object):
:raises: ``IOError`` -- if there was an error writing the file.
"""
- with open(filename, 'w') as f:
+ with open(filename, "w") as f:
os.chmod(filename, o600)
self._write_private_key(f, key, format, password=password)
@@ -361,11 +362,11 @@ class PKey(object):
else:
encryption = serialization.BestAvailableEncryption(b(password))
- f.write(key.private_bytes(
- serialization.Encoding.PEM,
- format,
- encryption
- ).decode())
+ f.write(
+ key.private_bytes(
+ serialization.Encoding.PEM, format, encryption
+ ).decode()
+ )
def _check_type_and_load_cert(self, msg, key_type, cert_type):
"""
@@ -388,7 +389,7 @@ class PKey(object):
cert_types = [cert_types]
# Can't do much with no message, that should've been handled elsewhere
if msg is None:
- raise SSHException('Key object may not be empty')
+ raise SSHException("Key object may not be empty")
# First field is always key type, in either kind of object. (make sure
# we rewind before grabbing it - sometimes caller had to do their own
# introspection first!)
@@ -411,7 +412,7 @@ class PKey(object):
# (requires going back into per-type subclasses.)
msg.get_string()
else:
- err = 'Invalid key (class: {}, data type: {}'
+ err = "Invalid key (class: {}, data type: {}"
raise SSHException(err.format(self.__class__.__name__, type_))
def load_certificate(self, value):
@@ -434,11 +435,11 @@ class PKey(object):
successfully.
"""
if isinstance(value, Message):
- constructor = 'from_message'
+ constructor = "from_message"
elif os.path.isfile(value):
- constructor = 'from_file'
+ constructor = "from_file"
else:
- constructor = 'from_string'
+ constructor = "from_string"
blob = getattr(PublicBlob, constructor)(value)
if not blob.key_type.startswith(self.get_name()):
err = "PublicBlob type {} incompatible with key type {}"
@@ -464,6 +465,7 @@ class PublicBlob(object):
`from_message` for useful instantiation, the main constructor is
basically "I should be using ``attrs`` for this."
"""
+
def __init__(self, type_, blob, comment=None):
"""
Create a new public blob of given type and contents.
@@ -505,7 +507,7 @@ class PublicBlob(object):
m = Message(key_blob)
blob_type = m.get_text()
if blob_type != key_type:
- msg = "Invalid PublicBlob contents: key type={!r}, but blob type={!r}" # noqa
+ msg = "Invalid PublicBlob contents: key type={!r}, but blob type={!r}" # noqa
raise ValueError(msg.format(key_type, blob_type))
# All good? All good.
return cls(type_=key_type, blob=key_blob, comment=comment)
@@ -522,7 +524,7 @@ class PublicBlob(object):
return cls(type_=type_, blob=message.asbytes())
def __str__(self):
- ret = '{} public key/certificate'.format(self.key_type)
+ ret = "{} public key/certificate".format(self.key_type)
if self.comment:
ret += "- {}".format(self.comment)
return ret
diff --git a/paramiko/primes.py b/paramiko/primes.py
index ca8f9be..dba1c7e 100644
--- a/paramiko/primes.py
+++ b/paramiko/primes.py
@@ -49,7 +49,7 @@ def _roll_random(n):
return num
-class ModulusPack (object):
+class ModulusPack(object):
"""
convenience object for holding the contents of the /etc/ssh/moduli file,
on systems that have such a file.
@@ -61,8 +61,7 @@ class ModulusPack (object):
self.discarded = []
def _parse_modulus(self, line):
- timestamp, mod_type, tests, tries, size, generator, modulus = \
- line.split()
+ timestamp, mod_type, tests, tries, size, generator, modulus = line.split()
mod_type = int(mod_type)
tests = int(tests)
tries = int(tries)
@@ -75,12 +74,13 @@ class ModulusPack (object):
# test 4 (more than just a small-prime sieve)
# tries < 100 if test & 4 (at least 100 tries of miller-rabin)
if (
- mod_type < 2 or
- tests < 4 or
- (tests & 4 and tests < 8 and tries < 100)
+ mod_type < 2
+ or tests < 4
+ or (tests & 4 and tests < 8 and tries < 100)
):
self.discarded.append(
- (modulus, 'does not meet basic requirements'))
+ (modulus, "does not meet basic requirements")
+ )
return
if generator == 0:
generator = 2
@@ -91,7 +91,8 @@ class ModulusPack (object):
bl = util.bit_length(modulus)
if (bl != size) and (bl != size + 1):
self.discarded.append(
- (modulus, 'incorrectly reported bit length {}'.format(size)))
+ (modulus, "incorrectly reported bit length {}".format(size))
+ )
return
if bl not in self.pack:
self.pack[bl] = []
@@ -102,10 +103,10 @@ class ModulusPack (object):
:raises IOError: passed from any file operations that fail.
"""
self.pack = {}
- with open(filename, 'r') as f:
+ with open(filename, "r") as f:
for line in f:
line = line.strip()
- if (len(line) == 0) or (line[0] == '#'):
+ if (len(line) == 0) or (line[0] == "#"):
continue
try:
self._parse_modulus(line)
@@ -115,7 +116,7 @@ class ModulusPack (object):
def get_modulus(self, min, prefer, max):
bitsizes = sorted(self.pack.keys())
if len(bitsizes) == 0:
- raise SSHException('no moduli available')
+ raise SSHException("no moduli available")
good = -1
# find nearest bitsize >= preferred
for b in bitsizes:
diff --git a/paramiko/proxy.py b/paramiko/proxy.py
index c4ec627..d0ef878 100644
--- a/paramiko/proxy.py
+++ b/paramiko/proxy.py
@@ -39,6 +39,7 @@ class ProxyCommand(ClosingContextManager):
Instances of this class may be used as context managers.
"""
+
def __init__(self, command_line):
"""
Create a new CommandProxy instance. The instance created by this
@@ -50,9 +51,11 @@ class ProxyCommand(ClosingContextManager):
# NOTE: subprocess import done lazily so platforms without it (e.g.
# GAE) can still import us during overall Paramiko load.
from subprocess import Popen, PIPE
+
self.cmd = shlsplit(command_line)
- self.process = Popen(self.cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE,
- bufsize=0)
+ self.process = Popen(
+ self.cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, bufsize=0
+ )
self.timeout = None
def send(self, content):
@@ -69,7 +72,7 @@ class ProxyCommand(ClosingContextManager):
# died and we can't proceed. The best option here is to
# raise an exception informing the user that the informed
# ProxyCommand is not working.
- raise ProxyCommandFailure(' '.join(self.cmd), e.strerror)
+ raise ProxyCommandFailure(" ".join(self.cmd), e.strerror)
return len(content)
def recv(self, size):
@@ -81,7 +84,7 @@ class ProxyCommand(ClosingContextManager):
:return: the string of bytes read, which may be shorter than requested
"""
try:
- buffer = b''
+ buffer = b""
start = time.time()
while len(buffer) < size:
select_timeout = None
@@ -91,11 +94,11 @@ class ProxyCommand(ClosingContextManager):
raise socket.timeout()
select_timeout = self.timeout - elapsed
- r, w, x = select(
- [self.process.stdout], [], [], select_timeout)
+ r, w, x = select([self.process.stdout], [], [], select_timeout)
if r and r[0] == self.process.stdout:
buffer += os.read(
- self.process.stdout.fileno(), size - len(buffer))
+ self.process.stdout.fileno(), size - len(buffer)
+ )
return buffer
except socket.timeout:
if buffer:
@@ -103,7 +106,7 @@ class ProxyCommand(ClosingContextManager):
return buffer
raise # socket.timeout is a subclass of IOError
except IOError as e:
- raise ProxyCommandFailure(' '.join(self.cmd), e.strerror)
+ raise ProxyCommandFailure(" ".join(self.cmd), e.strerror)
def close(self):
os.kill(self.process.pid, signal.SIGTERM)
diff --git a/paramiko/py3compat.py b/paramiko/py3compat.py
index 67c0f20..e1f33fe 100644
--- a/paramiko/py3compat.py
+++ b/paramiko/py3compat.py
@@ -2,10 +2,28 @@ import sys
import base64
__all__ = [
- 'BytesIO', 'MAXSIZE', 'PY2', 'StringIO', 'b', 'b2s', 'builtins',
- 'byte_chr', 'byte_mask', 'byte_ord', 'bytes', 'bytes_types', 'decodebytes',
- 'encodebytes', 'input', 'integer_types', 'is_callable', 'long', 'next',
- 'string_types', 'text_type', 'u',
+ "BytesIO",
+ "MAXSIZE",
+ "PY2",
+ "StringIO",
+ "b",
+ "b2s",
+ "builtins",
+ "byte_chr",
+ "byte_mask",
+ "byte_ord",
+ "bytes",
+ "bytes_types",
+ "decodebytes",
+ "encodebytes",
+ "input",
+ "integer_types",
+ "is_callable",
+ "long",
+ "next",
+ "string_types",
+ "text_type",
+ "u",
]
PY2 = sys.version_info[0] < 3
@@ -23,16 +41,13 @@ if PY2:
import __builtin__ as builtins
-
byte_ord = ord # NOQA
byte_chr = chr # NOQA
-
def byte_mask(c, mask):
return chr(ord(c) & mask)
-
- def b(s, encoding='utf8'): # NOQA
+ def b(s, encoding="utf8"): # NOQA
"""cast unicode or bytes to bytes"""
if isinstance(s, str):
return s
@@ -43,8 +58,7 @@ if PY2:
else:
raise TypeError("Expected unicode or bytes, got {!r}".format(s))
-
- def u(s, encoding='utf8'): # NOQA
+ def u(s, encoding="utf8"): # NOQA
"""cast bytes or unicode to unicode"""
if isinstance(s, str):
return s.decode(encoding)
@@ -55,53 +69,52 @@ if PY2:
else:
raise TypeError("Expected unicode or bytes, got {!r}".format(s))
-
def b2s(s):
return s
-
import cStringIO
+
StringIO = cStringIO.StringIO
BytesIO = StringIO
-
def is_callable(c): # NOQA
return callable(c)
-
def get_next(c): # NOQA
return c.next
-
def next(c):
return c.next()
# It's possible to have sizeof(long) != sizeof(Py_ssize_t).
class X(object):
+
def __len__(self):
return 1 << 31
-
try:
len(X())
except OverflowError:
# 32-bit
- MAXSIZE = int((1 << 31) - 1) # NOQA
+ MAXSIZE = int((1 << 31) - 1) # NOQA
else:
# 64-bit
- MAXSIZE = int((1 << 63) - 1) # NOQA
+ MAXSIZE = int((1 << 63) - 1) # NOQA
del X
else:
import collections
import struct
import builtins
+
string_types = str
text_type = str
bytes = bytes
bytes_types = bytes
integer_types = int
+
class long(int):
pass
+
input = input
decodebytes = base64.decodebytes
encodebytes = base64.encodebytes
@@ -114,13 +127,13 @@ else:
def byte_chr(c):
assert isinstance(c, int)
- return struct.pack('B', c)
+ return struct.pack("B", c)
def byte_mask(c, mask):
assert isinstance(c, int)
- return struct.pack('B', c & mask)
+ return struct.pack("B", c & mask)
- def b(s, encoding='utf8'):
+ def b(s, encoding="utf8"):
"""cast unicode or bytes to bytes"""
if isinstance(s, bytes):
return s
@@ -129,7 +142,7 @@ else:
else:
raise TypeError("Expected unicode or bytes, got {!r}".format(s))
- def u(s, encoding='utf8'):
+ def u(s, encoding="utf8"):
"""cast bytes or unicode to unicode"""
if isinstance(s, bytes):
return s.decode(encoding)
@@ -142,8 +155,9 @@ else:
return s.decode() if isinstance(s, bytes) else s
import io
- StringIO = io.StringIO # NOQA
- BytesIO = io.BytesIO # NOQA
+
+ StringIO = io.StringIO # NOQA
+ BytesIO = io.BytesIO # NOQA
def is_callable(c):
return isinstance(c, collections.Callable)
@@ -153,4 +167,4 @@ else:
next = next
- MAXSIZE = sys.maxsize # NOQA
+ MAXSIZE = sys.maxsize # NOQA
diff --git a/paramiko/rsakey.py b/paramiko/rsakey.py
index 8dfcfb0..b0fce1f 100644
--- a/paramiko/rsakey.py
+++ b/paramiko/rsakey.py
@@ -37,8 +37,15 @@ class RSAKey(PKey):
data.
"""
- def __init__(self, msg=None, data=None, filename=None, password=None,
- key=None, file_obj=None):
+ def __init__(
+ self,
+ msg=None,
+ data=None,
+ filename=None,
+ password=None,
+ key=None,
+ file_obj=None,
+ ):
self.key = None
self.public_blob = None
if file_obj is not None:
@@ -54,12 +61,14 @@ class RSAKey(PKey):
else:
self._check_type_and_load_cert(
msg=msg,
- key_type='ssh-rsa',
- cert_type='ssh-rsa-cert-v01@openssh.com',
+ key_type="ssh-rsa",
+ cert_type="ssh-rsa-cert-v01@openssh.com",
)
self.key = rsa.RSAPublicNumbers(
e=msg.get_mpint(), n=msg.get_mpint()
- ).public_key(default_backend())
+ ).public_key(
+ default_backend()
+ )
@property
def size(self):
@@ -74,7 +83,7 @@ class RSAKey(PKey):
def asbytes(self):
m = Message()
- m.add_string('ssh-rsa')
+ m.add_string("ssh-rsa")
m.add_mpint(self.public_numbers.e)
m.add_mpint(self.public_numbers.n)
return m.asbytes()
@@ -89,14 +98,15 @@ class RSAKey(PKey):
# tries stuffing it into ASCII for whatever godforsaken reason
return self.asbytes()
else:
- return self.asbytes().decode('utf8', errors='ignore')
+ return self.asbytes().decode("utf8", errors="ignore")
def __hash__(self):
- return hash((self.get_name(), self.public_numbers.e,
- self.public_numbers.n))
+ return hash(
+ (self.get_name(), self.public_numbers.e, self.public_numbers.n)
+ )
def get_name(self):
- return 'ssh-rsa'
+ return "ssh-rsa"
def get_bits(self):
return self.size
@@ -106,18 +116,16 @@ class RSAKey(PKey):
def sign_ssh_data(self, data):
sig = self.key.sign(
- data,
- padding=padding.PKCS1v15(),
- algorithm=hashes.SHA1(),
+ data, padding=padding.PKCS1v15(), algorithm=hashes.SHA1()
)
m = Message()
- m.add_string('ssh-rsa')
+ m.add_string("ssh-rsa")
m.add_string(sig)
return m
def verify_ssh_sig(self, data, msg):
- if msg.get_text() != 'ssh-rsa':
+ if msg.get_text() != "ssh-rsa":
return False
key = self.key
if isinstance(key, rsa.RSAPrivateKey):
@@ -137,7 +145,7 @@ class RSAKey(PKey):
filename,
self.key,
serialization.PrivateFormat.TraditionalOpenSSL,
- password=password
+ password=password,
)
def write_private_key(self, file_obj, password=None):
@@ -145,7 +153,7 @@ class RSAKey(PKey):
file_obj,
self.key,
serialization.PrivateFormat.TraditionalOpenSSL,
- password=password
+ password=password,
)
@staticmethod
@@ -166,11 +174,11 @@ class RSAKey(PKey):
# ...internals...
def _from_private_key_file(self, filename, password):
- data = self._read_private_key_file('RSA', filename, password)
+ data = self._read_private_key_file("RSA", filename, password)
self._decode_key(data)
def _from_private_key(self, file_obj, password):
- data = self._read_private_key('RSA', file_obj, password)
+ data = self._read_private_key("RSA", file_obj, password)
self._decode_key(data)
def _decode_key(self, data):
diff --git a/paramiko/server.py b/paramiko/server.py
index a711781..2fe9cc1 100644
--- a/paramiko/server.py
+++ b/paramiko/server.py
@@ -23,13 +23,16 @@
import threading
from paramiko import util
from paramiko.common import (
- DEBUG, ERROR, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, AUTH_FAILED,
+ DEBUG,
+ ERROR,
+ OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED,
+ AUTH_FAILED,
AUTH_SUCCESSFUL,
)
from paramiko.py3compat import string_types
-class ServerInterface (object):
+class ServerInterface(object):
"""
This class defines an interface for controlling the behavior of Paramiko
in server mode.
@@ -99,7 +102,7 @@ class ServerInterface (object):
:param str username: the username requesting authentication.
:return: a comma-separated `str` of authentication types
"""
- return 'password'
+ return "password"
def check_auth_none(self, username):
"""
@@ -233,9 +236,9 @@ class ServerInterface (object):
"""
return AUTH_FAILED
- def check_auth_gssapi_with_mic(self, username,
- gss_authenticated=AUTH_FAILED,
- cc_file=None):
+ def check_auth_gssapi_with_mic(
+ self, username, gss_authenticated=AUTH_FAILED, cc_file=None
+ ):
"""
Authenticate the given user to the server if he is a valid krb5
principal.
@@ -263,9 +266,9 @@ class ServerInterface (object):
return AUTH_SUCCESSFUL
return AUTH_FAILED
- def check_auth_gssapi_keyex(self, username,
- gss_authenticated=AUTH_FAILED,
- cc_file=None):
+ def check_auth_gssapi_keyex(
+ self, username, gss_authenticated=AUTH_FAILED, cc_file=None
+ ):
"""
Authenticate the given user to the server if he is a valid krb5
principal and GSS-API Key Exchange was performed.
@@ -372,8 +375,8 @@ class ServerInterface (object):
# ...Channel requests...
def check_channel_pty_request(
- self, channel, term, width, height, pixelwidth, pixelheight,
- modes):
+ self, channel, term, width, height, pixelwidth, pixelheight, modes
+ ):
"""
Determine if a pseudo-terminal of the given dimensions (usually
requested for shell access) can be provided on the given channel.
@@ -460,7 +463,8 @@ class ServerInterface (object):
return True
def check_channel_window_change_request(
- self, channel, width, height, pixelwidth, pixelheight):
+ self, channel, width, height, pixelwidth, pixelheight
+ ):
"""
Determine if the pseudo-terminal on the given channel can be resized.
This only makes sense if a pty was previously allocated on it.
@@ -479,8 +483,13 @@ class ServerInterface (object):
return False
def check_channel_x11_request(
- self, channel, single_connection, auth_protocol, auth_cookie,
- screen_number):
+ self,
+ channel,
+ single_connection,
+ auth_protocol,
+ auth_cookie,
+ screen_number,
+ ):
"""
Determine if the client will be provided with an X11 session. If this
method returns ``True``, X11 applications should be routed through new
@@ -584,12 +593,13 @@ class ServerInterface (object):
"""
return (None, None)
-class InteractiveQuery (object):
+
+class InteractiveQuery(object):
"""
A query (set of prompts) for a user during interactive authentication.
"""
- def __init__(self, name='', instructions='', *prompts):
+ def __init__(self, name="", instructions="", *prompts):
"""
Create a new interactive query to send to the client. The name and
instructions are optional, but are generally displayed to the end
@@ -623,7 +633,7 @@ class InteractiveQuery (object):
self.prompts.append((prompt, echo))
-class SubsystemHandler (threading.Thread):
+class SubsystemHandler(threading.Thread):
"""
Handler for a subsytem in server mode. If you create a subclass of this
class and pass it to `.Transport.set_subsystem_handler`, an object of this
@@ -637,6 +647,7 @@ class SubsystemHandler (threading.Thread):
``MP3Handler`` will be created, and `start_subsystem` will be called on
it from a new thread.
"""
+
def __init__(self, channel, name, server):
"""
Create a new handler for a channel. This is used by `.ServerInterface`
@@ -667,7 +678,7 @@ class SubsystemHandler (threading.Thread):
def _run(self):
try:
self.__transport._log(
- DEBUG, 'Starting handler for subsystem {}'.format(self.__name)
+ DEBUG, "Starting handler for subsystem {}".format(self.__name)
)
self.start_subsystem(self.__name, self.__transport, self.__channel)
except Exception as e:
@@ -675,7 +686,7 @@ class SubsystemHandler (threading.Thread):
ERROR,
'Exception in subsystem handler for "{}": {}'.format(
self.__name, e
- )
+ ),
)
self.__transport._log(ERROR, util.tb_strings())
try:
diff --git a/paramiko/sftp.py b/paramiko/sftp.py
index e6786d1..57a1bea 100644
--- a/paramiko/sftp.py
+++ b/paramiko/sftp.py
@@ -26,27 +26,28 @@ from paramiko.message import Message
from paramiko.py3compat import byte_chr, byte_ord
-CMD_INIT, CMD_VERSION, CMD_OPEN, CMD_CLOSE, CMD_READ, CMD_WRITE, CMD_LSTAT, \
- CMD_FSTAT, CMD_SETSTAT, CMD_FSETSTAT, CMD_OPENDIR, CMD_READDIR, \
- CMD_REMOVE, CMD_MKDIR, CMD_RMDIR, CMD_REALPATH, CMD_STAT, CMD_RENAME, \
- CMD_READLINK, CMD_SYMLINK = range(1, 21)
+CMD_INIT, CMD_VERSION, CMD_OPEN, CMD_CLOSE, CMD_READ, CMD_WRITE, CMD_LSTAT, CMD_FSTAT, CMD_SETSTAT, CMD_FSETSTAT, CMD_OPENDIR, CMD_READDIR, CMD_REMOVE, CMD_MKDIR, CMD_RMDIR, CMD_REALPATH, CMD_STAT, CMD_RENAME, CMD_READLINK, CMD_SYMLINK = range(
+ 1, 21
+)
CMD_STATUS, CMD_HANDLE, CMD_DATA, CMD_NAME, CMD_ATTRS = range(101, 106)
CMD_EXTENDED, CMD_EXTENDED_REPLY = range(200, 202)
SFTP_OK = 0
-SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED, SFTP_FAILURE, \
- SFTP_BAD_MESSAGE, SFTP_NO_CONNECTION, SFTP_CONNECTION_LOST, \
- SFTP_OP_UNSUPPORTED = range(1, 9)
-
-SFTP_DESC = ['Success',
- 'End of file',
- 'No such file',
- 'Permission denied',
- 'Failure',
- 'Bad message',
- 'No connection',
- 'Connection lost',
- 'Operation unsupported']
+SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED, SFTP_FAILURE, SFTP_BAD_MESSAGE, SFTP_NO_CONNECTION, SFTP_CONNECTION_LOST, SFTP_OP_UNSUPPORTED = range(
+ 1, 9
+)
+
+SFTP_DESC = [
+ "Success",
+ "End of file",
+ "No such file",
+ "Permission denied",
+ "Failure",
+ "Bad message",
+ "No connection",
+ "Connection lost",
+ "Operation unsupported",
+]
SFTP_FLAG_READ = 0x1
SFTP_FLAG_WRITE = 0x2
@@ -60,54 +61,55 @@ _VERSION = 3
# for debugging
CMD_NAMES = {
- CMD_INIT: 'init',
- CMD_VERSION: 'version',
- CMD_OPEN: 'open',
- CMD_CLOSE: 'close',
- CMD_READ: 'read',
- CMD_WRITE: 'write',
- CMD_LSTAT: 'lstat',
- CMD_FSTAT: 'fstat',
- CMD_SETSTAT: 'setstat',
- CMD_FSETSTAT: 'fsetstat',
- CMD_OPENDIR: 'opendir',
- CMD_READDIR: 'readdir',
- CMD_REMOVE: 'remove',
- CMD_MKDIR: 'mkdir',
- CMD_RMDIR: 'rmdir',
- CMD_REALPATH: 'realpath',
- CMD_STAT: 'stat',
- CMD_RENAME: 'rename',
- CMD_READLINK: 'readlink',
- CMD_SYMLINK: 'symlink',
- CMD_STATUS: 'status',
- CMD_HANDLE: 'handle',
- CMD_DATA: 'data',
- CMD_NAME: 'name',
- CMD_ATTRS: 'attrs',
- CMD_EXTENDED: 'extended',
- CMD_EXTENDED_REPLY: 'extended_reply'
+ CMD_INIT: "init",
+ CMD_VERSION: "version",
+ CMD_OPEN: "open",
+ CMD_CLOSE: "close",
+ CMD_READ: "read",
+ CMD_WRITE: "write",
+ CMD_LSTAT: "lstat",
+ CMD_FSTAT: "fstat",
+ CMD_SETSTAT: "setstat",
+ CMD_FSETSTAT: "fsetstat",
+ CMD_OPENDIR: "opendir",
+ CMD_READDIR: "readdir",
+ CMD_REMOVE: "remove",
+ CMD_MKDIR: "mkdir",
+ CMD_RMDIR: "rmdir",
+ CMD_REALPATH: "realpath",
+ CMD_STAT: "stat",
+ CMD_RENAME: "rename",
+ CMD_READLINK: "readlink",
+ CMD_SYMLINK: "symlink",
+ CMD_STATUS: "status",
+ CMD_HANDLE: "handle",
+ CMD_DATA: "data",
+ CMD_NAME: "name",
+ CMD_ATTRS: "attrs",
+ CMD_EXTENDED: "extended",
+ CMD_EXTENDED_REPLY: "extended_reply",
}
-class SFTPError (Exception):
+class SFTPError(Exception):
pass
-class BaseSFTP (object):
+class BaseSFTP(object):
+
def __init__(self):
- self.logger = util.get_logger('paramiko.sftp')
+ self.logger = util.get_logger("paramiko.sftp")
self.sock = None
self.ultra_debug = False
# ...internals...
def _send_version(self):
- self._send_packet(CMD_INIT, struct.pack('>I', _VERSION))
+ self._send_packet(CMD_INIT, struct.pack(">I", _VERSION))
t, data = self._read_packet()
if t != CMD_VERSION:
- raise SFTPError('Incompatible sftp protocol')
- version = struct.unpack('>I', data[:4])[0]
+ raise SFTPError("Incompatible sftp protocol")
+ version = struct.unpack(">I", data[:4])[0]
# if version != _VERSION:
# raise SFTPError('Incompatible sftp protocol')
return version
@@ -117,10 +119,10 @@ class BaseSFTP (object):
# client finishes sending INIT.
t, data = self._read_packet()
if t != CMD_INIT:
- raise SFTPError('Incompatible sftp protocol')
- version = struct.unpack('>I', data[:4])[0]
+ raise SFTPError("Incompatible sftp protocol")
+ version = struct.unpack(">I", data[:4])[0]
# advertise that we support "check-file"
- extension_pairs = ['check-file', 'md5,sha1']
+ extension_pairs = ["check-file", "md5,sha1"]
msg = Message()
msg.add_int(_VERSION)
msg.add(*extension_pairs)
@@ -165,9 +167,9 @@ class BaseSFTP (object):
def _send_packet(self, t, packet):
packet = asbytes(packet)
- out = struct.pack('>I', len(packet) + 1) + byte_chr(t) + packet
+ out = struct.pack(">I", len(packet) + 1) + byte_chr(t) + packet
if self.ultra_debug:
- self._log(DEBUG, util.format_binary(out, 'OUT: '))
+ self._log(DEBUG, util.format_binary(out, "OUT: "))
self._write_all(out)
def _read_packet(self):
@@ -175,11 +177,11 @@ class BaseSFTP (object):
# most sftp servers won't accept packets larger than about 32k, so
# anything with the high byte set (> 16MB) is just garbage.
if byte_ord(x[0]):
- raise SFTPError('Garbage packet received')
- size = struct.unpack('>I', x)[0]
+ raise SFTPError("Garbage packet received")
+ size = struct.unpack(">I", x)[0]
data = self._read_all(size)
if self.ultra_debug:
- self._log(DEBUG, util.format_binary(data, 'IN: '))
+ self._log(DEBUG, util.format_binary(data, "IN: "))
if size > 0:
t = byte_ord(data[0])
return t, data[1:]
diff --git a/paramiko/sftp_attr.py b/paramiko/sftp_attr.py
index ea12b2f..8e48373 100644
--- a/paramiko/sftp_attr.py
+++ b/paramiko/sftp_attr.py
@@ -22,7 +22,7 @@ from paramiko.common import x80000000, o700, o70, xffffffff
from paramiko.py3compat import long, b
-class SFTPAttributes (object):
+class SFTPAttributes(object):
"""
Representation of the attributes of a file (or proxied file) for SFTP in
client or server mode. It attemps to mirror the object returned by
@@ -82,7 +82,7 @@ class SFTPAttributes (object):
return attr
def __repr__(self):
- return '<SFTPAttributes: {}>'.format(self._debug_str())
+ return "<SFTPAttributes: {}>".format(self._debug_str())
# ...internals...
@classmethod
@@ -144,29 +144,29 @@ class SFTPAttributes (object):
return
def _debug_str(self):
- out = '[ '
+ out = "[ "
if self.st_size is not None:
- out += 'size={} '.format(self.st_size)
+ out += "size={} ".format(self.st_size)
if (self.st_uid is not None) and (self.st_gid is not None):
- out += 'uid={} gid={} '.format(self.st_uid, self.st_gid)
+ out += "uid={} gid={} ".format(self.st_uid, self.st_gid)
if self.st_mode is not None:
- out += 'mode=' + oct(self.st_mode) + ' '
+ out += "mode=" + oct(self.st_mode) + " "
if (self.st_atime is not None) and (self.st_mtime is not None):
- out += 'atime={} mtime={} '.format(self.st_atime, self.st_mtime)
+ out += "atime={} mtime={} ".format(self.st_atime, self.st_mtime)
for k, v in self.attr.items():
out += '"{}"={!r} '.format(str(k), v)
- out += ']'
+ out += "]"
return out
@staticmethod
def _rwx(n, suid, sticky=False):
if suid:
suid = 2
- out = '-r'[n >> 2] + '-w'[(n >> 1) & 1]
+ out = "-r"[n >> 2] + "-w"[(n >> 1) & 1]
if sticky:
- out += '-xTt'[suid + (n & 1)]
+ out += "-xTt"[suid + (n & 1)]
else:
- out += '-xSs'[suid + (n & 1)]
+ out += "-xSs"[suid + (n & 1)]
return out
def __str__(self):
@@ -174,42 +174,47 @@ class SFTPAttributes (object):
if self.st_mode is not None:
kind = stat.S_IFMT(self.st_mode)
if kind == stat.S_IFIFO:
- ks = 'p'
+ ks = "p"
elif kind == stat.S_IFCHR:
- ks = 'c'
+ ks = "c"
elif kind == stat.S_IFDIR:
- ks = 'd'
+ ks = "d"
elif kind == stat.S_IFBLK:
- ks = 'b'
+ ks = "b"
elif kind == stat.S_IFREG:
- ks = '-'
+ ks = "-"
elif kind == stat.S_IFLNK:
- ks = 'l'
+ ks = "l"
elif kind == stat.S_IFSOCK:
- ks = 's'
+ ks = "s"
else:
- ks = '?'
+ ks = "?"
ks += self._rwx(
- (self.st_mode & o700) >> 6, self.st_mode & stat.S_ISUID)
+ (self.st_mode & o700) >> 6, self.st_mode & stat.S_ISUID
+ )
ks += self._rwx(
- (self.st_mode & o70) >> 3, self.st_mode & stat.S_ISGID)
+ (self.st_mode & o70) >> 3, self.st_mode & stat.S_ISGID
+ )
ks += self._rwx(
- self.st_mode & 7, self.st_mode & stat.S_ISVTX, True)
+ self.st_mode & 7, self.st_mode & stat.S_ISVTX, True
+ )
else:
- ks = '?---------'
+ ks = "?---------"
# compute display date
if (self.st_mtime is None) or (self.st_mtime == xffffffff):
# shouldn't really happen
- datestr = '(unknown date)'
+ datestr = "(unknown date)"
else:
if abs(time.time() - self.st_mtime) > 15552000:
# (15552000 = 6 months)
datestr = time.strftime(
- '%d %b %Y', time.localtime(self.st_mtime))
+ "%d %b %Y", time.localtime(self.st_mtime)
+ )
else:
datestr = time.strftime(
- '%d %b %H:%M', time.localtime(self.st_mtime))
- filename = getattr(self, 'filename', '?')
+ "%d %b %H:%M", time.localtime(self.st_mtime)
+ )
+ filename = getattr(self, "filename", "?")
# not all servers support uid/gid
uid = self.st_uid
@@ -225,8 +230,8 @@ class SFTPAttributes (object):
# TODO: not sure this actually worked as expected beforehand, leaving
# it untouched for the time being, re: .format() upgrade, until someone
# has time to doublecheck
- return '%s 1 %-8d %-8d %8d %-12s %s' % (
- ks, uid, gid, size, datestr, filename,
+ return "%s 1 %-8d %-8d %8d %-12s %s" % (
+ ks, uid, gid, size, datestr, filename
)
def asbytes(self):
diff --git a/paramiko/sftp_client.py b/paramiko/sftp_client.py
index 31dc234..de3f9f5 100644
--- a/paramiko/sftp_client.py
+++ b/paramiko/sftp_client.py
@@ -30,12 +30,37 @@ from paramiko.message import Message
from paramiko.common import INFO, DEBUG, o777
from paramiko.py3compat import b, u, long
from paramiko.sftp import (
- BaseSFTP, CMD_OPENDIR, CMD_HANDLE, SFTPError, CMD_READDIR, CMD_NAME,
- CMD_CLOSE, SFTP_FLAG_READ, SFTP_FLAG_WRITE, SFTP_FLAG_CREATE,
- SFTP_FLAG_TRUNC, SFTP_FLAG_APPEND, SFTP_FLAG_EXCL, CMD_OPEN, CMD_REMOVE,
- CMD_RENAME, CMD_MKDIR, CMD_RMDIR, CMD_STAT, CMD_ATTRS, CMD_LSTAT,
- CMD_SYMLINK, CMD_SETSTAT, CMD_READLINK, CMD_REALPATH, CMD_STATUS,
- CMD_EXTENDED, SFTP_OK, SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED,
+ BaseSFTP,
+ CMD_OPENDIR,
+ CMD_HANDLE,
+ SFTPError,
+ CMD_READDIR,
+ CMD_NAME,
+ CMD_CLOSE,
+ SFTP_FLAG_READ,
+ SFTP_FLAG_WRITE,
+ SFTP_FLAG_CREATE,
+ SFTP_FLAG_TRUNC,
+ SFTP_FLAG_APPEND,
+ SFTP_FLAG_EXCL,
+ CMD_OPEN,
+ CMD_REMOVE,
+ CMD_RENAME,
+ CMD_MKDIR,
+ CMD_RMDIR,
+ CMD_STAT,
+ CMD_ATTRS,
+ CMD_LSTAT,
+ CMD_SYMLINK,
+ CMD_SETSTAT,
+ CMD_READLINK,
+ CMD_REALPATH,
+ CMD_STATUS,
+ CMD_EXTENDED,
+ SFTP_OK,
+ SFTP_EOF,
+ SFTP_NO_SUCH_FILE,
+ SFTP_PERMISSION_DENIED,
)
from paramiko.sftp_attr import SFTPAttributes
@@ -51,15 +76,15 @@ def _to_unicode(s):
probably doesn't know the filename's encoding.
"""
try:
- return s.encode('ascii')
+ return s.encode("ascii")
except (UnicodeError, AttributeError):
try:
- return s.decode('utf-8')
+ return s.decode("utf-8")
except UnicodeError:
return s
-b_slash = b'/'
+b_slash = b"/"
class SFTPClient(BaseSFTP, ClosingContextManager):
@@ -71,6 +96,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
Instances of this class may be used as context managers.
"""
+
def __init__(self, sock):
"""
Create an SFTP client from an existing `.Channel`. The channel
@@ -97,15 +123,18 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
# override default logger
transport = self.sock.get_transport()
self.logger = util.get_logger(
- transport.get_log_channel() + '.sftp')
+ transport.get_log_channel() + ".sftp"
+ )
self.ultra_debug = transport.get_hexdump()
try:
server_version = self._send_version()
except EOFError:
- raise SSHException('EOF during negotiation')
+ raise SSHException("EOF during negotiation")
self._log(
INFO,
- 'Opened sftp connection (server version {})'.format(server_version)
+ "Opened sftp connection (server version {})".format(
+ server_version
+ ),
)
@classmethod
@@ -132,11 +161,12 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
.. versionchanged:: 1.15
Added the ``window_size`` and ``max_packet_size`` arguments.
"""
- chan = t.open_session(window_size=window_size,
- max_packet_size=max_packet_size)
+ chan = t.open_session(
+ window_size=window_size, max_packet_size=max_packet_size
+ )
if chan is None:
return None
- chan.invoke_subsystem('sftp')
+ chan.invoke_subsystem("sftp")
return cls(chan)
def _log(self, level, msg, *args):
@@ -148,10 +178,12 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
# logging.Logger.log() explicitly requires it. Grump.
# escape '%' in msg (they could come from file or directory names)
# before logging
- msg = msg.replace('%', '%%')
+ msg = msg.replace("%", "%%")
super(SFTPClient, self)._log(
level,
- "[chan %s] " + msg, *([self.sock.get_name()] + list(args)))
+ "[chan %s] " + msg,
+ *([self.sock.get_name()] + list(args))
+ )
def close(self):
"""
@@ -159,7 +191,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
.. versionadded:: 1.4
"""
- self._log(INFO, 'sftp session closed.')
+ self._log(INFO, "sftp session closed.")
self.sock.close()
def get_channel(self):
@@ -171,7 +203,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
"""
return self.sock
- def listdir(self, path='.'):
+ def listdir(self, path="."):
"""
Return a list containing the names of the entries in the given
``path``.
@@ -185,7 +217,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
"""
return [f.filename for f in self.listdir_attr(path)]
- def listdir_attr(self, path='.'):
+ def listdir_attr(self, path="."):
"""
Return a list containing `.SFTPAttributes` objects corresponding to
files in the given ``path``. The list is in arbitrary order. It does
@@ -203,10 +235,10 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
.. versionadded:: 1.2
"""
path = self._adjust_cwd(path)
- self._log(DEBUG, 'listdir({!r})'.format(path))
+ self._log(DEBUG, "listdir({!r})".format(path))
t, msg = self._request(CMD_OPENDIR, path)
if t != CMD_HANDLE:
- raise SFTPError('Expected handle')
+ raise SFTPError("Expected handle")
handle = msg.get_binary()
filelist = []
while True:
@@ -216,18 +248,18 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
# done with handle
break
if t != CMD_NAME:
- raise SFTPError('Expected name response')
+ raise SFTPError("Expected name response")
count = msg.get_int()
for i in range(count):
filename = msg.get_text()
longname = msg.get_text()
attr = SFTPAttributes._from_msg(msg, filename, longname)
- if (filename != '.') and (filename != '..'):
+ if (filename != ".") and (filename != ".."):
filelist.append(attr)
self._request(CMD_CLOSE, handle)
return filelist
- def listdir_iter(self, path='.', read_aheads=50):
+ def listdir_iter(self, path=".", read_aheads=50):
"""
Generator version of `.listdir_attr`.
@@ -242,11 +274,11 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
.. versionadded:: 1.15
"""
path = self._adjust_cwd(path)
- self._log(DEBUG, 'listdir({!r})'.format(path))
+ self._log(DEBUG, "listdir({!r})".format(path))
t, msg = self._request(CMD_OPENDIR, path)
if t != CMD_HANDLE:
- raise SFTPError('Expected handle')
+ raise SFTPError("Expected handle")
handle = msg.get_string()
@@ -261,7 +293,6 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
num = self._async_request(type(None), CMD_READDIR, handle)
nums.append(num)
-
# For each of our sent requests
# Read and parse the corresponding packets
# If we're at the end of our queued requests, then fire off
@@ -280,8 +311,9 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
filename = msg.get_text()
longname = msg.get_text()
attr = SFTPAttributes._from_msg(
- msg, filename, longname)
- if (filename != '.') and (filename != '..'):
+ msg, filename, longname
+ )
+ if (filename != ".") and (filename != ".."):
yield attr
# If we've hit the end of our queued requests, reset nums.
@@ -291,8 +323,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
self._request(CMD_CLOSE, handle)
return
-
- def open(self, filename, mode='r', bufsize=-1):
+ def open(self, filename, mode="r", bufsize=-1):
"""
Open a file on the remote server. The arguments are the same as for
Python's built-in `python:file` (aka `python:open`). A file-like
@@ -325,26 +356,28 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:raises: ``IOError`` -- if the file could not be opened.
"""
filename = self._adjust_cwd(filename)
- self._log(DEBUG, 'open({!r}, {!r})'.format(filename, mode))
+ self._log(DEBUG, "open({!r}, {!r})".format(filename, mode))
imode = 0
- if ('r' in mode) or ('+' in mode):
+ if ("r" in mode) or ("+" in mode):
imode |= SFTP_FLAG_READ
- if ('w' in mode) or ('+' in mode) or ('a' in mode):
+ if ("w" in mode) or ("+" in mode) or ("a" in mode):
imode |= SFTP_FLAG_WRITE
- if 'w' in mode:
+ if "w" in mode:
imode |= SFTP_FLAG_CREATE | SFTP_FLAG_TRUNC
- if 'a' in mode:
+ if "a" in mode:
imode |= SFTP_FLAG_CREATE | SFTP_FLAG_APPEND
- if 'x' in mode:
+ if "x" in mode:
imode |= SFTP_FLAG_CREATE | SFTP_FLAG_EXCL
attrblock = SFTPAttributes()
t, msg = self._request(CMD_OPEN, filename, imode, attrblock)
if t != CMD_HANDLE:
- raise SFTPError('Expected handle')
+ raise SFTPError("Expected handle")
handle = msg.get_binary()
self._log(
DEBUG,
- 'open({!r}, {!r}) -> {}'.format(filename, mode, u(hexlify(handle)))
+ "open({!r}, {!r}) -> {}".format(
+ filename, mode, u(hexlify(handle))
+ ),
)
return SFTPFile(self, handle, mode, bufsize)
@@ -361,7 +394,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:raises: ``IOError`` -- if the path refers to a folder (directory)
"""
path = self._adjust_cwd(path)
- self._log(DEBUG, 'remove({!r})'.format(path))
+ self._log(DEBUG, "remove({!r})".format(path))
self._request(CMD_REMOVE, path)
unlink = remove
@@ -386,7 +419,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
"""
oldpath = self._adjust_cwd(oldpath)
newpath = self._adjust_cwd(newpath)
- self._log(DEBUG, 'rename({!r}, {!r})'.format(oldpath, newpath))
+ self._log(DEBUG, "rename({!r}, {!r})".format(oldpath, newpath))
self._request(CMD_RENAME, oldpath, newpath)
def posix_rename(self, oldpath, newpath):
@@ -406,7 +439,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
"""
oldpath = self._adjust_cwd(oldpath)
newpath = self._adjust_cwd(newpath)
- self._log(DEBUG, 'posix_rename({!r}, {!r})'.format(oldpath, newpath))
+ self._log(DEBUG, "posix_rename({!r}, {!r})".format(oldpath, newpath))
self._request(
CMD_EXTENDED, "posix-rename@openssh.com", oldpath, newpath
)
@@ -421,7 +454,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:param int mode: permissions (posix-style) for the newly-created folder
"""
path = self._adjust_cwd(path)
- self._log(DEBUG, 'mkdir({!r}, {!r})'.format(path, mode))
+ self._log(DEBUG, "mkdir({!r}, {!r})".format(path, mode))
attr = SFTPAttributes()
attr.st_mode = mode
self._request(CMD_MKDIR, path, attr)
@@ -433,7 +466,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:param str path: name of the folder to remove
"""
path = self._adjust_cwd(path)
- self._log(DEBUG, 'rmdir({!r})'.format(path))
+ self._log(DEBUG, "rmdir({!r})".format(path))
self._request(CMD_RMDIR, path)
def stat(self, path):
@@ -456,10 +489,10 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
file
"""
path = self._adjust_cwd(path)
- self._log(DEBUG, 'stat({!r})'.format(path))
+ self._log(DEBUG, "stat({!r})".format(path))
t, msg = self._request(CMD_STAT, path)
if t != CMD_ATTRS:
- raise SFTPError('Expected attributes')
+ raise SFTPError("Expected attributes")
return SFTPAttributes._from_msg(msg)
def lstat(self, path):
@@ -474,10 +507,10 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
file
"""
path = self._adjust_cwd(path)
- self._log(DEBUG, 'lstat({!r})'.format(path))
+ self._log(DEBUG, "lstat({!r})".format(path))
t, msg = self._request(CMD_LSTAT, path)
if t != CMD_ATTRS:
- raise SFTPError('Expected attributes')
+ raise SFTPError("Expected attributes")
return SFTPAttributes._from_msg(msg)
def symlink(self, source, dest):
@@ -488,7 +521,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:param str dest: path of the newly created symlink
"""
dest = self._adjust_cwd(dest)
- self._log(DEBUG, 'symlink({!r}, {!r})'.format(source, dest))
+ self._log(DEBUG, "symlink({!r}, {!r})".format(source, dest))
source = b(source)
self._request(CMD_SYMLINK, source, dest)
@@ -502,7 +535,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:param int mode: new permissions
"""
path = self._adjust_cwd(path)
- self._log(DEBUG, 'chmod({!r}, {!r})'.format(path, mode))
+ self._log(DEBUG, "chmod({!r}, {!r})".format(path, mode))
attr = SFTPAttributes()
attr.st_mode = mode
self._request(CMD_SETSTAT, path, attr)
@@ -519,7 +552,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:param int gid: new group id
"""
path = self._adjust_cwd(path)
- self._log(DEBUG, 'chown({!r}, {!r}, {!r})'.format(path, uid, gid))
+ self._log(DEBUG, "chown({!r}, {!r}, {!r})".format(path, uid, gid))
attr = SFTPAttributes()
attr.st_uid, attr.st_gid = uid, gid
self._request(CMD_SETSTAT, path, attr)
@@ -541,7 +574,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
path = self._adjust_cwd(path)
if times is None:
times = (time.time(), time.time())
- self._log(DEBUG, 'utime({!r}, {!r})'.format(path, times))
+ self._log(DEBUG, "utime({!r}, {!r})".format(path, times))
attr = SFTPAttributes()
attr.st_atime, attr.st_mtime = times
self._request(CMD_SETSTAT, path, attr)
@@ -556,7 +589,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:param int size: the new size of the file
"""
path = self._adjust_cwd(path)
- self._log(DEBUG, 'truncate({!r}, {!r})'.format(path, size))
+ self._log(DEBUG, "truncate({!r}, {!r})".format(path, size))
attr = SFTPAttributes()
attr.st_size = size
self._request(CMD_SETSTAT, path, attr)
@@ -571,15 +604,15 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:return: target path, as a `str`
"""
path = self._adjust_cwd(path)
- self._log(DEBUG, 'readlink({!r})'.format(path))
+ self._log(DEBUG, "readlink({!r})".format(path))
t, msg = self._request(CMD_READLINK, path)
if t != CMD_NAME:
- raise SFTPError('Expected name response')
+ raise SFTPError("Expected name response")
count = msg.get_int()
if count == 0:
return None
if count != 1:
- raise SFTPError('Readlink returned {} results'.format(count))
+ raise SFTPError("Readlink returned {} results".format(count))
return _to_unicode(msg.get_string())
def normalize(self, path):
@@ -595,13 +628,13 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:raises: ``IOError`` -- if the path can't be resolved on the server
"""
path = self._adjust_cwd(path)
- self._log(DEBUG, 'normalize({!r})'.format(path))
+ self._log(DEBUG, "normalize({!r})".format(path))
t, msg = self._request(CMD_REALPATH, path)
if t != CMD_NAME:
- raise SFTPError('Expected name response')
+ raise SFTPError("Expected name response")
count = msg.get_int()
if count != 1:
- raise SFTPError('Realpath returned {} results'.format(count))
+ raise SFTPError("Realpath returned {} results".format(count))
return msg.get_text()
def chdir(self, path=None):
@@ -625,9 +658,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
return
if not stat.S_ISDIR(self.stat(path).st_mode):
code = errno.ENOTDIR
- raise SFTPError(
- code, "{}: {}".format(os.strerror(code), path)
- )
+ raise SFTPError(code, "{}: {}".format(os.strerror(code), path))
self._cwd = b(self.normalize(path))
def getcwd(self):
@@ -680,7 +711,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
.. versionadded:: 1.10
"""
- with self.file(remotepath, 'wb') as fr:
+ with self.file(remotepath, "wb") as fr:
fr.set_pipelined(True)
size = self._transfer_with_callback(
reader=fl, writer=fr, file_size=file_size, callback=callback
@@ -689,7 +720,8 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
s = self.stat(remotepath)
if s.st_size != size:
raise IOError(
- 'size mismatch in put! {} != {}'.format(s.st_size, size))
+ "size mismatch in put! {} != {}".format(s.st_size, size)
+ )
else:
s = SFTPAttributes()
return s
@@ -723,7 +755,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
``confirm`` param added.
"""
file_size = os.stat(localpath).st_size
- with open(localpath, 'rb') as fl:
+ with open(localpath, "rb") as fl:
return self.putfo(fl, remotepath, file_size, callback, confirm)
def getfo(self, remotepath, fl, callback=None):
@@ -744,7 +776,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
.. versionadded:: 1.10
"""
file_size = self.stat(remotepath).st_size
- with self.open(remotepath, 'rb') as fr:
+ with self.open(remotepath, "rb") as fr:
fr.prefetch(file_size)
return self._transfer_with_callback(
reader=fr, writer=fl, file_size=file_size, callback=callback
@@ -766,12 +798,13 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
.. versionchanged:: 1.7.4
Added the ``callback`` param
"""
- with open(localpath, 'wb') as fl:
+ with open(localpath, "wb") as fl:
size = self.getfo(remotepath, fl, callback)
s = os.stat(localpath)
if s.st_size != size:
raise IOError(
- 'size mismatch in get! {} != {}'.format(s.st_size, size))
+ "size mismatch in get! {} != {}".format(s.st_size, size)
+ )
# ...internals...
@@ -809,7 +842,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
try:
t, data = self._read_packet()
except EOFError as e:
- raise SSHException('Server connection dropped: {}'.format(e))
+ raise SSHException("Server connection dropped: {}".format(e))
msg = Message(data)
num = msg.get_int()
self._lock.acquire()
@@ -817,7 +850,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
if num not in self._expecting:
# might be response for a file that was closed before
# responses came back
- self._log(DEBUG, 'Unexpected response #{}'.format(num))
+ self._log(DEBUG, "Unexpected response #{}".format(num))
if waitfor is None:
# just doing a single check
break
diff --git a/paramiko/sftp_file.py b/paramiko/sftp_file.py
index 52f2bde..049e804 100644
--- a/paramiko/sftp_file.py
+++ b/paramiko/sftp_file.py
@@ -32,13 +32,21 @@ from paramiko.common import DEBUG
from paramiko.file import BufferedFile
from paramiko.py3compat import u, long
from paramiko.sftp import (
- CMD_CLOSE, CMD_READ, CMD_DATA, SFTPError, CMD_WRITE, CMD_STATUS, CMD_FSTAT,
- CMD_ATTRS, CMD_FSETSTAT, CMD_EXTENDED,
+ CMD_CLOSE,
+ CMD_READ,
+ CMD_DATA,
+ SFTPError,
+ CMD_WRITE,
+ CMD_STATUS,
+ CMD_FSTAT,
+ CMD_ATTRS,
+ CMD_FSETSTAT,
+ CMD_EXTENDED,
)
from paramiko.sftp_attr import SFTPAttributes
-class SFTPFile (BufferedFile):
+class SFTPFile(BufferedFile):
"""
Proxy object for a file on the remote server, in client mode SFTP.
@@ -50,7 +58,7 @@ class SFTPFile (BufferedFile):
# this size.
MAX_REQUEST_SIZE = 32768
- def __init__(self, sftp, handle, mode='r', bufsize=-1):
+ def __init__(self, sftp, handle, mode="r", bufsize=-1):
BufferedFile.__init__(self)
self.sftp = sftp
self.handle = handle
@@ -83,7 +91,7 @@ class SFTPFile (BufferedFile):
# __del__.)
if self._closed:
return
- self.sftp._log(DEBUG, 'close({})'.format(u(hexlify(self.handle))))
+ self.sftp._log(DEBUG, "close({})".format(u(hexlify(self.handle))))
if self.pipelined:
self.sftp._finish_responses(self)
BufferedFile.close(self)
@@ -102,8 +110,9 @@ class SFTPFile (BufferedFile):
pass
def _data_in_prefetch_requests(self, offset, size):
- k = [x for x in list(self._prefetch_extents.values())
- if x[0] <= offset]
+ k = [
+ x for x in list(self._prefetch_extents.values()) if x[0] <= offset
+ ]
if len(k) == 0:
return False
k.sort(key=lambda x: x[0])
@@ -117,8 +126,8 @@ class SFTPFile (BufferedFile):
# well, we have part of the request. see if another chunk has
# the rest.
return self._data_in_prefetch_requests(
- buf_offset + buf_size,
- offset + size - buf_offset - buf_size)
+ buf_offset + buf_size, offset + size - buf_offset - buf_size
+ )
def _data_in_prefetch_buffers(self, offset):
"""
@@ -174,13 +183,10 @@ class SFTPFile (BufferedFile):
if data is not None:
return data
t, msg = self.sftp._request(
- CMD_READ,
- self.handle,
- long(self._realpos),
- int(size)
+ CMD_READ, self.handle, long(self._realpos), int(size)
)
if t != CMD_DATA:
- raise SFTPError('Expected data')
+ raise SFTPError("Expected data")
return msg.get_string()
def _write(self, data):
@@ -191,18 +197,18 @@ class SFTPFile (BufferedFile):
CMD_WRITE,
self.handle,
long(self._realpos),
- data[:chunk]
+ data[:chunk],
)
self._reqs.append(sftp_async_request)
if (
- not self.pipelined or
- (len(self._reqs) > 100 and self.sftp.sock.recv_ready())
+ not self.pipelined
+ or (len(self._reqs) > 100 and self.sftp.sock.recv_ready())
):
while len(self._reqs):
req = self._reqs.popleft()
t, msg = self.sftp._read_response(req)
if t != CMD_STATUS:
- raise SFTPError('Expected status')
+ raise SFTPError("Expected status")
# convert_status already called
return chunk
@@ -277,7 +283,7 @@ class SFTPFile (BufferedFile):
"""
t, msg = self.sftp._request(CMD_FSTAT, self.handle)
if t != CMD_ATTRS:
- raise SFTPError('Expected attributes')
+ raise SFTPError("Expected attributes")
return SFTPAttributes._from_msg(msg)
def chmod(self, mode):
@@ -288,8 +294,9 @@ class SFTPFile (BufferedFile):
:param int mode: new permissions
"""
- self.sftp._log(DEBUG, 'chmod({}, {!r})'.format(
- hexlify(self.handle), mode))
+ self.sftp._log(
+ DEBUG, "chmod({}, {!r})".format(hexlify(self.handle), mode)
+ )
attr = SFTPAttributes()
attr.st_mode = mode
self.sftp._request(CMD_FSETSTAT, self.handle, attr)
@@ -306,7 +313,8 @@ class SFTPFile (BufferedFile):
"""
self.sftp._log(
DEBUG,
- 'chown({}, {!r}, {!r})'.format(hexlify(self.handle), uid, gid))
+ "chown({}, {!r}, {!r})".format(hexlify(self.handle), uid, gid),
+ )
attr = SFTPAttributes()
attr.st_uid, attr.st_gid = uid, gid
self.sftp._request(CMD_FSETSTAT, self.handle, attr)
@@ -326,8 +334,9 @@ class SFTPFile (BufferedFile):
"""
if times is None:
times = (time.time(), time.time())
- self.sftp._log(DEBUG, 'utime({}, {!r})'.format(
- hexlify(self.handle), times))
+ self.sftp._log(
+ DEBUG, "utime({}, {!r})".format(hexlify(self.handle), times)
+ )
attr = SFTPAttributes()
attr.st_atime, attr.st_mtime = times
self.sftp._request(CMD_FSETSTAT, self.handle, attr)
@@ -341,8 +350,8 @@ class SFTPFile (BufferedFile):
:param size: the new size of the file
"""
self.sftp._log(
- DEBUG,
- 'truncate({}, {!r})'.format(hexlify(self.handle), size))
+ DEBUG, "truncate({}, {!r})".format(hexlify(self.handle), size)
+ )
attr = SFTPAttributes()
attr.st_size = size
self.sftp._request(CMD_FSETSTAT, self.handle, attr)
@@ -394,8 +403,14 @@ class SFTPFile (BufferedFile):
.. versionadded:: 1.4
"""
t, msg = self.sftp._request(
- CMD_EXTENDED, 'check-file', self.handle,
- hash_algorithm, long(offset), long(length), block_size)
+ CMD_EXTENDED,
+ "check-file",
+ self.handle,
+ hash_algorithm,
+ long(offset),
+ long(length),
+ block_size,
+ )
msg.get_text() # ext
msg.get_text() # alg
data = msg.get_remainder()
@@ -475,15 +490,16 @@ class SFTPFile (BufferedFile):
.. versionadded:: 1.5.4
"""
- self.sftp._log(DEBUG, 'readv({}, {!r})'.format(
- hexlify(self.handle), chunks))
+ self.sftp._log(
+ DEBUG, "readv({}, {!r})".format(hexlify(self.handle), chunks)
+ )
read_chunks = []
for offset, size in chunks:
# don't fetch data that's already in the prefetch buffer
if (
- self._data_in_prefetch_buffers(offset) or
- self._data_in_prefetch_requests(offset, size)
+ self._data_in_prefetch_buffers(offset)
+ or self._data_in_prefetch_requests(offset, size)
):
continue
@@ -521,11 +537,8 @@ class SFTPFile (BufferedFile):
# a lot of them, so it may block.
for offset, length in chunks:
num = self.sftp._async_request(
- self,
- CMD_READ,
- self.handle,
- long(offset),
- int(length))
+ self, CMD_READ, self.handle, long(offset), int(length)
+ )
with self._prefetch_lock:
self._prefetch_extents[num] = (offset, length)
@@ -538,7 +551,7 @@ class SFTPFile (BufferedFile):
self._saved_exception = e
return
if t != CMD_DATA:
- raise SFTPError('Expected data')
+ raise SFTPError("Expected data")
data = msg.get_string()
while True:
with self._prefetch_lock:
diff --git a/paramiko/sftp_handle.py b/paramiko/sftp_handle.py
index ca47390..a7e22f0 100644
--- a/paramiko/sftp_handle.py
+++ b/paramiko/sftp_handle.py
@@ -25,7 +25,7 @@ from paramiko.sftp import SFTP_OP_UNSUPPORTED, SFTP_OK
from paramiko.util import ClosingContextManager
-class SFTPHandle (ClosingContextManager):
+class SFTPHandle(ClosingContextManager):
"""
Abstract object representing a handle to an open file (or folder) in an
SFTP server implementation. Each handle has a string representation used
@@ -36,6 +36,7 @@ class SFTPHandle (ClosingContextManager):
Instances of this class may be used as context managers.
"""
+
def __init__(self, flags=0):
"""
Create a new file handle representing a local file being served over
@@ -63,10 +64,10 @@ class SFTPHandle (ClosingContextManager):
using the default implementations of `read` and `write`, this
method's default implementation should be fine also.
"""
- readfile = getattr(self, 'readfile', None)
+ readfile = getattr(self, "readfile", None)
if readfile is not None:
readfile.close()
- writefile = getattr(self, 'writefile', None)
+ writefile = getattr(self, "writefile", None)
if writefile is not None:
writefile.close()
@@ -88,7 +89,7 @@ class SFTPHandle (ClosingContextManager):
:param int length: number of bytes to attempt to read.
:return: data read from the file, or an SFTP error code, as a `str`.
"""
- readfile = getattr(self, 'readfile', None)
+ readfile = getattr(self, "readfile", None)
if readfile is None:
return SFTP_OP_UNSUPPORTED
try:
@@ -122,7 +123,7 @@ class SFTPHandle (ClosingContextManager):
:param str data: data to write into the file.
:return: an SFTP error code like ``SFTP_OK``.
"""
- writefile = getattr(self, 'writefile', None)
+ writefile = getattr(self, "writefile", None)
if writefile is None:
return SFTP_OP_UNSUPPORTED
try:
diff --git a/paramiko/sftp_server.py b/paramiko/sftp_server.py
index f8c4f72..8265df9 100644
--- a/paramiko/sftp_server.py
+++ b/paramiko/sftp_server.py
@@ -27,7 +27,11 @@ from hashlib import md5, sha1
from paramiko import util
from paramiko.sftp import (
- BaseSFTP, Message, SFTP_FAILURE, SFTP_PERMISSION_DENIED, SFTP_NO_SUCH_FILE,
+ BaseSFTP,
+ Message,
+ SFTP_FAILURE,
+ SFTP_PERMISSION_DENIED,
+ SFTP_NO_SUCH_FILE,
)
from paramiko.sftp_si import SFTPServerInterface
from paramiko.sftp_attr import SFTPAttributes
@@ -38,30 +42,64 @@ from paramiko.server import SubsystemHandler
# known hash algorithms for the "check-file" extension
from paramiko.sftp import (
- CMD_HANDLE, SFTP_DESC, CMD_STATUS, SFTP_EOF, CMD_NAME, SFTP_BAD_MESSAGE,
- CMD_EXTENDED_REPLY, SFTP_FLAG_READ, SFTP_FLAG_WRITE, SFTP_FLAG_APPEND,
- SFTP_FLAG_CREATE, SFTP_FLAG_TRUNC, SFTP_FLAG_EXCL, CMD_NAMES, CMD_OPEN,
- CMD_CLOSE, SFTP_OK, CMD_READ, CMD_DATA, CMD_WRITE, CMD_REMOVE, CMD_RENAME,
- CMD_MKDIR, CMD_RMDIR, CMD_OPENDIR, CMD_READDIR, CMD_STAT, CMD_ATTRS,
- CMD_LSTAT, CMD_FSTAT, CMD_SETSTAT, CMD_FSETSTAT, CMD_READLINK, CMD_SYMLINK,
- CMD_REALPATH, CMD_EXTENDED, SFTP_OP_UNSUPPORTED,
+ CMD_HANDLE,
+ SFTP_DESC,
+ CMD_STATUS,
+ SFTP_EOF,
+ CMD_NAME,
+ SFTP_BAD_MESSAGE,
+ CMD_EXTENDED_REPLY,
+ SFTP_FLAG_READ,
+ SFTP_FLAG_WRITE,
+ SFTP_FLAG_APPEND,
+ SFTP_FLAG_CREATE,
+ SFTP_FLAG_TRUNC,
+ SFTP_FLAG_EXCL,
+ CMD_NAMES,
+ CMD_OPEN,
+ CMD_CLOSE,
+ SFTP_OK,
+ CMD_READ,
+ CMD_DATA,
+ CMD_WRITE,
+ CMD_REMOVE,
+ CMD_RENAME,
+ CMD_MKDIR,
+ CMD_RMDIR,
+ CMD_OPENDIR,
+ CMD_READDIR,
+ CMD_STAT,
+ CMD_ATTRS,
+ CMD_LSTAT,
+ CMD_FSTAT,
+ CMD_SETSTAT,
+ CMD_FSETSTAT,
+ CMD_READLINK,
+ CMD_SYMLINK,
+ CMD_REALPATH,
+ CMD_EXTENDED,
+ SFTP_OP_UNSUPPORTED,
)
-_hash_class = {
- 'sha1': sha1,
- 'md5': md5,
-}
+_hash_class = {"sha1": sha1, "md5": md5}
-class SFTPServer (BaseSFTP, SubsystemHandler):
+class SFTPServer(BaseSFTP, SubsystemHandler):
"""
Server-side SFTP subsystem support. Since this is a `.SubsystemHandler`,
it can be (and is meant to be) set as the handler for ``"sftp"`` requests.
Use `.Transport.set_subsystem_handler` to activate this class.
"""
- def __init__(self, channel, name, server, sftp_si=SFTPServerInterface,
- *largs, **kwargs):
+ def __init__(
+ self,
+ channel,
+ name,
+ server,
+ sftp_si=SFTPServerInterface,
+ *largs,
+ **kwargs
+ ):
"""
The constructor for SFTPServer is meant to be called from within the
`.Transport` as a subsystem handler. ``server`` and any additional
@@ -79,7 +117,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
BaseSFTP.__init__(self)
SubsystemHandler.__init__(self, channel, name, server)
transport = channel.get_transport()
- self.logger = util.get_logger(transport.get_log_channel() + '.sftp')
+ self.logger = util.get_logger(transport.get_log_channel() + ".sftp")
self.ultra_debug = transport.get_hexdump()
self.next_handle = 1
# map of handle-string to SFTPHandle for files & folders:
@@ -91,26 +129,26 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
if issubclass(type(msg), list):
for m in msg:
super(SFTPServer, self)._log(
- level,
- "[chan " + self.sock.get_name() + "] " + m)
+ level, "[chan " + self.sock.get_name() + "] " + m
+ )
else:
super(SFTPServer, self)._log(
- level,
- "[chan " + self.sock.get_name() + "] " + msg)
+ level, "[chan " + self.sock.get_name() + "] " + msg
+ )
def start_subsystem(self, name, transport, channel):
self.sock = channel
- self._log(DEBUG, 'Started sftp server on channel {!r}'.format(channel))
+ self._log(DEBUG, "Started sftp server on channel {!r}".format(channel))
self._send_server_version()
self.server.session_started()
while True:
try:
t, data = self._read_packet()
except EOFError:
- self._log(DEBUG, 'EOF -- end of session')
+ self._log(DEBUG, "EOF -- end of session")
return
except Exception as e:
- self._log(DEBUG, 'Exception on channel: ' + str(e))
+ self._log(DEBUG, "Exception on channel: " + str(e))
self._log(DEBUG, util.tb_strings())
return
msg = Message(data)
@@ -118,7 +156,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
try:
self._process(t, request_number, msg)
except Exception as e:
- self._log(DEBUG, 'Exception in server processing: ' + str(e))
+ self._log(DEBUG, "Exception in server processing: " + str(e))
self._log(DEBUG, util.tb_strings())
# send some kind of failure message, at least
try:
@@ -172,7 +210,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
name of the file to alter (should usually be an absolute path).
:param .SFTPAttributes attr: attributes to change.
"""
- if sys.platform != 'win32':
+ if sys.platform != "win32":
# mode operations are meaningless on win32
if attr._flags & attr.FLAG_PERMISSIONS:
os.chmod(filename, attr.st_mode)
@@ -181,7 +219,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
if attr._flags & attr.FLAG_AMTIME:
os.utime(filename, (attr.st_atime, attr.st_mtime))
if attr._flags & attr.FLAG_SIZE:
- with open(filename, 'w+') as f:
+ with open(filename, "w+") as f:
f.truncate(attr.st_size)
# ...internals...
@@ -200,8 +238,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
item._pack(msg)
else:
raise Exception(
- 'unknown type for {!r} type {!r}'.format(
- item, type(item)))
+ "unknown type for {!r} type {!r}".format(item, type(item))
+ )
self._send_packet(t, msg)
def _send_handle_response(self, request_number, handle, folder=False):
@@ -209,7 +247,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
# must be error code
self._send_status(request_number, handle)
return
- handle._set_name(b('hx{:d}'.format(self.next_handle)))
+ handle._set_name(b("hx{:d}".format(self.next_handle)))
self.next_handle += 1
if folder:
self.folder_table[handle._get_name()] = handle
@@ -222,10 +260,10 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
try:
desc = SFTP_DESC[code]
except IndexError:
- desc = 'Unknown'
+ desc = "Unknown"
# some clients expect a "langauge" tag at the end
# (but don't mind it being blank)
- self._response(request_number, CMD_STATUS, code, desc, '')
+ self._response(request_number, CMD_STATUS, code, desc, "")
def _open_folder(self, request_number, path):
resp = self.server.list_folder(path)
@@ -264,7 +302,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
block_size = msg.get_int()
if handle not in self.file_table:
self._send_status(
- request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
+ request_number, SFTP_BAD_MESSAGE, "Invalid handle"
+ )
return
f = self.file_table[handle]
for x in alg_list:
@@ -274,19 +313,21 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
break
else:
self._send_status(
- request_number, SFTP_FAILURE, 'No supported hash types found')
+ request_number, SFTP_FAILURE, "No supported hash types found"
+ )
return
if length == 0:
st = f.stat()
if not issubclass(type(st), SFTPAttributes):
- self._send_status(request_number, st, 'Unable to stat file')
+ self._send_status(request_number, st, "Unable to stat file")
return
length = st.st_size - start
if block_size == 0:
block_size = length
if block_size < 256:
self._send_status(
- request_number, SFTP_FAILURE, 'Block size too small')
+ request_number, SFTP_FAILURE, "Block size too small"
+ )
return
sum_out = bytes()
@@ -301,7 +342,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
data = f.read(offset, chunklen)
if not isinstance(data, bytes_types):
self._send_status(
- request_number, data, 'Unable to hash file')
+ request_number, data, "Unable to hash file"
+ )
return
hash_obj.update(data)
count += len(data)
@@ -310,7 +352,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
msg = Message()
msg.add_int(request_number)
- msg.add_string('check-file')
+ msg.add_string("check-file")
msg.add_string(algname)
msg.add_bytes(sum_out)
self._send_packet(CMD_EXTENDED_REPLY, msg)
@@ -334,13 +376,14 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
return flags
def _process(self, t, request_number, msg):
- self._log(DEBUG, 'Request: {}'.format(CMD_NAMES[t]))
+ self._log(DEBUG, "Request: {}".format(CMD_NAMES[t]))
if t == CMD_OPEN:
path = msg.get_text()
flags = self._convert_pflags(msg.get_int())
attr = SFTPAttributes._from_msg(msg)
self._send_handle_response(
- request_number, self.server.open(path, flags, attr))
+ request_number, self.server.open(path, flags, attr)
+ )
elif t == CMD_CLOSE:
handle = msg.get_binary()
if handle in self.folder_table:
@@ -353,14 +396,16 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
self._send_status(request_number, SFTP_OK)
return
self._send_status(
- request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
+ request_number, SFTP_BAD_MESSAGE, "Invalid handle"
+ )
elif t == CMD_READ:
handle = msg.get_binary()
offset = msg.get_int64()
length = msg.get_int()
if handle not in self.file_table:
self._send_status(
- request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
+ request_number, SFTP_BAD_MESSAGE, "Invalid handle"
+ )
return
data = self.file_table[handle].read(offset, length)
if isinstance(data, (bytes_types, string_types)):
@@ -376,10 +421,12 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
data = msg.get_binary()
if handle not in self.file_table:
self._send_status(
- request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
+ request_number, SFTP_BAD_MESSAGE, "Invalid handle"
+ )
return
self._send_status(
- request_number, self.file_table[handle].write(offset, data))
+ request_number, self.file_table[handle].write(offset, data)
+ )
elif t == CMD_REMOVE:
path = msg.get_text()
self._send_status(request_number, self.server.remove(path))
@@ -387,7 +434,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
oldpath = msg.get_text()
newpath = msg.get_text()
self._send_status(
- request_number, self.server.rename(oldpath, newpath))
+ request_number, self.server.rename(oldpath, newpath)
+ )
elif t == CMD_MKDIR:
path = msg.get_text()
attr = SFTPAttributes._from_msg(msg)
@@ -403,7 +451,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
handle = msg.get_binary()
if handle not in self.folder_table:
self._send_status(
- request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
+ request_number, SFTP_BAD_MESSAGE, "Invalid handle"
+ )
return
folder = self.folder_table[handle]
self._read_folder(request_number, folder)
@@ -425,7 +474,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
handle = msg.get_binary()
if handle not in self.file_table:
self._send_status(
- request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
+ request_number, SFTP_BAD_MESSAGE, "Invalid handle"
+ )
return
resp = self.file_table[handle].stat()
if issubclass(type(resp), SFTPAttributes):
@@ -441,16 +491,19 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
attr = SFTPAttributes._from_msg(msg)
if handle not in self.file_table:
self._response(
- request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
+ request_number, SFTP_BAD_MESSAGE, "Invalid handle"
+ )
return
self._send_status(
- request_number, self.file_table[handle].chattr(attr))
+ request_number, self.file_table[handle].chattr(attr)
+ )
elif t == CMD_READLINK:
path = msg.get_text()
resp = self.server.readlink(path)
if isinstance(resp, (bytes_types, string_types)):
self._response(
- request_number, CMD_NAME, 1, resp, '', SFTPAttributes())
+ request_number, CMD_NAME, 1, resp, "", SFTPAttributes()
+ )
else:
self._send_status(request_number, resp)
elif t == CMD_SYMLINK:
@@ -459,17 +512,19 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
target_path = msg.get_text()
path = msg.get_text()
self._send_status(
- request_number, self.server.symlink(target_path, path))
+ request_number, self.server.symlink(target_path, path)
+ )
elif t == CMD_REALPATH:
path = msg.get_text()
rpath = self.server.canonicalize(path)
self._response(
- request_number, CMD_NAME, 1, rpath, '', SFTPAttributes())
+ request_number, CMD_NAME, 1, rpath, "", SFTPAttributes()
+ )
elif t == CMD_EXTENDED:
tag = msg.get_text()
- if tag == 'check-file':
+ if tag == "check-file":
self._check_file(request_number, msg)
- elif tag == 'posix-rename@openssh.com':
+ elif tag == "posix-rename@openssh.com":
oldpath = msg.get_text()
newpath = msg.get_text()
self._send_status(
diff --git a/paramiko/sftp_si.py b/paramiko/sftp_si.py
index b43b582..40dc561 100644
--- a/paramiko/sftp_si.py
+++ b/paramiko/sftp_si.py
@@ -25,7 +25,7 @@ import sys
from paramiko.sftp import SFTP_OP_UNSUPPORTED
-class SFTPServerInterface (object):
+class SFTPServerInterface(object):
"""
This class defines an interface for controlling the behavior of paramiko
when using the `.SFTPServer` subsystem to provide an SFTP server.
@@ -39,6 +39,7 @@ class SFTPServerInterface (object):
All paths are in string form instead of unicode because not all SFTP
clients & servers obey the requirement that paths be encoded in UTF-8.
"""
+
def __init__(self, server, *largs, **kwargs):
"""
Create a new SFTPServerInterface object. This method does nothing by
@@ -281,10 +282,10 @@ class SFTPServerInterface (object):
if os.path.isabs(path):
out = os.path.normpath(path)
else:
- out = os.path.normpath('/' + path)
- if sys.platform == 'win32':
+ out = os.path.normpath("/" + path)
+ if sys.platform == "win32":
# on windows, normalize backslashes to sftp/posix format
- out = out.replace('\\', '/')
+ out = out.replace("\\", "/")
return out
def readlink(self, path):
diff --git a/paramiko/ssh_exception.py b/paramiko/ssh_exception.py
index 2df84b6..c1276c6 100644
--- a/paramiko/ssh_exception.py
+++ b/paramiko/ssh_exception.py
@@ -19,14 +19,14 @@
import socket
-class SSHException (Exception):
+class SSHException(Exception):
"""
Exception raised by failures in SSH2 protocol negotiation or logic errors.
"""
pass
-class AuthenticationException (SSHException):
+class AuthenticationException(SSHException):
"""
Exception raised when authentication failed for some reason. It may be
possible to retry with different credentials. (Other classes specify more
@@ -37,14 +37,14 @@ class AuthenticationException (SSHException):
pass
-class PasswordRequiredException (AuthenticationException):
+class PasswordRequiredException(AuthenticationException):
"""
Exception raised when a password is needed to unlock a private key file.
"""
pass
-class BadAuthenticationType (AuthenticationException):
+class BadAuthenticationType(AuthenticationException):
"""
Exception raised when an authentication type (like password) is used, but
the server isn't allowing that type. (It may only allow public-key, for
@@ -60,28 +60,28 @@ class BadAuthenticationType (AuthenticationException):
AuthenticationException.__init__(self, explanation)
self.allowed_types = types
# for unpickling
- self.args = (explanation, types, )
+ self.args = (explanation, types)
def __str__(self):
- return '{} (allowed_types={!r})'.format(
+ return "{} (allowed_types={!r})".format(
SSHException.__str__(self), self.allowed_types
)
-class PartialAuthentication (AuthenticationException):
+class PartialAuthentication(AuthenticationException):
"""
An internal exception thrown in the case of partial authentication.
"""
allowed_types = []
def __init__(self, types):
- AuthenticationException.__init__(self, 'partial authentication')
+ AuthenticationException.__init__(self, "partial authentication")
self.allowed_types = types
# for unpickling
- self.args = (types, )
+ self.args = (types,)
-class ChannelException (SSHException):
+class ChannelException(SSHException):
"""
Exception raised when an attempt to open a new `.Channel` fails.
@@ -89,14 +89,15 @@ class ChannelException (SSHException):
.. versionadded:: 1.6
"""
+
def __init__(self, code, text):
SSHException.__init__(self, text)
self.code = code
# for unpickling
- self.args = (code, text, )
+ self.args = (code, text)
-class BadHostKeyException (SSHException):
+class BadHostKeyException(SSHException):
"""
The host key given by the SSH server did not match what we were expecting.
@@ -106,35 +107,38 @@ class BadHostKeyException (SSHException):
.. versionadded:: 1.6
"""
+
def __init__(self, hostname, got_key, expected_key):
- message = 'Host key for server {} does not match: got {}, expected {}' # noqa
+ message = "Host key for server {} does not match: got {}, expected {}" # noqa
message = message.format(
- hostname, got_key.get_base64(),
- expected_key.get_base64())
+ hostname, got_key.get_base64(), expected_key.get_base64()
+ )
SSHException.__init__(self, message)
self.hostname = hostname
self.key = got_key
self.expected_key = expected_key
# for unpickling
- self.args = (hostname, got_key, expected_key, )
+ self.args = (hostname, got_key, expected_key)
-class ProxyCommandFailure (SSHException):
+class ProxyCommandFailure(SSHException):
"""
The "ProxyCommand" found in the .ssh/config file returned an error.
:param str command: The command line that is generating this exception.
:param str error: The error captured from the proxy command output.
"""
+
def __init__(self, command, error):
- SSHException.__init__(self,
+ SSHException.__init__(
+ self,
'"ProxyCommand ({})" returned non-zero exit status: {}'.format(
command, error
- )
+ ),
)
self.error = error
# for unpickling
- self.args = (command, error, )
+ self.args = (command, error)
class NoValidConnectionsError(socket.error):
@@ -159,23 +163,23 @@ class NoValidConnectionsError(socket.error):
.. versionadded:: 1.16
"""
+
def __init__(self, errors):
"""
:param dict errors:
The errors dict to store, as described by class docstring.
"""
addrs = sorted(errors.keys())
- body = ', '.join([x[0] for x in addrs[:-1]])
+ body = ", ".join([x[0] for x in addrs[:-1]])
tail = addrs[-1][0]
if body:
msg = "Unable to connect to port {0} on {1} or {2}"
else:
msg = "Unable to connect to port {0} on {2}"
super(NoValidConnectionsError, self).__init__(
- None, # stand-in for errno
- msg.format(addrs[0][1], body, tail)
+ None, msg.format(addrs[0][1], body, tail) # stand-in for errno
)
self.errors = errors
def __reduce__(self):
- return (self.__class__, (self.errors, ))
+ return (self.__class__, (self.errors,))
diff --git a/paramiko/ssh_gss.py b/paramiko/ssh_gss.py
index aa7cc74..cf5ec7e 100644
--- a/paramiko/ssh_gss.py
+++ b/paramiko/ssh_gss.py
@@ -42,19 +42,19 @@ GSS_AUTH_AVAILABLE = True
GSS_EXCEPTIONS = ()
-
-
#: :var str _API: Constraint for the used API
_API = "MIT"
try:
import gssapi
+
GSS_EXCEPTIONS = (gssapi.GSSException,)
except (ImportError, OSError):
try:
import pywintypes
import sspicon
import sspi
+
_API = "SSPI"
GSS_EXCEPTIONS = (pywintypes.error,)
except ImportError:
@@ -99,6 +99,7 @@ class _SSH_GSSAuth(object):
Contains the shared variables and methods of `._SSH_GSSAPI` and
`._SSH_SSPI`.
"""
+
def __init__(self, auth_method, gss_deleg_creds):
"""
:param str auth_method: The name of the SSH authentication mechanism
@@ -160,6 +161,7 @@ class _SSH_GSSAuth(object):
"""
from pyasn1.type.univ import ObjectIdentifier
from pyasn1.codec.der import encoder
+
OIDs = self._make_uint32(1)
krb5_OID = encoder.encode(ObjectIdentifier(self._krb5_mech))
OID_len = self._make_uint32(len(krb5_OID))
@@ -175,6 +177,7 @@ class _SSH_GSSAuth(object):
:return: ``True`` if the given OID is supported, otherwise C{False}
"""
from pyasn1.codec.der import decoder
+
mech, __ = decoder.decode(desired_mech)
if mech.__str__() != self._krb5_mech:
return False
@@ -210,7 +213,7 @@ class _SSH_GSSAuth(object):
"""
mic = self._make_uint32(len(session_id))
mic += session_id
- mic += struct.pack('B', MSG_USERAUTH_REQUEST)
+ mic += struct.pack("B", MSG_USERAUTH_REQUEST)
mic += self._make_uint32(len(username))
mic += username.encode()
mic += self._make_uint32(len(service))
@@ -226,6 +229,7 @@ class _SSH_GSSAPI(_SSH_GSSAuth):
:see: `.GSSAuth`
"""
+
def __init__(self, auth_method, gss_deleg_creds):
"""
:param str auth_method: The name of the SSH authentication mechanism
@@ -235,17 +239,22 @@ class _SSH_GSSAPI(_SSH_GSSAuth):
_SSH_GSSAuth.__init__(self, auth_method, gss_deleg_creds)
if self._gss_deleg_creds:
- self._gss_flags = (gssapi.C_PROT_READY_FLAG,
- gssapi.C_INTEG_FLAG,
- gssapi.C_MUTUAL_FLAG,
- gssapi.C_DELEG_FLAG)
+ self._gss_flags = (
+ gssapi.C_PROT_READY_FLAG,
+ gssapi.C_INTEG_FLAG,
+ gssapi.C_MUTUAL_FLAG,
+ gssapi.C_DELEG_FLAG,
+ )
else:
- self._gss_flags = (gssapi.C_PROT_READY_FLAG,
- gssapi.C_INTEG_FLAG,
- gssapi.C_MUTUAL_FLAG)
+ self._gss_flags = (
+ gssapi.C_PROT_READY_FLAG,
+ gssapi.C_INTEG_FLAG,
+ gssapi.C_MUTUAL_FLAG,
+ )
- def ssh_init_sec_context(self, target, desired_mech=None,
- username=None, recv_token=None):
+ def ssh_init_sec_context(
+ self, target, desired_mech=None, username=None, recv_token=None
+ ):
"""
Initialize a GSS-API context.
@@ -262,10 +271,12 @@ class _SSH_GSSAPI(_SSH_GSSAuth):
``None`` if no token was returned
"""
from pyasn1.codec.der import decoder
+
self._username = username
self._gss_host = target
- targ_name = gssapi.Name("host@" + self._gss_host,
- gssapi.C_NT_HOSTBASED_SERVICE)
+ targ_name = gssapi.Name(
+ "host@" + self._gss_host, gssapi.C_NT_HOSTBASED_SERVICE
+ )
ctx = gssapi.Context()
ctx.flags = self._gss_flags
if desired_mech is None:
@@ -279,15 +290,16 @@ class _SSH_GSSAPI(_SSH_GSSAuth):
token = None
try:
if recv_token is None:
- self._gss_ctxt = gssapi.InitContext(peer_name=targ_name,
- mech_type=krb5_mech,
- req_flags=ctx.flags)
+ self._gss_ctxt = gssapi.InitContext(
+ peer_name=targ_name,
+ mech_type=krb5_mech,
+ req_flags=ctx.flags,
+ )
token = self._gss_ctxt.step(token)
else:
token = self._gss_ctxt.step(recv_token)
except gssapi.GSSException:
- message = "{} Target: {}".format(
- sys.exc_info()[1], self._gss_host)
+ message = "{} Target: {}".format(sys.exc_info()[1], self._gss_host)
raise gssapi.GSSException(message)
self._gss_ctxt_status = self._gss_ctxt.established
return token
@@ -307,10 +319,12 @@ class _SSH_GSSAPI(_SSH_GSSAuth):
"""
self._session_id = session_id
if not gss_kex:
- mic_field = self._ssh_build_mic(self._session_id,
- self._username,
- self._service,
- self._auth_method)
+ mic_field = self._ssh_build_mic(
+ self._session_id,
+ self._username,
+ self._service,
+ self._auth_method,
+ )
mic_token = self._gss_ctxt.get_mic(mic_field)
else:
# for key exchange with gssapi-keyex
@@ -351,16 +365,17 @@ class _SSH_GSSAPI(_SSH_GSSAuth):
self._username = username
if self._username is not None:
# server mode
- mic_field = self._ssh_build_mic(self._session_id,
- self._username,
- self._service,
- self._auth_method)
+ mic_field = self._ssh_build_mic(
+ self._session_id,
+ self._username,
+ self._service,
+ self._auth_method,
+ )
self._gss_srv_ctxt.verify_mic(mic_field, mic_token)
else:
# for key exchange with gssapi-keyex
# client mode
- self._gss_ctxt.verify_mic(self._session_id,
- mic_token)
+ self._gss_ctxt.verify_mic(self._session_id, mic_token)
@property
def credentials_delegated(self):
@@ -393,6 +408,7 @@ class _SSH_SSPI(_SSH_GSSAuth):
:see: `.GSSAuth`
"""
+
def __init__(self, auth_method, gss_deleg_creds):
"""
:param str auth_method: The name of the SSH authentication mechanism
@@ -403,18 +419,18 @@ class _SSH_SSPI(_SSH_GSSAuth):
if self._gss_deleg_creds:
self._gss_flags = (
- sspicon.ISC_REQ_INTEGRITY |
- sspicon.ISC_REQ_MUTUAL_AUTH |
- sspicon.ISC_REQ_DELEGATE
+ sspicon.ISC_REQ_INTEGRITY
+ | sspicon.ISC_REQ_MUTUAL_AUTH
+ | sspicon.ISC_REQ_DELEGATE
)
else:
self._gss_flags = (
- sspicon.ISC_REQ_INTEGRITY |
- sspicon.ISC_REQ_MUTUAL_AUTH
+ sspicon.ISC_REQ_INTEGRITY | sspicon.ISC_REQ_MUTUAL_AUTH
)
- def ssh_init_sec_context(self, target, desired_mech=None,
- username=None, recv_token=None):
+ def ssh_init_sec_context(
+ self, target, desired_mech=None, username=None, recv_token=None
+ ):
"""
Initialize a SSPI context.
@@ -431,6 +447,7 @@ class _SSH_SSPI(_SSH_GSSAuth):
no token was returned
"""
from pyasn1.codec.der import decoder
+
self._username = username
self._gss_host = target
error = 0
@@ -441,9 +458,9 @@ class _SSH_SSPI(_SSH_GSSAuth):
raise SSHException("Unsupported mechanism OID.")
try:
if recv_token is None:
- self._gss_ctxt = sspi.ClientAuth("Kerberos",
- scflags=self._gss_flags,
- targetspn=targ_name)
+ self._gss_ctxt = sspi.ClientAuth(
+ "Kerberos", scflags=self._gss_flags, targetspn=targ_name
+ )
error, token = self._gss_ctxt.authorize(recv_token)
token = token[0].Buffer
except pywintypes.error as e:
@@ -478,10 +495,12 @@ class _SSH_SSPI(_SSH_GSSAuth):
"""
self._session_id = session_id
if not gss_kex:
- mic_field = self._ssh_build_mic(self._session_id,
- self._username,
- self._service,
- self._auth_method)
+ mic_field = self._ssh_build_mic(
+ self._session_id,
+ self._username,
+ self._service,
+ self._auth_method,
+ )
mic_token = self._gss_ctxt.sign(mic_field)
else:
# for key exchange with gssapi-keyex
@@ -524,10 +543,12 @@ class _SSH_SSPI(_SSH_GSSAuth):
self._username = username
if username is not None:
# server mode
- mic_field = self._ssh_build_mic(self._session_id,
- self._username,
- self._service,
- self._auth_method)
+ mic_field = self._ssh_build_mic(
+ self._session_id,
+ self._username,
+ self._service,
+ self._auth_method,
+ )
# Verifies data and its signature. If verification fails, an
# sspi.error will be raised.
self._gss_srv_ctxt.verify(mic_field, mic_token)
@@ -546,8 +567,8 @@ class _SSH_SSPI(_SSH_GSSAuth):
:return: ``True`` if credentials are delegated, otherwise ``False``
"""
return (
- self._gss_flags & sspicon.ISC_REQ_DELEGATE and
- (self._gss_srv_ctxt_status or self._gss_flags)
+ self._gss_flags & sspicon.ISC_REQ_DELEGATE
+ and (self._gss_srv_ctxt_status or self._gss_flags)
)
def save_client_creds(self, client_token):
diff --git a/paramiko/transport.py b/paramiko/transport.py
index ddcb291..ed9f886 100644
--- a/paramiko/transport.py
+++ b/paramiko/transport.py
@@ -39,18 +39,49 @@ from paramiko.auth_handler import AuthHandler
from paramiko.ssh_gss import GSSAuth
from paramiko.channel import Channel
from paramiko.common import (
- xffffffff, cMSG_CHANNEL_OPEN, cMSG_IGNORE, cMSG_GLOBAL_REQUEST, DEBUG,
- MSG_KEXINIT, MSG_IGNORE, MSG_DISCONNECT, MSG_DEBUG, ERROR, WARNING,
- cMSG_UNIMPLEMENTED, INFO, cMSG_KEXINIT, cMSG_NEWKEYS, MSG_NEWKEYS,
- cMSG_REQUEST_SUCCESS, cMSG_REQUEST_FAILURE, CONNECTION_FAILED_CODE,
- OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, OPEN_SUCCEEDED,
- cMSG_CHANNEL_OPEN_FAILURE, cMSG_CHANNEL_OPEN_SUCCESS, MSG_GLOBAL_REQUEST,
- MSG_REQUEST_SUCCESS, MSG_REQUEST_FAILURE, MSG_CHANNEL_OPEN_SUCCESS,
- MSG_CHANNEL_OPEN_FAILURE, MSG_CHANNEL_OPEN, MSG_CHANNEL_SUCCESS,
- MSG_CHANNEL_FAILURE, MSG_CHANNEL_DATA, MSG_CHANNEL_EXTENDED_DATA,
- MSG_CHANNEL_WINDOW_ADJUST, MSG_CHANNEL_REQUEST, MSG_CHANNEL_EOF,
- MSG_CHANNEL_CLOSE, MIN_WINDOW_SIZE, MIN_PACKET_SIZE, MAX_WINDOW_SIZE,
- DEFAULT_WINDOW_SIZE, DEFAULT_MAX_PACKET_SIZE, HIGHEST_USERAUTH_MESSAGE_ID,
+ xffffffff,
+ cMSG_CHANNEL_OPEN,
+ cMSG_IGNORE,
+ cMSG_GLOBAL_REQUEST,
+ DEBUG,
+ MSG_KEXINIT,
+ MSG_IGNORE,
+ MSG_DISCONNECT,
+ MSG_DEBUG,
+ ERROR,
+ WARNING,
+ cMSG_UNIMPLEMENTED,
+ INFO,
+ cMSG_KEXINIT,
+ cMSG_NEWKEYS,
+ MSG_NEWKEYS,
+ cMSG_REQUEST_SUCCESS,
+ cMSG_REQUEST_FAILURE,
+ CONNECTION_FAILED_CODE,
+ OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED,
+ OPEN_SUCCEEDED,
+ cMSG_CHANNEL_OPEN_FAILURE,
+ cMSG_CHANNEL_OPEN_SUCCESS,
+ MSG_GLOBAL_REQUEST,
+ MSG_REQUEST_SUCCESS,
+ MSG_REQUEST_FAILURE,
+ MSG_CHANNEL_OPEN_SUCCESS,
+ MSG_CHANNEL_OPEN_FAILURE,
+ MSG_CHANNEL_OPEN,
+ MSG_CHANNEL_SUCCESS,
+ MSG_CHANNEL_FAILURE,
+ MSG_CHANNEL_DATA,
+ MSG_CHANNEL_EXTENDED_DATA,
+ MSG_CHANNEL_WINDOW_ADJUST,
+ MSG_CHANNEL_REQUEST,
+ MSG_CHANNEL_EOF,
+ MSG_CHANNEL_CLOSE,
+ MIN_WINDOW_SIZE,
+ MIN_PACKET_SIZE,
+ MAX_WINDOW_SIZE,
+ DEFAULT_WINDOW_SIZE,
+ DEFAULT_MAX_PACKET_SIZE,
+ HIGHEST_USERAUTH_MESSAGE_ID,
)
from paramiko.compress import ZlibCompressor, ZlibDecompressor
from paramiko.dsskey import DSSKey
@@ -69,7 +100,10 @@ from paramiko.ecdsakey import ECDSAKey
from paramiko.server import ServerInterface
from paramiko.sftp_client import SFTPClient
from paramiko.ssh_exception import (
- SSHException, BadAuthenticationType, ChannelException, ProxyCommandFailure,
+ SSHException,
+ BadAuthenticationType,
+ ChannelException,
+ ProxyCommandFailure,
)
from paramiko.util import retry_on_signal, ClosingContextManager, clamp_value
@@ -77,12 +111,14 @@ from paramiko.util import retry_on_signal, ClosingContextManager, clamp_value
# for thread cleanup
_active_threads = []
+
def _join_lingering_threads():
for thr in _active_threads:
thr.stop_thread()
import atexit
+
atexit.register(_join_lingering_threads)
@@ -99,160 +135,161 @@ class Transport(threading.Thread, ClosingContextManager):
_ENCRYPT = object()
_DECRYPT = object()
- _PROTO_ID = '2.0'
- _CLIENT_ID = 'paramiko_{}'.format(paramiko.__version__)
+ _PROTO_ID = "2.0"
+ _CLIENT_ID = "paramiko_{}".format(paramiko.__version__)
# These tuples of algorithm identifiers are in preference order; do not
# reorder without reason!
_preferred_ciphers = (
- 'aes128-ctr',
- 'aes192-ctr',
- 'aes256-ctr',
- 'aes128-cbc',
- 'aes192-cbc',
- 'aes256-cbc',
- 'blowfish-cbc',
- '3des-cbc',
+ "aes128-ctr",
+ "aes192-ctr",
+ "aes256-ctr",
+ "aes128-cbc",
+ "aes192-cbc",
+ "aes256-cbc",
+ "blowfish-cbc",
+ "3des-cbc",
)
_preferred_macs = (
- 'hmac-sha2-256',
- 'hmac-sha2-512',
- 'hmac-sha1',
- 'hmac-md5',
- 'hmac-sha1-96',
- 'hmac-md5-96',
+ "hmac-sha2-256",
+ "hmac-sha2-512",
+ "hmac-sha1",
+ "hmac-md5",
+ "hmac-sha1-96",
+ "hmac-md5-96",
)
_preferred_keys = (
- 'ssh-ed25519',
- 'ecdsa-sha2-nistp256',
- 'ecdsa-sha2-nistp384',
- 'ecdsa-sha2-nistp521',
- 'ssh-rsa',
- 'ssh-dss',
+ "ssh-ed25519",
+ "ecdsa-sha2-nistp256",
+ "ecdsa-sha2-nistp384",
+ "ecdsa-sha2-nistp521",
+ "ssh-rsa",
+ "ssh-dss",
)
_preferred_kex = (
- 'ecdh-sha2-nistp256',
- 'ecdh-sha2-nistp384',
- 'ecdh-sha2-nistp521',
- 'diffie-hellman-group-exchange-sha256',
- 'diffie-hellman-group-exchange-sha1',
- 'diffie-hellman-group14-sha1',
- 'diffie-hellman-group1-sha1',
+ "ecdh-sha2-nistp256",
+ "ecdh-sha2-nistp384",
+ "ecdh-sha2-nistp521",
+ "diffie-hellman-group-exchange-sha256",
+ "diffie-hellman-group-exchange-sha1",
+ "diffie-hellman-group14-sha1",
+ "diffie-hellman-group1-sha1",
)
_preferred_gsskex = (
- 'gss-gex-sha1-toWM5Slw5Ew8Mqkay+al2g==',
- 'gss-group14-sha1-toWM5Slw5Ew8Mqkay+al2g==',
- 'gss-group1-sha1-toWM5Slw5Ew8Mqkay+al2g==',
+ "gss-gex-sha1-toWM5Slw5Ew8Mqkay+al2g==",
+ "gss-group14-sha1-toWM5Slw5Ew8Mqkay+al2g==",
+ "gss-group1-sha1-toWM5Slw5Ew8Mqkay+al2g==",
)
- _preferred_compression = ('none',)
+ _preferred_compression = ("none",)
_cipher_info = {
- 'aes128-ctr': {
- 'class': algorithms.AES,
- 'mode': modes.CTR,
- 'block-size': 16,
- 'key-size': 16
+ "aes128-ctr": {
+ "class": algorithms.AES,
+ "mode": modes.CTR,
+ "block-size": 16,
+ "key-size": 16,
},
- 'aes192-ctr': {
- 'class': algorithms.AES,
- 'mode': modes.CTR,
- 'block-size': 16,
- 'key-size': 24
+ "aes192-ctr": {
+ "class": algorithms.AES,
+ "mode": modes.CTR,
+ "block-size": 16,
+ "key-size": 24,
},
- 'aes256-ctr': {
- 'class': algorithms.AES,
- 'mode': modes.CTR,
- 'block-size': 16,
- 'key-size': 32
+ "aes256-ctr": {
+ "class": algorithms.AES,
+ "mode": modes.CTR,
+ "block-size": 16,
+ "key-size": 32,
},
- 'blowfish-cbc': {
- 'class': algorithms.Blowfish,
- 'mode': modes.CBC,
- 'block-size': 8,
- 'key-size': 16
+ "blowfish-cbc": {
+ "class": algorithms.Blowfish,
+ "mode": modes.CBC,
+ "block-size": 8,
+ "key-size": 16,
},
- 'aes128-cbc': {
- 'class': algorithms.AES,
- 'mode': modes.CBC,
- 'block-size': 16,
- 'key-size': 16
+ "aes128-cbc": {
+ "class": algorithms.AES,
+ "mode": modes.CBC,
+ "block-size": 16,
+ "key-size": 16,
},
- 'aes192-cbc': {
- 'class': algorithms.AES,
- 'mode': modes.CBC,
- 'block-size': 16,
- 'key-size': 24
+ "aes192-cbc": {
+ "class": algorithms.AES,
+ "mode": modes.CBC,
+ "block-size": 16,
+ "key-size": 24,
},
- 'aes256-cbc': {
- 'class': algorithms.AES,
- 'mode': modes.CBC,
- 'block-size': 16,
- 'key-size': 32
+ "aes256-cbc": {
+ "class": algorithms.AES,
+ "mode": modes.CBC,
+ "block-size": 16,
+ "key-size": 32,
},
- '3des-cbc': {
- 'class': algorithms.TripleDES,
- 'mode': modes.CBC,
- 'block-size': 8,
- 'key-size': 24
+ "3des-cbc": {
+ "class": algorithms.TripleDES,
+ "mode": modes.CBC,
+ "block-size": 8,
+ "key-size": 24,
},
}
-
_mac_info = {
- 'hmac-sha1': {'class': sha1, 'size': 20},
- 'hmac-sha1-96': {'class': sha1, 'size': 12},
- 'hmac-sha2-256': {'class': sha256, 'size': 32},
- 'hmac-sha2-512': {'class': sha512, 'size': 64},
- 'hmac-md5': {'class': md5, 'size': 16},
- 'hmac-md5-96': {'class': md5, 'size': 12},
+ "hmac-sha1": {"class": sha1, "size": 20},
+ "hmac-sha1-96": {"class": sha1, "size": 12},
+ "hmac-sha2-256": {"class": sha256, "size": 32},
+ "hmac-sha2-512": {"class": sha512, "size": 64},
+ "hmac-md5": {"class": md5, "size": 16},
+ "hmac-md5-96": {"class": md5, "size": 12},
}
_key_info = {
- 'ssh-rsa': RSAKey,
- 'ssh-rsa-cert-v01@openssh.com': RSAKey,
- 'ssh-dss': DSSKey,
- 'ssh-dss-cert-v01@openssh.com': DSSKey,
- 'ecdsa-sha2-nistp256': ECDSAKey,
- 'ecdsa-sha2-nistp256-cert-v01@openssh.com': ECDSAKey,
- 'ecdsa-sha2-nistp384': ECDSAKey,
- 'ecdsa-sha2-nistp384-cert-v01@openssh.com': ECDSAKey,
- 'ecdsa-sha2-nistp521': ECDSAKey,
- 'ecdsa-sha2-nistp521-cert-v01@openssh.com': ECDSAKey,
- 'ssh-ed25519': Ed25519Key,
- 'ssh-ed25519-cert-v01@openssh.com': Ed25519Key,
+ "ssh-rsa": RSAKey,
+ "ssh-rsa-cert-v01@openssh.com": RSAKey,
+ "ssh-dss": DSSKey,
+ "ssh-dss-cert-v01@openssh.com": DSSKey,
+ "ecdsa-sha2-nistp256": ECDSAKey,
+ "ecdsa-sha2-nistp256-cert-v01@openssh.com": ECDSAKey,
+ "ecdsa-sha2-nistp384": ECDSAKey,
+ "ecdsa-sha2-nistp384-cert-v01@openssh.com": ECDSAKey,
+ "ecdsa-sha2-nistp521": ECDSAKey,
+ "ecdsa-sha2-nistp521-cert-v01@openssh.com": ECDSAKey,
+ "ssh-ed25519": Ed25519Key,
+ "ssh-ed25519-cert-v01@openssh.com": Ed25519Key,
}
_kex_info = {
- 'diffie-hellman-group1-sha1': KexGroup1,
- 'diffie-hellman-group14-sha1': KexGroup14,
- 'diffie-hellman-group-exchange-sha1': KexGex,
- 'diffie-hellman-group-exchange-sha256': KexGexSHA256,
- 'gss-group1-sha1-toWM5Slw5Ew8Mqkay+al2g==': KexGSSGroup1,
- 'gss-group14-sha1-toWM5Slw5Ew8Mqkay+al2g==': KexGSSGroup14,
- 'gss-gex-sha1-toWM5Slw5Ew8Mqkay+al2g==': KexGSSGex,
- 'ecdh-sha2-nistp256': KexNistp256,
- 'ecdh-sha2-nistp384': KexNistp384,
- 'ecdh-sha2-nistp521': KexNistp521,
+ "diffie-hellman-group1-sha1": KexGroup1,
+ "diffie-hellman-group14-sha1": KexGroup14,
+ "diffie-hellman-group-exchange-sha1": KexGex,
+ "diffie-hellman-group-exchange-sha256": KexGexSHA256,
+ "gss-group1-sha1-toWM5Slw5Ew8Mqkay+al2g==": KexGSSGroup1,
+ "gss-group14-sha1-toWM5Slw5Ew8Mqkay+al2g==": KexGSSGroup14,
+ "gss-gex-sha1-toWM5Slw5Ew8Mqkay+al2g==": KexGSSGex,
+ "ecdh-sha2-nistp256": KexNistp256,
+ "ecdh-sha2-nistp384": KexNistp384,
+ "ecdh-sha2-nistp521": KexNistp521,
}
_compression_info = {
# zlib@openssh.com is just zlib, but only turned on after a successful
# authentication. openssh servers may only offer this type because
# they've had troubles with security holes in zlib in the past.
- 'zlib@openssh.com': (ZlibCompressor, ZlibDecompressor),
- 'zlib': (ZlibCompressor, ZlibDecompressor),
- 'none': (None, None),
+ "zlib@openssh.com": (ZlibCompressor, ZlibDecompressor),
+ "zlib": (ZlibCompressor, ZlibDecompressor),
+ "none": (None, None),
}
_modulus_pack = None
_active_check_timeout = 0.1
- def __init__(self,
- sock,
- default_window_size=DEFAULT_WINDOW_SIZE,
- default_max_packet_size=DEFAULT_MAX_PACKET_SIZE,
- gss_kex=False,
- gss_deleg_creds=True):
+ def __init__(
+ self,
+ sock,
+ default_window_size=DEFAULT_WINDOW_SIZE,
+ default_max_packet_size=DEFAULT_MAX_PACKET_SIZE,
+ gss_kex=False,
+ gss_deleg_creds=True,
+ ):
"""
Create a new SSH session over an existing socket, or socket-like
object. This only creates the `.Transport` object; it doesn't begin
@@ -302,7 +339,7 @@ class Transport(threading.Thread, ClosingContextManager):
if isinstance(sock, string_types):
# convert "host:port" into (host, port)
- hl = sock.split(':', 1)
+ hl = sock.split(":", 1)
self.hostname = hl[0]
if len(hl) == 1:
sock = (hl[0], 22)
@@ -312,7 +349,7 @@ class Transport(threading.Thread, ClosingContextManager):
# connect to the given (host, port)
hostname, port = sock
self.hostname = hostname
- reason = 'No suitable address family'
+ reason = "No suitable address family"
addrinfos = socket.getaddrinfo(
hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM
)
@@ -329,7 +366,8 @@ class Transport(threading.Thread, ClosingContextManager):
break
else:
raise SSHException(
- 'Unable to connect to {}: {}'.format(hostname, reason))
+ "Unable to connect to {}: {}".format(hostname, reason)
+ )
# okay, normal socket-ish flow here...
threading.Thread.__init__(self)
self.setDaemon(True)
@@ -340,9 +378,9 @@ class Transport(threading.Thread, ClosingContextManager):
# negotiated crypto parameters
self.packetizer = Packetizer(sock)
- self.local_version = 'SSH-' + self._PROTO_ID + '-' + self._CLIENT_ID
- self.remote_version = ''
- self.local_cipher = self.remote_cipher = ''
+ self.local_version = "SSH-" + self._PROTO_ID + "-" + self._CLIENT_ID
+ self.remote_version = ""
+ self.local_cipher = self.remote_cipher = ""
self.local_kex_init = self.remote_kex_init = None
self.local_mac = self.remote_mac = None
self.local_compression = self.remote_compression = None
@@ -374,8 +412,8 @@ class Transport(threading.Thread, ClosingContextManager):
# tracking open channels
self._channels = ChannelMap()
- self.channel_events = {} # (id -> Event)
- self.channels_seen = {} # (id -> True)
+ self.channel_events = {} # (id -> Event)
+ self.channels_seen = {} # (id -> True)
self._channel_counter = 0
self.default_max_packet_size = default_max_packet_size
self.default_window_size = default_window_size
@@ -387,7 +425,7 @@ class Transport(threading.Thread, ClosingContextManager):
self.clear_to_send = threading.Event()
self.clear_to_send_lock = threading.Lock()
self.clear_to_send_timeout = 30.0
- self.log_name = 'paramiko.transport'
+ self.log_name = "paramiko.transport"
self.logger = util.get_logger(self.log_name)
self.packetizer.set_log(self.logger)
self.auth_handler = None
@@ -416,23 +454,24 @@ class Transport(threading.Thread, ClosingContextManager):
Returns a string representation of this object, for debugging.
"""
id_ = hex(long(id(self)) & xffffffff)
- out = '<paramiko.Transport at {}'.format(id_)
+ out = "<paramiko.Transport at {}".format(id_)
if not self.active:
- out += ' (unconnected)'
+ out += " (unconnected)"
else:
- if self.local_cipher != '':
- out += ' (cipher {}, {:d} bits)'.format(
+ if self.local_cipher != "":
+ out += " (cipher {}, {:d} bits)".format(
self.local_cipher,
- self._cipher_info[self.local_cipher]['key-size'] * 8
+ self._cipher_info[self.local_cipher]["key-size"] * 8,
)
if self.is_authenticated():
- out += ' (active; {} open channel(s))'.format(
- len(self._channels))
+ out += " (active; {} open channel(s))".format(
+ len(self._channels)
+ )
elif self.initial_kex_done:
- out += ' (connected; awaiting auth)'
+ out += " (connected; awaiting auth)"
else:
- out += ' (connecting)'
- out += '>'
+ out += " (connecting)"
+ out += ">"
return out
def atfork(self):
@@ -543,10 +582,10 @@ class Transport(threading.Thread, ClosingContextManager):
e = self.get_exception()
if e is not None:
raise e
- raise SSHException('Negotiation failed.')
+ raise SSHException("Negotiation failed.")
if (
- event.is_set() or
- (timeout is not None and time.time() >= max_time)
+ event.is_set()
+ or (timeout is not None and time.time() >= max_time)
):
break
@@ -612,7 +651,7 @@ class Transport(threading.Thread, ClosingContextManager):
e = self.get_exception()
if e is not None:
raise e
- raise SSHException('Negotiation failed.')
+ raise SSHException("Negotiation failed.")
if event.is_set():
break
@@ -679,7 +718,7 @@ class Transport(threading.Thread, ClosingContextManager):
"""
Transport._modulus_pack = ModulusPack()
# places to look for the openssh "moduli" file
- file_list = ['/etc/ssh/moduli', '/usr/local/etc/moduli']
+ file_list = ["/etc/ssh/moduli", "/usr/local/etc/moduli"]
if filename is not None:
file_list.insert(0, filename)
for fn in file_list:
@@ -717,7 +756,7 @@ class Transport(threading.Thread, ClosingContextManager):
:return: public key (`.PKey`) of the remote server
"""
if (not self.active) or (not self.initial_kex_done):
- raise SSHException('No existing session')
+ raise SSHException("No existing session")
return self.host_key
def is_active(self):
@@ -731,10 +770,7 @@ class Transport(threading.Thread, ClosingContextManager):
return self.active
def open_session(
- self,
- window_size=None,
- max_packet_size=None,
- timeout=None,
+ self, window_size=None, max_packet_size=None, timeout=None
):
"""
Request a new channel to the server, of type ``"session"``. This is
@@ -761,10 +797,12 @@ class Transport(threading.Thread, ClosingContextManager):
.. versionchanged:: 1.15
Added the ``window_size`` and ``max_packet_size`` arguments.
"""
- return self.open_channel('session',
- window_size=window_size,
- max_packet_size=max_packet_size,
- timeout=timeout)
+ return self.open_channel(
+ "session",
+ window_size=window_size,
+ max_packet_size=max_packet_size,
+ timeout=timeout,
+ )
def open_x11_channel(self, src_addr=None):
"""
@@ -780,7 +818,7 @@ class Transport(threading.Thread, ClosingContextManager):
`.SSHException` -- if the request is rejected or the session ends
prematurely
"""
- return self.open_channel('x11', src_addr=src_addr)
+ return self.open_channel("x11", src_addr=src_addr)
def open_forward_agent_channel(self):
"""
@@ -794,7 +832,7 @@ class Transport(threading.Thread, ClosingContextManager):
:raises: `.SSHException` --
if the request is rejected or the session ends prematurely
"""
- return self.open_channel('auth-agent@openssh.com')
+ return self.open_channel("auth-agent@openssh.com")
def open_forwarded_tcpip_channel(self, src_addr, dest_addr):
"""
@@ -806,15 +844,17 @@ class Transport(threading.Thread, ClosingContextManager):
:param src_addr: originator's address
:param dest_addr: local (server) connected address
"""
- return self.open_channel('forwarded-tcpip', dest_addr, src_addr)
+ return self.open_channel("forwarded-tcpip", dest_addr, src_addr)
- def open_channel(self,
- kind,
- dest_addr=None,
- src_addr=None,
- window_size=None,
- max_packet_size=None,
- timeout=None):
+ def open_channel(
+ self,
+ kind,
+ dest_addr=None,
+ src_addr=None,
+ window_size=None,
+ max_packet_size=None,
+ timeout=None,
+ ):
"""
Request a new channel to the server. `Channels <.Channel>` are
socket-like objects used for the actual transfer of data across the
@@ -851,7 +891,7 @@ class Transport(threading.Thread, ClosingContextManager):
Added the ``window_size`` and ``max_packet_size`` arguments.
"""
if not self.active:
- raise SSHException('SSH session not active')
+ raise SSHException("SSH session not active")
timeout = 3600 if timeout is None else timeout
self.lock.acquire()
try:
@@ -864,12 +904,12 @@ class Transport(threading.Thread, ClosingContextManager):
m.add_int(chanid)
m.add_int(window_size)
m.add_int(max_packet_size)
- if (kind == 'forwarded-tcpip') or (kind == 'direct-tcpip'):
+ if (kind == "forwarded-tcpip") or (kind == "direct-tcpip"):
m.add_string(dest_addr[0])
m.add_int(dest_addr[1])
m.add_string(src_addr[0])
m.add_int(src_addr[1])
- elif kind == 'x11':
+ elif kind == "x11":
m.add_string(src_addr[0])
m.add_int(src_addr[1])
chan = Channel(chanid)
@@ -887,18 +927,18 @@ class Transport(threading.Thread, ClosingContextManager):
if not self.active:
e = self.get_exception()
if e is None:
- e = SSHException('Unable to open channel.')
+ e = SSHException("Unable to open channel.")
raise e
if event.is_set():
break
elif start_ts + timeout < time.time():
- raise SSHException('Timeout opening channel.')
+ raise SSHException("Timeout opening channel.")
chan = self._channels.get(chanid)
if chan is not None:
return chan
e = self.get_exception()
if e is None:
- e = SSHException('Unable to open channel.')
+ e = SSHException("Unable to open channel.")
raise e
def request_port_forward(self, address, port, handler=None):
@@ -935,20 +975,22 @@ class Transport(threading.Thread, ClosingContextManager):
`.SSHException` -- if the server refused the TCP forward request
"""
if not self.active:
- raise SSHException('SSH session not active')
+ raise SSHException("SSH session not active")
port = int(port)
response = self.global_request(
- 'tcpip-forward', (address, port), wait=True
+ "tcpip-forward", (address, port), wait=True
)
if response is None:
- raise SSHException('TCP forwarding request denied')
+ raise SSHException("TCP forwarding request denied")
if port == 0:
port = response.get_int()
if handler is None:
+
def default_handler(channel, src_addr, dest_addr_port):
# src_addr, src_port = src_addr_port
# dest_addr, dest_port = dest_addr_port
self._queue_incoming_channel(channel)
+
handler = default_handler
self._tcp_handler = handler
return port
@@ -965,7 +1007,7 @@ class Transport(threading.Thread, ClosingContextManager):
if not self.active:
return
self._tcp_handler = None
- self.global_request('cancel-tcpip-forward', (address, port), wait=True)
+ self.global_request("cancel-tcpip-forward", (address, port), wait=True)
def open_sftp_client(self):
"""
@@ -1018,7 +1060,7 @@ class Transport(threading.Thread, ClosingContextManager):
e = self.get_exception()
if e is not None:
raise e
- raise SSHException('Negotiation failed.')
+ raise SSHException("Negotiation failed.")
if self.completion_event.is_set():
break
return
@@ -1034,8 +1076,10 @@ class Transport(threading.Thread, ClosingContextManager):
seconds to wait before sending a keepalive packet (or
0 to disable keepalives).
"""
+
def _request(x=weakref.proxy(self)):
- return x.global_request('keepalive@lag.net', wait=False)
+ return x.global_request("keepalive@lag.net", wait=False)
+
self.packetizer.set_keepalive(interval, _request)
def global_request(self, kind, data=None, wait=True):
@@ -1103,7 +1147,7 @@ class Transport(threading.Thread, ClosingContextManager):
def connect(
self,
hostkey=None,
- username='',
+ username="",
password=None,
pkey=None,
gss_host=None,
@@ -1177,34 +1221,43 @@ class Transport(threading.Thread, ClosingContextManager):
if (hostkey is not None) and not gss_kex:
key = self.get_remote_server_key()
if (
- key.get_name() != hostkey.get_name() or
- key.asbytes() != hostkey.asbytes()
+ key.get_name() != hostkey.get_name()
+ or key.asbytes() != hostkey.asbytes()
):
- self._log(DEBUG, 'Bad host key from server')
- self._log(DEBUG, 'Expected: {}: {}'.format(
- hostkey.get_name(), repr(hostkey.asbytes()),
- ))
- self._log(DEBUG, 'Got : {}: {}'.format(
- key.get_name(), repr(key.asbytes()),
- ))
- raise SSHException('Bad host key from server')
- self._log(DEBUG, 'Host key verified ({})'.format(
- hostkey.get_name()))
+ self._log(DEBUG, "Bad host key from server")
+ self._log(
+ DEBUG,
+ "Expected: {}: {}".format(
+ hostkey.get_name(), repr(hostkey.asbytes())
+ ),
+ )
+ self._log(
+ DEBUG,
+ "Got : {}: {}".format(
+ key.get_name(), repr(key.asbytes())
+ ),
+ )
+ raise SSHException("Bad host key from server")
+ self._log(
+ DEBUG, "Host key verified ({})".format(hostkey.get_name())
+ )
if (pkey is not None) or (password is not None) or gss_auth or gss_kex:
if gss_auth:
- self._log(DEBUG, 'Attempting GSS-API auth... (gssapi-with-mic)') # noqa
+ self._log(
+ DEBUG, "Attempting GSS-API auth... (gssapi-with-mic)"
+ ) # noqa
self.auth_gssapi_with_mic(
- username, self.gss_host, gss_deleg_creds,
+ username, self.gss_host, gss_deleg_creds
)
elif gss_kex:
- self._log(DEBUG, 'Attempting GSS-API auth... (gssapi-keyex)')
+ self._log(DEBUG, "Attempting GSS-API auth... (gssapi-keyex)")
self.auth_gssapi_keyex(username)
elif pkey is not None:
- self._log(DEBUG, 'Attempting public-key auth...')
+ self._log(DEBUG, "Attempting public-key auth...")
self.auth_publickey(username, pkey)
else:
- self._log(DEBUG, 'Attempting password auth...')
+ self._log(DEBUG, "Attempting password auth...")
self.auth_password(username, password)
return
@@ -1259,9 +1312,9 @@ class Transport(threading.Thread, ClosingContextManager):
closed.
"""
return (
- self.active and
- self.auth_handler is not None and
- self.auth_handler.is_authenticated()
+ self.active
+ and self.auth_handler is not None
+ and self.auth_handler.is_authenticated()
)
def get_username(self):
@@ -1311,7 +1364,7 @@ class Transport(threading.Thread, ClosingContextManager):
.. versionadded:: 1.5
"""
if (not self.active) or (not self.initial_kex_done):
- raise SSHException('No existing session')
+ raise SSHException("No existing session")
my_event = threading.Event()
self.auth_handler = AuthHandler(self)
self.auth_handler.auth_none(username, my_event)
@@ -1367,7 +1420,7 @@ class Transport(threading.Thread, ClosingContextManager):
if (not self.active) or (not self.initial_kex_done):
# we should never try to send the password unless we're on a secure
# link
- raise SSHException('No existing session')
+ raise SSHException("No existing session")
if event is None:
my_event = threading.Event()
else:
@@ -1382,12 +1435,13 @@ class Transport(threading.Thread, ClosingContextManager):
except BadAuthenticationType as e:
# if password auth isn't allowed, but keyboard-interactive *is*,
# try to fudge it
- if not fallback or ('keyboard-interactive' not in e.allowed_types):
+ if not fallback or ("keyboard-interactive" not in e.allowed_types):
raise
try:
+
def handler(title, instructions, fields):
if len(fields) > 1:
- raise SSHException('Fallback authentication failed.')
+ raise SSHException("Fallback authentication failed.")
if len(fields) == 0:
# for some reason, at least on os x, a 2nd request will
# be made with zero fields requested. maybe it's just
@@ -1395,6 +1449,7 @@ class Transport(threading.Thread, ClosingContextManager):
# type we're doing here. *shrug* :)
return []
return [password]
+
return self.auth_interactive(username, handler)
except SSHException:
# attempt failed; just raise the original exception
@@ -1437,7 +1492,7 @@ class Transport(threading.Thread, ClosingContextManager):
"""
if (not self.active) or (not self.initial_kex_done):
# we should never try to authenticate unless we're on a secure link
- raise SSHException('No existing session')
+ raise SSHException("No existing session")
if event is None:
my_event = threading.Event()
else:
@@ -1449,7 +1504,7 @@ class Transport(threading.Thread, ClosingContextManager):
return []
return self.auth_handler.wait_for_response(my_event)
- def auth_interactive(self, username, handler, submethods=''):
+ def auth_interactive(self, username, handler, submethods=""):
"""
Authenticate to the server interactively. A handler is used to answer
arbitrary questions from the server. On many servers, this is just a
@@ -1494,7 +1549,7 @@ class Transport(threading.Thread, ClosingContextManager):
"""
if (not self.active) or (not self.initial_kex_done):
# we should never try to authenticate unless we're on a secure link
- raise SSHException('No existing session')
+ raise SSHException("No existing session")
my_event = threading.Event()
self.auth_handler = AuthHandler(self)
self.auth_handler.auth_interactive(
@@ -1502,7 +1557,7 @@ class Transport(threading.Thread, ClosingContextManager):
)
return self.auth_handler.wait_for_response(my_event)
- def auth_interactive_dumb(self, username, handler=None, submethods=''):
+ def auth_interactive_dumb(self, username, handler=None, submethods=""):
"""
Autenticate to the server interactively but dumber.
Just print the prompt and / or instructions to stdout and send back
@@ -1511,6 +1566,7 @@ class Transport(threading.Thread, ClosingContextManager):
"""
if not handler:
+
def handler(title, instructions, prompt_list):
answers = []
if title:
@@ -1518,9 +1574,10 @@ class Transport(threading.Thread, ClosingContextManager):
if instructions:
print(instructions.strip())
for prompt, show_input in prompt_list:
- print(prompt.strip(), end=' ')
+ print(prompt.strip(), end=" ")
answers.append(input())
return answers
+
return self.auth_interactive(username, handler, submethods)
def auth_gssapi_with_mic(self, username, gss_host, gss_deleg_creds):
@@ -1541,7 +1598,7 @@ class Transport(threading.Thread, ClosingContextManager):
"""
if (not self.active) or (not self.initial_kex_done):
# we should never try to authenticate unless we're on a secure link
- raise SSHException('No existing session')
+ raise SSHException("No existing session")
my_event = threading.Event()
self.auth_handler = AuthHandler(self)
self.auth_handler.auth_gssapi_with_mic(
@@ -1566,7 +1623,7 @@ class Transport(threading.Thread, ClosingContextManager):
"""
if (not self.active) or (not self.initial_kex_done):
# we should never try to authenticate unless we're on a secure link
- raise SSHException('No existing session')
+ raise SSHException("No existing session")
my_event = threading.Event()
self.auth_handler = AuthHandler(self)
self.auth_handler.auth_gssapi_keyex(username, my_event)
@@ -1633,9 +1690,9 @@ class Transport(threading.Thread, ClosingContextManager):
.. versionadded:: 1.5.2
"""
if compress:
- self._preferred_compression = ('zlib@openssh.com', 'zlib', 'none')
+ self._preferred_compression = ("zlib@openssh.com", "zlib", "none")
else:
- self._preferred_compression = ('none',)
+ self._preferred_compression = ("none",)
def getpeername(self):
"""
@@ -1649,9 +1706,9 @@ class Transport(threading.Thread, ClosingContextManager):
the address of the remote host, if known, as a ``(str, int)``
tuple.
"""
- gp = getattr(self.sock, 'getpeername', None)
+ gp = getattr(self.sock, "getpeername", None)
if gp is None:
- return 'unknown', 0
+ return "unknown", 0
return gp()
def stop_thread(self):
@@ -1670,10 +1727,10 @@ class Transport(threading.Thread, ClosingContextManager):
# our socket and packetizer are both closed (but where we'd
# otherwise be sitting forever on that recv()).
while (
- self.is_alive() and
- self is not threading.current_thread() and
- not self.sock._closed and
- not self.packetizer.closed
+ self.is_alive()
+ and self is not threading.current_thread()
+ and not self.sock._closed
+ and not self.packetizer.closed
):
self.join(0.1)
@@ -1715,14 +1772,18 @@ class Transport(threading.Thread, ClosingContextManager):
while True:
self.clear_to_send.wait(0.1)
if not self.active:
- self._log(DEBUG, 'Dropping user packet because connection is dead.') # noqa
+ self._log(
+ DEBUG, "Dropping user packet because connection is dead."
+ ) # noqa
return
self.clear_to_send_lock.acquire()
if self.clear_to_send.is_set():
break
self.clear_to_send_lock.release()
if time.time() > start + self.clear_to_send_timeout:
- raise SSHException('Key-exchange timed out waiting for key negotiation') # noqa
+ raise SSHException(
+ "Key-exchange timed out waiting for key negotiation"
+ ) # noqa
try:
self._send_message(data)
finally:
@@ -1746,9 +1807,13 @@ class Transport(threading.Thread, ClosingContextManager):
def _verify_key(self, host_key, sig):
key = self._key_info[self.host_key_type](Message(host_key))
if key is None:
- raise SSHException('Unknown host key type')
+ raise SSHException("Unknown host key type")
if not key.verify_ssh_sig(self.H, Message(sig)):
- raise SSHException('Signature verification ({}) failed.'.format(self.host_key_type)) # noqa
+ raise SSHException(
+ "Signature verification ({}) failed.".format(
+ self.host_key_type
+ )
+ ) # noqa
self.host_key = key
def _compute_key(self, id, nbytes):
@@ -1760,16 +1825,16 @@ class Transport(threading.Thread, ClosingContextManager):
m.add_bytes(self.session_id)
# Fallback to SHA1 for kex engines that fail to specify a hex
# algorithm, or for e.g. transport tests that don't run kexinit.
- hash_algo = getattr(self.kex_engine, 'hash_algo', None)
+ hash_algo = getattr(self.kex_engine, "hash_algo", None)
hash_select_msg = "kex engine {} specified hash_algo {!r}".format(
- self.kex_engine.__class__.__name__, hash_algo,
+ self.kex_engine.__class__.__name__, hash_algo
)
if hash_algo is None:
hash_algo = sha1
hash_select_msg += ", falling back to sha1"
- if not hasattr(self, '_logged_hash_selection'):
+ if not hasattr(self, "_logged_hash_selection"):
self._log(DEBUG, hash_select_msg)
- setattr(self, '_logged_hash_selection', True)
+ setattr(self, "_logged_hash_selection", True)
out = sofar = hash_algo(m.asbytes()).digest()
while len(out) < nbytes:
m = Message()
@@ -1783,11 +1848,11 @@ class Transport(threading.Thread, ClosingContextManager):
def _get_cipher(self, name, key, iv, operation):
if name not in self._cipher_info:
- raise SSHException('Unknown client cipher ' + name)
+ raise SSHException("Unknown client cipher " + name)
else:
cipher = Cipher(
- self._cipher_info[name]['class'](key),
- self._cipher_info[name]['mode'](iv),
+ self._cipher_info[name]["class"](key),
+ self._cipher_info[name]["mode"](iv),
backend=default_backend(),
)
if operation is self._ENCRYPT:
@@ -1797,8 +1862,10 @@ class Transport(threading.Thread, ClosingContextManager):
def _set_forward_agent_handler(self, handler):
if handler is None:
+
def default_handler(channel):
self._queue_incoming_channel(channel)
+
self._forward_agent_handler = default_handler
else:
self._forward_agent_handler = handler
@@ -1809,6 +1876,7 @@ class Transport(threading.Thread, ClosingContextManager):
# by default, use the same mechanism as accept()
def default_handler(channel, src_addr_port):
self._queue_incoming_channel(channel)
+
self._x11_handler = default_handler
else:
self._x11_handler = handler
@@ -1842,9 +1910,9 @@ class Transport(threading.Thread, ClosingContextManager):
Otherwise (client mode, authed, or pre-auth message) returns None.
"""
if (
- not self.server_mode or
- ptype <= HIGHEST_USERAUTH_MESSAGE_ID or
- self.is_authenticated()
+ not self.server_mode
+ or ptype <= HIGHEST_USERAUTH_MESSAGE_ID
+ or self.is_authenticated()
):
return None
# WELP. We must be dealing with someone trying to do non-auth things
@@ -1855,13 +1923,13 @@ class Transport(threading.Thread, ClosingContextManager):
reply.add_byte(cMSG_REQUEST_FAILURE)
# Channel opens let us reject w/ a specific type + message.
elif ptype == MSG_CHANNEL_OPEN:
- kind = message.get_text() # noqa
+ kind = message.get_text() # noqa
chanid = message.get_int()
reply.add_byte(cMSG_CHANNEL_OPEN_FAILURE)
reply.add_int(chanid)
reply.add_int(OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED)
- reply.add_string('')
- reply.add_string('en')
+ reply.add_string("")
+ reply.add_string("en")
# NOTE: Post-open channel messages do not need checking; the above will
# reject attemps to open channels, meaning that even if a malicious
# user tries to send a MSG_CHANNEL_REQUEST, it will simply fall under
@@ -1883,13 +1951,16 @@ class Transport(threading.Thread, ClosingContextManager):
_active_threads.append(self)
tid = hex(long(id(self)) & xffffffff)
if self.server_mode:
- self._log(DEBUG, 'starting thread (server mode): {}'.format(tid))
+ self._log(DEBUG, "starting thread (server mode): {}".format(tid))
else:
- self._log(DEBUG, 'starting thread (client mode): {}'.format(tid))
+ self._log(DEBUG, "starting thread (client mode): {}".format(tid))
try:
try:
- self.packetizer.write_all(b(self.local_version + '\r\n'))
- self._log(DEBUG, 'Local version/idstring: {}'.format(self.local_version)) # noqa
+ self.packetizer.write_all(b(self.local_version + "\r\n"))
+ self._log(
+ DEBUG,
+ "Local version/idstring: {}".format(self.local_version),
+ ) # noqa
self._check_banner()
# The above is actually very much part of the handshake, but
# sometimes the banner can be read but the machine is not
@@ -1919,7 +1990,11 @@ class Transport(threading.Thread, ClosingContextManager):
continue
if len(self._expected_packet) > 0:
if ptype not in self._expected_packet:
- raise SSHException('Expecting packet from {!r}, got {:d}'.format(self._expected_packet, ptype)) # noqa
+ raise SSHException(
+ "Expecting packet from {!r}, got {:d}".format(
+ self._expected_packet, ptype
+ )
+ ) # noqa
self._expected_packet = tuple()
if (ptype >= 30) and (ptype <= 41):
self.kex_engine.parse_next(ptype, m)
@@ -1937,20 +2012,30 @@ class Transport(threading.Thread, ClosingContextManager):
if chan is not None:
self._channel_handler_table[ptype](chan, m)
elif chanid in self.channels_seen:
- self._log(DEBUG, 'Ignoring message for dead channel {:d}'.format(chanid)) # noqa
+ self._log(
+ DEBUG,
+ "Ignoring message for dead channel {:d}".format(
+ chanid
+ ),
+ ) # noqa
else:
- self._log(ERROR, 'Channel request for unknown channel {:d}'.format(chanid)) # noqa
+ self._log(
+ ERROR,
+ "Channel request for unknown channel {:d}".format(
+ chanid
+ ),
+ ) # noqa
break
elif (
- self.auth_handler is not None and
- ptype in self.auth_handler._handler_table
+ self.auth_handler is not None
+ and ptype in self.auth_handler._handler_table
):
handler = self.auth_handler._handler_table[ptype]
handler(self.auth_handler, m)
if len(self._expected_packet) > 0:
continue
else:
- err = 'Oops, unhandled type {:d}'.format(ptype)
+ err = "Oops, unhandled type {:d}".format(ptype)
self._log(WARNING, err)
msg = Message()
msg.add_byte(cMSG_UNIMPLEMENTED)
@@ -1958,24 +2043,24 @@ class Transport(threading.Thread, ClosingContextManager):
self._send_message(msg)
self.packetizer.complete_handshake()
except SSHException as e:
- self._log(ERROR, 'Exception: ' + str(e))
+ self._log(ERROR, "Exception: " + str(e))
self._log(ERROR, util.tb_strings())
self.saved_exception = e
except EOFError as e:
- self._log(DEBUG, 'EOF in transport thread')
+ self._log(DEBUG, "EOF in transport thread")
self.saved_exception = e
except socket.error as e:
if type(e.args) is tuple:
if e.args:
- emsg = '{} ({:d})'.format(e.args[1], e.args[0])
+ emsg = "{} ({:d})".format(e.args[1], e.args[0])
else: # empty tuple, e.g. socket.timeout
emsg = str(e) or repr(e)
else:
emsg = e.args
- self._log(ERROR, 'Socket exception: ' + emsg)
+ self._log(ERROR, "Socket exception: " + emsg)
self.saved_exception = e
except Exception as e:
- self._log(ERROR, 'Unknown exception: ' + str(e))
+ self._log(ERROR, "Unknown exception: " + str(e))
self._log(ERROR, util.tb_strings())
self.saved_exception = e
_active_threads.remove(self)
@@ -2004,7 +2089,6 @@ class Transport(threading.Thread, ClosingContextManager):
if self.sys.modules is not None:
raise
-
def _log_agreement(self, which, local, remote):
# Log useful, non-duplicative line re: an agreed-upon algorithm.
# Old code implied algorithms could be asymmetrical (different for
@@ -2046,32 +2130,32 @@ class Transport(threading.Thread, ClosingContextManager):
raise
except Exception as e:
raise SSHException(
- 'Error reading SSH protocol banner' + str(e)
+ "Error reading SSH protocol banner" + str(e)
)
- if buf[:4] == 'SSH-':
+ if buf[:4] == "SSH-":
break
- self._log(DEBUG, 'Banner: ' + buf)
- if buf[:4] != 'SSH-':
+ self._log(DEBUG, "Banner: " + buf)
+ if buf[:4] != "SSH-":
raise SSHException('Indecipherable protocol version "' + buf + '"')
# save this server version string for later
self.remote_version = buf
- self._log(DEBUG, 'Remote version/idstring: {}'.format(buf))
+ self._log(DEBUG, "Remote version/idstring: {}".format(buf))
# pull off any attached comment
# NOTE: comment used to be stored in a variable and then...never used.
# since 2003. ca 877cd974b8182d26fa76d566072917ea67b64e67
- i = buf.find(' ')
+ i = buf.find(" ")
if i >= 0:
buf = buf[:i]
# parse out version string and make sure it matches
- segs = buf.split('-', 2)
+ segs = buf.split("-", 2)
if len(segs) < 3:
- raise SSHException('Invalid SSH banner')
+ raise SSHException("Invalid SSH banner")
version = segs[1]
client = segs[2]
- if version != '1.99' and version != '2.0':
- msg = 'Incompatible version ({} instead of 2.0)'
+ if version != "1.99" and version != "2.0":
+ msg = "Incompatible version ({} instead of 2.0)"
raise SSHException(msg.format(version))
- msg = 'Connected (version {}, client {})'.format(version, client)
+ msg = "Connected (version {}, client {})".format(version, client)
self._log(INFO, msg)
def _send_kex_init(self):
@@ -2087,25 +2171,27 @@ class Transport(threading.Thread, ClosingContextManager):
self.gss_kex_used = False
self.in_kex = True
if self.server_mode:
- mp_required_prefix = 'diffie-hellman-group-exchange-sha'
+ mp_required_prefix = "diffie-hellman-group-exchange-sha"
kex_mp = [
- k for k
- in self._preferred_kex
+ k
+ for k in self._preferred_kex
if k.startswith(mp_required_prefix)
]
if (self._modulus_pack is None) and (len(kex_mp) > 0):
# can't do group-exchange if we don't have a pack of potential
# primes
pkex = [
- k for k
- in self.get_security_options().kex
+ k
+ for k in self.get_security_options().kex
if not k.startswith(mp_required_prefix)
]
self.get_security_options().kex = pkex
- available_server_keys = list(filter(
- list(self.server_key_dict.keys()).__contains__,
- self._preferred_keys
- ))
+ available_server_keys = list(
+ filter(
+ list(self.server_key_dict.keys()).__contains__,
+ self._preferred_keys,
+ )
+ )
else:
available_server_keys = self._preferred_keys
@@ -2129,7 +2215,7 @@ class Transport(threading.Thread, ClosingContextManager):
self._send_message(m)
def _parse_kex_init(self, m):
- m.get_bytes(16) # cookie, discarded
+ m.get_bytes(16) # cookie, discarded
kex_algo_list = m.get_list()
server_key_algo_list = m.get_list()
client_encrypt_algo_list = m.get_list()
@@ -2141,20 +2227,32 @@ class Transport(threading.Thread, ClosingContextManager):
client_lang_list = m.get_list()
server_lang_list = m.get_list()
kex_follows = m.get_boolean()
- m.get_int() # unused
-
- self._log(DEBUG,
- 'kex algos:' + str(kex_algo_list) +
- ' server key:' + str(server_key_algo_list) +
- ' client encrypt:' + str(client_encrypt_algo_list) +
- ' server encrypt:' + str(server_encrypt_algo_list) +
- ' client mac:' + str(client_mac_algo_list) +
- ' server mac:' + str(server_mac_algo_list) +
- ' client compress:' + str(client_compress_algo_list) +
- ' server compress:' + str(server_compress_algo_list) +
- ' client lang:' + str(client_lang_list) +
- ' server lang:' + str(server_lang_list) +
- ' kex follows?' + str(kex_follows)
+ m.get_int() # unused
+
+ self._log(
+ DEBUG,
+ "kex algos:"
+ + str(kex_algo_list)
+ + " server key:"
+ + str(server_key_algo_list)
+ + " client encrypt:"
+ + str(client_encrypt_algo_list)
+ + " server encrypt:"
+ + str(server_encrypt_algo_list)
+ + " client mac:"
+ + str(client_mac_algo_list)
+ + " server mac:"
+ + str(server_mac_algo_list)
+ + " client compress:"
+ + str(client_compress_algo_list)
+ + " server compress:"
+ + str(server_compress_algo_list)
+ + " client lang:"
+ + str(client_lang_list)
+ + " server lang:"
+ + str(server_lang_list)
+ + " kex follows?"
+ + str(kex_follows),
)
# as a server, we pick the first item in the client's list that we
@@ -2162,122 +2260,149 @@ class Transport(threading.Thread, ClosingContextManager):
# as a client, we pick the first item in our list that the server
# supports.
if self.server_mode:
- agreed_kex = list(filter(
- self._preferred_kex.__contains__,
- kex_algo_list
- ))
+ agreed_kex = list(
+ filter(self._preferred_kex.__contains__, kex_algo_list)
+ )
else:
- agreed_kex = list(filter(
- kex_algo_list.__contains__,
- self._preferred_kex
- ))
+ agreed_kex = list(
+ filter(kex_algo_list.__contains__, self._preferred_kex)
+ )
if len(agreed_kex) == 0:
- raise SSHException('Incompatible ssh peer (no acceptable kex algorithm)') # noqa
+ raise SSHException(
+ "Incompatible ssh peer (no acceptable kex algorithm)"
+ ) # noqa
self.kex_engine = self._kex_info[agreed_kex[0]](self)
self._log(DEBUG, "Kex agreed: {}".format(agreed_kex[0]))
if self.server_mode:
- available_server_keys = list(filter(
- list(self.server_key_dict.keys()).__contains__,
- self._preferred_keys
- ))
- agreed_keys = list(filter(
- available_server_keys.__contains__, server_key_algo_list
- ))
+ available_server_keys = list(
+ filter(
+ list(self.server_key_dict.keys()).__contains__,
+ self._preferred_keys,
+ )
+ )
+ agreed_keys = list(
+ filter(
+ available_server_keys.__contains__, server_key_algo_list
+ )
+ )
else:
- agreed_keys = list(filter(
- server_key_algo_list.__contains__, self._preferred_keys
- ))
+ agreed_keys = list(
+ filter(server_key_algo_list.__contains__, self._preferred_keys)
+ )
if len(agreed_keys) == 0:
- raise SSHException('Incompatible ssh peer (no acceptable host key)') # noqa
+ raise SSHException(
+ "Incompatible ssh peer (no acceptable host key)"
+ ) # noqa
self.host_key_type = agreed_keys[0]
if self.server_mode and (self.get_server_key() is None):
- raise SSHException('Incompatible ssh peer (can\'t match requested host key type)') # noqa
- self._log_agreement(
- 'HostKey', agreed_keys[0], agreed_keys[0]
- )
+ raise SSHException(
+ "Incompatible ssh peer (can't match requested host key type)"
+ ) # noqa
+ self._log_agreement("HostKey", agreed_keys[0], agreed_keys[0])
if self.server_mode:
- agreed_local_ciphers = list(filter(
- self._preferred_ciphers.__contains__,
- server_encrypt_algo_list
- ))
- agreed_remote_ciphers = list(filter(
- self._preferred_ciphers.__contains__,
- client_encrypt_algo_list
- ))
+ agreed_local_ciphers = list(
+ filter(
+ self._preferred_ciphers.__contains__,
+ server_encrypt_algo_list,
+ )
+ )
+ agreed_remote_ciphers = list(
+ filter(
+ self._preferred_ciphers.__contains__,
+ client_encrypt_algo_list,
+ )
+ )
else:
- agreed_local_ciphers = list(filter(
- client_encrypt_algo_list.__contains__,
- self._preferred_ciphers
- ))
- agreed_remote_ciphers = list(filter(
- server_encrypt_algo_list.__contains__,
- self._preferred_ciphers
- ))
+ agreed_local_ciphers = list(
+ filter(
+ client_encrypt_algo_list.__contains__,
+ self._preferred_ciphers,
+ )
+ )
+ agreed_remote_ciphers = list(
+ filter(
+ server_encrypt_algo_list.__contains__,
+ self._preferred_ciphers,
+ )
+ )
if len(agreed_local_ciphers) == 0 or len(agreed_remote_ciphers) == 0:
- raise SSHException('Incompatible ssh server (no acceptable ciphers)') # noqa
+ raise SSHException(
+ "Incompatible ssh server (no acceptable ciphers)"
+ ) # noqa
self.local_cipher = agreed_local_ciphers[0]
self.remote_cipher = agreed_remote_ciphers[0]
self._log_agreement(
- 'Cipher', local=self.local_cipher, remote=self.remote_cipher
+ "Cipher", local=self.local_cipher, remote=self.remote_cipher
)
if self.server_mode:
- agreed_remote_macs = list(filter(
- self._preferred_macs.__contains__, client_mac_algo_list
- ))
- agreed_local_macs = list(filter(
- self._preferred_macs.__contains__, server_mac_algo_list
- ))
+ agreed_remote_macs = list(
+ filter(self._preferred_macs.__contains__, client_mac_algo_list)
+ )
+ agreed_local_macs = list(
+ filter(self._preferred_macs.__contains__, server_mac_algo_list)
+ )
else:
- agreed_local_macs = list(filter(
- client_mac_algo_list.__contains__, self._preferred_macs
- ))
- agreed_remote_macs = list(filter(
- server_mac_algo_list.__contains__, self._preferred_macs
- ))
+ agreed_local_macs = list(
+ filter(client_mac_algo_list.__contains__, self._preferred_macs)
+ )
+ agreed_remote_macs = list(
+ filter(server_mac_algo_list.__contains__, self._preferred_macs)
+ )
if (len(agreed_local_macs) == 0) or (len(agreed_remote_macs) == 0):
- raise SSHException('Incompatible ssh server (no acceptable macs)')
+ raise SSHException("Incompatible ssh server (no acceptable macs)")
self.local_mac = agreed_local_macs[0]
self.remote_mac = agreed_remote_macs[0]
self._log_agreement(
- 'MAC', local=self.local_mac, remote=self.remote_mac
+ "MAC", local=self.local_mac, remote=self.remote_mac
)
if self.server_mode:
- agreed_remote_compression = list(filter(
- self._preferred_compression.__contains__,
- client_compress_algo_list
- ))
- agreed_local_compression = list(filter(
- self._preferred_compression.__contains__,
- server_compress_algo_list
- ))
+ agreed_remote_compression = list(
+ filter(
+ self._preferred_compression.__contains__,
+ client_compress_algo_list,
+ )
+ )
+ agreed_local_compression = list(
+ filter(
+ self._preferred_compression.__contains__,
+ server_compress_algo_list,
+ )
+ )
else:
- agreed_local_compression = list(filter(
- client_compress_algo_list.__contains__,
- self._preferred_compression
- ))
- agreed_remote_compression = list(filter(
- server_compress_algo_list.__contains__,
- self._preferred_compression
- ))
+ agreed_local_compression = list(
+ filter(
+ client_compress_algo_list.__contains__,
+ self._preferred_compression,
+ )
+ )
+ agreed_remote_compression = list(
+ filter(
+ server_compress_algo_list.__contains__,
+ self._preferred_compression,
+ )
+ )
if (
- len(agreed_local_compression) == 0 or
- len(agreed_remote_compression) == 0
+ len(agreed_local_compression) == 0
+ or len(agreed_remote_compression) == 0
):
- msg = 'Incompatible ssh server (no acceptable compression) {!r} {!r} {!r}' # noqa
- raise SSHException(msg.format(
- agreed_local_compression, agreed_remote_compression,
- self._preferred_compression,
- ))
+ msg = "Incompatible ssh server (no acceptable compression) {!r} {!r} {!r}" # noqa
+ raise SSHException(
+ msg.format(
+ agreed_local_compression,
+ agreed_remote_compression,
+ self._preferred_compression,
+ )
+ )
self.local_compression = agreed_local_compression[0]
self.remote_compression = agreed_remote_compression[0]
self._log_agreement(
- 'Compression',
+ "Compression",
local=self.local_compression,
- remote=self.remote_compression
+ remote=self.remote_compression,
)
# save for computing hash later...
@@ -2290,40 +2415,40 @@ class Transport(threading.Thread, ClosingContextManager):
def _activate_inbound(self):
"""switch on newly negotiated encryption parameters for
inbound traffic"""
- block_size = self._cipher_info[self.remote_cipher]['block-size']
+ block_size = self._cipher_info[self.remote_cipher]["block-size"]
if self.server_mode:
- IV_in = self._compute_key('A', block_size)
+ IV_in = self._compute_key("A", block_size)
key_in = self._compute_key(
- 'C', self._cipher_info[self.remote_cipher]['key-size']
+ "C", self._cipher_info[self.remote_cipher]["key-size"]
)
else:
- IV_in = self._compute_key('B', block_size)
+ IV_in = self._compute_key("B", block_size)
key_in = self._compute_key(
- 'D', self._cipher_info[self.remote_cipher]['key-size']
+ "D", self._cipher_info[self.remote_cipher]["key-size"]
)
engine = self._get_cipher(
self.remote_cipher, key_in, IV_in, self._DECRYPT
)
- mac_size = self._mac_info[self.remote_mac]['size']
- mac_engine = self._mac_info[self.remote_mac]['class']
+ mac_size = self._mac_info[self.remote_mac]["size"]
+ mac_engine = self._mac_info[self.remote_mac]["class"]
# initial mac keys are done in the hash's natural size (not the
# potentially truncated transmission size)
if self.server_mode:
- mac_key = self._compute_key('E', mac_engine().digest_size)
+ mac_key = self._compute_key("E", mac_engine().digest_size)
else:
- mac_key = self._compute_key('F', mac_engine().digest_size)
+ mac_key = self._compute_key("F", mac_engine().digest_size)
self.packetizer.set_inbound_cipher(
engine, block_size, mac_engine, mac_size, mac_key
)
compress_in = self._compression_info[self.remote_compression][1]
if (
- compress_in is not None and
- (
- self.remote_compression != 'zlib@openssh.com' or
- self.authenticated
+ compress_in is not None
+ and (
+ self.remote_compression != "zlib@openssh.com"
+ or self.authenticated
)
):
- self._log(DEBUG, 'Switching on inbound compression ...')
+ self._log(DEBUG, "Switching on inbound compression ...")
self.packetizer.set_inbound_compressor(compress_in())
def _activate_outbound(self):
@@ -2332,37 +2457,41 @@ class Transport(threading.Thread, ClosingContextManager):
m = Message()
m.add_byte(cMSG_NEWKEYS)
self._send_message(m)
- block_size = self._cipher_info[self.local_cipher]['block-size']
+ block_size = self._cipher_info[self.local_cipher]["block-size"]
if self.server_mode:
- IV_out = self._compute_key('B', block_size)
+ IV_out = self._compute_key("B", block_size)
key_out = self._compute_key(
- 'D', self._cipher_info[self.local_cipher]['key-size'])
+ "D", self._cipher_info[self.local_cipher]["key-size"]
+ )
else:
- IV_out = self._compute_key('A', block_size)
+ IV_out = self._compute_key("A", block_size)
key_out = self._compute_key(
- 'C', self._cipher_info[self.local_cipher]['key-size'])
+ "C", self._cipher_info[self.local_cipher]["key-size"]
+ )
engine = self._get_cipher(
- self.local_cipher, key_out, IV_out, self._ENCRYPT)
- mac_size = self._mac_info[self.local_mac]['size']
- mac_engine = self._mac_info[self.local_mac]['class']
+ self.local_cipher, key_out, IV_out, self._ENCRYPT
+ )
+ mac_size = self._mac_info[self.local_mac]["size"]
+ mac_engine = self._mac_info[self.local_mac]["class"]
# initial mac keys are done in the hash's natural size (not the
# potentially truncated transmission size)
if self.server_mode:
- mac_key = self._compute_key('F', mac_engine().digest_size)
+ mac_key = self._compute_key("F", mac_engine().digest_size)
else:
- mac_key = self._compute_key('E', mac_engine().digest_size)
- sdctr = self.local_cipher.endswith('-ctr')
+ mac_key = self._compute_key("E", mac_engine().digest_size)
+ sdctr = self.local_cipher.endswith("-ctr")
self.packetizer.set_outbound_cipher(
- engine, block_size, mac_engine, mac_size, mac_key, sdctr)
+ engine, block_size, mac_engine, mac_size, mac_key, sdctr
+ )
compress_out = self._compression_info[self.local_compression][0]
if (
- compress_out is not None and
- (
- self.local_compression != 'zlib@openssh.com' or
- self.authenticated
+ compress_out is not None
+ and (
+ self.local_compression != "zlib@openssh.com"
+ or self.authenticated
)
):
- self._log(DEBUG, 'Switching on outbound compression ...')
+ self._log(DEBUG, "Switching on outbound compression ...")
self.packetizer.set_outbound_compressor(compress_out())
if not self.packetizer.need_rekey():
self.in_kex = False
@@ -2372,17 +2501,17 @@ class Transport(threading.Thread, ClosingContextManager):
def _auth_trigger(self):
self.authenticated = True
# delayed initiation of compression
- if self.local_compression == 'zlib@openssh.com':
+ if self.local_compression == "zlib@openssh.com":
compress_out = self._compression_info[self.local_compression][0]
- self._log(DEBUG, 'Switching on outbound compression ...')
+ self._log(DEBUG, "Switching on outbound compression ...")
self.packetizer.set_outbound_compressor(compress_out())
- if self.remote_compression == 'zlib@openssh.com':
+ if self.remote_compression == "zlib@openssh.com":
compress_in = self._compression_info[self.remote_compression][1]
- self._log(DEBUG, 'Switching on inbound compression ...')
+ self._log(DEBUG, "Switching on inbound compression ...")
self.packetizer.set_inbound_compressor(compress_in())
def _parse_newkeys(self, m):
- self._log(DEBUG, 'Switch to new keys ...')
+ self._log(DEBUG, "Switch to new keys ...")
self._activate_inbound()
# can also free a bunch of stuff here
self.local_kex_init = self.remote_kex_init = None
@@ -2410,7 +2539,7 @@ class Transport(threading.Thread, ClosingContextManager):
def _parse_disconnect(self, m):
code = m.get_int()
desc = m.get_text()
- self._log(INFO, 'Disconnect (code {:d}): {}'.format(code, desc))
+ self._log(INFO, "Disconnect (code {:d}): {}".format(code, desc))
def _parse_global_request(self, m):
kind = m.get_text()
@@ -2419,16 +2548,16 @@ class Transport(threading.Thread, ClosingContextManager):
if not self.server_mode:
self._log(
DEBUG,
- 'Rejecting "{}" global request from server.'.format(kind)
+ 'Rejecting "{}" global request from server.'.format(kind),
)
ok = False
- elif kind == 'tcpip-forward':
+ elif kind == "tcpip-forward":
address = m.get_text()
port = m.get_int()
ok = self.server_object.check_port_forward_request(address, port)
if ok:
ok = (ok,)
- elif kind == 'cancel-tcpip-forward':
+ elif kind == "cancel-tcpip-forward":
address = m.get_text()
port = m.get_int()
self.server_object.cancel_port_forward_request(address, port)
@@ -2449,13 +2578,13 @@ class Transport(threading.Thread, ClosingContextManager):
self._send_message(msg)
def _parse_request_success(self, m):
- self._log(DEBUG, 'Global request successful.')
+ self._log(DEBUG, "Global request successful.")
self.global_response = m
if self.completion_event is not None:
self.completion_event.set()
def _parse_request_failure(self, m):
- self._log(DEBUG, 'Global request denied.')
+ self._log(DEBUG, "Global request denied.")
self.global_response = None
if self.completion_event is not None:
self.completion_event.set()
@@ -2467,13 +2596,14 @@ class Transport(threading.Thread, ClosingContextManager):
server_max_packet_size = m.get_int()
chan = self._channels.get(chanid)
if chan is None:
- self._log(WARNING, 'Success for unrequested channel! [??]')
+ self._log(WARNING, "Success for unrequested channel! [??]")
return
self.lock.acquire()
try:
chan._set_remote_channel(
- server_chanid, server_window_size, server_max_packet_size)
- self._log(DEBUG, 'Secsh channel {:d} opened.'.format(chanid))
+ server_chanid, server_window_size, server_max_packet_size
+ )
+ self._log(DEBUG, "Secsh channel {:d} opened.".format(chanid))
if chanid in self.channel_events:
self.channel_events[chanid].set()
del self.channel_events[chanid]
@@ -2486,12 +2616,12 @@ class Transport(threading.Thread, ClosingContextManager):
reason = m.get_int()
reason_str = m.get_text()
m.get_text() # ignored language
- reason_text = CONNECTION_FAILED_CODE.get(reason, '(unknown code)')
+ reason_text = CONNECTION_FAILED_CODE.get(reason, "(unknown code)")
self._log(
ERROR,
- 'Secsh channel {:d} open FAILED: {}: {}'.format(
- chanid, reason_str, reason_text,
- )
+ "Secsh channel {:d} open FAILED: {}: {}".format(
+ chanid, reason_str, reason_text
+ ),
)
self.lock.acquire()
try:
@@ -2512,39 +2642,39 @@ class Transport(threading.Thread, ClosingContextManager):
max_packet_size = m.get_int()
reject = False
if (
- kind == 'auth-agent@openssh.com' and
- self._forward_agent_handler is not None
+ kind == "auth-agent@openssh.com"
+ and self._forward_agent_handler is not None
):
- self._log(DEBUG, 'Incoming forward agent connection')
+ self._log(DEBUG, "Incoming forward agent connection")
self.lock.acquire()
try:
my_chanid = self._next_channel()
finally:
self.lock.release()
- elif (kind == 'x11') and (self._x11_handler is not None):
+ elif (kind == "x11") and (self._x11_handler is not None):
origin_addr = m.get_text()
origin_port = m.get_int()
self._log(
DEBUG,
- 'Incoming x11 connection from {}:{:d}'.format(
- origin_addr, origin_port,
- )
+ "Incoming x11 connection from {}:{:d}".format(
+ origin_addr, origin_port
+ ),
)
self.lock.acquire()
try:
my_chanid = self._next_channel()
finally:
self.lock.release()
- elif (kind == 'forwarded-tcpip') and (self._tcp_handler is not None):
+ elif (kind == "forwarded-tcpip") and (self._tcp_handler is not None):
server_addr = m.get_text()
server_port = m.get_int()
origin_addr = m.get_text()
origin_port = m.get_int()
self._log(
DEBUG,
- 'Incoming tcp forwarded connection from {}:{:d}'.format(
- origin_addr, origin_port,
- )
+ "Incoming tcp forwarded connection from {}:{:d}".format(
+ origin_addr, origin_port
+ ),
)
self.lock.acquire()
try:
@@ -2554,7 +2684,8 @@ class Transport(threading.Thread, ClosingContextManager):
elif not self.server_mode:
self._log(
DEBUG,
- 'Rejecting "{}" channel request from server.'.format(kind))
+ 'Rejecting "{}" channel request from server.'.format(kind),
+ )
reject = True
reason = OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
else:
@@ -2563,7 +2694,7 @@ class Transport(threading.Thread, ClosingContextManager):
my_chanid = self._next_channel()
finally:
self.lock.release()
- if kind == 'direct-tcpip':
+ if kind == "direct-tcpip":
# handle direct-tcpip requests coming from the client
dest_addr = m.get_text()
dest_port = m.get_int()
@@ -2572,23 +2703,25 @@ class Transport(threading.Thread, ClosingContextManager):
reason = self.server_object.check_channel_direct_tcpip_request(
my_chanid,
(origin_addr, origin_port),
- (dest_addr, dest_port)
+ (dest_addr, dest_port),
)
else:
reason = self.server_object.check_channel_request(
- kind, my_chanid)
+ kind, my_chanid
+ )
if reason != OPEN_SUCCEEDED:
self._log(
DEBUG,
- 'Rejecting "{}" channel request from client.'.format(kind))
+ 'Rejecting "{}" channel request from client.'.format(kind),
+ )
reject = True
if reject:
msg = Message()
msg.add_byte(cMSG_CHANNEL_OPEN_FAILURE)
msg.add_int(chanid)
msg.add_int(reason)
- msg.add_string('')
- msg.add_string('en')
+ msg.add_string("")
+ msg.add_string("en")
self._send_message(msg)
return
@@ -2599,9 +2732,11 @@ class Transport(threading.Thread, ClosingContextManager):
self.channels_seen[my_chanid] = True
chan._set_transport(self)
chan._set_window(
- self.default_window_size, self.default_max_packet_size)
+ self.default_window_size, self.default_max_packet_size
+ )
chan._set_remote_channel(
- chanid, initial_window_size, max_packet_size)
+ chanid, initial_window_size, max_packet_size
+ )
finally:
self.lock.release()
m = Message()
@@ -2611,19 +2746,17 @@ class Transport(threading.Thread, ClosingContextManager):
m.add_int(self.default_window_size)
m.add_int(self.default_max_packet_size)
self._send_message(m)
- self._log(DEBUG,
- 'Secsh channel {:d} ({}) opened.'.format(my_chanid, kind)
+ self._log(
+ DEBUG, "Secsh channel {:d} ({}) opened.".format(my_chanid, kind)
)
- if kind == 'auth-agent@openssh.com':
+ if kind == "auth-agent@openssh.com":
self._forward_agent_handler(chan)
- elif kind == 'x11':
+ elif kind == "x11":
self._x11_handler(chan, (origin_addr, origin_port))
- elif kind == 'forwarded-tcpip':
+ elif kind == "forwarded-tcpip":
chan.origin_addr = (origin_addr, origin_port)
self._tcp_handler(
- chan,
- (origin_addr, origin_port),
- (server_addr, server_port)
+ chan, (origin_addr, origin_port), (server_addr, server_port)
)
else:
self._queue_incoming_channel(chan)
@@ -2632,7 +2765,7 @@ class Transport(threading.Thread, ClosingContextManager):
m.get_boolean() # always_display
msg = m.get_string()
m.get_string() # language
- self._log(DEBUG, 'Debug msg: {}'.format(util.safe_string(msg)))
+ self._log(DEBUG, "Debug msg: {}".format(util.safe_string(msg)))
def _get_subsystem_handler(self, name):
try:
@@ -2666,7 +2799,7 @@ class Transport(threading.Thread, ClosingContextManager):
}
-class SecurityOptions (object):
+class SecurityOptions(object):
"""
Simple object containing the security preferences of an ssh transport.
These are tuples of acceptable ciphers, digests, key types, and key
@@ -2678,7 +2811,7 @@ class SecurityOptions (object):
``ValueError`` will be raised. If you try to assign something besides a
tuple to one of the fields, ``TypeError`` will be raised.
"""
- __slots__ = '_transport'
+ __slots__ = "_transport"
def __init__(self, transport):
self._transport = transport
@@ -2687,17 +2820,17 @@ class SecurityOptions (object):
"""
Returns a string representation of this object, for debugging.
"""
- return '<paramiko.SecurityOptions for {!r}>'.format(self._transport)
+ return "<paramiko.SecurityOptions for {!r}>".format(self._transport)
def _set(self, name, orig, x):
if type(x) is list:
x = tuple(x)
if type(x) is not tuple:
- raise TypeError('expected tuple or list')
+ raise TypeError("expected tuple or list")
possible = list(getattr(self._transport, orig).keys())
forbidden = [n for n in x if n not in possible]
if len(forbidden) > 0:
- raise ValueError('unknown cipher')
+ raise ValueError("unknown cipher")
setattr(self._transport, name, x)
@property
@@ -2707,7 +2840,7 @@ class SecurityOptions (object):
@ciphers.setter
def ciphers(self, x):
- self._set('_preferred_ciphers', '_cipher_info', x)
+ self._set("_preferred_ciphers", "_cipher_info", x)
@property
def digests(self):
@@ -2716,7 +2849,7 @@ class SecurityOptions (object):
@digests.setter
def digests(self, x):
- self._set('_preferred_macs', '_mac_info', x)
+ self._set("_preferred_macs", "_mac_info", x)
@property
def key_types(self):
@@ -2725,8 +2858,7 @@ class SecurityOptions (object):
@key_types.setter
def key_types(self, x):
- self._set('_preferred_keys', '_key_info', x)
-
+ self._set("_preferred_keys", "_key_info", x)
@property
def kex(self):
@@ -2735,7 +2867,7 @@ class SecurityOptions (object):
@kex.setter
def kex(self, x):
- self._set('_preferred_kex', '_kex_info', x)
+ self._set("_preferred_kex", "_kex_info", x)
@property
def compression(self):
@@ -2744,10 +2876,11 @@ class SecurityOptions (object):
@compression.setter
def compression(self, x):
- self._set('_preferred_compression', '_compression_info', x)
+ self._set("_preferred_compression", "_compression_info", x)
+
+class ChannelMap(object):
-class ChannelMap (object):
def __init__(self):
# (id -> Channel)
self._map = weakref.WeakValueDictionary()
diff --git a/paramiko/util.py b/paramiko/util.py
index 2854ef9..c60c040 100644
--- a/paramiko/util.py
+++ b/paramiko/util.py
@@ -49,7 +49,7 @@ def inflate_long(s, always_positive=False):
# noinspection PyAugmentAssignment
s = filler * (4 - len(s) % 4) + s
for i in range(0, len(s), 4):
- out = (out << 32) + struct.unpack('>I', s[i:i + 4])[0]
+ out = (out << 32) + struct.unpack(">I", s[i:i + 4])[0]
if negative:
out -= (long(1) << (8 * len(s)))
return out
@@ -66,7 +66,7 @@ def deflate_long(n, add_sign_padding=True):
s = bytes()
n = long(n)
while (n != 0) and (n != -1):
- s = struct.pack('>I', n & xffffffff) + s
+ s = struct.pack(">I", n & xffffffff) + s
n >>= 32
# strip off leading zeros, FFs
for i in enumerate(s):
@@ -90,7 +90,7 @@ def deflate_long(n, add_sign_padding=True):
return s
-def format_binary(data, prefix=''):
+def format_binary(data, prefix=""):
x = 0
out = []
while len(data) > x + 16:
@@ -102,22 +102,21 @@ def format_binary(data, prefix=''):
def format_binary_line(data):
- left = ' '.join(['{:02X}'.format(byte_ord(c)) for c in data])
- right = ''.join([
- '.{:c}..'.format(byte_ord(c))[(byte_ord(c) + 63) // 95]
- for c in data
- ])
- return '{:50s} {}'.format(left, right)
+ left = " ".join(["{:02X}".format(byte_ord(c)) for c in data])
+ right = "".join(
+ [".{:c}..".format(byte_ord(c))[(byte_ord(c) + 63) // 95] for c in data]
+ )
+ return "{:50s} {}".format(left, right)
def safe_string(s):
- out = b''
+ out = b""
for c in s:
i = byte_ord(c)
if 32 <= i <= 127:
out += byte_chr(i)
else:
- out += b('%{:02X}'.format(i))
+ out += b("%{:02X}".format(i))
return out
@@ -137,7 +136,7 @@ def bit_length(n):
def tb_strings():
- return ''.join(traceback.format_exception(*sys.exc_info())).split('\n')
+ return "".join(traceback.format_exception(*sys.exc_info())).split("\n")
def generate_key_bytes(hash_alg, salt, key, nbytes):
@@ -188,6 +187,7 @@ def load_host_keys(filename):
nested dict of `.PKey` objects, indexed by hostname and then keytype
"""
from paramiko.hostkeys import HostKeys
+
return HostKeys(filename)
@@ -249,15 +249,16 @@ def log_to_file(filename, level=DEBUG):
if len(l.handlers) > 0:
return
l.setLevel(level)
- f = open(filename, 'a')
+ f = open(filename, "a")
lh = logging.StreamHandler(f)
- frm = '%(levelname)-.3s [%(asctime)s.%(msecs)03d] thr=%(_threadid)-3d %(name)s: %(message)s' # noqa
- lh.setFormatter(logging.Formatter(frm, '%Y%m%d-%H:%M:%S'))
+ frm = "%(levelname)-.3s [%(asctime)s.%(msecs)03d] thr=%(_threadid)-3d %(name)s: %(message)s" # noqa
+ lh.setFormatter(logging.Formatter(frm, "%Y%m%d-%H:%M:%S"))
l.addHandler(lh)
# make only one filter object, so it doesn't get applied more than once
-class PFilter (object):
+class PFilter(object):
+
def filter(self, record):
record._threadid = get_thread_id()
return True
@@ -293,6 +294,7 @@ def constant_time_bytes_eq(a, b):
class ClosingContextManager(object):
+
def __enter__(self):
return self
diff --git a/paramiko/win_pageant.py b/paramiko/win_pageant.py
index 661ba57..2bba789 100644
--- a/paramiko/win_pageant.py
+++ b/paramiko/win_pageant.py
@@ -44,7 +44,7 @@ win32con_WM_COPYDATA = 74
def _get_pageant_window_object():
- return ctypes.windll.user32.FindWindowA(b'Pageant', b'Pageant')
+ return ctypes.windll.user32.FindWindowA(b"Pageant", b"Pageant")
def can_talk_to_agent():
@@ -57,7 +57,7 @@ def can_talk_to_agent():
return bool(_get_pageant_window_object())
-if platform.architecture()[0] == '64bit':
+if platform.architecture()[0] == "64bit":
ULONG_PTR = ctypes.c_uint64
else:
ULONG_PTR = ctypes.c_uint32
@@ -69,9 +69,9 @@ class COPYDATASTRUCT(ctypes.Structure):
http://msdn.microsoft.com/en-us/library/windows/desktop/ms649010%28v=vs.85%29.aspx
"""
_fields_ = [
- ('num_data', ULONG_PTR),
- ('data_size', ctypes.wintypes.DWORD),
- ('data_loc', ctypes.c_void_p),
+ ("num_data", ULONG_PTR),
+ ("data_size", ctypes.wintypes.DWORD),
+ ("data_loc", ctypes.c_void_p),
]
@@ -86,27 +86,29 @@ def _query_pageant(msg):
return None
# create a name for the mmap
- map_name = 'PageantRequest%08x' % thread.get_ident()
+ map_name = "PageantRequest%08x" % thread.get_ident()
- pymap = _winapi.MemoryMap(map_name, _AGENT_MAX_MSGLEN,
- _winapi.get_security_attributes_for_user(),
- )
+ pymap = _winapi.MemoryMap(
+ map_name, _AGENT_MAX_MSGLEN, _winapi.get_security_attributes_for_user()
+ )
with pymap:
pymap.write(msg)
# Create an array buffer containing the mapped filename
char_buffer = array.array("b", b(map_name) + zero_byte) # noqa
char_buffer_address, char_buffer_size = char_buffer.buffer_info()
# Create a string to use for the SendMessage function call
- cds = COPYDATASTRUCT(_AGENT_COPYDATA_ID, char_buffer_size,
- char_buffer_address)
+ cds = COPYDATASTRUCT(
+ _AGENT_COPYDATA_ID, char_buffer_size, char_buffer_address
+ )
- response = ctypes.windll.user32.SendMessageA(hwnd,
- win32con_WM_COPYDATA, ctypes.sizeof(cds), ctypes.byref(cds))
+ response = ctypes.windll.user32.SendMessageA(
+ hwnd, win32con_WM_COPYDATA, ctypes.sizeof(cds), ctypes.byref(cds)
+ )
if response > 0:
pymap.seek(0)
datalen = pymap.read(4)
- retlen = struct.unpack('>I', datalen)[0]
+ retlen = struct.unpack(">I", datalen)[0]
return datalen + pymap.read(retlen)
return None
@@ -127,10 +129,10 @@ class PageantConnection(object):
def recv(self, n):
if self._response is None:
- return ''
+ return ""
ret = self._response[:n]
self._response = self._response[n:]
- if self._response == '':
+ if self._response == "":
self._response = None
return ret
diff --git a/tests/conftest.py b/tests/conftest.py
index d1967a7..2b509c5 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -19,7 +19,7 @@ from .util import _support
# presenting it on error/failure. (But also allow turning it off when doing
# very pinpoint debugging - e.g. using breakpoints, so you don't want output
# hiding enabled, but also don't want all the logging to gum up the terminal.)
-if not os.environ.get('DISABLE_LOGGING', False):
+if not os.environ.get("DISABLE_LOGGING", False):
logging.basicConfig(
level=logging.DEBUG,
# Also make sure to set up timestamping for more sanity when debugging.
@@ -43,7 +43,7 @@ def make_sftp_folder():
# TODO: if we want to lock ourselves even harder into localhost-only
# testing (probably not?) could use tempdir modules for this for improved
# safety. Then again...why would someone have such a folder???
- path = os.environ.get('TEST_FOLDER', 'paramiko-test-target')
+ path = os.environ.get("TEST_FOLDER", "paramiko-test-target")
# Forcibly nuke this directory locally, since at the moment, the below
# fixtures only ever run with a locally scoped stub test server.
shutil.rmtree(path, ignore_errors=True)
@@ -52,7 +52,7 @@ def make_sftp_folder():
return path
-@pytest.fixture#(scope='session')
+@pytest.fixture # (scope='session')
def sftp_server():
"""
Set up an in-memory SFTP server thread. Yields the client Transport/socket.
@@ -69,17 +69,17 @@ def sftp_server():
tc = Transport(sockc)
ts = Transport(socks)
# Auth
- host_key = RSAKey.from_private_key_file(_support('test_rsa.key'))
+ host_key = RSAKey.from_private_key_file(_support("test_rsa.key"))
ts.add_server_key(host_key)
# Server setup
event = threading.Event()
server = StubServer()
- ts.set_subsystem_handler('sftp', SFTPServer, StubSFTPServer)
+ ts.set_subsystem_handler("sftp", SFTPServer, StubSFTPServer)
ts.start_server(event, server)
# Wait (so client has time to connect? Not sure. Old.)
event.wait(1.0)
# Make & yield connection.
- tc.connect(username='slowdive', password='pygmalion')
+ tc.connect(username="slowdive", password="pygmalion")
yield tc
# TODO: any need for shutdown? Why didn't old suite do so? Or was that the
# point of the "join all threads from threading module" crap in test.py?
diff --git a/tests/loop.py b/tests/loop.py
index 6c43286..dd1f5a0 100644
--- a/tests/loop.py
+++ b/tests/loop.py
@@ -22,13 +22,13 @@ import threading
from paramiko.common import asbytes
-class LoopSocket (object):
+class LoopSocket(object):
"""
A LoopSocket looks like a normal socket, but all data written to it is
delivered on the read-end of another LoopSocket, and vice versa. It's
like a software "socketpair".
"""
-
+
def __init__(self):
self.__in_buffer = bytes()
self.__lock = threading.Lock()
@@ -84,7 +84,7 @@ class LoopSocket (object):
self.__cv.notifyAll()
finally:
self.__lock.release()
-
+
def __unlink(self):
m = None
self.__lock.acquire()
@@ -96,5 +96,3 @@ class LoopSocket (object):
self.__lock.release()
if m is not None:
m.__unlink()
-
-
diff --git a/tests/stub_sftp.py b/tests/stub_sftp.py
index 1954586..170304a 100644
--- a/tests/stub_sftp.py
+++ b/tests/stub_sftp.py
@@ -24,13 +24,21 @@ import os
import sys
from paramiko import (
- ServerInterface, SFTPServerInterface, SFTPServer, SFTPAttributes,
- SFTPHandle, SFTP_OK, SFTP_FAILURE, AUTH_SUCCESSFUL, OPEN_SUCCEEDED,
+ ServerInterface,
+ SFTPServerInterface,
+ SFTPServer,
+ SFTPAttributes,
+ SFTPHandle,
+ SFTP_OK,
+ SFTP_FAILURE,
+ AUTH_SUCCESSFUL,
+ OPEN_SUCCEEDED,
)
from paramiko.common import o666
-class StubServer (ServerInterface):
+class StubServer(ServerInterface):
+
def check_auth_password(self, username, password):
# all are allowed
return AUTH_SUCCESSFUL
@@ -39,7 +47,8 @@ class StubServer (ServerInterface):
return OPEN_SUCCEEDED
-class StubSFTPHandle (SFTPHandle):
+class StubSFTPHandle(SFTPHandle):
+
def stat(self):
try:
return SFTPAttributes.from_stat(os.fstat(self.readfile.fileno()))
@@ -56,11 +65,11 @@ class StubSFTPHandle (SFTPHandle):
return SFTPServer.convert_errno(e.errno)
-class StubSFTPServer (SFTPServerInterface):
+class StubSFTPServer(SFTPServerInterface):
# assume current folder is a fine root
# (the tests always create and eventually delete a subfolder, so there shouldn't be any mess)
ROOT = os.getcwd()
-
+
def _realpath(self, path):
return self.ROOT + self.canonicalize(path)
@@ -70,7 +79,9 @@ class StubSFTPServer (SFTPServerInterface):
out = []
flist = os.listdir(path)
for fname in flist:
- attr = SFTPAttributes.from_stat(os.stat(os.path.join(path, fname)))
+ attr = SFTPAttributes.from_stat(
+ os.stat(os.path.join(path, fname))
+ )
attr.filename = fname
out.append(attr)
return out
@@ -94,9 +105,9 @@ class StubSFTPServer (SFTPServerInterface):
def open(self, path, flags, attr):
path = self._realpath(path)
try:
- binary_flag = getattr(os, 'O_BINARY', 0)
+ binary_flag = getattr(os, "O_BINARY", 0)
flags |= binary_flag
- mode = getattr(attr, 'st_mode', None)
+ mode = getattr(attr, "st_mode", None)
if mode is not None:
fd = os.open(path, flags, mode)
else:
@@ -110,17 +121,17 @@ class StubSFTPServer (SFTPServerInterface):
SFTPServer.set_file_attr(path, attr)
if flags & os.O_WRONLY:
if flags & os.O_APPEND:
- fstr = 'ab'
+ fstr = "ab"
else:
- fstr = 'wb'
+ fstr = "wb"
elif flags & os.O_RDWR:
if flags & os.O_APPEND:
- fstr = 'a+b'
+ fstr = "a+b"
else:
- fstr = 'r+b'
+ fstr = "r+b"
else:
# O_RDONLY (== 0)
- fstr = 'rb'
+ fstr = "rb"
try:
f = os.fdopen(fd, fstr)
except OSError as e:
@@ -159,7 +170,6 @@ class StubSFTPServer (SFTPServerInterface):
return SFTPServer.convert_errno(e.errno)
return SFTP_OK
-
def mkdir(self, path, attr):
path = self._realpath(path)
try:
@@ -188,10 +198,10 @@ class StubSFTPServer (SFTPServerInterface):
def symlink(self, target_path, path):
path = self._realpath(path)
- if (len(target_path) > 0) and (target_path[0] == '/'):
+ if (len(target_path) > 0) and (target_path[0] == "/"):
# absolute symlink
target_path = os.path.join(self.ROOT, target_path[1:])
- if target_path[:2] == '//':
+ if target_path[:2] == "//":
# bug in os.path.join
target_path = target_path[1:]
else:
@@ -199,7 +209,7 @@ class StubSFTPServer (SFTPServerInterface):
abspath = os.path.join(os.path.dirname(path), target_path)
if abspath[:len(self.ROOT)] != self.ROOT:
# this symlink isn't going to work anyway -- just break it immediately
- target_path = '<error>'
+ target_path = "<error>"
try:
os.symlink(target_path, path)
except OSError as e:
@@ -216,8 +226,8 @@ class StubSFTPServer (SFTPServerInterface):
if os.path.isabs(symlink):
if symlink[:len(self.ROOT)] == self.ROOT:
symlink = symlink[len(self.ROOT):]
- if (len(symlink) == 0) or (symlink[0] != '/'):
- symlink = '/' + symlink
+ if (len(symlink) == 0) or (symlink[0] != "/"):
+ symlink = "/" + symlink
else:
- symlink = '<error>'
+ symlink = "<error>"
return symlink
diff --git a/tests/test_auth.py b/tests/test_auth.py
index dacdd65..acabb1b 100644
--- a/tests/test_auth.py
+++ b/tests/test_auth.py
@@ -26,8 +26,13 @@ import unittest
from time import sleep
from paramiko import (
- Transport, ServerInterface, RSAKey, DSSKey, BadAuthenticationType,
- InteractiveQuery, AuthenticationException,
+ Transport,
+ ServerInterface,
+ RSAKey,
+ DSSKey,
+ BadAuthenticationType,
+ InteractiveQuery,
+ AuthenticationException,
)
from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL
from paramiko.py3compat import u
@@ -36,54 +41,57 @@ from .loop import LoopSocket
from .util import _support, slow
-_pwd = u('\u2022')
+_pwd = u("\u2022")
-class NullServer (ServerInterface):
+class NullServer(ServerInterface):
paranoid_did_password = False
paranoid_did_public_key = False
- paranoid_key = DSSKey.from_private_key_file(_support('test_dss.key'))
+ paranoid_key = DSSKey.from_private_key_file(_support("test_dss.key"))
def get_allowed_auths(self, username):
- if username == 'slowdive':
- return 'publickey,password'
- if username == 'paranoid':
- if not self.paranoid_did_password and not self.paranoid_did_public_key:
- return 'publickey,password'
+ if username == "slowdive":
+ return "publickey,password"
+ if username == "paranoid":
+ if (
+ not self.paranoid_did_password
+ and not self.paranoid_did_public_key
+ ):
+ return "publickey,password"
elif self.paranoid_did_password:
- return 'publickey'
+ return "publickey"
else:
- return 'password'
- if username == 'commie':
- return 'keyboard-interactive'
- if username == 'utf8':
- return 'password'
- if username == 'non-utf8':
- return 'password'
- return 'publickey'
+ return "password"
+ if username == "commie":
+ return "keyboard-interactive"
+ if username == "utf8":
+ return "password"
+ if username == "non-utf8":
+ return "password"
+ return "publickey"
def check_auth_password(self, username, password):
- if (username == 'slowdive') and (password == 'pygmalion'):
+ if (username == "slowdive") and (password == "pygmalion"):
return AUTH_SUCCESSFUL
- if (username == 'paranoid') and (password == 'paranoid'):
+ if (username == "paranoid") and (password == "paranoid"):
# 2-part auth (even openssh doesn't support this)
self.paranoid_did_password = True
if self.paranoid_did_public_key:
return AUTH_SUCCESSFUL
return AUTH_PARTIALLY_SUCCESSFUL
- if (username == 'utf8') and (password == _pwd):
+ if (username == "utf8") and (password == _pwd):
return AUTH_SUCCESSFUL
- if (username == 'non-utf8') and (password == '\xff'):
+ if (username == "non-utf8") and (password == "\xff"):
return AUTH_SUCCESSFUL
- if username == 'bad-server':
+ if username == "bad-server":
raise Exception("Ack!")
- if username == 'unresponsive-server':
+ if username == "unresponsive-server":
sleep(5)
return AUTH_SUCCESSFUL
return AUTH_FAILED
def check_auth_publickey(self, username, key):
- if (username == 'paranoid') and (key == self.paranoid_key):
+ if (username == "paranoid") and (key == self.paranoid_key):
# 2-part auth
self.paranoid_did_public_key = True
if self.paranoid_did_password:
@@ -92,19 +100,21 @@ class NullServer (ServerInterface):
return AUTH_FAILED
def check_auth_interactive(self, username, submethods):
- if username == 'commie':
+ if username == "commie":
self.username = username
- return InteractiveQuery('password', 'Please enter a password.', ('Password', False))
+ return InteractiveQuery(
+ "password", "Please enter a password.", ("Password", False)
+ )
return AUTH_FAILED
def check_auth_interactive_response(self, responses):
- if self.username == 'commie':
- if (len(responses) == 1) and (responses[0] == 'cat'):
+ if self.username == "commie":
+ if (len(responses) == 1) and (responses[0] == "cat"):
return AUTH_SUCCESSFUL
return AUTH_FAILED
-class AuthTest (unittest.TestCase):
+class AuthTest(unittest.TestCase):
def setUp(self):
self.socks = LoopSocket()
@@ -120,7 +130,7 @@ class AuthTest (unittest.TestCase):
self.sockc.close()
def start_server(self):
- host_key = RSAKey.from_private_key_file(_support('test_rsa.key'))
+ host_key = RSAKey.from_private_key_file(_support("test_rsa.key"))
self.public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key)
self.event = threading.Event()
@@ -140,13 +150,16 @@ class AuthTest (unittest.TestCase):
"""
self.start_server()
try:
- self.tc.connect(hostkey=self.public_host_key,
- username='unknown', password='error')
+ self.tc.connect(
+ hostkey=self.public_host_key,
+ username="unknown",
+ password="error",
+ )
self.assertTrue(False)
except:
etype, evalue, etb = sys.exc_info()
self.assertEqual(BadAuthenticationType, etype)
- self.assertEqual(['publickey'], evalue.allowed_types)
+ self.assertEqual(["publickey"], evalue.allowed_types)
def test_bad_password(self):
"""
@@ -156,12 +169,12 @@ class AuthTest (unittest.TestCase):
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
try:
- self.tc.auth_password(username='slowdive', password='error')
+ self.tc.auth_password(username="slowdive", password="error")
self.assertTrue(False)
except:
etype, evalue, etb = sys.exc_info()
self.assertTrue(issubclass(etype, AuthenticationException))
- self.tc.auth_password(username='slowdive', password='pygmalion')
+ self.tc.auth_password(username="slowdive", password="pygmalion")
self.verify_finished()
def test_multipart_auth(self):
@@ -170,10 +183,12 @@ class AuthTest (unittest.TestCase):
"""
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
- remain = self.tc.auth_password(username='paranoid', password='paranoid')
- self.assertEqual(['publickey'], remain)
- key = DSSKey.from_private_key_file(_support('test_dss.key'))
- remain = self.tc.auth_publickey(username='paranoid', key=key)
+ remain = self.tc.auth_password(
+ username="paranoid", password="paranoid"
+ )
+ self.assertEqual(["publickey"], remain)
+ key = DSSKey.from_private_key_file(_support("test_dss.key"))
+ remain = self.tc.auth_publickey(username="paranoid", key=key)
self.assertEqual([], remain)
self.verify_finished()
@@ -188,10 +203,11 @@ class AuthTest (unittest.TestCase):
self.got_title = title
self.got_instructions = instructions
self.got_prompts = prompts
- return ['cat']
- remain = self.tc.auth_interactive('commie', handler)
- self.assertEqual(self.got_title, 'password')
- self.assertEqual(self.got_prompts, [('Password', False)])
+ return ["cat"]
+
+ remain = self.tc.auth_interactive("commie", handler)
+ self.assertEqual(self.got_title, "password")
+ self.assertEqual(self.got_prompts, [("Password", False)])
self.assertEqual([], remain)
self.verify_finished()
@@ -202,7 +218,7 @@ class AuthTest (unittest.TestCase):
"""
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
- remain = self.tc.auth_password('commie', 'cat')
+ remain = self.tc.auth_password("commie", "cat")
self.assertEqual([], remain)
self.verify_finished()
@@ -212,7 +228,7 @@ class AuthTest (unittest.TestCase):
"""
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
- remain = self.tc.auth_password('utf8', _pwd)
+ remain = self.tc.auth_password("utf8", _pwd)
self.assertEqual([], remain)
self.verify_finished()
@@ -223,7 +239,7 @@ class AuthTest (unittest.TestCase):
"""
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
- remain = self.tc.auth_password('non-utf8', '\xff')
+ remain = self.tc.auth_password("non-utf8", "\xff")
self.assertEqual([], remain)
self.verify_finished()
@@ -235,7 +251,7 @@ class AuthTest (unittest.TestCase):
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
try:
- remain = self.tc.auth_password('bad-server', 'hello')
+ remain = self.tc.auth_password("bad-server", "hello")
except:
etype, evalue, etb = sys.exc_info()
self.assertTrue(issubclass(etype, AuthenticationException))
@@ -250,8 +266,8 @@ class AuthTest (unittest.TestCase):
self.start_server()
self.tc.connect()
try:
- remain = self.tc.auth_password('unresponsive-server', 'hello')
+ remain = self.tc.auth_password("unresponsive-server", "hello")
except:
etype, evalue, etb = sys.exc_info()
self.assertTrue(issubclass(etype, AuthenticationException))
- self.assertTrue('Authentication timeout' in str(evalue))
+ self.assertTrue("Authentication timeout" in str(evalue))
diff --git a/tests/test_buffered_pipe.py b/tests/test_buffered_pipe.py
index 03616c5..9f986a5 100644
--- a/tests/test_buffered_pipe.py
+++ b/tests/test_buffered_pipe.py
@@ -30,9 +30,9 @@ from paramiko.py3compat import b
def delay_thread(p):
- p.feed('a')
+ p.feed("a")
time.sleep(0.5)
- p.feed('b')
+ p.feed("b")
p.close()
@@ -42,41 +42,42 @@ def close_thread(p):
class BufferedPipeTest(unittest.TestCase):
+
def test_1_buffered_pipe(self):
p = BufferedPipe()
self.assertTrue(not p.read_ready())
- p.feed('hello.')
+ p.feed("hello.")
self.assertTrue(p.read_ready())
data = p.read(6)
- self.assertEqual(b'hello.', data)
+ self.assertEqual(b"hello.", data)
- p.feed('plus/minus')
- self.assertEqual(b'plu', p.read(3))
- self.assertEqual(b's/m', p.read(3))
- self.assertEqual(b'inus', p.read(4))
+ p.feed("plus/minus")
+ self.assertEqual(b"plu", p.read(3))
+ self.assertEqual(b"s/m", p.read(3))
+ self.assertEqual(b"inus", p.read(4))
p.close()
self.assertTrue(not p.read_ready())
- self.assertEqual(b'', p.read(1))
+ self.assertEqual(b"", p.read(1))
def test_2_delay(self):
p = BufferedPipe()
self.assertTrue(not p.read_ready())
threading.Thread(target=delay_thread, args=(p,)).start()
- self.assertEqual(b'a', p.read(1, 0.1))
+ self.assertEqual(b"a", p.read(1, 0.1))
try:
p.read(1, 0.1)
self.assertTrue(False)
except PipeTimeout:
pass
- self.assertEqual(b'b', p.read(1, 1.0))
- self.assertEqual(b'', p.read(1))
+ self.assertEqual(b"b", p.read(1, 1.0))
+ self.assertEqual(b"", p.read(1))
def test_3_close_while_reading(self):
p = BufferedPipe()
threading.Thread(target=close_thread, args=(p,)).start()
data = p.read(1, 1.0)
- self.assertEqual(b'', data)
+ self.assertEqual(b"", data)
def test_4_or_pipe(self):
p = pipe.make_pipe()
diff --git a/tests/test_client.py b/tests/test_client.py
index 7163fdc..4a1e982 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -48,30 +48,31 @@ requires_gss_auth = unittest.skipUnless(
)
FINGERPRINTS = {
- 'ssh-dss': b'\x44\x78\xf0\xb9\xa2\x3c\xc5\x18\x20\x09\xff\x75\x5b\xc1\xd2\x6c',
- 'ssh-rsa': b'\x60\x73\x38\x44\xcb\x51\x86\x65\x7f\xde\xda\xa2\x2b\x5a\x57\xd5',
- 'ecdsa-sha2-nistp256': b'\x25\x19\xeb\x55\xe6\xa1\x47\xff\x4f\x38\xd2\x75\x6f\xa5\xd5\x60',
- 'ssh-ed25519': b'\xb3\xd5"\xaa\xf9u^\xe8\xcd\x0e\xea\x02\xb9)\xa2\x80',
+ "ssh-dss": b"\x44\x78\xf0\xb9\xa2\x3c\xc5\x18\x20\x09\xff\x75\x5b\xc1\xd2\x6c",
+ "ssh-rsa": b"\x60\x73\x38\x44\xcb\x51\x86\x65\x7f\xde\xda\xa2\x2b\x5a\x57\xd5",
+ "ecdsa-sha2-nistp256": b"\x25\x19\xeb\x55\xe6\xa1\x47\xff\x4f\x38\xd2\x75\x6f\xa5\xd5\x60",
+ "ssh-ed25519": b'\xb3\xd5"\xaa\xf9u^\xe8\xcd\x0e\xea\x02\xb9)\xa2\x80',
}
class NullServer(paramiko.ServerInterface):
+
def __init__(self, *args, **kwargs):
# Allow tests to enable/disable specific key types
- self.__allowed_keys = kwargs.pop('allowed_keys', [])
+ self.__allowed_keys = kwargs.pop("allowed_keys", [])
# And allow them to set a (single...meh) expected public blob (cert)
- self.__expected_public_blob = kwargs.pop('public_blob', None)
+ self.__expected_public_blob = kwargs.pop("public_blob", None)
super(NullServer, self).__init__(*args, **kwargs)
def get_allowed_auths(self, username):
- if username == 'slowdive':
- return 'publickey,password'
- return 'publickey'
+ if username == "slowdive":
+ return "publickey,password"
+ return "publickey"
def check_auth_password(self, username, password):
- if (username == 'slowdive') and (password == 'pygmalion'):
+ if (username == "slowdive") and (password == "pygmalion"):
return paramiko.AUTH_SUCCESSFUL
- if (username == 'slowdive') and (password == 'unresponsive-server'):
+ if (username == "slowdive") and (password == "unresponsive-server"):
time.sleep(5)
return paramiko.AUTH_SUCCESSFUL
return paramiko.AUTH_FAILED
@@ -83,13 +84,13 @@ class NullServer(paramiko.ServerInterface):
return paramiko.AUTH_FAILED
# Base check: allowed auth type & fingerprint matches
happy = (
- key.get_name() in self.__allowed_keys and
- key.get_fingerprint() == expected
+ key.get_name() in self.__allowed_keys
+ and key.get_fingerprint() == expected
)
# Secondary check: if test wants assertions about cert data
if (
- self.__expected_public_blob is not None and
- key.public_blob != self.__expected_public_blob
+ self.__expected_public_blob is not None
+ and key.public_blob != self.__expected_public_blob
):
happy = False
return paramiko.AUTH_SUCCESSFUL if happy else paramiko.AUTH_FAILED
@@ -98,31 +99,32 @@ class NullServer(paramiko.ServerInterface):
return paramiko.OPEN_SUCCEEDED
def check_channel_exec_request(self, channel, command):
- if command != b'yes':
+ if command != b"yes":
return False
return True
def check_channel_env_request(self, channel, name, value):
- if name == 'INVALID_ENV':
+ if name == "INVALID_ENV":
return False
- if not hasattr(channel, 'env'):
- setattr(channel, 'env', {})
+ if not hasattr(channel, "env"):
+ setattr(channel, "env", {})
channel.env[name] = value
return True
class ClientTest(unittest.TestCase):
+
def setUp(self):
self.sockl = socket.socket()
- self.sockl.bind(('localhost', 0))
+ self.sockl.bind(("localhost", 0))
self.sockl.listen(1)
self.addr, self.port = self.sockl.getsockname()
self.connect_kwargs = dict(
hostname=self.addr,
port=self.port,
- username='slowdive',
+ username="slowdive",
look_for_keys=False,
)
self.event = threading.Event()
@@ -130,10 +132,10 @@ class ClientTest(unittest.TestCase):
def tearDown(self):
# Shut down client Transport
- if hasattr(self, 'tc'):
+ if hasattr(self, "tc"):
self.tc.close()
# Shut down shared socket
- if hasattr(self, 'sockl'):
+ if hasattr(self, "sockl"):
# Signal to server thread that it should shut down early; it checks
# this immediately after accept(). (In scenarios where connection
# actually succeeded during the test, this becomes a no-op.)
@@ -151,7 +153,7 @@ class ClientTest(unittest.TestCase):
self.sockl.close()
def _run(
- self, allowed_keys=None, delay=0, public_blob=None, kill_event=None,
+ self, allowed_keys=None, delay=0, public_blob=None, kill_event=None
):
if allowed_keys is None:
allowed_keys = FINGERPRINTS.keys()
@@ -163,10 +165,10 @@ class ClientTest(unittest.TestCase):
self.socks.close()
return
self.ts = paramiko.Transport(self.socks)
- keypath = _support('test_rsa.key')
+ keypath = _support("test_rsa.key")
host_key = paramiko.RSAKey.from_private_key_file(keypath)
self.ts.add_server_key(host_key)
- keypath = _support('test_ecdsa_256.key')
+ keypath = _support("test_ecdsa_256.key")
host_key = paramiko.ECDSAKey.from_private_key_file(keypath)
self.ts.add_server_key(host_key)
server = NullServer(allowed_keys=allowed_keys, public_blob=public_blob)
@@ -181,17 +183,21 @@ class ClientTest(unittest.TestCase):
The exception is ``allowed_keys`` which is stripped and handed to the
``NullServer`` used for testing.
"""
- run_kwargs = {'kill_event': self.kill_event}
- for key in ('allowed_keys', 'public_blob'):
+ run_kwargs = {"kill_event": self.kill_event}
+ for key in ("allowed_keys", "public_blob"):
run_kwargs[key] = kwargs.pop(key, None)
# Server setup
threading.Thread(target=self._run, kwargs=run_kwargs).start()
- host_key = paramiko.RSAKey.from_private_key_file(_support('test_rsa.key'))
+ host_key = paramiko.RSAKey.from_private_key_file(
+ _support("test_rsa.key")
+ )
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
# Client setup
self.tc = paramiko.SSHClient()
- self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key)
+ self.tc.get_host_keys().add(
+ "[%s]:%d" % (self.addr, self.port), "ssh-rsa", public_host_key
+ )
# Actual connection
self.tc.connect(**dict(self.connect_kwargs, **kwargs))
@@ -200,22 +206,22 @@ class ClientTest(unittest.TestCase):
self.event.wait(1.0)
self.assertTrue(self.event.is_set())
self.assertTrue(self.ts.is_active())
- self.assertEqual('slowdive', self.ts.get_username())
+ self.assertEqual("slowdive", self.ts.get_username())
self.assertEqual(True, self.ts.is_authenticated())
self.assertEqual(False, self.tc.get_transport().gss_kex_used)
# Command execution functions?
- stdin, stdout, stderr = self.tc.exec_command('yes')
+ stdin, stdout, stderr = self.tc.exec_command("yes")
schan = self.ts.accept(1.0)
- schan.send('Hello there.\n')
- schan.send_stderr('This is on stderr.\n')
+ schan.send("Hello there.\n")
+ schan.send_stderr("This is on stderr.\n")
schan.close()
- self.assertEqual('Hello there.\n', stdout.readline())
- self.assertEqual('', stdout.readline())
- self.assertEqual('This is on stderr.\n', stderr.readline())
- self.assertEqual('', stderr.readline())
+ self.assertEqual("Hello there.\n", stdout.readline())
+ self.assertEqual("", stdout.readline())
+ self.assertEqual("This is on stderr.\n", stderr.readline())
+ self.assertEqual("", stderr.readline())
# Cleanup
stdin.close()
@@ -224,32 +230,33 @@ class ClientTest(unittest.TestCase):
class SSHClientTest(ClientTest):
+
def test_1_client(self):
"""
verify that the SSHClient stuff works too.
"""
- self._test_connection(password='pygmalion')
+ self._test_connection(password="pygmalion")
def test_2_client_dsa(self):
"""
verify that SSHClient works with a DSA key.
"""
- self._test_connection(key_filename=_support('test_dss.key'))
+ self._test_connection(key_filename=_support("test_dss.key"))
def test_client_rsa(self):
"""
verify that SSHClient works with an RSA key.
"""
- self._test_connection(key_filename=_support('test_rsa.key'))
+ self._test_connection(key_filename=_support("test_rsa.key"))
def test_2_5_client_ecdsa(self):
"""
verify that SSHClient works with an ECDSA key.
"""
- self._test_connection(key_filename=_support('test_ecdsa_256.key'))
+ self._test_connection(key_filename=_support("test_ecdsa_256.key"))
def test_client_ed25519(self):
- self._test_connection(key_filename=_support('test_ed25519.key'))
+ self._test_connection(key_filename=_support("test_ed25519.key"))
def test_3_multiple_key_files(self):
"""
@@ -257,22 +264,20 @@ class SSHClientTest(ClientTest):
"""
# This is dumb :(
types_ = {
- 'rsa': 'ssh-rsa',
- 'dss': 'ssh-dss',
- 'ecdsa': 'ecdsa-sha2-nistp256',
+ "rsa": "ssh-rsa", "dss": "ssh-dss", "ecdsa": "ecdsa-sha2-nistp256"
}
# Various combos of attempted & valid keys
# TODO: try every possible combo using itertools functions
for attempt, accept in (
- (['rsa', 'dss'], ['dss']), # Original test #3
- (['dss', 'rsa'], ['dss']), # Ordering matters sometimes, sadly
- (['dss', 'rsa', 'ecdsa_256'], ['dss']), # Try ECDSA but fail
- (['rsa', 'ecdsa_256'], ['ecdsa']), # ECDSA success
+ (["rsa", "dss"], ["dss"]), # Original test #3
+ (["dss", "rsa"], ["dss"]), # Ordering matters sometimes, sadly
+ (["dss", "rsa", "ecdsa_256"], ["dss"]), # Try ECDSA but fail
+ (["rsa", "ecdsa_256"], ["ecdsa"]), # ECDSA success
):
try:
self._test_connection(
key_filename=[
- _support('test_{}.key'.format(x)) for x in attempt
+ _support("test_{}.key".format(x)) for x in attempt
],
allowed_keys=[types_[x] for x in accept],
)
@@ -288,10 +293,11 @@ class SSHClientTest(ClientTest):
"""
# Until #387 is fixed we have to catch a high-up exception since
# various platforms trigger different errors here >_<
- self.assertRaises(SSHException,
+ self.assertRaises(
+ SSHException,
self._test_connection,
- key_filename=[_support('test_rsa.key')],
- allowed_keys=['ecdsa-sha2-nistp256'],
+ key_filename=[_support("test_rsa.key")],
+ allowed_keys=["ecdsa-sha2-nistp256"],
)
def test_certs_allowed_as_key_filename_values(self):
@@ -299,9 +305,9 @@ class SSHClientTest(ClientTest):
# They're similar except for which path is given; the expected auth and
# server-side behavior is 100% identical.)
# NOTE: only bothered whipping up one cert per overall class/family.
- for type_ in ('rsa', 'dss', 'ecdsa_256', 'ed25519'):
- cert_name = 'test_{}.key-cert.pub'.format(type_)
- cert_path = _support(os.path.join('cert_support', cert_name))
+ for type_ in ("rsa", "dss", "ecdsa_256", "ed25519"):
+ cert_name = "test_{}.key-cert.pub".format(type_)
+ cert_path = _support(os.path.join("cert_support", cert_name))
self._test_connection(
key_filename=cert_path,
public_blob=PublicBlob.from_file(cert_path),
@@ -314,13 +320,13 @@ class SSHClientTest(ClientTest):
# about the server-side key object's public blob. Thus, we can prove
# that a specific cert was found, along with regular authorization
# succeeding proving that the overall flow works.
- for type_ in ('rsa', 'dss', 'ecdsa_256', 'ed25519'):
- key_name = 'test_{}.key'.format(type_)
- key_path = _support(os.path.join('cert_support', key_name))
+ for type_ in ("rsa", "dss", "ecdsa_256", "ed25519"):
+ key_name = "test_{}.key".format(type_)
+ key_path = _support(os.path.join("cert_support", key_name))
self._test_connection(
key_filename=key_path,
public_blob=PublicBlob.from_file(
- '{}-cert.pub'.format(key_path)
+ "{}-cert.pub".format(key_path)
),
)
@@ -335,19 +341,19 @@ class SSHClientTest(ClientTest):
verify that SSHClient's AutoAddPolicy works.
"""
threading.Thread(target=self._run).start()
- hostname = '[%s]:%d' % (self.addr, self.port)
- key_file = _support('test_ecdsa_256.key')
+ hostname = "[%s]:%d" % (self.addr, self.port)
+ key_file = _support("test_ecdsa_256.key")
public_host_key = paramiko.ECDSAKey.from_private_key_file(key_file)
self.tc = paramiko.SSHClient()
self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.assertEqual(0, len(self.tc.get_host_keys()))
- self.tc.connect(password='pygmalion', **self.connect_kwargs)
+ self.tc.connect(password="pygmalion", **self.connect_kwargs)
self.event.wait(1.0)
self.assertTrue(self.event.is_set())
self.assertTrue(self.ts.is_active())
- self.assertEqual('slowdive', self.ts.get_username())
+ self.assertEqual("slowdive", self.ts.get_username())
self.assertEqual(True, self.ts.is_authenticated())
self.assertEqual(1, len(self.tc.get_host_keys()))
new_host_key = list(self.tc.get_host_keys()[hostname].values())[0]
@@ -357,9 +363,11 @@ class SSHClientTest(ClientTest):
"""
verify that SSHClient correctly saves a known_hosts file.
"""
- warnings.filterwarnings('ignore', 'tempnam.*')
+ warnings.filterwarnings("ignore", "tempnam.*")
- host_key = paramiko.RSAKey.from_private_key_file(_support('test_rsa.key'))
+ host_key = paramiko.RSAKey.from_private_key_file(
+ _support("test_rsa.key")
+ )
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
fd, localname = mkstemp()
os.close(fd)
@@ -367,11 +375,13 @@ class SSHClientTest(ClientTest):
client = paramiko.SSHClient()
self.assertEquals(0, len(client.get_host_keys()))
- host_id = '[%s]:%d' % (self.addr, self.port)
+ host_id = "[%s]:%d" % (self.addr, self.port)
- client.get_host_keys().add(host_id, 'ssh-rsa', public_host_key)
+ client.get_host_keys().add(host_id, "ssh-rsa", public_host_key)
self.assertEquals(1, len(client.get_host_keys()))
- self.assertEquals(public_host_key, client.get_host_keys()[host_id]['ssh-rsa'])
+ self.assertEquals(
+ public_host_key, client.get_host_keys()[host_id]["ssh-rsa"]
+ )
client.save_host_keys(localname)
@@ -394,7 +404,7 @@ class SSHClientTest(ClientTest):
self.tc = paramiko.SSHClient()
self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.assertEqual(0, len(self.tc.get_host_keys()))
- self.tc.connect(**dict(self.connect_kwargs, password='pygmalion'))
+ self.tc.connect(**dict(self.connect_kwargs, password="pygmalion"))
self.event.wait(1.0)
self.assertTrue(self.event.is_set())
@@ -423,7 +433,7 @@ class SSHClientTest(ClientTest):
self.tc = tc
self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.assertEquals(0, len(self.tc.get_host_keys()))
- self.tc.connect(**dict(self.connect_kwargs, password='pygmalion'))
+ self.tc.connect(**dict(self.connect_kwargs, password="pygmalion"))
self.event.wait(1.0)
self.assertTrue(self.event.is_set())
@@ -438,19 +448,19 @@ class SSHClientTest(ClientTest):
verify that the SSHClient has a configurable banner timeout.
"""
# Start the thread with a 1 second wait.
- threading.Thread(target=self._run, kwargs={'delay': 1}).start()
- host_key = paramiko.RSAKey.from_private_key_file(_support('test_rsa.key'))
+ threading.Thread(target=self._run, kwargs={"delay": 1}).start()
+ host_key = paramiko.RSAKey.from_private_key_file(
+ _support("test_rsa.key")
+ )
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient()
- self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key)
+ self.tc.get_host_keys().add(
+ "[%s]:%d" % (self.addr, self.port), "ssh-rsa", public_host_key
+ )
# Connect with a half second banner timeout.
kwargs = dict(self.connect_kwargs, banner_timeout=0.5)
- self.assertRaises(
- paramiko.SSHException,
- self.tc.connect,
- **kwargs
- )
+ self.assertRaises(paramiko.SSHException, self.tc.connect, **kwargs)
def test_8_auth_trickledown(self):
"""
@@ -466,9 +476,9 @@ class SSHClientTest(ClientTest):
# 'television' as per tests/test_pkey.py). NOTE: must use
# key_filename, loading the actual key here with PKey will except
# immediately; we're testing the try/except crap within Client.
- key_filename=[_support('test_rsa_password.key')],
+ key_filename=[_support("test_rsa_password.key")],
# Actual password for default 'slowdive' user
- password='pygmalion',
+ password="pygmalion",
)
self._test_connection(**kwargs)
@@ -481,7 +491,7 @@ class SSHClientTest(ClientTest):
self.assertRaises(
AuthenticationException,
self._test_connection,
- password='unresponsive-server',
+ password="unresponsive-server",
auth_timeout=0.5,
)
@@ -490,10 +500,7 @@ class SSHClientTest(ClientTest):
"""
Failed gssapi-keyex auth doesn't prevent subsequent key auth from succeeding
"""
- kwargs = dict(
- gss_kex=True,
- key_filename=[_support('test_rsa.key')],
- )
+ kwargs = dict(gss_kex=True, key_filename=[_support("test_rsa.key")])
self._test_connection(**kwargs)
@requires_gss_auth
@@ -501,10 +508,7 @@ class SSHClientTest(ClientTest):
"""
Failed gssapi-with-mic auth doesn't prevent subsequent key auth from succeeding
"""
- kwargs = dict(
- gss_auth=True,
- key_filename=[_support('test_rsa.key')],
- )
+ kwargs = dict(gss_auth=True, key_filename=[_support("test_rsa.key")])
self._test_connection(**kwargs)
def test_12_reject_policy(self):
@@ -519,7 +523,8 @@ class SSHClientTest(ClientTest):
self.assertRaises(
paramiko.SSHException,
self.tc.connect,
- password='pygmalion', **self.connect_kwargs
+ password="pygmalion",
+ **self.connect_kwargs
)
@requires_gss_auth
@@ -537,14 +542,14 @@ class SSHClientTest(ClientTest):
self.assertRaises(
paramiko.SSHException,
self.tc.connect,
- password='pygmalion',
+ password="pygmalion",
gss_kex=True,
- **self.connect_kwargs
+ **self.connect_kwargs
)
def _client_host_key_bad(self, host_key):
threading.Thread(target=self._run).start()
- hostname = '[%s]:%d' % (self.addr, self.port)
+ hostname = "[%s]:%d" % (self.addr, self.port)
self.tc = paramiko.SSHClient()
self.tc.set_missing_host_key_policy(paramiko.WarningPolicy())
@@ -554,13 +559,13 @@ class SSHClientTest(ClientTest):
self.assertRaises(
paramiko.BadHostKeyException,
self.tc.connect,
- password='pygmalion',
+ password="pygmalion",
**self.connect_kwargs
)
def _client_host_key_good(self, ktype, kfile):
threading.Thread(target=self._run).start()
- hostname = '[%s]:%d' % (self.addr, self.port)
+ hostname = "[%s]:%d" % (self.addr, self.port)
self.tc = paramiko.SSHClient()
self.tc.set_missing_host_key_policy(paramiko.RejectPolicy())
@@ -568,7 +573,7 @@ class SSHClientTest(ClientTest):
known_hosts = self.tc.get_host_keys()
known_hosts.add(hostname, host_key.get_name(), host_key)
- self.tc.connect(password='pygmalion', **self.connect_kwargs)
+ self.tc.connect(password="pygmalion", **self.connect_kwargs)
self.event.wait(1.0)
self.assertTrue(self.event.is_set())
self.assertTrue(self.ts.is_active())
@@ -583,10 +588,10 @@ class SSHClientTest(ClientTest):
self._client_host_key_bad(host_key)
def test_host_key_negotiation_3(self):
- self._client_host_key_good(paramiko.ECDSAKey, 'test_ecdsa_256.key')
+ self._client_host_key_good(paramiko.ECDSAKey, "test_ecdsa_256.key")
def test_host_key_negotiation_4(self):
- self._client_host_key_good(paramiko.RSAKey, 'test_rsa.key')
+ self._client_host_key_good(paramiko.RSAKey, "test_rsa.key")
def _setup_for_env(self):
threading.Thread(target=self._run).start()
@@ -594,7 +599,9 @@ class SSHClientTest(ClientTest):
self.tc = paramiko.SSHClient()
self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.assertEqual(0, len(self.tc.get_host_keys()))
- self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion')
+ self.tc.connect(
+ self.addr, self.port, username="slowdive", password="pygmalion"
+ )
self.event.wait(1.0)
self.assertTrue(self.event.isSet())
@@ -605,11 +612,11 @@ class SSHClientTest(ClientTest):
Verify that environment variables can be set by the client.
"""
self._setup_for_env()
- target_env = {b'A': b'B', b'C': b'd'}
+ target_env = {b"A": b"B", b"C": b"d"}
- self.tc.exec_command('yes', environment=target_env)
+ self.tc.exec_command("yes", environment=target_env)
schan = self.ts.accept(1.0)
- self.assertEqual(target_env, getattr(schan, 'env', {}))
+ self.assertEqual(target_env, getattr(schan, "env", {}))
schan.close()
@unittest.skip("Clients normally fail silently, thus so do we, for now")
@@ -617,14 +624,14 @@ class SSHClientTest(ClientTest):
self._setup_for_env()
with self.assertRaises(SSHException) as manager:
# Verify that a rejection by the server can be detected
- self.tc.exec_command('yes', environment={b'INVALID_ENV': b''})
+ self.tc.exec_command("yes", environment={b"INVALID_ENV": b""})
self.assertTrue(
- 'INVALID_ENV' in str(manager.exception),
- 'Expected variable name in error message'
+ "INVALID_ENV" in str(manager.exception),
+ "Expected variable name in error message",
)
self.assertTrue(
isinstance(manager.exception.args[1], SSHException),
- 'Expected original SSHException in exception'
+ "Expected original SSHException in exception",
)
def test_missing_key_policy_accepts_classes_or_instances(self):
@@ -652,35 +659,39 @@ class PasswordPassphraseTests(ClientTest):
def test_password_kwarg_works_for_password_auth(self):
# Straightforward / duplicate of earlier basic password test.
- self._test_connection(password='pygmalion')
+ self._test_connection(password="pygmalion")
# TODO: more granular exception pending #387; should be signaling "no auth
# methods available" because no key and no password
@raises(SSHException)
def test_passphrase_kwarg_not_used_for_password_auth(self):
# Using the "right" password in the "wrong" field shouldn't work.
- self._test_connection(passphrase='pygmalion')
+ self._test_connection(passphrase="pygmalion")
def test_passphrase_kwarg_used_for_key_passphrase(self):
# Straightforward again, with new passphrase kwarg.
self._test_connection(
- key_filename=_support('test_rsa_password.key'),
- passphrase='television',
+ key_filename=_support("test_rsa_password.key"),
+ passphrase="television",
)
- def test_password_kwarg_used_for_passphrase_when_no_passphrase_kwarg_given(self): # noqa
+ def test_password_kwarg_used_for_passphrase_when_no_passphrase_kwarg_given(
+ self
+ ): # noqa
# Backwards compatibility: passphrase in the password field.
self._test_connection(
- key_filename=_support('test_rsa_password.key'),
- password='television',
+ key_filename=_support("test_rsa_password.key"),
+ password="television",
)
- @raises(AuthenticationException) # TODO: more granular
- def test_password_kwarg_not_used_for_passphrase_when_passphrase_kwarg_given(self): # noqa
+ @raises(AuthenticationException) # TODO: more granular
+ def test_password_kwarg_not_used_for_passphrase_when_passphrase_kwarg_given(
+ self
+ ): # noqa
# Sanity: if we're given both fields, the password field is NOT used as
# a passphrase.
self._test_connection(
- key_filename=_support('test_rsa_password.key'),
- password='television',
- passphrase='wat? lol no',
+ key_filename=_support("test_rsa_password.key"),
+ password="television",
+ passphrase="wat? lol no",
)
diff --git a/tests/test_config.py b/tests/test_config.py
index 20f051f..cbd3f62 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -7,6 +7,7 @@ from paramiko import config
from paramiko.util import parse_ssh_config
from paramiko.py3compat import StringIO
+
def test_SSHConfigDict_construct_empty():
assert not config.SSHConfigDict()
@@ -71,4 +72,3 @@ Host *
f = StringIO(test_config_file)
config = parse_ssh_config(f)
assert config.lookup("anything-else").as_int("port") == 3333
-
diff --git a/tests/test_file.py b/tests/test_file.py
index 3d2c94e..d299011 100644
--- a/tests/test_file.py
+++ b/tests/test_file.py
@@ -30,18 +30,19 @@ from paramiko.py3compat import BytesIO
from .util import needs_builtin
-class LoopbackFile (BufferedFile):
+class LoopbackFile(BufferedFile):
"""
BufferedFile object that you can write data into, and then read it back.
"""
- def __init__(self, mode='r', bufsize=-1):
+
+ def __init__(self, mode="r", bufsize=-1):
BufferedFile.__init__(self)
self._set_mode(mode, bufsize)
self.buffer = BytesIO()
self.offset = 0
def _read(self, size):
- data = self.buffer.getvalue()[self.offset:self.offset+size]
+ data = self.buffer.getvalue()[self.offset:self.offset + size]
self.offset += len(data)
return data
@@ -50,44 +51,46 @@ class LoopbackFile (BufferedFile):
return len(data)
-class BufferedFileTest (unittest.TestCase):
+class BufferedFileTest(unittest.TestCase):
def test_1_simple(self):
- f = LoopbackFile('r')
+ f = LoopbackFile("r")
try:
- f.write(b'hi')
- self.assertTrue(False, 'no exception on write to read-only file')
+ f.write(b"hi")
+ self.assertTrue(False, "no exception on write to read-only file")
except:
pass
f.close()
- f = LoopbackFile('w')
+ f = LoopbackFile("w")
try:
f.read(1)
- self.assertTrue(False, 'no exception to read from write-only file')
+ self.assertTrue(False, "no exception to read from write-only file")
except:
pass
f.close()
def test_2_readline(self):
- f = LoopbackFile('r+U')
- f.write(b'First line.\nSecond line.\r\nThird line.\n' +
- b'Fourth line.\nFinal line non-terminated.')
+ f = LoopbackFile("r+U")
+ f.write(
+ b"First line.\nSecond line.\r\nThird line.\n"
+ + b"Fourth line.\nFinal line non-terminated."
+ )
- self.assertEqual(f.readline(), 'First line.\n')
+ self.assertEqual(f.readline(), "First line.\n")
# universal newline mode should convert this linefeed:
- self.assertEqual(f.readline(), 'Second line.\n')
+ self.assertEqual(f.readline(), "Second line.\n")
# truncated line:
- self.assertEqual(f.readline(7), 'Third l')
- self.assertEqual(f.readline(), 'ine.\n')
+ self.assertEqual(f.readline(7), "Third l")
+ self.assertEqual(f.readline(), "ine.\n")
# newline should be detected and only the fourth line returned
- self.assertEqual(f.readline(39), 'Fourth line.\n')
- self.assertEqual(f.readline(), 'Final line non-terminated.')
- self.assertEqual(f.readline(), '')
+ self.assertEqual(f.readline(39), "Fourth line.\n")
+ self.assertEqual(f.readline(), "Final line non-terminated.")
+ self.assertEqual(f.readline(), "")
f.close()
try:
f.readline()
- self.assertTrue(False, 'no exception on readline of closed file')
+ self.assertTrue(False, "no exception on readline of closed file")
except IOError:
pass
self.assertTrue(linefeed_byte in f.newlines)
@@ -98,11 +101,11 @@ class BufferedFileTest (unittest.TestCase):
"""
try to trick the linefeed detector.
"""
- f = LoopbackFile('r+U')
- f.write(b'First line.\r')
- self.assertEqual(f.readline(), 'First line.\n')
- f.write(b'\nSecond.\r\n')
- self.assertEqual(f.readline(), 'Second.\n')
+ f = LoopbackFile("r+U")
+ f.write(b"First line.\r")
+ self.assertEqual(f.readline(), "First line.\n")
+ f.write(b"\nSecond.\r\n")
+ self.assertEqual(f.readline(), "Second.\n")
f.close()
self.assertEqual(f.newlines, crlf)
@@ -110,51 +113,54 @@ class BufferedFileTest (unittest.TestCase):
"""
verify that write buffering is on.
"""
- f = LoopbackFile('r+', 1)
- f.write(b'Complete line.\nIncomplete line.')
- self.assertEqual(f.readline(), 'Complete line.\n')
- self.assertEqual(f.readline(), '')
- f.write('..\n')
- self.assertEqual(f.readline(), 'Incomplete line...\n')
+ f = LoopbackFile("r+", 1)
+ f.write(b"Complete line.\nIncomplete line.")
+ self.assertEqual(f.readline(), "Complete line.\n")
+ self.assertEqual(f.readline(), "")
+ f.write("..\n")
+ self.assertEqual(f.readline(), "Incomplete line...\n")
f.close()
def test_5_flush(self):
"""
verify that flush will force a write.
"""
- f = LoopbackFile('r+', 512)
- f.write('Not\nquite\n512 bytes.\n')
- self.assertEqual(f.read(1), b'')
+ f = LoopbackFile("r+", 512)
+ f.write("Not\nquite\n512 bytes.\n")
+ self.assertEqual(f.read(1), b"")
f.flush()
- self.assertEqual(f.read(5), b'Not\nq')
- self.assertEqual(f.read(10), b'uite\n512 b')
- self.assertEqual(f.read(9), b'ytes.\n')
- self.assertEqual(f.read(3), b'')
+ self.assertEqual(f.read(5), b"Not\nq")
+ self.assertEqual(f.read(10), b"uite\n512 b")
+ self.assertEqual(f.read(9), b"ytes.\n")
+ self.assertEqual(f.read(3), b"")
f.close()
def test_6_buffering(self):
"""
verify that flushing happens automatically on buffer crossing.
"""
- f = LoopbackFile('r+', 16)
- f.write(b'Too small.')
- self.assertEqual(f.read(4), b'')
- f.write(b' ')
- self.assertEqual(f.read(4), b'')
- f.write(b'Enough.')
- self.assertEqual(f.read(20), b'Too small. Enough.')
+ f = LoopbackFile("r+", 16)
+ f.write(b"Too small.")
+ self.assertEqual(f.read(4), b"")
+ f.write(b" ")
+ self.assertEqual(f.read(4), b"")
+ f.write(b"Enough.")
+ self.assertEqual(f.read(20), b"Too small. Enough.")
f.close()
def test_7_read_all(self):
"""
verify that read(-1) returns everything left in the file.
"""
- f = LoopbackFile('r+', 16)
- f.write(b'The first thing you need to do is open your eyes. ')
- f.write(b'Then, you need to close them again.\n')
+ f = LoopbackFile("r+", 16)
+ f.write(b"The first thing you need to do is open your eyes. ")
+ f.write(b"Then, you need to close them again.\n")
s = f.read(-1)
- self.assertEqual(s, b'The first thing you need to do is open your eyes. Then, you ' +
- b'need to close them again.\n')
+ self.assertEqual(
+ s,
+ b"The first thing you need to do is open your eyes. Then, you "
+ + b"need to close them again.\n",
+ )
f.close()
def test_8_buffering(self):
@@ -162,19 +168,19 @@ class BufferedFileTest (unittest.TestCase):
verify that buffered objects can be written
"""
if sys.version_info[0] == 2:
- f = LoopbackFile('r+', 16)
- f.write(buffer(b'Too small.'))
+ f = LoopbackFile("r+", 16)
+ f.write(buffer(b"Too small."))
f.close()
def test_9_readable(self):
- f = LoopbackFile('r')
+ f = LoopbackFile("r")
self.assertTrue(f.readable())
self.assertFalse(f.writable())
self.assertFalse(f.seekable())
f.close()
def test_A_writable(self):
- f = LoopbackFile('w')
+ f = LoopbackFile("w")
self.assertTrue(f.writable())
self.assertFalse(f.readable())
self.assertFalse(f.seekable())
@@ -182,48 +188,49 @@ class BufferedFileTest (unittest.TestCase):
def test_B_readinto(self):
data = bytearray(5)
- f = LoopbackFile('r+')
+ f = LoopbackFile("r+")
f._write(b"hello")
f.readinto(data)
- self.assertEqual(data, b'hello')
+ self.assertEqual(data, b"hello")
f.close()
def test_write_bad_type(self):
- with LoopbackFile('wb') as f:
+ with LoopbackFile("wb") as f:
self.assertRaises(TypeError, f.write, object())
def test_write_unicode_as_binary(self):
text = u"\xa7 why is writing text to a binary file allowed?\n"
- with LoopbackFile('rb+') as f:
+ with LoopbackFile("rb+") as f:
f.write(text)
self.assertEqual(f.read(), text.encode("utf-8"))
- @needs_builtin('memoryview')
+ @needs_builtin("memoryview")
def test_write_bytearray(self):
- with LoopbackFile('rb+') as f:
+ with LoopbackFile("rb+") as f:
f.write(bytearray(12))
self.assertEqual(f.read(), 12 * b"\0")
- @needs_builtin('buffer')
+ @needs_builtin("buffer")
def test_write_buffer(self):
data = 3 * b"pretend giant block of data\n"
offsets = range(0, len(data), 8)
- with LoopbackFile('rb+') as f:
+ with LoopbackFile("rb+") as f:
for offset in offsets:
f.write(buffer(data, offset, 8))
self.assertEqual(f.read(), data)
- @needs_builtin('memoryview')
+ @needs_builtin("memoryview")
def test_write_memoryview(self):
data = 3 * b"pretend giant block of data\n"
offsets = range(0, len(data), 8)
- with LoopbackFile('rb+') as f:
+ with LoopbackFile("rb+") as f:
view = memoryview(data)
for offset in offsets:
- f.write(view[offset:offset+8])
+ f.write(view[offset:offset + 8])
self.assertEqual(f.read(), data)
-if __name__ == '__main__':
+if __name__ == "__main__":
from unittest import main
+
main()
diff --git a/tests/test_gssapi.py b/tests/test_gssapi.py
index d4b632b..d7fbdd5 100644
--- a/tests/test_gssapi.py
+++ b/tests/test_gssapi.py
@@ -30,6 +30,7 @@ from .util import needs_gssapi
@needs_gssapi
class GSSAPITest(unittest.TestCase):
+
def setup():
# TODO: these vars should all come from os.environ or whatever the
# approved pytest method is for runtime-configuring test data.
@@ -43,6 +44,7 @@ class GSSAPITest(unittest.TestCase):
"""
from pyasn1.type.univ import ObjectIdentifier
from pyasn1.codec.der import encoder, decoder
+
oid = encoder.encode(ObjectIdentifier(self.krb5_mech))
mech, __ = decoder.decode(oid)
self.assertEquals(self.krb5_mech, mech.__str__())
@@ -57,6 +59,7 @@ class GSSAPITest(unittest.TestCase):
except ImportError:
import sspicon
import sspi
+
_API = "SSPI"
c_token = None
@@ -65,23 +68,28 @@ class GSSAPITest(unittest.TestCase):
if _API == "MIT":
if self.server_mode:
- gss_flags = (gssapi.C_PROT_READY_FLAG,
- gssapi.C_INTEG_FLAG,
- gssapi.C_MUTUAL_FLAG,
- gssapi.C_DELEG_FLAG)
+ gss_flags = (
+ gssapi.C_PROT_READY_FLAG,
+ gssapi.C_INTEG_FLAG,
+ gssapi.C_MUTUAL_FLAG,
+ gssapi.C_DELEG_FLAG,
+ )
else:
- gss_flags = (gssapi.C_PROT_READY_FLAG,
- gssapi.C_INTEG_FLAG,
- gssapi.C_DELEG_FLAG)
+ gss_flags = (
+ gssapi.C_PROT_READY_FLAG,
+ gssapi.C_INTEG_FLAG,
+ gssapi.C_DELEG_FLAG,
+ )
# Initialize a GSS-API context.
ctx = gssapi.Context()
ctx.flags = gss_flags
krb5_oid = gssapi.OID.mech_from_string(self.krb5_mech)
- target_name = gssapi.Name("host@" + self.targ_name,
- gssapi.C_NT_HOSTBASED_SERVICE)
- gss_ctxt = gssapi.InitContext(peer_name=target_name,
- mech_type=krb5_oid,
- req_flags=ctx.flags)
+ target_name = gssapi.Name(
+ "host@" + self.targ_name, gssapi.C_NT_HOSTBASED_SERVICE
+ )
+ gss_ctxt = gssapi.InitContext(
+ peer_name=target_name, mech_type=krb5_oid, req_flags=ctx.flags
+ )
if self.server_mode:
c_token = gss_ctxt.step(c_token)
gss_ctxt_status = gss_ctxt.established
@@ -108,15 +116,15 @@ class GSSAPITest(unittest.TestCase):
self.assertEquals(0, status)
else:
gss_flags = (
- sspicon.ISC_REQ_INTEGRITY |
- sspicon.ISC_REQ_MUTUAL_AUTH |
- sspicon.ISC_REQ_DELEGATE
+ sspicon.ISC_REQ_INTEGRITY
+ | sspicon.ISC_REQ_MUTUAL_AUTH
+ | sspicon.ISC_REQ_DELEGATE
)
# Initialize a GSS-API context.
target_name = "host/" + socket.getfqdn(self.targ_name)
- gss_ctxt = sspi.ClientAuth("Kerberos",
- scflags=gss_flags,
- targetspn=target_name)
+ gss_ctxt = sspi.ClientAuth(
+ "Kerberos", scflags=gss_flags, targetspn=target_name
+ )
if self.server_mode:
error, token = gss_ctxt.authorize(c_token)
c_token = token[0].Buffer
diff --git a/tests/test_hostkeys.py b/tests/test_hostkeys.py
index cd75f8a..a1b7a9e 100644
--- a/tests/test_hostkeys.py
+++ b/tests/test_hostkeys.py
@@ -54,77 +54,80 @@ Ngw3qIch/WgRmMHy4kBq1SsXMjQCte1So6HBMvBPIW5SiMTmjCfZZiw4AYHK+B/JaOwaG9yRg2Ejg\
0d54U0X/NeX5QxuYR6OMJlrkQB7oiW/P/1mwjQgE="""
-class HostKeysTest (unittest.TestCase):
+class HostKeysTest(unittest.TestCase):
def setUp(self):
- with open('hostfile.temp', 'w') as f:
+ with open("hostfile.temp", "w") as f:
f.write(test_hosts_file)
def tearDown(self):
- os.unlink('hostfile.temp')
+ os.unlink("hostfile.temp")
def test_1_load(self):
- hostdict = paramiko.HostKeys('hostfile.temp')
+ hostdict = paramiko.HostKeys("hostfile.temp")
self.assertEqual(2, len(hostdict))
self.assertEqual(1, len(list(hostdict.values())[0]))
self.assertEqual(1, len(list(hostdict.values())[1]))
- fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper()
- self.assertEqual(b'E6684DB30E109B67B70FF1DC5C7F1363', fp)
+ fp = hexlify(
+ hostdict["secure.example.com"]["ssh-rsa"].get_fingerprint()
+ ).upper()
+ self.assertEqual(b"E6684DB30E109B67B70FF1DC5C7F1363", fp)
def test_2_add(self):
- hostdict = paramiko.HostKeys('hostfile.temp')
- hh = '|1|BMsIC6cUIP2zBuXR3t2LRcJYjzM=|hpkJMysjTk/+zzUUzxQEa2ieq6c='
+ hostdict = paramiko.HostKeys("hostfile.temp")
+ hh = "|1|BMsIC6cUIP2zBuXR3t2LRcJYjzM=|hpkJMysjTk/+zzUUzxQEa2ieq6c="
key = paramiko.RSAKey(data=decodebytes(keyblob))
- hostdict.add(hh, 'ssh-rsa', key)
+ hostdict.add(hh, "ssh-rsa", key)
self.assertEqual(3, len(list(hostdict)))
- x = hostdict['foo.example.com']
- fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper()
- self.assertEqual(b'7EC91BB336CB6D810B124B1353C32396', fp)
- self.assertTrue(hostdict.check('foo.example.com', key))
+ x = hostdict["foo.example.com"]
+ fp = hexlify(x["ssh-rsa"].get_fingerprint()).upper()
+ self.assertEqual(b"7EC91BB336CB6D810B124B1353C32396", fp)
+ self.assertTrue(hostdict.check("foo.example.com", key))
def test_3_dict(self):
- hostdict = paramiko.HostKeys('hostfile.temp')
- self.assertTrue('secure.example.com' in hostdict)
- self.assertTrue('not.example.com' not in hostdict)
- self.assertTrue('secure.example.com' in hostdict)
- self.assertTrue('not.example.com' not in hostdict)
- x = hostdict.get('secure.example.com', None)
+ hostdict = paramiko.HostKeys("hostfile.temp")
+ self.assertTrue("secure.example.com" in hostdict)
+ self.assertTrue("not.example.com" not in hostdict)
+ self.assertTrue("secure.example.com" in hostdict)
+ self.assertTrue("not.example.com" not in hostdict)
+ x = hostdict.get("secure.example.com", None)
self.assertTrue(x is not None)
- fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper()
- self.assertEqual(b'E6684DB30E109B67B70FF1DC5C7F1363', fp)
+ fp = hexlify(x["ssh-rsa"].get_fingerprint()).upper()
+ self.assertEqual(b"E6684DB30E109B67B70FF1DC5C7F1363", fp)
i = 0
for key in hostdict:
i += 1
self.assertEqual(2, i)
-
+
def test_4_dict_set(self):
- hostdict = paramiko.HostKeys('hostfile.temp')
+ hostdict = paramiko.HostKeys("hostfile.temp")
key = paramiko.RSAKey(data=decodebytes(keyblob))
key_dss = paramiko.DSSKey(data=decodebytes(keyblob_dss))
- hostdict['secure.example.com'] = {
- 'ssh-rsa': key,
- 'ssh-dss': key_dss
- }
- hostdict['fake.example.com'] = {}
- hostdict['fake.example.com']['ssh-rsa'] = key
-
+ hostdict["secure.example.com"] = {"ssh-rsa": key, "ssh-dss": key_dss}
+ hostdict["fake.example.com"] = {}
+ hostdict["fake.example.com"]["ssh-rsa"] = key
+
self.assertEqual(3, len(hostdict))
self.assertEqual(2, len(list(hostdict.values())[0]))
self.assertEqual(1, len(list(hostdict.values())[1]))
self.assertEqual(1, len(list(hostdict.values())[2]))
- fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper()
- self.assertEqual(b'7EC91BB336CB6D810B124B1353C32396', fp)
- fp = hexlify(hostdict['secure.example.com']['ssh-dss'].get_fingerprint()).upper()
- self.assertEqual(b'4478F0B9A23CC5182009FF755BC1D26C', fp)
+ fp = hexlify(
+ hostdict["secure.example.com"]["ssh-rsa"].get_fingerprint()
+ ).upper()
+ self.assertEqual(b"7EC91BB336CB6D810B124B1353C32396", fp)
+ fp = hexlify(
+ hostdict["secure.example.com"]["ssh-dss"].get_fingerprint()
+ ).upper()
+ self.assertEqual(b"4478F0B9A23CC5182009FF755BC1D26C", fp)
def test_delitem(self):
- hostdict = paramiko.HostKeys('hostfile.temp')
- target = 'happy.example.com'
- entry = hostdict[target] # will KeyError if not present
+ hostdict = paramiko.HostKeys("hostfile.temp")
+ target = "happy.example.com"
+ entry = hostdict[target] # will KeyError if not present
del hostdict[target]
try:
entry = hostdict[target]
except KeyError:
- pass # Good
+ pass # Good
else:
assert False, "Entry was not deleted from HostKeys on delitem!"
diff --git a/tests/test_kex.py b/tests/test_kex.py
index b5808e7..41e2dea 100644
--- a/tests/test_kex.py
+++ b/tests/test_kex.py
@@ -38,29 +38,45 @@ from paramiko.kex_ecdh_nist import KexNistp256
def dummy_urandom(n):
return byte_chr(0xcc) * n
+
def dummy_generate_key_pair(obj):
private_key_value = 94761803665136558137557783047955027733968423115106677159790289642479432803037
public_key_numbers = "042bdab212fa8ba1b7c843301682a4db424d307246c7e1e6083c41d9ca7b098bf30b3d63e2ec6278488c135360456cc054b3444ecc45998c08894cbc1370f5f989"
- public_key_numbers_obj = ec.EllipticCurvePublicNumbers.from_encoded_point(ec.SECP256R1(), unhexlify(public_key_numbers))
- obj.P = ec.EllipticCurvePrivateNumbers(private_value=private_key_value, public_numbers=public_key_numbers_obj).private_key(default_backend())
+ public_key_numbers_obj = ec.EllipticCurvePublicNumbers.from_encoded_point(
+ ec.SECP256R1(), unhexlify(public_key_numbers)
+ )
+ obj.P = ec.EllipticCurvePrivateNumbers(
+ private_value=private_key_value, public_numbers=public_key_numbers_obj
+ ).private_key(
+ default_backend()
+ )
if obj.transport.server_mode:
- obj.Q_S = ec.EllipticCurvePublicNumbers.from_encoded_point(ec.SECP256R1(), unhexlify(public_key_numbers)).public_key(default_backend())
+ obj.Q_S = ec.EllipticCurvePublicNumbers.from_encoded_point(
+ ec.SECP256R1(), unhexlify(public_key_numbers)
+ ).public_key(
+ default_backend()
+ )
return
- obj.Q_C = ec.EllipticCurvePublicNumbers.from_encoded_point(ec.SECP256R1(), unhexlify(public_key_numbers)).public_key(default_backend())
+ obj.Q_C = ec.EllipticCurvePublicNumbers.from_encoded_point(
+ ec.SECP256R1(), unhexlify(public_key_numbers)
+ ).public_key(
+ default_backend()
+ )
+
+class FakeKey(object):
-class FakeKey (object):
def __str__(self):
- return 'fake-key'
+ return "fake-key"
def asbytes(self):
- return b'fake-key'
+ return b"fake-key"
def sign_ssh_data(self, H):
- return b'fake-sig'
+ return b"fake-sig"
-class FakeModulusPack (object):
+class FakeModulusPack(object):
P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF
G = 2
@@ -69,10 +85,10 @@ class FakeModulusPack (object):
class FakeTransport(object):
- local_version = 'SSH-2.0-paramiko_1.0'
- remote_version = 'SSH-2.0-lame'
- local_kex_init = 'local-kex-init'
- remote_kex_init = 'remote-kex-init'
+ local_version = "SSH-2.0-paramiko_1.0"
+ remote_version = "SSH-2.0-lame"
+ local_kex_init = "local-kex-init"
+ remote_kex_init = "remote-kex-init"
def _send_message(self, m):
self._message = m
@@ -100,7 +116,7 @@ class FakeTransport(object):
return FakeModulusPack()
-class KexTest (unittest.TestCase):
+class KexTest(unittest.TestCase):
K = 14730343317708716439807310032871972459448364195094179797249681733965528989482751523943515690110179031004049109375612685505881911274101441415545039654102474376472240501616988799699744135291070488314748284283496055223852115360852283821334858541043710301057312858051901453919067023103730011648890038847384890504
@@ -119,21 +135,23 @@ class KexTest (unittest.TestCase):
transport.server_mode = False
kex = KexGroup1(transport)
kex.start_kex()
- x = b'1E000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
+ x = b"1E000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4"
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_group1._MSG_KEXDH_REPLY,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_group1._MSG_KEXDH_REPLY,), transport._expect
+ )
# fake "reply"
msg = Message()
- msg.add_string('fake-host-key')
+ msg.add_string("fake-host-key")
msg.add_mpint(69)
- msg.add_string('fake-sig')
+ msg.add_string("fake-sig")
msg.rewind()
kex.parse_next(paramiko.kex_group1._MSG_KEXDH_REPLY, msg)
- H = b'03079780F3D3AD0B3C6DB30C8D21685F367A86D2'
+ H = b"03079780F3D3AD0B3C6DB30C8D21685F367A86D2"
self.assertEqual(self.K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
- self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify)
+ self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify)
self.assertTrue(transport._activated)
def test_2_group1_server(self):
@@ -141,14 +159,16 @@ class KexTest (unittest.TestCase):
transport.server_mode = True
kex = KexGroup1(transport)
kex.start_kex()
- self.assertEqual((paramiko.kex_group1._MSG_KEXDH_INIT,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_group1._MSG_KEXDH_INIT,), transport._expect
+ )
msg = Message()
msg.add_mpint(69)
msg.rewind()
kex.parse_next(paramiko.kex_group1._MSG_KEXDH_INIT, msg)
- H = b'B16BF34DD10945EDE84E9C1EF24A14BFDC843389'
- x = b'1F0000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
+ H = b"B16BF34DD10945EDE84E9C1EF24A14BFDC843389"
+ x = b"1F0000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967"
self.assertEqual(self.K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
@@ -159,29 +179,33 @@ class KexTest (unittest.TestCase):
transport.server_mode = False
kex = KexGex(transport)
kex.start_kex()
- x = b'22000004000000080000002000'
+ x = b"22000004000000080000002000"
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect
+ )
msg = Message()
msg.add_mpint(FakeModulusPack.P)
msg.add_mpint(FakeModulusPack.G)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg)
- x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
+ x = b"20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4"
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect
+ )
msg = Message()
- msg.add_string('fake-host-key')
+ msg.add_string("fake-host-key")
msg.add_mpint(69)
- msg.add_string('fake-sig')
+ msg.add_string("fake-sig")
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg)
- H = b'A265563F2FA87F1A89BF007EE90D58BE2E4A4BD0'
+ H = b"A265563F2FA87F1A89BF007EE90D58BE2E4A4BD0"
self.assertEqual(self.K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
- self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify)
+ self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify)
self.assertTrue(transport._activated)
def test_4_gex_old_client(self):
@@ -189,37 +213,47 @@ class KexTest (unittest.TestCase):
transport.server_mode = False
kex = KexGex(transport)
kex.start_kex(_test_old_style=True)
- x = b'1E00000800'
+ x = b"1E00000800"
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect
+ )
msg = Message()
msg.add_mpint(FakeModulusPack.P)
msg.add_mpint(FakeModulusPack.G)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg)
- x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
+ x = b"20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4"
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect
+ )
msg = Message()
- msg.add_string('fake-host-key')
+ msg.add_string("fake-host-key")
msg.add_mpint(69)
- msg.add_string('fake-sig')
+ msg.add_string("fake-sig")
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg)
- H = b'807F87B269EF7AC5EC7E75676808776A27D5864C'
+ H = b"807F87B269EF7AC5EC7E75676808776A27D5864C"
self.assertEqual(self.K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
- self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify)
+ self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify)
self.assertTrue(transport._activated)
-
+
def test_5_gex_server(self):
transport = FakeTransport()
transport.server_mode = True
kex = KexGex(transport)
kex.start_kex()
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect)
+ self.assertEqual(
+ (
+ paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST,
+ paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD,
+ ),
+ transport._expect,
+ )
msg = Message()
msg.add_int(1024)
@@ -227,17 +261,19 @@ class KexTest (unittest.TestCase):
msg.add_int(4096)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, msg)
- x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102'
+ x = b"1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102"
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect
+ )
msg = Message()
msg.add_mpint(12345)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg)
K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581
- H = b'CE754197C21BF3452863B4F44D0B3951F12516EF'
- x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
+ H = b"CE754197C21BF3452863B4F44D0B3951F12516EF"
+ x = b"210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967"
self.assertEqual(K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
@@ -248,23 +284,31 @@ class KexTest (unittest.TestCase):
transport.server_mode = True
kex = KexGex(transport)
kex.start_kex()
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect)
+ self.assertEqual(
+ (
+ paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST,
+ paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD,
+ ),
+ transport._expect,
+ )
msg = Message()
msg.add_int(2048)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, msg)
- x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102'
+ x = b"1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102"
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect
+ )
msg = Message()
msg.add_mpint(12345)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg)
K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581
- H = b'B41A06B2E59043CEFC1AE16EC31F1E2D12EC455B'
- x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
+ H = b"B41A06B2E59043CEFC1AE16EC31F1E2D12EC455B"
+ x = b"210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967"
self.assertEqual(K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
@@ -275,29 +319,33 @@ class KexTest (unittest.TestCase):
transport.server_mode = False
kex = KexGexSHA256(transport)
kex.start_kex()
- x = b'22000004000000080000002000'
+ x = b"22000004000000080000002000"
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect
+ )
msg = Message()
msg.add_mpint(FakeModulusPack.P)
msg.add_mpint(FakeModulusPack.G)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg)
- x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
+ x = b"20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4"
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect
+ )
msg = Message()
- msg.add_string('fake-host-key')
+ msg.add_string("fake-host-key")
msg.add_mpint(69)
- msg.add_string('fake-sig')
+ msg.add_string("fake-sig")
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg)
- H = b'AD1A9365A67B4496F05594AD1BF656E3CDA0851289A4C1AFF549FEAE50896DF4'
+ H = b"AD1A9365A67B4496F05594AD1BF656E3CDA0851289A4C1AFF549FEAE50896DF4"
self.assertEqual(self.K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
- self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify)
+ self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify)
self.assertTrue(transport._activated)
def test_8_gex_sha256_old_client(self):
@@ -305,29 +353,33 @@ class KexTest (unittest.TestCase):
transport.server_mode = False
kex = KexGexSHA256(transport)
kex.start_kex(_test_old_style=True)
- x = b'1E00000800'
+ x = b"1E00000800"
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect
+ )
msg = Message()
msg.add_mpint(FakeModulusPack.P)
msg.add_mpint(FakeModulusPack.G)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg)
- x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
+ x = b"20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4"
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect
+ )
msg = Message()
- msg.add_string('fake-host-key')
+ msg.add_string("fake-host-key")
msg.add_mpint(69)
- msg.add_string('fake-sig')
+ msg.add_string("fake-sig")
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg)
- H = b'518386608B15891AE5237DEE08DCADDE76A0BCEFCE7F6DB3AD66BC41D256DFE5'
+ H = b"518386608B15891AE5237DEE08DCADDE76A0BCEFCE7F6DB3AD66BC41D256DFE5"
self.assertEqual(self.K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
- self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify)
+ self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify)
self.assertTrue(transport._activated)
def test_9_gex_sha256_server(self):
@@ -335,7 +387,13 @@ class KexTest (unittest.TestCase):
transport.server_mode = True
kex = KexGexSHA256(transport)
kex.start_kex()
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect)
+ self.assertEqual(
+ (
+ paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST,
+ paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD,
+ ),
+ transport._expect,
+ )
msg = Message()
msg.add_int(1024)
@@ -343,17 +401,19 @@ class KexTest (unittest.TestCase):
msg.add_int(4096)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, msg)
- x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102'
+ x = b"1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102"
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect
+ )
msg = Message()
msg.add_mpint(12345)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg)
K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581
- H = b'CCAC0497CF0ABA1DBF55E1A3995D17F4CC31824B0E8D95CDF8A06F169D050D80'
- x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
+ H = b"CCAC0497CF0ABA1DBF55E1A3995D17F4CC31824B0E8D95CDF8A06F169D050D80"
+ x = b"210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967"
self.assertEqual(K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
@@ -364,23 +424,31 @@ class KexTest (unittest.TestCase):
transport.server_mode = True
kex = KexGexSHA256(transport)
kex.start_kex()
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect)
+ self.assertEqual(
+ (
+ paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST,
+ paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD,
+ ),
+ transport._expect,
+ )
msg = Message()
msg.add_int(2048)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, msg)
- x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102'
+ x = b"1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102"
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect
+ )
msg = Message()
msg.add_mpint(12345)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg)
K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581
- H = b'3DDD2AD840AD095E397BA4D0573972DC60F6461FD38A187CACA6615A5BC8ADBB'
- x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
+ H = b"3DDD2AD840AD095E397BA4D0573972DC60F6461FD38A187CACA6615A5BC8ADBB"
+ x = b"210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967"
self.assertEqual(K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
@@ -392,20 +460,24 @@ class KexTest (unittest.TestCase):
transport.server_mode = False
kex = KexNistp256(transport)
kex.start_kex()
- self.assertEqual((paramiko.kex_ecdh_nist._MSG_KEXECDH_REPLY,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_ecdh_nist._MSG_KEXECDH_REPLY,), transport._expect
+ )
- #fake reply
+ # fake reply
msg = Message()
- msg.add_string('fake-host-key')
- Q_S = unhexlify("043ae159594ba062efa121480e9ef136203fa9ec6b6e1f8723a321c16e62b945f573f3b822258cbcd094b9fa1c125cbfe5f043280893e66863cc0cb4dccbe70210")
+ msg.add_string("fake-host-key")
+ Q_S = unhexlify(
+ "043ae159594ba062efa121480e9ef136203fa9ec6b6e1f8723a321c16e62b945f573f3b822258cbcd094b9fa1c125cbfe5f043280893e66863cc0cb4dccbe70210"
+ )
msg.add_string(Q_S)
- msg.add_string('fake-sig')
+ msg.add_string("fake-sig")
msg.rewind()
kex.parse_next(paramiko.kex_ecdh_nist._MSG_KEXECDH_REPLY, msg)
- H = b'BAF7CE243A836037EB5D2221420F35C02B9AB6C957FE3BDE3369307B9612570A'
+ H = b"BAF7CE243A836037EB5D2221420F35C02B9AB6C957FE3BDE3369307B9612570A"
self.assertEqual(K, kex.transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
- self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify)
+ self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify)
self.assertTrue(transport._activated)
def test_12_kex_nistp256_server(self):
@@ -414,12 +486,16 @@ class KexTest (unittest.TestCase):
transport.server_mode = True
kex = KexNistp256(transport)
kex.start_kex()
- self.assertEqual((paramiko.kex_ecdh_nist._MSG_KEXECDH_INIT,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_ecdh_nist._MSG_KEXECDH_INIT,), transport._expect
+ )
- #fake init
- msg=Message()
- Q_C = unhexlify("043ae159594ba062efa121480e9ef136203fa9ec6b6e1f8723a321c16e62b945f573f3b822258cbcd094b9fa1c125cbfe5f043280893e66863cc0cb4dccbe70210")
- H = b'2EF4957AFD530DD3F05DBEABF68D724FACC060974DA9704F2AEE4C3DE861E7CA'
+ # fake init
+ msg = Message()
+ Q_C = unhexlify(
+ "043ae159594ba062efa121480e9ef136203fa9ec6b6e1f8723a321c16e62b945f573f3b822258cbcd094b9fa1c125cbfe5f043280893e66863cc0cb4dccbe70210"
+ )
+ H = b"2EF4957AFD530DD3F05DBEABF68D724FACC060974DA9704F2AEE4C3DE861E7CA"
msg.add_string(Q_C)
msg.rewind()
kex.parse_next(paramiko.kex_ecdh_nist._MSG_KEXECDH_INIT, msg)
diff --git a/tests/test_kex_gss.py b/tests/test_kex_gss.py
index 025d1fa..afddee0 100644
--- a/tests/test_kex_gss.py
+++ b/tests/test_kex_gss.py
@@ -34,14 +34,14 @@ import paramiko
from .util import needs_gssapi
-class NullServer (paramiko.ServerInterface):
+class NullServer(paramiko.ServerInterface):
def get_allowed_auths(self, username):
- return 'gssapi-keyex'
+ return "gssapi-keyex"
- def check_auth_gssapi_keyex(self, username,
- gss_authenticated=paramiko.AUTH_FAILED,
- cc_file=None):
+ def check_auth_gssapi_keyex(
+ self, username, gss_authenticated=paramiko.AUTH_FAILED, cc_file=None
+ ):
if gss_authenticated == paramiko.AUTH_SUCCESSFUL:
return paramiko.AUTH_SUCCESSFUL
return paramiko.AUTH_FAILED
@@ -54,13 +54,14 @@ class NullServer (paramiko.ServerInterface):
return paramiko.OPEN_SUCCEEDED
def check_channel_exec_request(self, channel, command):
- if command != 'yes':
+ if command != "yes":
return False
return True
@needs_gssapi
class GSSKexTest(unittest.TestCase):
+
@staticmethod
def init(username, hostname):
global krb5_principal, targ_name
@@ -86,13 +87,13 @@ class GSSKexTest(unittest.TestCase):
def _run(self):
self.socks, addr = self.sockl.accept()
self.ts = paramiko.Transport(self.socks, gss_kex=True)
- host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
+ host_key = paramiko.RSAKey.from_private_key_file("tests/test_rsa.key")
self.ts.add_server_key(host_key)
self.ts.set_gss_host(targ_name)
try:
self.ts.load_server_moduli()
except:
- print ('(Failed to load moduli -- gex will be unsupported.)')
+ print("(Failed to load moduli -- gex will be unsupported.)")
server = NullServer()
self.ts.start_server(self.event, server)
@@ -102,14 +103,21 @@ class GSSKexTest(unittest.TestCase):
Diffie-Hellman Key Exchange and user authentication with the GSS-API
context created during key exchange.
"""
- host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
+ host_key = paramiko.RSAKey.from_private_key_file("tests/test_rsa.key")
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient()
- self.tc.get_host_keys().add('[%s]:%d' % (self.hostname, self.port),
- 'ssh-rsa', public_host_key)
- self.tc.connect(self.hostname, self.port, username=self.username,
- gss_auth=True, gss_kex=True, gss_host=gss_host)
+ self.tc.get_host_keys().add(
+ "[%s]:%d" % (self.hostname, self.port), "ssh-rsa", public_host_key
+ )
+ self.tc.connect(
+ self.hostname,
+ self.port,
+ username=self.username,
+ gss_auth=True,
+ gss_kex=True,
+ gss_host=gss_host,
+ )
self.event.wait(1.0)
self.assert_(self.event.is_set())
@@ -118,19 +126,19 @@ class GSSKexTest(unittest.TestCase):
self.assertEquals(True, self.ts.is_authenticated())
self.assertEquals(True, self.tc.get_transport().gss_kex_used)
- stdin, stdout, stderr = self.tc.exec_command('yes')
+ stdin, stdout, stderr = self.tc.exec_command("yes")
schan = self.ts.accept(1.0)
if rekey:
self.tc.get_transport().renegotiate_keys()
- schan.send('Hello there.\n')
- schan.send_stderr('This is on stderr.\n')
+ schan.send("Hello there.\n")
+ schan.send_stderr("This is on stderr.\n")
schan.close()
- self.assertEquals('Hello there.\n', stdout.readline())
- self.assertEquals('', stdout.readline())
- self.assertEquals('This is on stderr.\n', stderr.readline())
- self.assertEquals('', stderr.readline())
+ self.assertEquals("Hello there.\n", stdout.readline())
+ self.assertEquals("", stdout.readline())
+ self.assertEquals("This is on stderr.\n", stderr.readline())
+ self.assertEquals("", stderr.readline())
stdin.close()
stdout.close()
diff --git a/tests/test_message.py b/tests/test_message.py
index 645b050..e6b80f3 100644
--- a/tests/test_message.py
+++ b/tests/test_message.py
@@ -26,20 +26,20 @@ from paramiko.message import Message
from paramiko.common import byte_chr, zero_byte
-class MessageTest (unittest.TestCase):
+class MessageTest(unittest.TestCase):
- __a = b'\x00\x00\x00\x17\x07\x60\xe0\x90\x00\x00\x00\x01\x71\x00\x00\x00\x05\x68\x65\x6c\x6c\x6f\x00\x00\x03\xe8' + b'x' * 1000
- __b = b'\x01\x00\xf3\x00\x3f\x00\x00\x00\x10\x68\x75\x65\x79\x2c\x64\x65\x77\x65\x79\x2c\x6c\x6f\x75\x69\x65'
- __c = b'\x00\x00\x00\x00\x00\x00\x00\x05\x00\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x01\x11\x00\x00\x00\x07\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x06\x9a\x1b\x2c\x3d\x4e\xf7'
- __d = b'\x00\x00\x00\x05\xff\x00\x00\x00\x05\x11\x22\x33\x44\x55\xff\x00\x00\x00\x0a\x00\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x63\x61\x74\x00\x00\x00\x03\x61\x2c\x62'
+ __a = b"\x00\x00\x00\x17\x07\x60\xe0\x90\x00\x00\x00\x01\x71\x00\x00\x00\x05\x68\x65\x6c\x6c\x6f\x00\x00\x03\xe8" + b"x" * 1000
+ __b = b"\x01\x00\xf3\x00\x3f\x00\x00\x00\x10\x68\x75\x65\x79\x2c\x64\x65\x77\x65\x79\x2c\x6c\x6f\x75\x69\x65"
+ __c = b"\x00\x00\x00\x00\x00\x00\x00\x05\x00\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x01\x11\x00\x00\x00\x07\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x06\x9a\x1b\x2c\x3d\x4e\xf7"
+ __d = b"\x00\x00\x00\x05\xff\x00\x00\x00\x05\x11\x22\x33\x44\x55\xff\x00\x00\x00\x0a\x00\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x63\x61\x74\x00\x00\x00\x03\x61\x2c\x62"
def test_1_encode(self):
msg = Message()
msg.add_int(23)
msg.add_int(123789456)
- msg.add_string('q')
- msg.add_string('hello')
- msg.add_string('x' * 1000)
+ msg.add_string("q")
+ msg.add_string("hello")
+ msg.add_string("x" * 1000)
self.assertEqual(msg.asbytes(), self.__a)
msg = Message()
@@ -48,7 +48,7 @@ class MessageTest (unittest.TestCase):
msg.add_byte(byte_chr(0xf3))
msg.add_bytes(zero_byte + byte_chr(0x3f))
- msg.add_list(['huey', 'dewey', 'louie'])
+ msg.add_list(["huey", "dewey", "louie"])
self.assertEqual(msg.asbytes(), self.__b)
msg = Message()
@@ -63,16 +63,16 @@ class MessageTest (unittest.TestCase):
msg = Message(self.__a)
self.assertEqual(msg.get_int(), 23)
self.assertEqual(msg.get_int(), 123789456)
- self.assertEqual(msg.get_text(), 'q')
- self.assertEqual(msg.get_text(), 'hello')
- self.assertEqual(msg.get_text(), 'x' * 1000)
+ self.assertEqual(msg.get_text(), "q")
+ self.assertEqual(msg.get_text(), "hello")
+ self.assertEqual(msg.get_text(), "x" * 1000)
msg = Message(self.__b)
self.assertEqual(msg.get_boolean(), True)
self.assertEqual(msg.get_boolean(), False)
self.assertEqual(msg.get_byte(), byte_chr(0xf3))
self.assertEqual(msg.get_bytes(2), zero_byte + byte_chr(0x3f))
- self.assertEqual(msg.get_list(), ['huey', 'dewey', 'louie'])
+ self.assertEqual(msg.get_list(), ["huey", "dewey", "louie"])
msg = Message(self.__c)
self.assertEqual(msg.get_int64(), 5)
@@ -87,8 +87,8 @@ class MessageTest (unittest.TestCase):
msg.add(0x1122334455)
msg.add(0xf00000000000000000)
msg.add(True)
- msg.add('cat')
- msg.add(['a', 'b'])
+ msg.add("cat")
+ msg.add(["a", "b"])
self.assertEqual(msg.asbytes(), self.__d)
def test_4_misc(self):
diff --git a/tests/test_packetizer.py b/tests/test_packetizer.py
index 414b7e3..dbe5993 100644
--- a/tests/test_packetizer.py
+++ b/tests/test_packetizer.py
@@ -36,19 +36,20 @@ from .loop import LoopSocket
x55 = byte_chr(0x55)
x1f = byte_chr(0x1f)
-class PacketizerTest (unittest.TestCase):
+
+class PacketizerTest(unittest.TestCase):
def test_1_write(self):
rsock = LoopSocket()
wsock = LoopSocket()
rsock.link(wsock)
p = Packetizer(wsock)
- p.set_log(util.get_logger('paramiko.transport'))
+ p.set_log(util.get_logger("paramiko.transport"))
p.set_hexdump(True)
encryptor = Cipher(
algorithms.AES(zero_byte * 16),
modes.CBC(x55 * 16),
- backend=default_backend()
+ backend=default_backend(),
).encryptor()
p.set_outbound_cipher(encryptor, 16, sha1, 12, x1f * 20)
@@ -63,22 +64,27 @@ class PacketizerTest (unittest.TestCase):
data = rsock.recv(100)
# 32 + 12 bytes of MAC = 44
self.assertEqual(44, len(data))
- self.assertEqual(b'\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0', data[:16])
+ self.assertEqual(
+ b"\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0",
+ data[:16],
+ )
def test_2_read(self):
rsock = LoopSocket()
wsock = LoopSocket()
rsock.link(wsock)
p = Packetizer(rsock)
- p.set_log(util.get_logger('paramiko.transport'))
+ p.set_log(util.get_logger("paramiko.transport"))
p.set_hexdump(True)
decryptor = Cipher(
algorithms.AES(zero_byte * 16),
modes.CBC(x55 * 16),
- backend=default_backend()
+ backend=default_backend(),
).decryptor()
p.set_inbound_cipher(decryptor, 16, sha1, 12, x1f * 20)
- wsock.send(b'\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0\x90\xd2\x16\x56\x0d\x71\x73\x61\x38\x7c\x4c\x3d\xfb\x97\x7d\xe2\x6e\x03\xb1\xa0\xc2\x1c\xd6\x41\x41\x4c\xb4\x59')
+ wsock.send(
+ b"\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0\x90\xd2\x16\x56\x0d\x71\x73\x61\x38\x7c\x4c\x3d\xfb\x97\x7d\xe2\x6e\x03\xb1\xa0\xc2\x1c\xd6\x41\x41\x4c\xb4\x59"
+ )
cmd, m = p.read_message()
self.assertEqual(100, cmd)
self.assertEqual(100, m.get_int())
@@ -86,18 +92,18 @@ class PacketizerTest (unittest.TestCase):
self.assertEqual(900, m.get_int())
def test_3_closed(self):
- if sys.platform.startswith("win"): # no SIGALRM on windows
+ if sys.platform.startswith("win"): # no SIGALRM on windows
return
rsock = LoopSocket()
wsock = LoopSocket()
rsock.link(wsock)
p = Packetizer(wsock)
- p.set_log(util.get_logger('paramiko.transport'))
+ p.set_log(util.get_logger("paramiko.transport"))
p.set_hexdump(True)
encryptor = Cipher(
algorithms.AES(zero_byte * 16),
modes.CBC(x55 * 16),
- backend=default_backend()
+ backend=default_backend(),
).encryptor()
p.set_outbound_cipher(encryptor, 16, sha1, 12, x1f * 20)
@@ -115,14 +121,17 @@ class PacketizerTest (unittest.TestCase):
import signal
class TimeoutError(Exception):
+
def __init__(self, error_message):
- if hasattr(errno, 'ETIME'):
+ if hasattr(errno, "ETIME"):
self.message = os.sterror(errno.ETIME)
else:
self.messaage = error_message
- def timeout(seconds=1, error_message='Timer expired'):
+ def timeout(seconds=1, error_message="Timer expired"):
+
def decorator(func):
+
def _handle_timeout(signum, frame):
raise TimeoutError(error_message)
@@ -138,5 +147,6 @@ class PacketizerTest (unittest.TestCase):
return wraps(func)(wrapper)
return decorator
+
send = timeout()(p.send_message)
self.assertRaises(EOFError, send, m)
diff --git a/tests/test_pkey.py b/tests/test_pkey.py
index 1827d2a..66bebc4 100644
--- a/tests/test_pkey.py
+++ b/tests/test_pkey.py
@@ -34,18 +34,18 @@ from .util import _support
# from openssh's ssh-keygen
-PUB_RSA = 'ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAIEA049W6geFpmsljTwfvI1UmKWWJPNFI74+vNKTk4dmzkQY2yAMs6FhlvhlI8ysU4oj71ZsRYMecHbBbxdN79+JRFVYTKaLqjwGENeTd+yv4q+V2PvZv3fLnzApI3l7EJCqhWwJUHJ1jAkZzqDx0tyOL4uoZpww3nmE0kb3y21tH4c='
-PUB_DSS = 'ssh-dss AAAAB3NzaC1kc3MAAACBAOeBpgNnfRzr/twmAQRu2XwWAp3CFtrVnug6s6fgwj/oLjYbVtjAy6pl/h0EKCWx2rf1IetyNsTxWrniA9I6HeDj65X1FyDkg6g8tvCnaNB8Xp/UUhuzHuGsMIipRxBxw9LF608EqZcj1E3ytktoW5B5OcjrkEoz3xG7C+rpIjYvAAAAFQDwz4UnmsGiSNu5iqjn3uTzwUpshwAAAIEAkxfFeY8P2wZpDjX0MimZl5wkoFQDL25cPzGBuB4OnB8NoUk/yjAHIIpEShw8V+LzouMK5CTJQo5+Ngw3qIch/WgRmMHy4kBq1SsXMjQCte1So6HBMvBPIW5SiMTmjCfZZiw4AYHK+B/JaOwaG9yRg2Ejg4Ok10+XFDxlqZo8Y+wAAACARmR7CCPjodxASvRbIyzaVpZoJ/Z6x7dAumV+ysrV1BVYd0lYukmnjO1kKBWApqpH1ve9XDQYN8zgxM4b16L21kpoWQnZtXrY3GZ4/it9kUgyB7+NwacIBlXa8cMDL7Q/69o0d54U0X/NeX5QxuYR6OMJlrkQB7oiW/P/1mwjQgE='
-PUB_ECDSA_256 = 'ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJSPZm3ZWkvk/Zx8WP+fZRZ5/NBBHnGQwR6uIC6XHGPDIHuWUzIjAwA0bzqkOUffEsbLe+uQgKl5kbc/L8KA/eo='
-PUB_ECDSA_384 = 'ecdsa-sha2-nistp384 AAAAE2VjZHNhLXNoYTItbmlzdHAzODQAAAAIbmlzdHAzODQAAABhBBbGibQLW9AAZiGN2hEQxWYYoFaWKwN3PKSaDJSMqmIn1Z9sgRUuw8Y/w502OGvXL/wFk0i2z50l3pWZjD7gfMH7gX5TUiCzwrQkS+Hn1U2S9aF5WJp0NcIzYxXw2r4M2A=='
-PUB_ECDSA_521 = 'ecdsa-sha2-nistp521 AAAAE2VjZHNhLXNoYTItbmlzdHA1MjEAAAAIbmlzdHA1MjEAAACFBACaOaFLZGuxa5AW16qj6VLypFbLrEWrt9AZUloCMefxO8bNLjK/O5g0rAVasar1TnyHE9qj4NwzANZASWjQNbc4MAG8vzqezFwLIn/kNyNTsXNfqEko9OgHZknlj2Z79dwTJcRAL4QLcT5aND0EHZLB2fAUDXiWIb2j4rg1mwPlBMiBXA=='
-
-FINGER_RSA = '1024 60:73:38:44:cb:51:86:65:7f:de:da:a2:2b:5a:57:d5'
-FINGER_DSS = '1024 44:78:f0:b9:a2:3c:c5:18:20:09:ff:75:5b:c1:d2:6c'
-FINGER_ECDSA_256 = '256 25:19:eb:55:e6:a1:47:ff:4f:38:d2:75:6f:a5:d5:60'
-FINGER_ECDSA_384 = '384 c1:8d:a0:59:09:47:41:8e:a8:a6:07:01:29:23:b4:65'
-FINGER_ECDSA_521 = '521 44:58:22:52:12:33:16:0e:ce:0e:be:2c:7c:7e:cc:1e'
-SIGNED_RSA = '20:d7:8a:31:21:cb:f7:92:12:f2:a4:89:37:f5:78:af:e6:16:b6:25:b9:97:3d:a2:cd:5f:ca:20:21:73:4c:ad:34:73:8f:20:77:28:e2:94:15:08:d8:91:40:7a:85:83:bf:18:37:95:dc:54:1a:9b:88:29:6c:73:ca:38:b4:04:f1:56:b9:f2:42:9d:52:1b:29:29:b4:4f:fd:c9:2d:af:47:d2:40:76:30:f3:63:45:0c:d9:1d:43:86:0f:1c:70:e2:93:12:34:f3:ac:c5:0a:2f:14:50:66:59:f1:88:ee:c1:4a:e9:d1:9c:4e:46:f0:0e:47:6f:38:74:f1:44:a8'
+PUB_RSA = "ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAIEA049W6geFpmsljTwfvI1UmKWWJPNFI74+vNKTk4dmzkQY2yAMs6FhlvhlI8ysU4oj71ZsRYMecHbBbxdN79+JRFVYTKaLqjwGENeTd+yv4q+V2PvZv3fLnzApI3l7EJCqhWwJUHJ1jAkZzqDx0tyOL4uoZpww3nmE0kb3y21tH4c="
+PUB_DSS = "ssh-dss AAAAB3NzaC1kc3MAAACBAOeBpgNnfRzr/twmAQRu2XwWAp3CFtrVnug6s6fgwj/oLjYbVtjAy6pl/h0EKCWx2rf1IetyNsTxWrniA9I6HeDj65X1FyDkg6g8tvCnaNB8Xp/UUhuzHuGsMIipRxBxw9LF608EqZcj1E3ytktoW5B5OcjrkEoz3xG7C+rpIjYvAAAAFQDwz4UnmsGiSNu5iqjn3uTzwUpshwAAAIEAkxfFeY8P2wZpDjX0MimZl5wkoFQDL25cPzGBuB4OnB8NoUk/yjAHIIpEShw8V+LzouMK5CTJQo5+Ngw3qIch/WgRmMHy4kBq1SsXMjQCte1So6HBMvBPIW5SiMTmjCfZZiw4AYHK+B/JaOwaG9yRg2Ejg4Ok10+XFDxlqZo8Y+wAAACARmR7CCPjodxASvRbIyzaVpZoJ/Z6x7dAumV+ysrV1BVYd0lYukmnjO1kKBWApqpH1ve9XDQYN8zgxM4b16L21kpoWQnZtXrY3GZ4/it9kUgyB7+NwacIBlXa8cMDL7Q/69o0d54U0X/NeX5QxuYR6OMJlrkQB7oiW/P/1mwjQgE="
+PUB_ECDSA_256 = "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJSPZm3ZWkvk/Zx8WP+fZRZ5/NBBHnGQwR6uIC6XHGPDIHuWUzIjAwA0bzqkOUffEsbLe+uQgKl5kbc/L8KA/eo="
+PUB_ECDSA_384 = "ecdsa-sha2-nistp384 AAAAE2VjZHNhLXNoYTItbmlzdHAzODQAAAAIbmlzdHAzODQAAABhBBbGibQLW9AAZiGN2hEQxWYYoFaWKwN3PKSaDJSMqmIn1Z9sgRUuw8Y/w502OGvXL/wFk0i2z50l3pWZjD7gfMH7gX5TUiCzwrQkS+Hn1U2S9aF5WJp0NcIzYxXw2r4M2A=="
+PUB_ECDSA_521 = "ecdsa-sha2-nistp521 AAAAE2VjZHNhLXNoYTItbmlzdHA1MjEAAAAIbmlzdHA1MjEAAACFBACaOaFLZGuxa5AW16qj6VLypFbLrEWrt9AZUloCMefxO8bNLjK/O5g0rAVasar1TnyHE9qj4NwzANZASWjQNbc4MAG8vzqezFwLIn/kNyNTsXNfqEko9OgHZknlj2Z79dwTJcRAL4QLcT5aND0EHZLB2fAUDXiWIb2j4rg1mwPlBMiBXA=="
+
+FINGER_RSA = "1024 60:73:38:44:cb:51:86:65:7f:de:da:a2:2b:5a:57:d5"
+FINGER_DSS = "1024 44:78:f0:b9:a2:3c:c5:18:20:09:ff:75:5b:c1:d2:6c"
+FINGER_ECDSA_256 = "256 25:19:eb:55:e6:a1:47:ff:4f:38:d2:75:6f:a5:d5:60"
+FINGER_ECDSA_384 = "384 c1:8d:a0:59:09:47:41:8e:a8:a6:07:01:29:23:b4:65"
+FINGER_ECDSA_521 = "521 44:58:22:52:12:33:16:0e:ce:0e:be:2c:7c:7e:cc:1e"
+SIGNED_RSA = "20:d7:8a:31:21:cb:f7:92:12:f2:a4:89:37:f5:78:af:e6:16:b6:25:b9:97:3d:a2:cd:5f:ca:20:21:73:4c:ad:34:73:8f:20:77:28:e2:94:15:08:d8:91:40:7a:85:83:bf:18:37:95:dc:54:1a:9b:88:29:6c:73:ca:38:b4:04:f1:56:b9:f2:42:9d:52:1b:29:29:b4:4f:fd:c9:2d:af:47:d2:40:76:30:f3:63:45:0c:d9:1d:43:86:0f:1c:70:e2:93:12:34:f3:ac:c5:0a:2f:14:50:66:59:f1:88:ee:c1:4a:e9:d1:9c:4e:46:f0:0e:47:6f:38:74:f1:44:a8"
RSA_PRIVATE_OUT = """\
-----BEGIN RSA PRIVATE KEY-----
@@ -107,10 +107,10 @@ L4QLcT5aND0EHZLB2fAUDXiWIb2j4rg1mwPlBMiBXA==
-----END EC PRIVATE KEY-----
"""
-x1234 = b'\x01\x02\x03\x04'
+x1234 = b"\x01\x02\x03\x04"
-TEST_KEY_BYTESTR_2 = '\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01#\x00\x00\x00\x81\x00\xd3\x8fV\xea\x07\x85\xa6k%\x8d<\x1f\xbc\x8dT\x98\xa5\x96$\xf3E#\xbe>\xbc\xd2\x93\x93\x87f\xceD\x18\xdb \x0c\xb3\xa1a\x96\xf8e#\xcc\xacS\x8a#\xefVlE\x83\x1epv\xc1o\x17M\xef\xdf\x89DUXL\xa6\x8b\xaa<\x06\x10\xd7\x93w\xec\xaf\xe2\xaf\x95\xd8\xfb\xd9\xbfw\xcb\x9f0)#y{\x10\x90\xaa\x85l\tPru\x8c\t\x19\xce\xa0\xf1\xd2\xdc\x8e/\x8b\xa8f\x9c0\xdey\x84\xd2F\xf7\xcbmm\x1f\x87'
-TEST_KEY_BYTESTR_3 = '\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01#\x00\x00\x00\x00ӏV\x07k%<\x1fT$E#>ғfD\x18 \x0cae#̬S#VlE\x1epvo\x17M߉DUXL<\x06\x10דw\u2bd5ٿw˟0)#y{\x10l\tPru\t\x19Π\u070e/f0yFmm\x1f'
+TEST_KEY_BYTESTR_2 = "\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01#\x00\x00\x00\x81\x00\xd3\x8fV\xea\x07\x85\xa6k%\x8d<\x1f\xbc\x8dT\x98\xa5\x96$\xf3E#\xbe>\xbc\xd2\x93\x93\x87f\xceD\x18\xdb \x0c\xb3\xa1a\x96\xf8e#\xcc\xacS\x8a#\xefVlE\x83\x1epv\xc1o\x17M\xef\xdf\x89DUXL\xa6\x8b\xaa<\x06\x10\xd7\x93w\xec\xaf\xe2\xaf\x95\xd8\xfb\xd9\xbfw\xcb\x9f0)#y{\x10\x90\xaa\x85l\tPru\x8c\t\x19\xce\xa0\xf1\xd2\xdc\x8e/\x8b\xa8f\x9c0\xdey\x84\xd2F\xf7\xcbmm\x1f\x87"
+TEST_KEY_BYTESTR_3 = "\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01#\x00\x00\x00\x00ӏV\x07k%<\x1fT$E#>ғfD\x18 \x0cae#̬S#VlE\x1epvo\x17M߉DUXL<\x06\x10דw\u2bd5ٿw˟0)#y{\x10l\tPru\t\x19Π\u070e/f0yFmm\x1f"
class KeyTest(unittest.TestCase):
@@ -127,21 +127,20 @@ class KeyTest(unittest.TestCase):
"""
with open(keyfile, "r") as fh:
self.assertEqual(
- fh.readline()[:-1],
- "-----BEGIN RSA PRIVATE KEY-----"
+ fh.readline()[:-1], "-----BEGIN RSA PRIVATE KEY-----"
)
self.assertEqual(fh.readline()[:-1], "Proc-Type: 4,ENCRYPTED")
self.assertEqual(fh.readline()[0:10], "DEK-Info: ")
def test_1_generate_key_bytes(self):
- key = util.generate_key_bytes(md5, x1234, 'happy birthday', 30)
- exp = b'\x61\xE1\xF2\x72\xF4\xC1\xC4\x56\x15\x86\xBD\x32\x24\x98\xC0\xE9\x24\x67\x27\x80\xF4\x7B\xB3\x7D\xDA\x7D\x54\x01\x9E\x64'
+ key = util.generate_key_bytes(md5, x1234, "happy birthday", 30)
+ exp = b"\x61\xE1\xF2\x72\xF4\xC1\xC4\x56\x15\x86\xBD\x32\x24\x98\xC0\xE9\x24\x67\x27\x80\xF4\x7B\xB3\x7D\xDA\x7D\x54\x01\x9E\x64"
self.assertEqual(exp, key)
def test_2_load_rsa(self):
- key = RSAKey.from_private_key_file(_support('test_rsa.key'))
- self.assertEqual('ssh-rsa', key.get_name())
- exp_rsa = b(FINGER_RSA.split()[1].replace(':', ''))
+ key = RSAKey.from_private_key_file(_support("test_rsa.key"))
+ self.assertEqual("ssh-rsa", key.get_name())
+ exp_rsa = b(FINGER_RSA.split()[1].replace(":", ""))
my_rsa = hexlify(key.get_fingerprint())
self.assertEqual(exp_rsa, my_rsa)
self.assertEqual(PUB_RSA.split()[1], key.get_base64())
@@ -155,18 +154,20 @@ class KeyTest(unittest.TestCase):
self.assertEqual(key, key2)
def test_3_load_rsa_password(self):
- key = RSAKey.from_private_key_file(_support('test_rsa_password.key'), 'television')
- self.assertEqual('ssh-rsa', key.get_name())
- exp_rsa = b(FINGER_RSA.split()[1].replace(':', ''))
+ key = RSAKey.from_private_key_file(
+ _support("test_rsa_password.key"), "television"
+ )
+ self.assertEqual("ssh-rsa", key.get_name())
+ exp_rsa = b(FINGER_RSA.split()[1].replace(":", ""))
my_rsa = hexlify(key.get_fingerprint())
self.assertEqual(exp_rsa, my_rsa)
self.assertEqual(PUB_RSA.split()[1], key.get_base64())
self.assertEqual(1024, key.get_bits())
def test_4_load_dss(self):
- key = DSSKey.from_private_key_file(_support('test_dss.key'))
- self.assertEqual('ssh-dss', key.get_name())
- exp_dss = b(FINGER_DSS.split()[1].replace(':', ''))
+ key = DSSKey.from_private_key_file(_support("test_dss.key"))
+ self.assertEqual("ssh-dss", key.get_name())
+ exp_dss = b(FINGER_DSS.split()[1].replace(":", ""))
my_dss = hexlify(key.get_fingerprint())
self.assertEqual(exp_dss, my_dss)
self.assertEqual(PUB_DSS.split()[1], key.get_base64())
@@ -180,9 +181,11 @@ class KeyTest(unittest.TestCase):
self.assertEqual(key, key2)
def test_5_load_dss_password(self):
- key = DSSKey.from_private_key_file(_support('test_dss_password.key'), 'television')
- self.assertEqual('ssh-dss', key.get_name())
- exp_dss = b(FINGER_DSS.split()[1].replace(':', ''))
+ key = DSSKey.from_private_key_file(
+ _support("test_dss_password.key"), "television"
+ )
+ self.assertEqual("ssh-dss", key.get_name())
+ exp_dss = b(FINGER_DSS.split()[1].replace(":", ""))
my_dss = hexlify(key.get_fingerprint())
self.assertEqual(exp_dss, my_dss)
self.assertEqual(PUB_DSS.split()[1], key.get_base64())
@@ -190,7 +193,7 @@ class KeyTest(unittest.TestCase):
def test_6_compare_rsa(self):
# verify that the private & public keys compare equal
- key = RSAKey.from_private_key_file(_support('test_rsa.key'))
+ key = RSAKey.from_private_key_file(_support("test_rsa.key"))
self.assertEqual(key, key)
pub = RSAKey(data=key.asbytes())
self.assertTrue(key.can_sign())
@@ -199,7 +202,7 @@ class KeyTest(unittest.TestCase):
def test_7_compare_dss(self):
# verify that the private & public keys compare equal
- key = DSSKey.from_private_key_file(_support('test_dss.key'))
+ key = DSSKey.from_private_key_file(_support("test_dss.key"))
self.assertEqual(key, key)
pub = DSSKey(data=key.asbytes())
self.assertTrue(key.can_sign())
@@ -208,77 +211,79 @@ class KeyTest(unittest.TestCase):
def test_8_sign_rsa(self):
# verify that the rsa private key can sign and verify
- key = RSAKey.from_private_key_file(_support('test_rsa.key'))
- msg = key.sign_ssh_data(b'ice weasels')
+ key = RSAKey.from_private_key_file(_support("test_rsa.key"))
+ msg = key.sign_ssh_data(b"ice weasels")
self.assertTrue(type(msg) is Message)
msg.rewind()
- self.assertEqual('ssh-rsa', msg.get_text())
- sig = bytes().join([byte_chr(int(x, 16)) for x in SIGNED_RSA.split(':')])
+ self.assertEqual("ssh-rsa", msg.get_text())
+ sig = bytes().join(
+ [byte_chr(int(x, 16)) for x in SIGNED_RSA.split(":")]
+ )
self.assertEqual(sig, msg.get_binary())
msg.rewind()
pub = RSAKey(data=key.asbytes())
- self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg))
+ self.assertTrue(pub.verify_ssh_sig(b"ice weasels", msg))
def test_9_sign_dss(self):
# verify that the dss private key can sign and verify
- key = DSSKey.from_private_key_file(_support('test_dss.key'))
- msg = key.sign_ssh_data(b'ice weasels')
+ key = DSSKey.from_private_key_file(_support("test_dss.key"))
+ msg = key.sign_ssh_data(b"ice weasels")
self.assertTrue(type(msg) is Message)
msg.rewind()
- self.assertEqual('ssh-dss', msg.get_text())
+ self.assertEqual("ssh-dss", msg.get_text())
# can't do the same test as we do for RSA, because DSS signatures
# are usually different each time. but we can test verification
# anyway so it's ok.
self.assertEqual(40, len(msg.get_binary()))
msg.rewind()
pub = DSSKey(data=key.asbytes())
- self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg))
+ self.assertTrue(pub.verify_ssh_sig(b"ice weasels", msg))
def test_A_generate_rsa(self):
key = RSAKey.generate(1024)
- msg = key.sign_ssh_data(b'jerri blank')
+ msg = key.sign_ssh_data(b"jerri blank")
msg.rewind()
- self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg))
+ self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg))
def test_B_generate_dss(self):
key = DSSKey.generate(1024)
- msg = key.sign_ssh_data(b'jerri blank')
+ msg = key.sign_ssh_data(b"jerri blank")
msg.rewind()
- self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg))
+ self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg))
def test_C_generate_ecdsa(self):
key = ECDSAKey.generate()
- msg = key.sign_ssh_data(b'jerri blank')
+ msg = key.sign_ssh_data(b"jerri blank")
msg.rewind()
- self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg))
+ self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg))
self.assertEqual(key.get_bits(), 256)
- self.assertEqual(key.get_name(), 'ecdsa-sha2-nistp256')
+ self.assertEqual(key.get_name(), "ecdsa-sha2-nistp256")
key = ECDSAKey.generate(bits=256)
- msg = key.sign_ssh_data(b'jerri blank')
+ msg = key.sign_ssh_data(b"jerri blank")
msg.rewind()
- self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg))
+ self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg))
self.assertEqual(key.get_bits(), 256)
- self.assertEqual(key.get_name(), 'ecdsa-sha2-nistp256')
+ self.assertEqual(key.get_name(), "ecdsa-sha2-nistp256")
key = ECDSAKey.generate(bits=384)
- msg = key.sign_ssh_data(b'jerri blank')
+ msg = key.sign_ssh_data(b"jerri blank")
msg.rewind()
- self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg))
+ self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg))
self.assertEqual(key.get_bits(), 384)
- self.assertEqual(key.get_name(), 'ecdsa-sha2-nistp384')
+ self.assertEqual(key.get_name(), "ecdsa-sha2-nistp384")
key = ECDSAKey.generate(bits=521)
- msg = key.sign_ssh_data(b'jerri blank')
+ msg = key.sign_ssh_data(b"jerri blank")
msg.rewind()
- self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg))
+ self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg))
self.assertEqual(key.get_bits(), 521)
- self.assertEqual(key.get_name(), 'ecdsa-sha2-nistp521')
+ self.assertEqual(key.get_name(), "ecdsa-sha2-nistp521")
def test_10_load_ecdsa_256(self):
- key = ECDSAKey.from_private_key_file(_support('test_ecdsa_256.key'))
- self.assertEqual('ecdsa-sha2-nistp256', key.get_name())
- exp_ecdsa = b(FINGER_ECDSA_256.split()[1].replace(':', ''))
+ key = ECDSAKey.from_private_key_file(_support("test_ecdsa_256.key"))
+ self.assertEqual("ecdsa-sha2-nistp256", key.get_name())
+ exp_ecdsa = b(FINGER_ECDSA_256.split()[1].replace(":", ""))
my_ecdsa = hexlify(key.get_fingerprint())
self.assertEqual(exp_ecdsa, my_ecdsa)
self.assertEqual(PUB_ECDSA_256.split()[1], key.get_base64())
@@ -292,9 +297,11 @@ class KeyTest(unittest.TestCase):
self.assertEqual(key, key2)
def test_11_load_ecdsa_password_256(self):
- key = ECDSAKey.from_private_key_file(_support('test_ecdsa_password_256.key'), b'television')
- self.assertEqual('ecdsa-sha2-nistp256', key.get_name())
- exp_ecdsa = b(FINGER_ECDSA_256.split()[1].replace(':', ''))
+ key = ECDSAKey.from_private_key_file(
+ _support("test_ecdsa_password_256.key"), b"television"
+ )
+ self.assertEqual("ecdsa-sha2-nistp256", key.get_name())
+ exp_ecdsa = b(FINGER_ECDSA_256.split()[1].replace(":", ""))
my_ecdsa = hexlify(key.get_fingerprint())
self.assertEqual(exp_ecdsa, my_ecdsa)
self.assertEqual(PUB_ECDSA_256.split()[1], key.get_base64())
@@ -302,7 +309,7 @@ class KeyTest(unittest.TestCase):
def test_12_compare_ecdsa_256(self):
# verify that the private & public keys compare equal
- key = ECDSAKey.from_private_key_file(_support('test_ecdsa_256.key'))
+ key = ECDSAKey.from_private_key_file(_support("test_ecdsa_256.key"))
self.assertEqual(key, key)
pub = ECDSAKey(data=key.asbytes())
self.assertTrue(key.can_sign())
@@ -311,11 +318,11 @@ class KeyTest(unittest.TestCase):
def test_13_sign_ecdsa_256(self):
# verify that the rsa private key can sign and verify
- key = ECDSAKey.from_private_key_file(_support('test_ecdsa_256.key'))
- msg = key.sign_ssh_data(b'ice weasels')
+ key = ECDSAKey.from_private_key_file(_support("test_ecdsa_256.key"))
+ msg = key.sign_ssh_data(b"ice weasels")
self.assertTrue(type(msg) is Message)
msg.rewind()
- self.assertEqual('ecdsa-sha2-nistp256', msg.get_text())
+ self.assertEqual("ecdsa-sha2-nistp256", msg.get_text())
# ECDSA signatures, like DSS signatures, tend to be different
# each time, so we can't compare against a "known correct"
# signature.
@@ -323,12 +330,12 @@ class KeyTest(unittest.TestCase):
msg.rewind()
pub = ECDSAKey(data=key.asbytes())
- self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg))
+ self.assertTrue(pub.verify_ssh_sig(b"ice weasels", msg))
def test_14_load_ecdsa_384(self):
- key = ECDSAKey.from_private_key_file(_support('test_ecdsa_384.key'))
- self.assertEqual('ecdsa-sha2-nistp384', key.get_name())
- exp_ecdsa = b(FINGER_ECDSA_384.split()[1].replace(':', ''))
+ key = ECDSAKey.from_private_key_file(_support("test_ecdsa_384.key"))
+ self.assertEqual("ecdsa-sha2-nistp384", key.get_name())
+ exp_ecdsa = b(FINGER_ECDSA_384.split()[1].replace(":", ""))
my_ecdsa = hexlify(key.get_fingerprint())
self.assertEqual(exp_ecdsa, my_ecdsa)
self.assertEqual(PUB_ECDSA_384.split()[1], key.get_base64())
@@ -342,9 +349,11 @@ class KeyTest(unittest.TestCase):
self.assertEqual(key, key2)
def test_15_load_ecdsa_password_384(self):
- key = ECDSAKey.from_private_key_file(_support('test_ecdsa_password_384.key'), b'television')
- self.assertEqual('ecdsa-sha2-nistp384', key.get_name())
- exp_ecdsa = b(FINGER_ECDSA_384.split()[1].replace(':', ''))
+ key = ECDSAKey.from_private_key_file(
+ _support("test_ecdsa_password_384.key"), b"television"
+ )
+ self.assertEqual("ecdsa-sha2-nistp384", key.get_name())
+ exp_ecdsa = b(FINGER_ECDSA_384.split()[1].replace(":", ""))
my_ecdsa = hexlify(key.get_fingerprint())
self.assertEqual(exp_ecdsa, my_ecdsa)
self.assertEqual(PUB_ECDSA_384.split()[1], key.get_base64())
@@ -352,7 +361,7 @@ class KeyTest(unittest.TestCase):
def test_16_compare_ecdsa_384(self):
# verify that the private & public keys compare equal
- key = ECDSAKey.from_private_key_file(_support('test_ecdsa_384.key'))
+ key = ECDSAKey.from_private_key_file(_support("test_ecdsa_384.key"))
self.assertEqual(key, key)
pub = ECDSAKey(data=key.asbytes())
self.assertTrue(key.can_sign())
@@ -361,11 +370,11 @@ class KeyTest(unittest.TestCase):
def test_17_sign_ecdsa_384(self):
# verify that the rsa private key can sign and verify
- key = ECDSAKey.from_private_key_file(_support('test_ecdsa_384.key'))
- msg = key.sign_ssh_data(b'ice weasels')
+ key = ECDSAKey.from_private_key_file(_support("test_ecdsa_384.key"))
+ msg = key.sign_ssh_data(b"ice weasels")
self.assertTrue(type(msg) is Message)
msg.rewind()
- self.assertEqual('ecdsa-sha2-nistp384', msg.get_text())
+ self.assertEqual("ecdsa-sha2-nistp384", msg.get_text())
# ECDSA signatures, like DSS signatures, tend to be different
# each time, so we can't compare against a "known correct"
# signature.
@@ -373,12 +382,12 @@ class KeyTest(unittest.TestCase):
msg.rewind()
pub = ECDSAKey(data=key.asbytes())
- self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg))
+ self.assertTrue(pub.verify_ssh_sig(b"ice weasels", msg))
def test_18_load_ecdsa_521(self):
- key = ECDSAKey.from_private_key_file(_support('test_ecdsa_521.key'))
- self.assertEqual('ecdsa-sha2-nistp521', key.get_name())
- exp_ecdsa = b(FINGER_ECDSA_521.split()[1].replace(':', ''))
+ key = ECDSAKey.from_private_key_file(_support("test_ecdsa_521.key"))
+ self.assertEqual("ecdsa-sha2-nistp521", key.get_name())
+ exp_ecdsa = b(FINGER_ECDSA_521.split()[1].replace(":", ""))
my_ecdsa = hexlify(key.get_fingerprint())
self.assertEqual(exp_ecdsa, my_ecdsa)
self.assertEqual(PUB_ECDSA_521.split()[1], key.get_base64())
@@ -395,9 +404,11 @@ class KeyTest(unittest.TestCase):
self.assertEqual(key, key2)
def test_19_load_ecdsa_password_521(self):
- key = ECDSAKey.from_private_key_file(_support('test_ecdsa_password_521.key'), b'television')
- self.assertEqual('ecdsa-sha2-nistp521', key.get_name())
- exp_ecdsa = b(FINGER_ECDSA_521.split()[1].replace(':', ''))
+ key = ECDSAKey.from_private_key_file(
+ _support("test_ecdsa_password_521.key"), b"television"
+ )
+ self.assertEqual("ecdsa-sha2-nistp521", key.get_name())
+ exp_ecdsa = b(FINGER_ECDSA_521.split()[1].replace(":", ""))
my_ecdsa = hexlify(key.get_fingerprint())
self.assertEqual(exp_ecdsa, my_ecdsa)
self.assertEqual(PUB_ECDSA_521.split()[1], key.get_base64())
@@ -405,7 +416,7 @@ class KeyTest(unittest.TestCase):
def test_20_compare_ecdsa_521(self):
# verify that the private & public keys compare equal
- key = ECDSAKey.from_private_key_file(_support('test_ecdsa_521.key'))
+ key = ECDSAKey.from_private_key_file(_support("test_ecdsa_521.key"))
self.assertEqual(key, key)
pub = ECDSAKey(data=key.asbytes())
self.assertTrue(key.can_sign())
@@ -414,11 +425,11 @@ class KeyTest(unittest.TestCase):
def test_21_sign_ecdsa_521(self):
# verify that the rsa private key can sign and verify
- key = ECDSAKey.from_private_key_file(_support('test_ecdsa_521.key'))
- msg = key.sign_ssh_data(b'ice weasels')
+ key = ECDSAKey.from_private_key_file(_support("test_ecdsa_521.key"))
+ msg = key.sign_ssh_data(b"ice weasels")
self.assertTrue(type(msg) is Message)
msg.rewind()
- self.assertEqual('ecdsa-sha2-nistp521', msg.get_text())
+ self.assertEqual("ecdsa-sha2-nistp521", msg.get_text())
# ECDSA signatures, like DSS signatures, tend to be different
# each time, so we can't compare against a "known correct"
# signature.
@@ -426,14 +437,14 @@ class KeyTest(unittest.TestCase):
msg.rewind()
pub = ECDSAKey(data=key.asbytes())
- self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg))
+ self.assertTrue(pub.verify_ssh_sig(b"ice weasels", msg))
def test_salt_size(self):
# Read an existing encrypted private key
- file_ = _support('test_rsa_password.key')
- password = 'television'
- newfile = file_ + '.new'
- newpassword = 'radio'
+ file_ = _support("test_rsa_password.key")
+ password = "television"
+ newfile = file_ + ".new"
+ newpassword = "radio"
key = RSAKey(filename=file_, password=password)
# Write out a newly re-encrypted copy with a new password.
# When the bug under test exists, this will ValueError.
@@ -447,20 +458,20 @@ class KeyTest(unittest.TestCase):
os.remove(newfile)
def test_stringification(self):
- key = RSAKey.from_private_key_file(_support('test_rsa.key'))
+ key = RSAKey.from_private_key_file(_support("test_rsa.key"))
comparable = TEST_KEY_BYTESTR_2 if PY2 else TEST_KEY_BYTESTR_3
self.assertEqual(str(key), comparable)
def test_ed25519(self):
- key1 = Ed25519Key.from_private_key_file(_support('test_ed25519.key'))
+ key1 = Ed25519Key.from_private_key_file(_support("test_ed25519.key"))
key2 = Ed25519Key.from_private_key_file(
- _support('test_ed25519_password.key'), b'abc123'
+ _support("test_ed25519_password.key"), b"abc123"
)
self.assertNotEqual(key1.asbytes(), key2.asbytes())
def test_ed25519_compare(self):
# verify that the private & public keys compare equal
- key = Ed25519Key.from_private_key_file(_support('test_ed25519.key'))
+ key = Ed25519Key.from_private_key_file(_support("test_ed25519.key"))
self.assertEqual(key, key)
pub = Ed25519Key(data=key.asbytes())
self.assertTrue(key.can_sign())
@@ -470,25 +481,25 @@ class KeyTest(unittest.TestCase):
def test_ed25519_nonbytes_password(self):
# https://github.com/paramiko/paramiko/issues/1039
key = Ed25519Key.from_private_key_file(
- _support('test_ed25519_password.key'),
+ _support("test_ed25519_password.key"),
# NOTE: not a bytes. Amusingly, the test above for same key DOES
# explicitly cast to bytes...code smell!
- 'abc123',
+ "abc123",
)
# No exception -> it's good. Meh.
def test_ed25519_load_from_file_obj(self):
- with open(_support('test_ed25519.key')) as pkey_fileobj:
+ with open(_support("test_ed25519.key")) as pkey_fileobj:
key = Ed25519Key.from_private_key(pkey_fileobj)
self.assertEqual(key, key)
self.assertTrue(key.can_sign())
def test_keyfile_is_actually_encrypted(self):
# Read an existing encrypted private key
- file_ = _support('test_rsa_password.key')
- password = 'television'
- newfile = file_ + '.new'
- newpassword = 'radio'
+ file_ = _support("test_rsa_password.key")
+ password = "television"
+ newfile = file_ + ".new"
+ newpassword = "radio"
key = RSAKey(filename=file_, password=password)
# Write out a newly re-encrypted copy with a new password.
# When the bug under test exists, this will ValueError.
@@ -503,19 +514,21 @@ class KeyTest(unittest.TestCase):
# test_client.py; this and nearby cert tests are more about the gritty
# details.
# PKey.load_certificate
- key_path = _support(os.path.join('cert_support', 'test_rsa.key'))
+ key_path = _support(os.path.join("cert_support", "test_rsa.key"))
key = RSAKey.from_private_key_file(key_path)
self.assertTrue(key.public_blob is None)
cert_path = _support(
- os.path.join('cert_support', 'test_rsa.key-cert.pub')
+ os.path.join("cert_support", "test_rsa.key-cert.pub")
)
key.load_certificate(cert_path)
self.assertTrue(key.public_blob is not None)
- self.assertEqual(key.public_blob.key_type, 'ssh-rsa-cert-v01@openssh.com')
- self.assertEqual(key.public_blob.comment, 'test_rsa.key.pub')
+ self.assertEqual(
+ key.public_blob.key_type, "ssh-rsa-cert-v01@openssh.com"
+ )
+ self.assertEqual(key.public_blob.comment, "test_rsa.key.pub")
# Delve into blob contents, for test purposes
msg = Message(key.public_blob.key_blob)
- self.assertEqual(msg.get_text(), 'ssh-rsa-cert-v01@openssh.com')
+ self.assertEqual(msg.get_text(), "ssh-rsa-cert-v01@openssh.com")
nonce = msg.get_string()
e = msg.get_mpint()
n = msg.get_mpint()
@@ -525,10 +538,10 @@ class KeyTest(unittest.TestCase):
self.assertEqual(msg.get_int64(), 1234)
# Prevented from loading certificate that doesn't match
- key_path = _support(os.path.join('cert_support', 'test_ed25519.key'))
+ key_path = _support(os.path.join("cert_support", "test_ed25519.key"))
key1 = Ed25519Key.from_private_key_file(key_path)
self.assertRaises(
ValueError,
key1.load_certificate,
- _support('test_rsa.key-cert.pub'),
+ _support("test_rsa.key-cert.pub"),
)
diff --git a/tests/test_sftp.py b/tests/test_sftp.py
index 09a5045..a03961d 100644
--- a/tests/test_sftp.py
+++ b/tests/test_sftp.py
@@ -45,7 +45,7 @@ from .stub_sftp import StubServer, StubSFTPServer
from .util import _support, slow
-ARTICLE = '''
+ARTICLE = """
Insulin sensitivity and liver insulin receptor structure in ducks from two
genera
@@ -70,7 +70,7 @@ receptors. Therefore the ducks from the two genera exhibit an alpha-beta-
structure for liver insulin receptors and a clear difference in the number of
liver insulin receptors. Their sensitivity to insulin is, however, similarly
decreased compared with chicken.
-'''
+"""
# Here is how unicode characters are encoded over 1 to 6 bytes in utf-8
@@ -82,32 +82,33 @@ decreased compared with chicken.
# U-04000000 - U-7FFFFFFF: 1111110x 10xxxxxx 10xxxxxx 10xxxxxx 10xxxxxx 10xxxxxx
# Note that: hex(int('11000011',2)) == '0xc3'
# Thus, the following 2-bytes sequence is not valid utf8: "invalid continuation byte"
-NON_UTF8_DATA = b'\xC3\xC3'
+NON_UTF8_DATA = b"\xC3\xC3"
-unicode_folder = u'\u00fcnic\u00f8de' if PY2 else '\u00fcnic\u00f8de'
-utf8_folder = b'/\xc3\xbcnic\xc3\xb8\x64\x65'
+unicode_folder = u"\u00fcnic\u00f8de" if PY2 else "\u00fcnic\u00f8de"
+utf8_folder = b"/\xc3\xbcnic\xc3\xb8\x64\x65"
@slow
class TestSFTP(object):
+
def test_1_file(self, sftp):
"""
verify that we can create a file.
"""
- f = sftp.open(sftp.FOLDER + '/test', 'w')
+ f = sftp.open(sftp.FOLDER + "/test", "w")
try:
assert f.stat().st_size == 0
finally:
f.close()
- sftp.remove(sftp.FOLDER + '/test')
+ sftp.remove(sftp.FOLDER + "/test")
def test_2_close(self, sftp):
"""
Verify that SFTP session close() causes a socket error on next action.
"""
sftp.close()
- with pytest.raises(socket.error, match='Socket is closed'):
- sftp.open(sftp.FOLDER + '/test2', 'w')
+ with pytest.raises(socket.error, match="Socket is closed"):
+ sftp.open(sftp.FOLDER + "/test2", "w")
def test_2_sftp_can_be_used_as_context_manager(self, sftp):
"""
@@ -115,117 +116,117 @@ class TestSFTP(object):
"""
with sftp:
pass
- with pytest.raises(socket.error, match='Socket is closed'):
- sftp.open(sftp.FOLDER + '/test2', 'w')
+ with pytest.raises(socket.error, match="Socket is closed"):
+ sftp.open(sftp.FOLDER + "/test2", "w")
def test_3_write(self, sftp):
"""
verify that a file can be created and written, and the size is correct.
"""
try:
- with sftp.open(sftp.FOLDER + '/duck.txt', 'w') as f:
+ with sftp.open(sftp.FOLDER + "/duck.txt", "w") as f:
f.write(ARTICLE)
- assert sftp.stat(sftp.FOLDER + '/duck.txt').st_size == 1483
+ assert sftp.stat(sftp.FOLDER + "/duck.txt").st_size == 1483
finally:
- sftp.remove(sftp.FOLDER + '/duck.txt')
+ sftp.remove(sftp.FOLDER + "/duck.txt")
def test_3_sftp_file_can_be_used_as_context_manager(self, sftp):
"""
verify that an opened file can be used as a context manager
"""
try:
- with sftp.open(sftp.FOLDER + '/duck.txt', 'w') as f:
+ with sftp.open(sftp.FOLDER + "/duck.txt", "w") as f:
f.write(ARTICLE)
- assert sftp.stat(sftp.FOLDER + '/duck.txt').st_size == 1483
+ assert sftp.stat(sftp.FOLDER + "/duck.txt").st_size == 1483
finally:
- sftp.remove(sftp.FOLDER + '/duck.txt')
+ sftp.remove(sftp.FOLDER + "/duck.txt")
def test_4_append(self, sftp):
"""
verify that a file can be opened for append, and tell() still works.
"""
try:
- with sftp.open(sftp.FOLDER + '/append.txt', 'w') as f:
- f.write('first line\nsecond line\n')
+ with sftp.open(sftp.FOLDER + "/append.txt", "w") as f:
+ f.write("first line\nsecond line\n")
assert f.tell() == 23
- with sftp.open(sftp.FOLDER + '/append.txt', 'a+') as f:
- f.write('third line!!!\n')
+ with sftp.open(sftp.FOLDER + "/append.txt", "a+") as f:
+ f.write("third line!!!\n")
assert f.tell() == 37
assert f.stat().st_size == 37
f.seek(-26, f.SEEK_CUR)
- assert f.readline() == 'second line\n'
+ assert f.readline() == "second line\n"
finally:
- sftp.remove(sftp.FOLDER + '/append.txt')
+ sftp.remove(sftp.FOLDER + "/append.txt")
def test_5_rename(self, sftp):
"""
verify that renaming a file works.
"""
try:
- with sftp.open(sftp.FOLDER + '/first.txt', 'w') as f:
- f.write('content!\n')
- sftp.rename(sftp.FOLDER + '/first.txt', sftp.FOLDER + '/second.txt')
- with pytest.raises(IOError, match='No such file'):
- sftp.open(sftp.FOLDER + '/first.txt', 'r')
- with sftp.open(sftp.FOLDER + '/second.txt', 'r') as f:
+ with sftp.open(sftp.FOLDER + "/first.txt", "w") as f:
+ f.write("content!\n")
+ sftp.rename(
+ sftp.FOLDER + "/first.txt", sftp.FOLDER + "/second.txt"
+ )
+ with pytest.raises(IOError, match="No such file"):
+ sftp.open(sftp.FOLDER + "/first.txt", "r")
+ with sftp.open(sftp.FOLDER + "/second.txt", "r") as f:
f.seek(-6, f.SEEK_END)
- assert u(f.read(4)) == 'tent'
+ assert u(f.read(4)) == "tent"
finally:
# TODO: this is gross, make some sort of 'remove if possible' / 'rm
# -f' a-like, jeez
try:
- sftp.remove(sftp.FOLDER + '/first.txt')
+ sftp.remove(sftp.FOLDER + "/first.txt")
except:
pass
try:
- sftp.remove(sftp.FOLDER + '/second.txt')
+ sftp.remove(sftp.FOLDER + "/second.txt")
except:
pass
-
def test_5a_posix_rename(self, sftp):
"""Test posix-rename@openssh.com protocol extension."""
try:
# first check that the normal rename works as specified
- with sftp.open(sftp.FOLDER + '/a', 'w') as f:
- f.write('one')
- sftp.rename(sftp.FOLDER + '/a', sftp.FOLDER + '/b')
- with sftp.open(sftp.FOLDER + '/a', 'w') as f:
- f.write('two')
- with pytest.raises(IOError): # actual message seems generic
- sftp.rename(sftp.FOLDER + '/a', sftp.FOLDER + '/b')
+ with sftp.open(sftp.FOLDER + "/a", "w") as f:
+ f.write("one")
+ sftp.rename(sftp.FOLDER + "/a", sftp.FOLDER + "/b")
+ with sftp.open(sftp.FOLDER + "/a", "w") as f:
+ f.write("two")
+ with pytest.raises(IOError): # actual message seems generic
+ sftp.rename(sftp.FOLDER + "/a", sftp.FOLDER + "/b")
# now check with the posix_rename
- sftp.posix_rename(sftp.FOLDER + '/a', sftp.FOLDER + '/b')
- with sftp.open(sftp.FOLDER + '/b', 'r') as f:
+ sftp.posix_rename(sftp.FOLDER + "/a", sftp.FOLDER + "/b")
+ with sftp.open(sftp.FOLDER + "/b", "r") as f:
data = u(f.read())
err = "Contents of renamed file not the same as original file"
- assert 'two' == data, err
+ assert "two" == data, err
finally:
try:
- sftp.remove(sftp.FOLDER + '/a')
+ sftp.remove(sftp.FOLDER + "/a")
except:
pass
try:
- sftp.remove(sftp.FOLDER + '/b')
+ sftp.remove(sftp.FOLDER + "/b")
except:
pass
-
def test_6_folder(self, sftp):
"""
create a temporary folder, verify that we can create a file in it, then
remove the folder and verify that we can't create a file in it anymore.
"""
- sftp.mkdir(sftp.FOLDER + '/subfolder')
- sftp.open(sftp.FOLDER + '/subfolder/test', 'w').close()
- sftp.remove(sftp.FOLDER + '/subfolder/test')
- sftp.rmdir(sftp.FOLDER + '/subfolder')
+ sftp.mkdir(sftp.FOLDER + "/subfolder")
+ sftp.open(sftp.FOLDER + "/subfolder/test", "w").close()
+ sftp.remove(sftp.FOLDER + "/subfolder/test")
+ sftp.rmdir(sftp.FOLDER + "/subfolder")
# shouldn't be able to create that file if dir removed
with pytest.raises(IOError, match="No such file"):
- sftp.open(sftp.FOLDER + '/subfolder/test')
+ sftp.open(sftp.FOLDER + "/subfolder/test")
def test_7_listdir(self, sftp):
"""
@@ -233,57 +234,57 @@ class TestSFTP(object):
it, and those files show up in sftp.listdir.
"""
try:
- sftp.open(sftp.FOLDER + '/duck.txt', 'w').close()
- sftp.open(sftp.FOLDER + '/fish.txt', 'w').close()
- sftp.open(sftp.FOLDER + '/tertiary.py', 'w').close()
+ sftp.open(sftp.FOLDER + "/duck.txt", "w").close()
+ sftp.open(sftp.FOLDER + "/fish.txt", "w").close()
+ sftp.open(sftp.FOLDER + "/tertiary.py", "w").close()
x = sftp.listdir(sftp.FOLDER)
assert len(x) == 3
- assert 'duck.txt' in x
- assert 'fish.txt' in x
- assert 'tertiary.py' in x
- assert 'random' not in x
+ assert "duck.txt" in x
+ assert "fish.txt" in x
+ assert "tertiary.py" in x
+ assert "random" not in x
finally:
- sftp.remove(sftp.FOLDER + '/duck.txt')
- sftp.remove(sftp.FOLDER + '/fish.txt')
- sftp.remove(sftp.FOLDER + '/tertiary.py')
+ sftp.remove(sftp.FOLDER + "/duck.txt")
+ sftp.remove(sftp.FOLDER + "/fish.txt")
+ sftp.remove(sftp.FOLDER + "/tertiary.py")
def test_7_5_listdir_iter(self, sftp):
"""
listdir_iter version of above test
"""
try:
- sftp.open(sftp.FOLDER + '/duck.txt', 'w').close()
- sftp.open(sftp.FOLDER + '/fish.txt', 'w').close()
- sftp.open(sftp.FOLDER + '/tertiary.py', 'w').close()
+ sftp.open(sftp.FOLDER + "/duck.txt", "w").close()
+ sftp.open(sftp.FOLDER + "/fish.txt", "w").close()
+ sftp.open(sftp.FOLDER + "/tertiary.py", "w").close()
x = [x.filename for x in sftp.listdir_iter(sftp.FOLDER)]
assert len(x) == 3
- assert 'duck.txt' in x
- assert 'fish.txt' in x
- assert 'tertiary.py' in x
- assert 'random' not in x
+ assert "duck.txt" in x
+ assert "fish.txt" in x
+ assert "tertiary.py" in x
+ assert "random" not in x
finally:
- sftp.remove(sftp.FOLDER + '/duck.txt')
- sftp.remove(sftp.FOLDER + '/fish.txt')
- sftp.remove(sftp.FOLDER + '/tertiary.py')
+ sftp.remove(sftp.FOLDER + "/duck.txt")
+ sftp.remove(sftp.FOLDER + "/fish.txt")
+ sftp.remove(sftp.FOLDER + "/tertiary.py")
def test_8_setstat(self, sftp):
"""
verify that the setstat functions (chown, chmod, utime, truncate) work.
"""
try:
- with sftp.open(sftp.FOLDER + '/special', 'w') as f:
- f.write('x' * 1024)
+ with sftp.open(sftp.FOLDER + "/special", "w") as f:
+ f.write("x" * 1024)
- stat = sftp.stat(sftp.FOLDER + '/special')
- sftp.chmod(sftp.FOLDER + '/special', (stat.st_mode & ~o777) | o600)
- stat = sftp.stat(sftp.FOLDER + '/special')
+ stat = sftp.stat(sftp.FOLDER + "/special")
+ sftp.chmod(sftp.FOLDER + "/special", (stat.st_mode & ~o777) | o600)
+ stat = sftp.stat(sftp.FOLDER + "/special")
expected_mode = o600
- if sys.platform == 'win32':
+ if sys.platform == "win32":
# chmod not really functional on windows
expected_mode = o666
- if sys.platform == 'cygwin':
+ if sys.platform == "cygwin":
# even worse.
expected_mode = o644
assert stat.st_mode & o777 == expected_mode
@@ -291,19 +292,19 @@ class TestSFTP(object):
mtime = stat.st_mtime - 3600
atime = stat.st_atime - 1800
- sftp.utime(sftp.FOLDER + '/special', (atime, mtime))
- stat = sftp.stat(sftp.FOLDER + '/special')
+ sftp.utime(sftp.FOLDER + "/special", (atime, mtime))
+ stat = sftp.stat(sftp.FOLDER + "/special")
assert stat.st_mtime == mtime
- if sys.platform not in ('win32', 'cygwin'):
+ if sys.platform not in ("win32", "cygwin"):
assert stat.st_atime == atime
# can't really test chown, since we'd have to know a valid uid.
- sftp.truncate(sftp.FOLDER + '/special', 512)
- stat = sftp.stat(sftp.FOLDER + '/special')
+ sftp.truncate(sftp.FOLDER + "/special", 512)
+ stat = sftp.stat(sftp.FOLDER + "/special")
assert stat.st_size == 512
finally:
- sftp.remove(sftp.FOLDER + '/special')
+ sftp.remove(sftp.FOLDER + "/special")
def test_9_fsetstat(self, sftp):
"""
@@ -311,19 +312,19 @@ class TestSFTP(object):
work on open files.
"""
try:
- with sftp.open(sftp.FOLDER + '/special', 'w') as f:
- f.write('x' * 1024)
+ with sftp.open(sftp.FOLDER + "/special", "w") as f:
+ f.write("x" * 1024)
- with sftp.open(sftp.FOLDER + '/special', 'r+') as f:
+ with sftp.open(sftp.FOLDER + "/special", "r+") as f:
stat = f.stat()
f.chmod((stat.st_mode & ~o777) | o600)
stat = f.stat()
expected_mode = o600
- if sys.platform == 'win32':
+ if sys.platform == "win32":
# chmod not really functional on windows
expected_mode = o666
- if sys.platform == 'cygwin':
+ if sys.platform == "cygwin":
# even worse.
expected_mode = o644
assert stat.st_mode & o777 == expected_mode
@@ -334,7 +335,7 @@ class TestSFTP(object):
f.utime((atime, mtime))
stat = f.stat()
assert stat.st_mtime == mtime
- if sys.platform not in ('win32', 'cygwin'):
+ if sys.platform not in ("win32", "cygwin"):
assert stat.st_atime == atime
# can't really test chown, since we'd have to know a valid uid.
@@ -343,7 +344,7 @@ class TestSFTP(object):
stat = f.stat()
assert stat.st_size == 512
finally:
- sftp.remove(sftp.FOLDER + '/special')
+ sftp.remove(sftp.FOLDER + "/special")
def test_A_readline_seek(self, sftp):
"""
@@ -353,10 +354,10 @@ class TestSFTP(object):
buffering is reset on 'seek'.
"""
try:
- with sftp.open(sftp.FOLDER + '/duck.txt', 'w') as f:
+ with sftp.open(sftp.FOLDER + "/duck.txt", "w") as f:
f.write(ARTICLE)
- with sftp.open(sftp.FOLDER + '/duck.txt', 'r+') as f:
+ with sftp.open(sftp.FOLDER + "/duck.txt", "r+") as f:
line_number = 0
loc = 0
pos_list = []
@@ -366,13 +367,16 @@ class TestSFTP(object):
loc = f.tell()
assert f.seekable()
f.seek(pos_list[6], f.SEEK_SET)
- assert f.readline(), 'Nouzilly == France.\n'
+ assert f.readline(), "Nouzilly == France.\n"
f.seek(pos_list[17], f.SEEK_SET)
- assert f.readline()[:4] == 'duck'
+ assert f.readline()[:4] == "duck"
f.seek(pos_list[10], f.SEEK_SET)
- assert f.readline() == 'duck types were equally resistant to exogenous insulin compared with chicken.\n'
+ assert (
+ f.readline()
+ == "duck types were equally resistant to exogenous insulin compared with chicken.\n"
+ )
finally:
- sftp.remove(sftp.FOLDER + '/duck.txt')
+ sftp.remove(sftp.FOLDER + "/duck.txt")
def test_B_write_seek(self, sftp):
"""
@@ -380,17 +384,17 @@ class TestSFTP(object):
changes worked.
"""
try:
- with sftp.open(sftp.FOLDER + '/testing.txt', 'w') as f:
- f.write('hello kitty.\n')
+ with sftp.open(sftp.FOLDER + "/testing.txt", "w") as f:
+ f.write("hello kitty.\n")
f.seek(-5, f.SEEK_CUR)
- f.write('dd')
+ f.write("dd")
- assert sftp.stat(sftp.FOLDER + '/testing.txt').st_size == 13
- with sftp.open(sftp.FOLDER + '/testing.txt', 'r') as f:
+ assert sftp.stat(sftp.FOLDER + "/testing.txt").st_size == 13
+ with sftp.open(sftp.FOLDER + "/testing.txt", "r") as f:
data = f.read(20)
- assert data == b'hello kiddy.\n'
+ assert data == b"hello kiddy.\n"
finally:
- sftp.remove(sftp.FOLDER + '/testing.txt')
+ sftp.remove(sftp.FOLDER + "/testing.txt")
def test_C_symlink(self, sftp):
"""
@@ -401,39 +405,41 @@ class TestSFTP(object):
return
try:
- with sftp.open(sftp.FOLDER + '/original.txt', 'w') as f:
- f.write('original\n')
- sftp.symlink('original.txt', sftp.FOLDER + '/link.txt')
- assert sftp.readlink(sftp.FOLDER + '/link.txt') == 'original.txt'
+ with sftp.open(sftp.FOLDER + "/original.txt", "w") as f:
+ f.write("original\n")
+ sftp.symlink("original.txt", sftp.FOLDER + "/link.txt")
+ assert sftp.readlink(sftp.FOLDER + "/link.txt") == "original.txt"
- with sftp.open(sftp.FOLDER + '/link.txt', 'r') as f:
- assert f.readlines() == ['original\n']
+ with sftp.open(sftp.FOLDER + "/link.txt", "r") as f:
+ assert f.readlines() == ["original\n"]
- cwd = sftp.normalize('.')
- if cwd[-1] == '/':
+ cwd = sftp.normalize(".")
+ if cwd[-1] == "/":
cwd = cwd[:-1]
- abs_path = cwd + '/' + sftp.FOLDER + '/original.txt'
- sftp.symlink(abs_path, sftp.FOLDER + '/link2.txt')
- assert abs_path == sftp.readlink(sftp.FOLDER + '/link2.txt')
+ abs_path = cwd + "/" + sftp.FOLDER + "/original.txt"
+ sftp.symlink(abs_path, sftp.FOLDER + "/link2.txt")
+ assert abs_path == sftp.readlink(sftp.FOLDER + "/link2.txt")
- assert sftp.lstat(sftp.FOLDER + '/link.txt').st_size == 12
- assert sftp.stat(sftp.FOLDER + '/link.txt').st_size == 9
+ assert sftp.lstat(sftp.FOLDER + "/link.txt").st_size == 12
+ assert sftp.stat(sftp.FOLDER + "/link.txt").st_size == 9
# the sftp server may be hiding extra path members from us, so the
# length may be longer than we expect:
- assert sftp.lstat(sftp.FOLDER + '/link2.txt').st_size >= len(abs_path)
- assert sftp.stat(sftp.FOLDER + '/link2.txt').st_size == 9
- assert sftp.stat(sftp.FOLDER + '/original.txt').st_size == 9
+ assert (
+ sftp.lstat(sftp.FOLDER + "/link2.txt").st_size >= len(abs_path)
+ )
+ assert sftp.stat(sftp.FOLDER + "/link2.txt").st_size == 9
+ assert sftp.stat(sftp.FOLDER + "/original.txt").st_size == 9
finally:
try:
- sftp.remove(sftp.FOLDER + '/link.txt')
+ sftp.remove(sftp.FOLDER + "/link.txt")
except:
pass
try:
- sftp.remove(sftp.FOLDER + '/link2.txt')
+ sftp.remove(sftp.FOLDER + "/link2.txt")
except:
pass
try:
- sftp.remove(sftp.FOLDER + '/original.txt')
+ sftp.remove(sftp.FOLDER + "/original.txt")
except:
pass
@@ -442,18 +448,18 @@ class TestSFTP(object):
verify that buffered writes are automatically flushed on seek.
"""
try:
- with sftp.open(sftp.FOLDER + '/happy.txt', 'w', 1) as f:
- f.write('full line.\n')
- f.write('partial')
+ with sftp.open(sftp.FOLDER + "/happy.txt", "w", 1) as f:
+ f.write("full line.\n")
+ f.write("partial")
f.seek(9, f.SEEK_SET)
- f.write('?\n')
+ f.write("?\n")
- with sftp.open(sftp.FOLDER + '/happy.txt', 'r') as f:
- assert f.readline() == u('full line?\n')
- assert f.read(7) == b'partial'
+ with sftp.open(sftp.FOLDER + "/happy.txt", "r") as f:
+ assert f.readline() == u("full line?\n")
+ assert f.read(7) == b"partial"
finally:
try:
- sftp.remove(sftp.FOLDER + '/happy.txt')
+ sftp.remove(sftp.FOLDER + "/happy.txt")
except:
pass
@@ -462,9 +468,9 @@ class TestSFTP(object):
test that realpath is returning something non-empty and not an
error.
"""
- pwd = sftp.normalize('.')
+ pwd = sftp.normalize(".")
assert len(pwd) > 0
- f = sftp.normalize('./' + sftp.FOLDER)
+ f = sftp.normalize("./" + sftp.FOLDER)
assert len(f) > 0
assert os.path.join(pwd, sftp.FOLDER) == f
@@ -472,46 +478,46 @@ class TestSFTP(object):
"""
verify that mkdir/rmdir work.
"""
- sftp.mkdir(sftp.FOLDER + '/subfolder')
- with pytest.raises(IOError): # generic msg only
- sftp.mkdir(sftp.FOLDER + '/subfolder')
- sftp.rmdir(sftp.FOLDER + '/subfolder')
+ sftp.mkdir(sftp.FOLDER + "/subfolder")
+ with pytest.raises(IOError): # generic msg only
+ sftp.mkdir(sftp.FOLDER + "/subfolder")
+ sftp.rmdir(sftp.FOLDER + "/subfolder")
with pytest.raises(IOError, match="No such file"):
- sftp.rmdir(sftp.FOLDER + '/subfolder')
+ sftp.rmdir(sftp.FOLDER + "/subfolder")
def test_G_chdir(self, sftp):
"""
verify that chdir/getcwd work.
"""
- root = sftp.normalize('.')
- if root[-1] != '/':
- root += '/'
+ root = sftp.normalize(".")
+ if root[-1] != "/":
+ root += "/"
try:
- sftp.mkdir(sftp.FOLDER + '/alpha')
- sftp.chdir(sftp.FOLDER + '/alpha')
- sftp.mkdir('beta')
- assert root + sftp.FOLDER + '/alpha' == sftp.getcwd()
- assert ['beta'] == sftp.listdir('.')
-
- sftp.chdir('beta')
- with sftp.open('fish', 'w') as f:
- f.write('hello\n')
- sftp.chdir('..')
- assert ['fish'] == sftp.listdir('beta')
- sftp.chdir('..')
- assert ['fish'] == sftp.listdir('alpha/beta')
+ sftp.mkdir(sftp.FOLDER + "/alpha")
+ sftp.chdir(sftp.FOLDER + "/alpha")
+ sftp.mkdir("beta")
+ assert root + sftp.FOLDER + "/alpha" == sftp.getcwd()
+ assert ["beta"] == sftp.listdir(".")
+
+ sftp.chdir("beta")
+ with sftp.open("fish", "w") as f:
+ f.write("hello\n")
+ sftp.chdir("..")
+ assert ["fish"] == sftp.listdir("beta")
+ sftp.chdir("..")
+ assert ["fish"] == sftp.listdir("alpha/beta")
finally:
sftp.chdir(root)
try:
- sftp.unlink(sftp.FOLDER + '/alpha/beta/fish')
+ sftp.unlink(sftp.FOLDER + "/alpha/beta/fish")
except:
pass
try:
- sftp.rmdir(sftp.FOLDER + '/alpha/beta')
+ sftp.rmdir(sftp.FOLDER + "/alpha/beta")
except:
pass
try:
- sftp.rmdir(sftp.FOLDER + '/alpha')
+ sftp.rmdir(sftp.FOLDER + "/alpha")
except:
pass
@@ -519,20 +525,21 @@ class TestSFTP(object):
"""
verify that get/put work.
"""
- warnings.filterwarnings('ignore', 'tempnam.*')
+ warnings.filterwarnings("ignore", "tempnam.*")
fd, localname = mkstemp()
os.close(fd)
- text = b'All I wanted was a plastic bunny rabbit.\n'
- with open(localname, 'wb') as f:
+ text = b"All I wanted was a plastic bunny rabbit.\n"
+ with open(localname, "wb") as f:
f.write(text)
saved_progress = []
def progress_callback(x, y):
saved_progress.append((x, y))
- sftp.put(localname, sftp.FOLDER + '/bunny.txt', progress_callback)
- with sftp.open(sftp.FOLDER + '/bunny.txt', 'rb') as f:
+ sftp.put(localname, sftp.FOLDER + "/bunny.txt", progress_callback)
+
+ with sftp.open(sftp.FOLDER + "/bunny.txt", "rb") as f:
assert text == f.read(128)
assert [(41, 41)] == saved_progress
@@ -540,14 +547,14 @@ class TestSFTP(object):
fd, localname = mkstemp()
os.close(fd)
saved_progress = []
- sftp.get(sftp.FOLDER + '/bunny.txt', localname, progress_callback)
+ sftp.get(sftp.FOLDER + "/bunny.txt", localname, progress_callback)
- with open(localname, 'rb') as f:
+ with open(localname, "rb") as f:
assert text == f.read(128)
assert [(41, 41)] == saved_progress
os.unlink(localname)
- sftp.unlink(sftp.FOLDER + '/bunny.txt')
+ sftp.unlink(sftp.FOLDER + "/bunny.txt")
def test_I_check(self, sftp):
"""
@@ -555,118 +562,132 @@ class TestSFTP(object):
(it's an sftp extension that we support, and may be the only ones who
support it.)
"""
- with sftp.open(sftp.FOLDER + '/kitty.txt', 'w') as f:
- f.write('here kitty kitty' * 64)
+ with sftp.open(sftp.FOLDER + "/kitty.txt", "w") as f:
+ f.write("here kitty kitty" * 64)
try:
- with sftp.open(sftp.FOLDER + '/kitty.txt', 'r') as f:
- sum = f.check('sha1')
- assert '91059CFC6615941378D413CB5ADAF4C5EB293402' == u(hexlify(sum)).upper()
- sum = f.check('md5', 0, 512)
- assert '93DE4788FCA28D471516963A1FE3856A' == u(hexlify(sum)).upper()
- sum = f.check('md5', 0, 0, 510)
- assert u(hexlify(sum)).upper() == 'EB3B45B8CD55A0707D99B177544A319F373183D241432BB2157AB9E46358C4AC90370B5CADE5D90336FC1716F90B36D6' # noqa
+ with sftp.open(sftp.FOLDER + "/kitty.txt", "r") as f:
+ sum = f.check("sha1")
+ assert (
+ "91059CFC6615941378D413CB5ADAF4C5EB293402"
+ == u(hexlify(sum)).upper()
+ )
+ sum = f.check("md5", 0, 512)
+ assert (
+ "93DE4788FCA28D471516963A1FE3856A"
+ == u(hexlify(sum)).upper()
+ )
+ sum = f.check("md5", 0, 0, 510)
+ assert (
+ u(hexlify(sum)).upper()
+ == "EB3B45B8CD55A0707D99B177544A319F373183D241432BB2157AB9E46358C4AC90370B5CADE5D90336FC1716F90B36D6"
+ ) # noqa
finally:
- sftp.unlink(sftp.FOLDER + '/kitty.txt')
+ sftp.unlink(sftp.FOLDER + "/kitty.txt")
def test_J_x_flag(self, sftp):
"""
verify that the 'x' flag works when opening a file.
"""
- sftp.open(sftp.FOLDER + '/unusual.txt', 'wx').close()
+ sftp.open(sftp.FOLDER + "/unusual.txt", "wx").close()
try:
try:
- sftp.open(sftp.FOLDER + '/unusual.txt', 'wx')
- self.fail('expected exception')
+ sftp.open(sftp.FOLDER + "/unusual.txt", "wx")
+ self.fail("expected exception")
except IOError:
pass
finally:
- sftp.unlink(sftp.FOLDER + '/unusual.txt')
+ sftp.unlink(sftp.FOLDER + "/unusual.txt")
def test_K_utf8(self, sftp):
"""
verify that unicode strings are encoded into utf8 correctly.
"""
- with sftp.open(sftp.FOLDER + '/something', 'w') as f:
- f.write('okay')
+ with sftp.open(sftp.FOLDER + "/something", "w") as f:
+ f.write("okay")
try:
- sftp.rename(sftp.FOLDER + '/something', sftp.FOLDER + '/' + unicode_folder)
- sftp.open(b(sftp.FOLDER) + utf8_folder, 'r')
+ sftp.rename(
+ sftp.FOLDER + "/something", sftp.FOLDER + "/" + unicode_folder
+ )
+ sftp.open(b(sftp.FOLDER) + utf8_folder, "r")
except Exception as e:
- self.fail('exception ' + str(e))
+ self.fail("exception " + str(e))
sftp.unlink(b(sftp.FOLDER) + utf8_folder)
def test_L_utf8_chdir(self, sftp):
- sftp.mkdir(sftp.FOLDER + '/' + unicode_folder)
+ sftp.mkdir(sftp.FOLDER + "/" + unicode_folder)
try:
- sftp.chdir(sftp.FOLDER + '/' + unicode_folder)
- with sftp.open('something', 'w') as f:
- f.write('okay')
- sftp.unlink('something')
+ sftp.chdir(sftp.FOLDER + "/" + unicode_folder)
+ with sftp.open("something", "w") as f:
+ f.write("okay")
+ sftp.unlink("something")
finally:
sftp.chdir()
- sftp.rmdir(sftp.FOLDER + '/' + unicode_folder)
+ sftp.rmdir(sftp.FOLDER + "/" + unicode_folder)
def test_M_bad_readv(self, sftp):
"""
verify that readv at the end of the file doesn't essplode.
"""
- sftp.open(sftp.FOLDER + '/zero', 'w').close()
+ sftp.open(sftp.FOLDER + "/zero", "w").close()
try:
- with sftp.open(sftp.FOLDER + '/zero', 'r') as f:
+ with sftp.open(sftp.FOLDER + "/zero", "r") as f:
f.readv([(0, 12)])
- with sftp.open(sftp.FOLDER + '/zero', 'r') as f:
+ with sftp.open(sftp.FOLDER + "/zero", "r") as f:
file_size = f.stat().st_size
f.prefetch(file_size)
f.read(100)
finally:
- sftp.unlink(sftp.FOLDER + '/zero')
+ sftp.unlink(sftp.FOLDER + "/zero")
def test_N_put_without_confirm(self, sftp):
"""
verify that get/put work without confirmation.
"""
- warnings.filterwarnings('ignore', 'tempnam.*')
+ warnings.filterwarnings("ignore", "tempnam.*")
fd, localname = mkstemp()
os.close(fd)
- text = b'All I wanted was a plastic bunny rabbit.\n'
- with open(localname, 'wb') as f:
+ text = b"All I wanted was a plastic bunny rabbit.\n"
+ with open(localname, "wb") as f:
f.write(text)
saved_progress = []
def progress_callback(x, y):
saved_progress.append((x, y))
- res = sftp.put(localname, sftp.FOLDER + '/bunny.txt', progress_callback, False)
+
+ res = sftp.put(
+ localname, sftp.FOLDER + "/bunny.txt", progress_callback, False
+ )
assert SFTPAttributes().attr == res.attr
- with sftp.open(sftp.FOLDER + '/bunny.txt', 'r') as f:
+ with sftp.open(sftp.FOLDER + "/bunny.txt", "r") as f:
assert text == f.read(128)
assert (41, 41) == saved_progress[-1]
os.unlink(localname)
- sftp.unlink(sftp.FOLDER + '/bunny.txt')
+ sftp.unlink(sftp.FOLDER + "/bunny.txt")
def test_O_getcwd(self, sftp):
"""
verify that chdir/getcwd work.
"""
assert sftp.getcwd() == None
- root = sftp.normalize('.')
- if root[-1] != '/':
- root += '/'
+ root = sftp.normalize(".")
+ if root[-1] != "/":
+ root += "/"
try:
- sftp.mkdir(sftp.FOLDER + '/alpha')
- sftp.chdir(sftp.FOLDER + '/alpha')
- assert sftp.getcwd() == '/' + sftp.FOLDER + '/alpha'
+ sftp.mkdir(sftp.FOLDER + "/alpha")
+ sftp.chdir(sftp.FOLDER + "/alpha")
+ assert sftp.getcwd() == "/" + sftp.FOLDER + "/alpha"
finally:
sftp.chdir(root)
try:
- sftp.rmdir(sftp.FOLDER + '/alpha')
+ sftp.rmdir(sftp.FOLDER + "/alpha")
except:
pass
@@ -677,24 +698,24 @@ class TestSFTP(object):
does not work except through paramiko. :( openssh fails.
"""
try:
- with sftp.open(sftp.FOLDER + '/append.txt', 'a') as f:
- f.write('first line\nsecond line\n')
+ with sftp.open(sftp.FOLDER + "/append.txt", "a") as f:
+ f.write("first line\nsecond line\n")
f.seek(11, f.SEEK_SET)
- f.write('third line\n')
+ f.write("third line\n")
- with sftp.open(sftp.FOLDER + '/append.txt', 'r') as f:
+ with sftp.open(sftp.FOLDER + "/append.txt", "r") as f:
assert f.stat().st_size == 34
- assert f.readline() == 'first line\n'
- assert f.readline() == 'second line\n'
- assert f.readline() == 'third line\n'
+ assert f.readline() == "first line\n"
+ assert f.readline() == "second line\n"
+ assert f.readline() == "third line\n"
finally:
- sftp.remove(sftp.FOLDER + '/append.txt')
+ sftp.remove(sftp.FOLDER + "/append.txt")
def test_putfo_empty_file(self, sftp):
"""
Send an empty file and confirm it is sent.
"""
- target = sftp.FOLDER + '/empty file.txt'
+ target = sftp.FOLDER + "/empty file.txt"
stream = StringIO()
try:
attrs = sftp.putfo(stream, target)
@@ -713,59 +734,61 @@ class TestSFTP(object):
verify that we can create a file with a '%' in the filename.
( it needs to be properly escaped by _log() )
"""
- f = sftp.open(sftp.FOLDER + '/test%file', 'w')
+ f = sftp.open(sftp.FOLDER + "/test%file", "w")
try:
assert f.stat().st_size == 0
finally:
f.close()
- sftp.remove(sftp.FOLDER + '/test%file')
+ sftp.remove(sftp.FOLDER + "/test%file")
def test_O_non_utf8_data(self, sftp):
"""Test write() and read() of non utf8 data"""
try:
- with sftp.open('%s/nonutf8data' % sftp.FOLDER, 'w') as f:
+ with sftp.open("%s/nonutf8data" % sftp.FOLDER, "w") as f:
f.write(NON_UTF8_DATA)
- with sftp.open('%s/nonutf8data' % sftp.FOLDER, 'r') as f:
+ with sftp.open("%s/nonutf8data" % sftp.FOLDER, "r") as f:
data = f.read()
assert data == NON_UTF8_DATA
- with sftp.open('%s/nonutf8data' % sftp.FOLDER, 'wb') as f:
+ with sftp.open("%s/nonutf8data" % sftp.FOLDER, "wb") as f:
f.write(NON_UTF8_DATA)
- with sftp.open('%s/nonutf8data' % sftp.FOLDER, 'rb') as f:
+ with sftp.open("%s/nonutf8data" % sftp.FOLDER, "rb") as f:
data = f.read()
assert data == NON_UTF8_DATA
finally:
- sftp.remove('%s/nonutf8data' % sftp.FOLDER)
-
+ sftp.remove("%s/nonutf8data" % sftp.FOLDER)
def test_sftp_attributes_empty_str(self, sftp):
sftp_attributes = SFTPAttributes()
- assert str(sftp_attributes) == "?--------- 1 0 0 0 (unknown date) ?"
+ assert (
+ str(sftp_attributes)
+ == "?--------- 1 0 0 0 (unknown date) ?"
+ )
- @needs_builtin('buffer')
+ @needs_builtin("buffer")
def test_write_buffer(self, sftp):
"""Test write() using a buffer instance."""
- data = 3 * b'A potentially large block of data to chunk up.\n'
+ data = 3 * b"A potentially large block of data to chunk up.\n"
try:
- with sftp.open('%s/write_buffer' % sftp.FOLDER, 'wb') as f:
+ with sftp.open("%s/write_buffer" % sftp.FOLDER, "wb") as f:
for offset in range(0, len(data), 8):
f.write(buffer(data, offset, 8))
- with sftp.open('%s/write_buffer' % sftp.FOLDER, 'rb') as f:
+ with sftp.open("%s/write_buffer" % sftp.FOLDER, "rb") as f:
assert f.read() == data
finally:
- sftp.remove('%s/write_buffer' % sftp.FOLDER)
+ sftp.remove("%s/write_buffer" % sftp.FOLDER)
- @needs_builtin('memoryview')
+ @needs_builtin("memoryview")
def test_write_memoryview(self, sftp):
"""Test write() using a memoryview instance."""
- data = 3 * b'A potentially large block of data to chunk up.\n'
+ data = 3 * b"A potentially large block of data to chunk up.\n"
try:
- with sftp.open('%s/write_memoryview' % sftp.FOLDER, 'wb') as f:
+ with sftp.open("%s/write_memoryview" % sftp.FOLDER, "wb") as f:
view = memoryview(data)
for offset in range(0, len(data), 8):
- f.write(view[offset:offset+8])
+ f.write(view[offset:offset + 8])
- with sftp.open('%s/write_memoryview' % sftp.FOLDER, 'rb') as f:
+ with sftp.open("%s/write_memoryview" % sftp.FOLDER, "rb") as f:
assert f.read() == data
finally:
- sftp.remove('%s/write_memoryview' % sftp.FOLDER)
+ sftp.remove("%s/write_memoryview" % sftp.FOLDER)
diff --git a/tests/test_sftp_big.py b/tests/test_sftp_big.py
index a659098..7f74d5f 100644
--- a/tests/test_sftp_big.py
+++ b/tests/test_sftp_big.py
@@ -37,6 +37,7 @@ from .util import slow
@slow
class TestBigSFTP(object):
+
def test_1_lots_of_files(self, sftp):
"""
create a bunch of files over the same session.
@@ -44,22 +45,24 @@ class TestBigSFTP(object):
numfiles = 100
try:
for i in range(numfiles):
- with sftp.open('%s/file%d.txt' % (sftp.FOLDER, i), 'w', 1) as f:
- f.write('this is file #%d.\n' % i)
- sftp.chmod('%s/file%d.txt' % (sftp.FOLDER, i), o660)
+ with sftp.open(
+ "%s/file%d.txt" % (sftp.FOLDER, i), "w", 1
+ ) as f:
+ f.write("this is file #%d.\n" % i)
+ sftp.chmod("%s/file%d.txt" % (sftp.FOLDER, i), o660)
# now make sure every file is there, by creating a list of filenmes
# and reading them in random order.
numlist = list(range(numfiles))
while len(numlist) > 0:
r = numlist[random.randint(0, len(numlist) - 1)]
- with sftp.open('%s/file%d.txt' % (sftp.FOLDER, r)) as f:
- assert f.readline() == 'this is file #%d.\n' % r
+ with sftp.open("%s/file%d.txt" % (sftp.FOLDER, r)) as f:
+ assert f.readline() == "this is file #%d.\n" % r
numlist.remove(r)
finally:
for i in range(numfiles):
try:
- sftp.remove('%s/file%d.txt' % (sftp.FOLDER, i))
+ sftp.remove("%s/file%d.txt" % (sftp.FOLDER, i))
except:
pass
@@ -67,52 +70,56 @@ class TestBigSFTP(object):
"""
write a 1MB file with no buffering.
"""
- kblob = (1024 * b'x')
+ kblob = (1024 * b"x")
start = time.time()
try:
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'w') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "w") as f:
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
- sys.stderr.write('.')
- sys.stderr.write(' ')
+ sys.stderr.write(".")
+ sys.stderr.write(" ")
- assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024
+ assert (
+ sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024
+ )
end = time.time()
- sys.stderr.write('%ds ' % round(end - start))
-
+ sys.stderr.write("%ds " % round(end - start))
+
start = time.time()
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'r') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "r") as f:
for n in range(1024):
data = f.read(1024)
assert data == kblob
end = time.time()
- sys.stderr.write('%ds ' % round(end - start))
+ sys.stderr.write("%ds " % round(end - start))
finally:
- sftp.remove('%s/hongry.txt' % sftp.FOLDER)
+ sftp.remove("%s/hongry.txt" % sftp.FOLDER)
def test_3_big_file_pipelined(self, sftp):
"""
write a 1MB file, with no linefeeds, using pipelining.
"""
- kblob = bytes().join([struct.pack('>H', n) for n in range(512)])
+ kblob = bytes().join([struct.pack(">H", n) for n in range(512)])
start = time.time()
try:
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'wb') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "wb") as f:
f.set_pipelined(True)
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
- sys.stderr.write('.')
- sys.stderr.write(' ')
+ sys.stderr.write(".")
+ sys.stderr.write(" ")
- assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024
+ assert (
+ sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024
+ )
end = time.time()
- sys.stderr.write('%ds ' % round(end - start))
-
+ sys.stderr.write("%ds " % round(end - start))
+
start = time.time()
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'rb') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "rb") as f:
file_size = f.stat().st_size
f.prefetch(file_size)
@@ -130,31 +137,35 @@ class TestBigSFTP(object):
n += chunk
end = time.time()
- sys.stderr.write('%ds ' % round(end - start))
+ sys.stderr.write("%ds " % round(end - start))
finally:
- sftp.remove('%s/hongry.txt' % sftp.FOLDER)
+ sftp.remove("%s/hongry.txt" % sftp.FOLDER)
def test_4_prefetch_seek(self, sftp):
- kblob = bytes().join([struct.pack('>H', n) for n in range(512)])
+ kblob = bytes().join([struct.pack(">H", n) for n in range(512)])
try:
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'wb') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "wb") as f:
f.set_pipelined(True)
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
- sys.stderr.write('.')
- sys.stderr.write(' ')
-
- assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024
-
+ sys.stderr.write(".")
+ sys.stderr.write(" ")
+
+ assert (
+ sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024
+ )
+
start = time.time()
k2blob = kblob + kblob
chunk = 793
for i in range(10):
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'rb') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "rb") as f:
file_size = f.stat().st_size
f.prefetch(file_size)
- base_offset = (512 * 1024) + 17 * random.randint(1000, 2000)
+ base_offset = (512 * 1024) + 17 * random.randint(
+ 1000, 2000
+ )
offsets = [base_offset + j * chunk for j in range(100)]
# randomly seek around and read them out
for j in range(100):
@@ -166,29 +177,33 @@ class TestBigSFTP(object):
assert data == k2blob[n_offset:n_offset + chunk]
offset += chunk
end = time.time()
- sys.stderr.write('%ds ' % round(end - start))
+ sys.stderr.write("%ds " % round(end - start))
finally:
- sftp.remove('%s/hongry.txt' % sftp.FOLDER)
+ sftp.remove("%s/hongry.txt" % sftp.FOLDER)
def test_5_readv_seek(self, sftp):
- kblob = bytes().join([struct.pack('>H', n) for n in range(512)])
+ kblob = bytes().join([struct.pack(">H", n) for n in range(512)])
try:
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'wb') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "wb") as f:
f.set_pipelined(True)
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
- sys.stderr.write('.')
- sys.stderr.write(' ')
+ sys.stderr.write(".")
+ sys.stderr.write(" ")
- assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024
+ assert (
+ sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024
+ )
start = time.time()
k2blob = kblob + kblob
chunk = 793
for i in range(10):
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'rb') as f:
- base_offset = (512 * 1024) + 17 * random.randint(1000, 2000)
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "rb") as f:
+ base_offset = (512 * 1024) + 17 * random.randint(
+ 1000, 2000
+ )
# make a bunch of offsets and put them in random order
offsets = [base_offset + j * chunk for j in range(100)]
readv_list = []
@@ -202,60 +217,64 @@ class TestBigSFTP(object):
n_offset = offset % 1024
assert next(ret) == k2blob[n_offset:n_offset + chunk]
end = time.time()
- sys.stderr.write('%ds ' % round(end - start))
+ sys.stderr.write("%ds " % round(end - start))
finally:
- sftp.remove('%s/hongry.txt' % sftp.FOLDER)
+ sftp.remove("%s/hongry.txt" % sftp.FOLDER)
def test_6_lots_of_prefetching(self, sftp):
"""
prefetch a 1MB file a bunch of times, discarding the file object
without using it, to verify that paramiko doesn't get confused.
"""
- kblob = (1024 * b'x')
+ kblob = (1024 * b"x")
try:
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'w') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "w") as f:
f.set_pipelined(True)
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
- sys.stderr.write('.')
- sys.stderr.write(' ')
+ sys.stderr.write(".")
+ sys.stderr.write(" ")
- assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024
+ assert (
+ sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024
+ )
for i in range(10):
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'r') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "r") as f:
file_size = f.stat().st_size
f.prefetch(file_size)
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'r') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "r") as f:
file_size = f.stat().st_size
f.prefetch(file_size)
for n in range(1024):
data = f.read(1024)
assert data == kblob
if n % 128 == 0:
- sys.stderr.write('.')
- sys.stderr.write(' ')
+ sys.stderr.write(".")
+ sys.stderr.write(" ")
finally:
- sftp.remove('%s/hongry.txt' % sftp.FOLDER)
-
+ sftp.remove("%s/hongry.txt" % sftp.FOLDER)
+
def test_7_prefetch_readv(self, sftp):
"""
verify that prefetch and readv don't conflict with each other.
"""
- kblob = bytes().join([struct.pack('>H', n) for n in range(512)])
+ kblob = bytes().join([struct.pack(">H", n) for n in range(512)])
try:
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'wb') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "wb") as f:
f.set_pipelined(True)
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
- sys.stderr.write('.')
- sys.stderr.write(' ')
-
- assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024
+ sys.stderr.write(".")
+ sys.stderr.write(" ")
+
+ assert (
+ sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024
+ )
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'rb') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "rb") as f:
file_size = f.stat().st_size
f.prefetch(file_size)
data = f.read(1024)
@@ -264,79 +283,94 @@ class TestBigSFTP(object):
chunk_size = 793
base_offset = 512 * 1024
k2blob = kblob + kblob
- chunks = [(base_offset + (chunk_size * i), chunk_size) for i in range(20)]
+ chunks = [
+ (base_offset + (chunk_size * i), chunk_size)
+ for i in range(20)
+ ]
for data in f.readv(chunks):
offset = base_offset % 1024
assert chunk_size == len(data)
assert k2blob[offset:offset + chunk_size] == data
base_offset += chunk_size
- sys.stderr.write(' ')
+ sys.stderr.write(" ")
finally:
- sftp.remove('%s/hongry.txt' % sftp.FOLDER)
-
+ sftp.remove("%s/hongry.txt" % sftp.FOLDER)
+
def test_8_large_readv(self, sftp):
"""
verify that a very large readv is broken up correctly and still
returned as a single blob.
"""
- kblob = bytes().join([struct.pack('>H', n) for n in range(512)])
+ kblob = bytes().join([struct.pack(">H", n) for n in range(512)])
try:
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'wb') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "wb") as f:
f.set_pipelined(True)
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
- sys.stderr.write('.')
- sys.stderr.write(' ')
+ sys.stderr.write(".")
+ sys.stderr.write(" ")
+
+ assert (
+ sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024
+ )
- assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024
-
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'rb') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "rb") as f:
data = list(f.readv([(23 * 1024, 128 * 1024)]))
assert len(data) == 1
data = data[0]
assert len(data) == 128 * 1024
-
- sys.stderr.write(' ')
+
+ sys.stderr.write(" ")
finally:
- sftp.remove('%s/hongry.txt' % sftp.FOLDER)
-
+ sftp.remove("%s/hongry.txt" % sftp.FOLDER)
+
def test_9_big_file_big_buffer(self, sftp):
"""
write a 1MB file, with no linefeeds, and a big buffer.
"""
- mblob = (1024 * 1024 * 'x')
+ mblob = (1024 * 1024 * "x")
try:
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'w', 128 * 1024) as f:
+ with sftp.open(
+ "%s/hongry.txt" % sftp.FOLDER, "w", 128 * 1024
+ ) as f:
f.write(mblob)
- assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024
+ assert (
+ sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024
+ )
finally:
- sftp.remove('%s/hongry.txt' % sftp.FOLDER)
-
+ sftp.remove("%s/hongry.txt" % sftp.FOLDER)
+
def test_A_big_file_renegotiate(self, sftp):
"""
write a 1MB file, forcing key renegotiation in the middle.
"""
t = sftp.sock.get_transport()
t.packetizer.REKEY_BYTES = 512 * 1024
- k32blob = (32 * 1024 * 'x')
+ k32blob = (32 * 1024 * "x")
try:
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'w', 128 * 1024) as f:
+ with sftp.open(
+ "%s/hongry.txt" % sftp.FOLDER, "w", 128 * 1024
+ ) as f:
for i in range(32):
f.write(k32blob)
- assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024
+ assert (
+ sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024
+ )
assert t.H != t.session_id
-
+
# try to read it too.
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'r', 128 * 1024) as f:
+ with sftp.open(
+ "%s/hongry.txt" % sftp.FOLDER, "r", 128 * 1024
+ ) as f:
file_size = f.stat().st_size
f.prefetch(file_size)
total = 0
while total < 1024 * 1024:
total += len(f.read(32 * 1024))
finally:
- sftp.remove('%s/hongry.txt' % sftp.FOLDER)
+ sftp.remove("%s/hongry.txt" % sftp.FOLDER)
t.packetizer.REKEY_BYTES = pow(2, 30)
diff --git a/tests/test_ssh_exception.py b/tests/test_ssh_exception.py
index 18f2a97..6cc5d06 100644
--- a/tests/test_ssh_exception.py
+++ b/tests/test_ssh_exception.py
@@ -4,28 +4,33 @@ import unittest
from paramiko.ssh_exception import NoValidConnectionsError
-class NoValidConnectionsErrorTest (unittest.TestCase):
+class NoValidConnectionsErrorTest(unittest.TestCase):
def test_pickling(self):
# Regression test for https://github.com/paramiko/paramiko/issues/617
- exc = NoValidConnectionsError({('127.0.0.1', '22'): Exception()})
+ exc = NoValidConnectionsError({("127.0.0.1", "22"): Exception()})
new_exc = pickle.loads(pickle.dumps(exc))
self.assertEqual(type(exc), type(new_exc))
self.assertEqual(str(exc), str(new_exc))
self.assertEqual(exc.args, new_exc.args)
def test_error_message_for_single_host(self):
- exc = NoValidConnectionsError({('127.0.0.1', '22'): Exception()})
+ exc = NoValidConnectionsError({("127.0.0.1", "22"): Exception()})
assert "Unable to connect to port 22 on 127.0.0.1" in str(exc)
def test_error_message_for_two_hosts(self):
- exc = NoValidConnectionsError({('127.0.0.1', '22'): Exception(),
- ('::1', '22'): Exception()})
+ exc = NoValidConnectionsError(
+ {("127.0.0.1", "22"): Exception(), ("::1", "22"): Exception()}
+ )
assert "Unable to connect to port 22 on 127.0.0.1 or ::1" in str(exc)
def test_error_message_for_multiple_hosts(self):
- exc = NoValidConnectionsError({('127.0.0.1', '22'): Exception(),
- ('::1', '22'): Exception(),
- ('10.0.0.42', '22'): Exception()})
+ exc = NoValidConnectionsError(
+ {
+ ("127.0.0.1", "22"): Exception(),
+ ("::1", "22"): Exception(),
+ ("10.0.0.42", "22"): Exception(),
+ }
+ )
exp = "Unable to connect to port 22 on 10.0.0.42, 127.0.0.1 or ::1"
assert exp in str(exc)
diff --git a/tests/test_ssh_gss.py b/tests/test_ssh_gss.py
index f0645e0..1e08b36 100644
--- a/tests/test_ssh_gss.py
+++ b/tests/test_ssh_gss.py
@@ -33,15 +33,13 @@ from .util import _support, needs_gssapi
from .test_client import FINGERPRINTS
-class NullServer (paramiko.ServerInterface):
+class NullServer(paramiko.ServerInterface):
+
def get_allowed_auths(self, username):
- return 'gssapi-with-mic,publickey'
+ return "gssapi-with-mic,publickey"
def check_auth_gssapi_with_mic(
- self,
- username,
- gss_authenticated=paramiko.AUTH_FAILED,
- cc_file=None,
+ self, username, gss_authenticated=paramiko.AUTH_FAILED, cc_file=None
):
if gss_authenticated == paramiko.AUTH_SUCCESSFUL:
return paramiko.AUTH_SUCCESSFUL
@@ -64,13 +62,14 @@ class NullServer (paramiko.ServerInterface):
return paramiko.OPEN_SUCCEEDED
def check_channel_exec_request(self, channel, command):
- if command != 'yes':
+ if command != "yes":
return False
return True
@needs_gssapi
class GSSAuthTest(unittest.TestCase):
+
def setUp(self):
# TODO: username and targ_name should come from os.environ or whatever
# the approved pytest method is for runtime-configuring test data.
@@ -92,7 +91,7 @@ class GSSAuthTest(unittest.TestCase):
def _run(self):
self.socks, addr = self.sockl.accept()
self.ts = paramiko.Transport(self.socks)
- host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
+ host_key = paramiko.RSAKey.from_private_key_file("tests/test_rsa.key")
self.ts.add_server_key(host_key)
server = NullServer()
self.ts.start_server(self.event, server)
@@ -103,15 +102,22 @@ class GSSAuthTest(unittest.TestCase):
The exception is ... no exception yet
"""
- host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
+ host_key = paramiko.RSAKey.from_private_key_file("tests/test_rsa.key")
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient()
self.tc.set_missing_host_key_policy(paramiko.WarningPolicy())
- self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port),
- 'ssh-rsa', public_host_key)
- self.tc.connect(hostname=self.addr, port=self.port, username=self.username, gss_host=self.hostname,
- gss_auth=True, **kwargs)
+ self.tc.get_host_keys().add(
+ "[%s]:%d" % (self.addr, self.port), "ssh-rsa", public_host_key
+ )
+ self.tc.connect(
+ hostname=self.addr,
+ port=self.port,
+ username=self.username,
+ gss_host=self.hostname,
+ gss_auth=True,
+ **kwargs
+ )
self.event.wait(1.0)
self.assert_(self.event.is_set())
@@ -119,17 +125,17 @@ class GSSAuthTest(unittest.TestCase):
self.assertEquals(self.username, self.ts.get_username())
self.assertEquals(True, self.ts.is_authenticated())
- stdin, stdout, stderr = self.tc.exec_command('yes')
+ stdin, stdout, stderr = self.tc.exec_command("yes")
schan = self.ts.accept(1.0)
- schan.send('Hello there.\n')
- schan.send_stderr('This is on stderr.\n')
+ schan.send("Hello there.\n")
+ schan.send_stderr("This is on stderr.\n")
schan.close()
- self.assertEquals('Hello there.\n', stdout.readline())
- self.assertEquals('', stdout.readline())
- self.assertEquals('This is on stderr.\n', stderr.readline())
- self.assertEquals('', stderr.readline())
+ self.assertEquals("Hello there.\n", stdout.readline())
+ self.assertEquals("", stdout.readline())
+ self.assertEquals("This is on stderr.\n", stderr.readline())
+ self.assertEquals("", stderr.readline())
stdin.close()
stdout.close()
@@ -140,14 +146,15 @@ class GSSAuthTest(unittest.TestCase):
Verify that Paramiko can handle SSHv2 GSS-API / SSPI authentication
(gssapi-with-mic) in client and server mode.
"""
- self._test_connection(allow_agent=False,
- look_for_keys=False)
+ self._test_connection(allow_agent=False, look_for_keys=False)
def test_2_auth_trickledown(self):
"""
Failed gssapi-with-mic auth doesn't prevent subsequent key auth from succeeding
"""
self.hostname = "this_host_does_not_exists_and_causes_a_GSSAPI-exception"
- self._test_connection(key_filename=[_support('test_rsa.key')],
- allow_agent=False,
- look_for_keys=False)
+ self._test_connection(
+ key_filename=[_support("test_rsa.key")],
+ allow_agent=False,
+ look_for_keys=False,
+ )
diff --git a/tests/test_transport.py b/tests/test_transport.py
index 9474acf..e09c5e9 100644
--- a/tests/test_transport.py
+++ b/tests/test_transport.py
@@ -32,14 +32,26 @@ from hashlib import sha1
import unittest
from paramiko import (
- Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, SSHException,
- ChannelException, Packetizer, Channel,
+ Transport,
+ SecurityOptions,
+ ServerInterface,
+ RSAKey,
+ DSSKey,
+ SSHException,
+ ChannelException,
+ Packetizer,
+ Channel,
)
from paramiko import AUTH_FAILED, AUTH_SUCCESSFUL
from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
from paramiko.common import (
- MSG_KEXINIT, cMSG_CHANNEL_WINDOW_ADJUST, MIN_PACKET_SIZE, MIN_WINDOW_SIZE,
- MAX_WINDOW_SIZE, DEFAULT_WINDOW_SIZE, DEFAULT_MAX_PACKET_SIZE,
+ MSG_KEXINIT,
+ cMSG_CHANNEL_WINDOW_ADJUST,
+ MIN_PACKET_SIZE,
+ MIN_WINDOW_SIZE,
+ MAX_WINDOW_SIZE,
+ DEFAULT_WINDOW_SIZE,
+ DEFAULT_MAX_PACKET_SIZE,
)
from paramiko.py3compat import bytes
from paramiko.message import Message
@@ -61,28 +73,28 @@ Maybe.
"""
-class NullServer (ServerInterface):
+class NullServer(ServerInterface):
paranoid_did_password = False
paranoid_did_public_key = False
- paranoid_key = DSSKey.from_private_key_file(_support('test_dss.key'))
+ paranoid_key = DSSKey.from_private_key_file(_support("test_dss.key"))
def get_allowed_auths(self, username):
- if username == 'slowdive':
- return 'publickey,password'
- return 'publickey'
+ if username == "slowdive":
+ return "publickey,password"
+ return "publickey"
def check_auth_password(self, username, password):
- if (username == 'slowdive') and (password == 'pygmalion'):
+ if (username == "slowdive") and (password == "pygmalion"):
return AUTH_SUCCESSFUL
return AUTH_FAILED
def check_channel_request(self, kind, chanid):
- if kind == 'bogus':
+ if kind == "bogus":
return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
return OPEN_SUCCEEDED
def check_channel_exec_request(self, channel, command):
- if command != b'yes':
+ if command != b"yes":
return False
return True
@@ -95,9 +107,16 @@ class NullServer (ServerInterface):
# tho that's only supposed to occur if the request cannot be served.
# For now, leaving that the default unless test supplies specific
# 'acceptable' request kind
- return kind == 'acceptable'
-
- def check_channel_x11_request(self, channel, single_connection, auth_protocol, auth_cookie, screen_number):
+ return kind == "acceptable"
+
+ def check_channel_x11_request(
+ self,
+ channel,
+ single_connection,
+ auth_protocol,
+ auth_cookie,
+ screen_number,
+ ):
self._x11_single_connection = single_connection
self._x11_auth_protocol = auth_protocol
self._x11_auth_cookie = auth_cookie
@@ -106,7 +125,7 @@ class NullServer (ServerInterface):
def check_port_forward_request(self, addr, port):
self._listen = socket.socket()
- self._listen.bind(('127.0.0.1', 0))
+ self._listen.bind(("127.0.0.1", 0))
self._listen.listen(1)
return self._listen.getsockname()[1]
@@ -120,6 +139,7 @@ class NullServer (ServerInterface):
class TransportTest(unittest.TestCase):
+
def setUp(self):
self.socks = LoopSocket()
self.sockc = LoopSocket()
@@ -134,9 +154,9 @@ class TransportTest(unittest.TestCase):
self.sockc.close()
def setup_test_server(
- self, client_options=None, server_options=None, connect_kwargs=None,
+ self, client_options=None, server_options=None, connect_kwargs=None
):
- host_key = RSAKey.from_private_key_file(_support('test_rsa.key'))
+ host_key = RSAKey.from_private_key_file(_support("test_rsa.key"))
public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key)
@@ -152,8 +172,8 @@ class TransportTest(unittest.TestCase):
if connect_kwargs is None:
connect_kwargs = dict(
hostkey=public_host_key,
- username='slowdive',
- password='pygmalion',
+ username="slowdive",
+ password="pygmalion",
)
self.tc.connect(**connect_kwargs)
event.wait(1.0)
@@ -163,11 +183,11 @@ class TransportTest(unittest.TestCase):
def test_1_security_options(self):
o = self.tc.get_security_options()
self.assertEqual(type(o), SecurityOptions)
- self.assertTrue(('aes256-cbc', 'blowfish-cbc') != o.ciphers)
- o.ciphers = ('aes256-cbc', 'blowfish-cbc')
- self.assertEqual(('aes256-cbc', 'blowfish-cbc'), o.ciphers)
+ self.assertTrue(("aes256-cbc", "blowfish-cbc") != o.ciphers)
+ o.ciphers = ("aes256-cbc", "blowfish-cbc")
+ self.assertEqual(("aes256-cbc", "blowfish-cbc"), o.ciphers)
try:
- o.ciphers = ('aes256-cbc', 'made-up-cipher')
+ o.ciphers = ("aes256-cbc", "made-up-cipher")
self.assertTrue(False)
except ValueError:
pass
@@ -188,11 +208,13 @@ class TransportTest(unittest.TestCase):
def test_2_compute_key(self):
self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929
- self.tc.H = b'\x0C\x83\x07\xCD\xE6\x85\x6F\xF3\x0B\xA9\x36\x84\xEB\x0F\x04\xC2\x52\x0E\x9E\xD3'
+ self.tc.H = b"\x0C\x83\x07\xCD\xE6\x85\x6F\xF3\x0B\xA9\x36\x84\xEB\x0F\x04\xC2\x52\x0E\x9E\xD3"
self.tc.session_id = self.tc.H
- key = self.tc._compute_key('C', 32)
- self.assertEqual(b'207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995',
- hexlify(key).upper())
+ key = self.tc._compute_key("C", 32)
+ self.assertEqual(
+ b"207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995",
+ hexlify(key).upper(),
+ )
def test_3_simple(self):
"""
@@ -200,7 +222,7 @@ class TransportTest(unittest.TestCase):
loopback sockets. this is hardly "simple" but it's simpler than the
later tests. :)
"""
- host_key = RSAKey.from_private_key_file(_support('test_rsa.key'))
+ host_key = RSAKey.from_private_key_file(_support("test_rsa.key"))
public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key)
event = threading.Event()
@@ -211,13 +233,14 @@ class TransportTest(unittest.TestCase):
self.assertEqual(False, self.tc.is_authenticated())
self.assertEqual(False, self.ts.is_authenticated())
self.ts.start_server(event, server)
- self.tc.connect(hostkey=public_host_key,
- username='slowdive', password='pygmalion')
+ self.tc.connect(
+ hostkey=public_host_key, username="slowdive", password="pygmalion"
+ )
event.wait(1.0)
self.assertTrue(event.is_set())
self.assertTrue(self.ts.is_active())
- self.assertEqual('slowdive', self.tc.get_username())
- self.assertEqual('slowdive', self.ts.get_username())
+ self.assertEqual("slowdive", self.tc.get_username())
+ self.assertEqual("slowdive", self.ts.get_username())
self.assertEqual(True, self.tc.is_authenticated())
self.assertEqual(True, self.ts.is_authenticated())
@@ -225,7 +248,7 @@ class TransportTest(unittest.TestCase):
"""
verify that a long banner doesn't mess up the handshake.
"""
- host_key = RSAKey.from_private_key_file(_support('test_rsa.key'))
+ host_key = RSAKey.from_private_key_file(_support("test_rsa.key"))
public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key)
event = threading.Event()
@@ -233,8 +256,9 @@ class TransportTest(unittest.TestCase):
self.assertTrue(not event.is_set())
self.socks.send(LONG_BANNER)
self.ts.start_server(event, server)
- self.tc.connect(hostkey=public_host_key,
- username='slowdive', password='pygmalion')
+ self.tc.connect(
+ hostkey=public_host_key, username="slowdive", password="pygmalion"
+ )
event.wait(1.0)
self.assertTrue(event.is_set())
self.assertTrue(self.ts.is_active())
@@ -244,12 +268,14 @@ class TransportTest(unittest.TestCase):
verify that the client can demand odd handshake settings, and can
renegotiate keys in mid-stream.
"""
+
def force_algorithms(options):
- options.ciphers = ('aes256-cbc',)
- options.digests = ('hmac-md5-96',)
+ options.ciphers = ("aes256-cbc",)
+ options.digests = ("hmac-md5-96",)
+
self.setup_test_server(client_options=force_algorithms)
- self.assertEqual('aes256-cbc', self.tc.local_cipher)
- self.assertEqual('aes256-cbc', self.tc.remote_cipher)
+ self.assertEqual("aes256-cbc", self.tc.local_cipher)
+ self.assertEqual("aes256-cbc", self.tc.remote_cipher)
self.assertEqual(12, self.tc.packetizer.get_mac_size_out())
self.assertEqual(12, self.tc.packetizer.get_mac_size_in())
@@ -263,10 +289,10 @@ class TransportTest(unittest.TestCase):
verify that the keepalive will be sent.
"""
self.setup_test_server()
- self.assertEqual(None, getattr(self.server, '_global_request', None))
+ self.assertEqual(None, getattr(self.server, "_global_request", None))
self.tc.set_keepalive(1)
time.sleep(2)
- self.assertEqual('keepalive@lag.net', self.server._global_request)
+ self.assertEqual("keepalive@lag.net", self.server._global_request)
def test_6_exec_command(self):
"""
@@ -277,39 +303,41 @@ class TransportTest(unittest.TestCase):
chan = self.tc.open_session()
schan = self.ts.accept(1.0)
try:
- chan.exec_command(b'command contains \xfc and is not a valid UTF-8 string')
+ chan.exec_command(
+ b"command contains \xfc and is not a valid UTF-8 string"
+ )
self.assertTrue(False)
except SSHException:
pass
chan = self.tc.open_session()
- chan.exec_command('yes')
+ chan.exec_command("yes")
schan = self.ts.accept(1.0)
- schan.send('Hello there.\n')
- schan.send_stderr('This is on stderr.\n')
+ schan.send("Hello there.\n")
+ schan.send_stderr("This is on stderr.\n")
schan.close()
f = chan.makefile()
- self.assertEqual('Hello there.\n', f.readline())
- self.assertEqual('', f.readline())
+ self.assertEqual("Hello there.\n", f.readline())
+ self.assertEqual("", f.readline())
f = chan.makefile_stderr()
- self.assertEqual('This is on stderr.\n', f.readline())
- self.assertEqual('', f.readline())
+ self.assertEqual("This is on stderr.\n", f.readline())
+ self.assertEqual("", f.readline())
# now try it with combined stdout/stderr
chan = self.tc.open_session()
- chan.exec_command('yes')
+ chan.exec_command("yes")
schan = self.ts.accept(1.0)
- schan.send('Hello there.\n')
- schan.send_stderr('This is on stderr.\n')
+ schan.send("Hello there.\n")
+ schan.send_stderr("This is on stderr.\n")
schan.close()
chan.set_combine_stderr(True)
f = chan.makefile()
- self.assertEqual('Hello there.\n', f.readline())
- self.assertEqual('This is on stderr.\n', f.readline())
- self.assertEqual('', f.readline())
-
+ self.assertEqual("Hello there.\n", f.readline())
+ self.assertEqual("This is on stderr.\n", f.readline())
+ self.assertEqual("", f.readline())
+
def test_6a_channel_can_be_used_as_context_manager(self):
"""
verify that exec_command() does something reasonable.
@@ -318,13 +346,13 @@ class TransportTest(unittest.TestCase):
with self.tc.open_session() as chan:
with self.ts.accept(1.0) as schan:
- chan.exec_command('yes')
- schan.send('Hello there.\n')
+ chan.exec_command("yes")
+ schan.send("Hello there.\n")
schan.close()
f = chan.makefile()
- self.assertEqual('Hello there.\n', f.readline())
- self.assertEqual('', f.readline())
+ self.assertEqual("Hello there.\n", f.readline())
+ self.assertEqual("", f.readline())
def test_7_invoke_shell(self):
"""
@@ -334,11 +362,11 @@ class TransportTest(unittest.TestCase):
chan = self.tc.open_session()
chan.invoke_shell()
schan = self.ts.accept(1.0)
- chan.send('communist j. cat\n')
+ chan.send("communist j. cat\n")
f = schan.makefile()
- self.assertEqual('communist j. cat\n', f.readline())
+ self.assertEqual("communist j. cat\n", f.readline())
chan.close()
- self.assertEqual('', f.readline())
+ self.assertEqual("", f.readline())
def test_8_channel_exception(self):
"""
@@ -346,8 +374,8 @@ class TransportTest(unittest.TestCase):
"""
self.setup_test_server()
try:
- chan = self.tc.open_channel('bogus')
- self.fail('expected exception')
+ chan = self.tc.open_channel("bogus")
+ self.fail("expected exception")
except ChannelException as e:
self.assertTrue(e.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED)
@@ -359,8 +387,8 @@ class TransportTest(unittest.TestCase):
chan = self.tc.open_session()
schan = self.ts.accept(1.0)
- chan.exec_command('yes')
- schan.send('Hello there.\n')
+ chan.exec_command("yes")
+ schan.send("Hello there.\n")
self.assertTrue(not chan.exit_status_ready())
# trigger an EOF
schan.shutdown_read()
@@ -369,8 +397,8 @@ class TransportTest(unittest.TestCase):
schan.close()
f = chan.makefile()
- self.assertEqual('Hello there.\n', f.readline())
- self.assertEqual('', f.readline())
+ self.assertEqual("Hello there.\n", f.readline())
+ self.assertEqual("", f.readline())
count = 0
while not chan.exit_status_ready():
time.sleep(0.1)
@@ -395,7 +423,7 @@ class TransportTest(unittest.TestCase):
self.assertEqual([], w)
self.assertEqual([], e)
- schan.send('hello\n')
+ schan.send("hello\n")
# something should be ready now (give it 1 second to appear)
for i in range(10):
@@ -407,7 +435,7 @@ class TransportTest(unittest.TestCase):
self.assertEqual([], w)
self.assertEqual([], e)
- self.assertEqual(b'hello\n', chan.recv(6))
+ self.assertEqual(b"hello\n", chan.recv(6))
# and, should be dead again now
r, w, e = select.select([chan], [], [], 0.1)
@@ -442,12 +470,12 @@ class TransportTest(unittest.TestCase):
self.setup_test_server()
self.tc.packetizer.REKEY_BYTES = 16384
chan = self.tc.open_session()
- chan.exec_command('yes')
+ chan.exec_command("yes")
schan = self.ts.accept(1.0)
self.assertEqual(self.tc.H, self.tc.session_id)
for i in range(20):
- chan.send('x' * 1024)
+ chan.send("x" * 1024)
chan.close()
# allow a few seconds for the rekeying to complete
@@ -463,18 +491,20 @@ class TransportTest(unittest.TestCase):
"""
verify that zlib compression is basically working.
"""
+
def force_compression(o):
- o.compression = ('zlib',)
+ o.compression = ("zlib",)
+
self.setup_test_server(force_compression, force_compression)
chan = self.tc.open_session()
- chan.exec_command('yes')
+ chan.exec_command("yes")
schan = self.ts.accept(1.0)
bytes = self.tc.packetizer._Packetizer__sent_bytes
- chan.send('x' * 1024)
+ chan.send("x" * 1024)
bytes2 = self.tc.packetizer._Packetizer__sent_bytes
- block_size = self.tc._cipher_info[self.tc.local_cipher]['block-size']
- mac_size = self.tc._mac_info[self.tc.local_mac]['size']
+ block_size = self.tc._cipher_info[self.tc.local_cipher]["block-size"]
+ mac_size = self.tc._mac_info[self.tc.local_mac]["size"]
# tests show this is actually compressed to *52 bytes*! including packet overhead! nice!! :)
self.assertTrue(bytes2 - bytes < 1024)
self.assertEqual(16 + block_size + mac_size, bytes2 - bytes)
@@ -488,29 +518,32 @@ class TransportTest(unittest.TestCase):
"""
self.setup_test_server()
chan = self.tc.open_session()
- chan.exec_command('yes')
+ chan.exec_command("yes")
schan = self.ts.accept(1.0)
requested = []
+
def handler(c, addr_port):
addr, port = addr_port
requested.append((addr, port))
self.tc._queue_incoming_channel(c)
- self.assertEqual(None, getattr(self.server, '_x11_screen_number', None))
+ self.assertEqual(
+ None, getattr(self.server, "_x11_screen_number", None)
+ )
cookie = chan.request_x11(0, single_connection=True, handler=handler)
self.assertEqual(0, self.server._x11_screen_number)
- self.assertEqual('MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol)
+ self.assertEqual("MIT-MAGIC-COOKIE-1", self.server._x11_auth_protocol)
self.assertEqual(cookie, self.server._x11_auth_cookie)
self.assertEqual(True, self.server._x11_single_connection)
- x11_server = self.ts.open_x11_channel(('localhost', 6093))
+ x11_server = self.ts.open_x11_channel(("localhost", 6093))
x11_client = self.tc.accept()
- self.assertEqual('localhost', requested[0][0])
+ self.assertEqual("localhost", requested[0][0])
self.assertEqual(6093, requested[0][1])
- x11_server.send('hello')
- self.assertEqual(b'hello', x11_client.recv(5))
+ x11_server.send("hello")
+ self.assertEqual(b"hello", x11_client.recv(5))
x11_server.close()
x11_client.close()
@@ -524,33 +557,36 @@ class TransportTest(unittest.TestCase):
"""
self.setup_test_server()
chan = self.tc.open_session()
- chan.exec_command('yes')
+ chan.exec_command("yes")
schan = self.ts.accept(1.0)
requested = []
+
def handler(c, origin_addr_port, server_addr_port):
requested.append(origin_addr_port)
requested.append(server_addr_port)
self.tc._queue_incoming_channel(c)
- port = self.tc.request_port_forward('127.0.0.1', 0, handler)
+ port = self.tc.request_port_forward("127.0.0.1", 0, handler)
self.assertEqual(port, self.server._listen.getsockname()[1])
cs = socket.socket()
- cs.connect(('127.0.0.1', port))
+ cs.connect(("127.0.0.1", port))
ss, _ = self.server._listen.accept()
- sch = self.ts.open_forwarded_tcpip_channel(ss.getsockname(), ss.getpeername())
+ sch = self.ts.open_forwarded_tcpip_channel(
+ ss.getsockname(), ss.getpeername()
+ )
cch = self.tc.accept()
- sch.send('hello')
- self.assertEqual(b'hello', cch.recv(5))
+ sch.send("hello")
+ self.assertEqual(b"hello", cch.recv(5))
sch.close()
cch.close()
ss.close()
cs.close()
# now cancel it.
- self.tc.cancel_port_forward('127.0.0.1', port)
+ self.tc.cancel_port_forward("127.0.0.1", port)
self.assertTrue(self.server._listen is None)
def test_F_port_forwarding(self):
@@ -560,27 +596,29 @@ class TransportTest(unittest.TestCase):
"""
self.setup_test_server()
chan = self.tc.open_session()
- chan.exec_command('yes')
+ chan.exec_command("yes")
schan = self.ts.accept(1.0)
# open a port on the "server" that the client will ask to forward to.
greeting_server = socket.socket()
- greeting_server.bind(('127.0.0.1', 0))
+ greeting_server.bind(("127.0.0.1", 0))
greeting_server.listen(1)
greeting_port = greeting_server.getsockname()[1]
- cs = self.tc.open_channel('direct-tcpip', ('127.0.0.1', greeting_port), ('', 9000))
+ cs = self.tc.open_channel(
+ "direct-tcpip", ("127.0.0.1", greeting_port), ("", 9000)
+ )
sch = self.ts.accept(1.0)
cch = socket.socket()
cch.connect(self.server._tcpip_dest)
ss, _ = greeting_server.accept()
- ss.send(b'Hello!\n')
+ ss.send(b"Hello!\n")
ss.close()
sch.send(cch.recv(8192))
sch.close()
- self.assertEqual(b'Hello!\n', cs.recv(7))
+ self.assertEqual(b"Hello!\n", cs.recv(7))
cs.close()
def test_G_stderr_select(self):
@@ -599,7 +637,7 @@ class TransportTest(unittest.TestCase):
self.assertEqual([], w)
self.assertEqual([], e)
- schan.send_stderr('hello\n')
+ schan.send_stderr("hello\n")
# something should be ready now (give it 1 second to appear)
for i in range(10):
@@ -611,7 +649,7 @@ class TransportTest(unittest.TestCase):
self.assertEqual([], w)
self.assertEqual([], e)
- self.assertEqual(b'hello\n', chan.recv_stderr(6))
+ self.assertEqual(b"hello\n", chan.recv_stderr(6))
# and, should be dead again now
r, w, e = select.select([chan], [], [], 0.1)
@@ -633,8 +671,8 @@ class TransportTest(unittest.TestCase):
self.assertEqual(chan.send_ready(), True)
total = 0
- K = '*' * 1024
- limit = 1+(64 * 2 ** 15)
+ K = "*" * 1024
+ limit = 1 + (64 * 2 ** 15)
while total < limit:
chan.send(K)
total += len(K)
@@ -696,8 +734,11 @@ class TransportTest(unittest.TestCase):
# expires, a deadlock is assumed.
class SendThread(threading.Thread):
+
def __init__(self, chan, iterations, done_event):
- threading.Thread.__init__(self, None, None, self.__class__.__name__)
+ threading.Thread.__init__(
+ self, None, None, self.__class__.__name__
+ )
self.setDaemon(True)
self.chan = chan
self.iterations = iterations
@@ -707,19 +748,22 @@ class TransportTest(unittest.TestCase):
def run(self):
try:
- for i in range(1, 1+self.iterations):
+ for i in range(1, 1 + self.iterations):
if self.done_event.is_set():
break
self.watchdog_event.set()
- #print i, "SEND"
+ # print i, "SEND"
self.chan.send("x" * 2048)
finally:
self.done_event.set()
self.watchdog_event.set()
class ReceiveThread(threading.Thread):
+
def __init__(self, chan, done_event):
- threading.Thread.__init__(self, None, None, self.__class__.__name__)
+ threading.Thread.__init__(
+ self, None, None, self.__class__.__name__
+ )
self.setDaemon(True)
self.chan = chan
self.done_event = done_event
@@ -742,7 +786,7 @@ class TransportTest(unittest.TestCase):
self.ts.packetizer.REKEY_BYTES = 2048
chan = self.tc.open_session()
- chan.exec_command('yes')
+ chan.exec_command("yes")
schan = self.ts.accept(1.0)
# Monkey patch the client's Transport._handler_table so that the client
@@ -751,21 +795,23 @@ class TransportTest(unittest.TestCase):
# on a real MSG_CHANNEL_WINDOW_ADJUST message.
self.tc._handler_table = self.tc._handler_table.copy() # copy per-class dictionary
_negotiate_keys = self.tc._handler_table[MSG_KEXINIT]
+
def _negotiate_keys_wrapper(self, m):
- if self.local_kex_init is None: # Remote side sent KEXINIT
+ if self.local_kex_init is None: # Remote side sent KEXINIT
# Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it
# before responding to the incoming MSG_KEXINIT.
m2 = Message()
m2.add_byte(cMSG_CHANNEL_WINDOW_ADJUST)
m2.add_int(chan.remote_chanid)
- m2.add_int(1) # bytes to add
+ m2.add_int(1) # bytes to add
self._send_message(m2)
return _negotiate_keys(self, m)
+
self.tc._handler_table[MSG_KEXINIT] = _negotiate_keys_wrapper
# Parameters for the test
- iterations = 500 # The deadlock does not happen every time, but it
- # should after many iterations.
+ iterations = 500 # The deadlock does not happen every time, but it
+ # should after many iterations.
timeout = 5
# This event is set when the test is completed
@@ -807,18 +853,22 @@ class TransportTest(unittest.TestCase):
"""
verify that we conform to the rfc of packet and window sizes.
"""
- for val, correct in [(4095, MIN_PACKET_SIZE),
- (None, DEFAULT_MAX_PACKET_SIZE),
- (2**32, MAX_WINDOW_SIZE)]:
+ for val, correct in [
+ (4095, MIN_PACKET_SIZE),
+ (None, DEFAULT_MAX_PACKET_SIZE),
+ (2 ** 32, MAX_WINDOW_SIZE),
+ ]:
self.assertEqual(self.tc._sanitize_packet_size(val), correct)
def test_K_sanitze_window_size(self):
"""
verify that we conform to the rfc of packet and window sizes.
"""
- for val, correct in [(32767, MIN_WINDOW_SIZE),
- (None, DEFAULT_WINDOW_SIZE),
- (2**32, MAX_WINDOW_SIZE)]:
+ for val, correct in [
+ (32767, MIN_WINDOW_SIZE),
+ (None, DEFAULT_WINDOW_SIZE),
+ (2 ** 32, MAX_WINDOW_SIZE),
+ ]:
self.assertEqual(self.tc._sanitize_window_size(val), correct)
@slow
@@ -834,15 +884,17 @@ class TransportTest(unittest.TestCase):
# (Doing this on the server's transport *sounds* more 'correct' but
# actually doesn't work nearly as well for whatever reason.)
class SlowPacketizer(Packetizer):
+
def read_message(self):
time.sleep(1)
return super(SlowPacketizer, self).read_message()
+
# NOTE: prettttty sure since the replaced .packetizer Packetizer is now
# no longer doing anything with its copy of the socket...everything'll
# be fine. Even tho it's a bit squicky.
self.tc.packetizer = SlowPacketizer(self.tc.sock)
# Continue with regular test red tape.
- host_key = RSAKey.from_private_key_file(_support('test_rsa.key'))
+ host_key = RSAKey.from_private_key_file(_support("test_rsa.key"))
public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key)
event = threading.Event()
@@ -850,10 +902,13 @@ class TransportTest(unittest.TestCase):
self.assertTrue(not event.is_set())
self.tc.handshake_timeout = 0.000000000001
self.ts.start_server(event, server)
- self.assertRaises(EOFError, self.tc.connect,
- hostkey=public_host_key,
- username='slowdive',
- password='pygmalion')
+ self.assertRaises(
+ EOFError,
+ self.tc.connect,
+ hostkey=public_host_key,
+ username="slowdive",
+ password="pygmalion",
+ )
def test_M_select_after_close(self):
"""
@@ -894,13 +949,13 @@ class TransportTest(unittest.TestCase):
expected = text.encode("utf-8")
self.assertEqual(sfile.read(len(expected)), expected)
- @needs_builtin('buffer')
+ @needs_builtin("buffer")
def test_channel_send_buffer(self):
"""
verify sending buffer instances to a channel
"""
self.setup_test_server()
- data = 3 * b'some test data\n whole'
+ data = 3 * b"some test data\n whole"
with self.tc.open_session() as chan:
schan = self.ts.accept(1.0)
if schan is None:
@@ -917,13 +972,13 @@ class TransportTest(unittest.TestCase):
chan.sendall(buffer(data))
self.assertEqual(sfile.read(len(data)), data)
- @needs_builtin('memoryview')
+ @needs_builtin("memoryview")
def test_channel_send_memoryview(self):
"""
verify sending memoryview instances to a channel
"""
self.setup_test_server()
- data = 3 * b'some test data\n whole'
+ data = 3 * b"some test data\n whole"
with self.tc.open_session() as chan:
schan = self.ts.accept(1.0)
if schan is None:
@@ -934,7 +989,7 @@ class TransportTest(unittest.TestCase):
sent = 0
view = memoryview(data)
while sent < len(view):
- sent += chan.send(view[sent:sent+8])
+ sent += chan.send(view[sent:sent + 8])
self.assertEqual(sfile.read(len(data)), data)
# sendall() accepts a memoryview instance
@@ -954,7 +1009,7 @@ class TransportTest(unittest.TestCase):
self.setup_test_server(connect_kwargs={})
# NOTE: this dummy global request kind would normally pass muster
# from the test server.
- self.tc.global_request('acceptable')
+ self.tc.global_request("acceptable")
# Global requests never raise exceptions, even on failure (not sure why
# this was the original design...ugh.) Best we can do to tell failure
# happened is that the client transport's global_response was set back
@@ -969,7 +1024,7 @@ class TransportTest(unittest.TestCase):
# an exception on the client side, unlike the general case...)
self.setup_test_server(connect_kwargs={})
try:
- self.tc.request_port_forward('localhost', 1234)
+ self.tc.request_port_forward("localhost", 1234)
except SSHException as e:
assert "forwarding request denied" in str(e)
else:
diff --git a/tests/test_util.py b/tests/test_util.py
index 90473f4..6431b9c 100644
--- a/tests/test_util.py
+++ b/tests/test_util.py
@@ -67,53 +67,70 @@ from paramiko import *
class UtilTest(unittest.TestCase):
+
def test_import(self):
"""
verify that all the classes can be imported from paramiko.
"""
symbols = list(globals().keys())
- self.assertTrue('Transport' in symbols)
- self.assertTrue('SSHClient' in symbols)
- self.assertTrue('MissingHostKeyPolicy' in symbols)
- self.assertTrue('AutoAddPolicy' in symbols)
- self.assertTrue('RejectPolicy' in symbols)
- self.assertTrue('WarningPolicy' in symbols)
- self.assertTrue('SecurityOptions' in symbols)
- self.assertTrue('SubsystemHandler' in symbols)
- self.assertTrue('Channel' in symbols)
- self.assertTrue('RSAKey' in symbols)
- self.assertTrue('DSSKey' in symbols)
- self.assertTrue('Message' in symbols)
- self.assertTrue('SSHException' in symbols)
- self.assertTrue('AuthenticationException' in symbols)
- self.assertTrue('PasswordRequiredException' in symbols)
- self.assertTrue('BadAuthenticationType' in symbols)
- self.assertTrue('ChannelException' in symbols)
- self.assertTrue('SFTP' in symbols)
- self.assertTrue('SFTPFile' in symbols)
- self.assertTrue('SFTPHandle' in symbols)
- self.assertTrue('SFTPClient' in symbols)
- self.assertTrue('SFTPServer' in symbols)
- self.assertTrue('SFTPError' in symbols)
- self.assertTrue('SFTPAttributes' in symbols)
- self.assertTrue('SFTPServerInterface' in symbols)
- self.assertTrue('ServerInterface' in symbols)
- self.assertTrue('BufferedFile' in symbols)
- self.assertTrue('Agent' in symbols)
- self.assertTrue('AgentKey' in symbols)
- self.assertTrue('HostKeys' in symbols)
- self.assertTrue('SSHConfig' in symbols)
- self.assertTrue('util' in symbols)
+ self.assertTrue("Transport" in symbols)
+ self.assertTrue("SSHClient" in symbols)
+ self.assertTrue("MissingHostKeyPolicy" in symbols)
+ self.assertTrue("AutoAddPolicy" in symbols)
+ self.assertTrue("RejectPolicy" in symbols)
+ self.assertTrue("WarningPolicy" in symbols)
+ self.assertTrue("SecurityOptions" in symbols)
+ self.assertTrue("SubsystemHandler" in symbols)
+ self.assertTrue("Channel" in symbols)
+ self.assertTrue("RSAKey" in symbols)
+ self.assertTrue("DSSKey" in symbols)
+ self.assertTrue("Message" in symbols)
+ self.assertTrue("SSHException" in symbols)
+ self.assertTrue("AuthenticationException" in symbols)
+ self.assertTrue("PasswordRequiredException" in symbols)
+ self.assertTrue("BadAuthenticationType" in symbols)
+ self.assertTrue("ChannelException" in symbols)
+ self.assertTrue("SFTP" in symbols)
+ self.assertTrue("SFTPFile" in symbols)
+ self.assertTrue("SFTPHandle" in symbols)
+ self.assertTrue("SFTPClient" in symbols)
+ self.assertTrue("SFTPServer" in symbols)
+ self.assertTrue("SFTPError" in symbols)
+ self.assertTrue("SFTPAttributes" in symbols)
+ self.assertTrue("SFTPServerInterface" in symbols)
+ self.assertTrue("ServerInterface" in symbols)
+ self.assertTrue("BufferedFile" in symbols)
+ self.assertTrue("Agent" in symbols)
+ self.assertTrue("AgentKey" in symbols)
+ self.assertTrue("HostKeys" in symbols)
+ self.assertTrue("SSHConfig" in symbols)
+ self.assertTrue("util" in symbols)
def test_parse_config(self):
global test_config_file
f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
- self.assertEqual(config._config,
- [{'host': ['*'], 'config': {}}, {'host': ['*'], 'config': {'identityfile': ['~/.ssh/id_rsa'], 'user': 'robey'}},
- {'host': ['*.example.com'], 'config': {'user': 'bjork', 'port': '3333'}},
- {'host': ['*'], 'config': {'crazy': 'something dumb'}},
- {'host': ['spoo.example.com'], 'config': {'crazy': 'something else'}}])
+ self.assertEqual(
+ config._config,
+ [
+ {"host": ["*"], "config": {}},
+ {
+ "host": ["*"],
+ "config": {
+ "identityfile": ["~/.ssh/id_rsa"], "user": "robey"
+ },
+ },
+ {
+ "host": ["*.example.com"],
+ "config": {"user": "bjork", "port": "3333"},
+ },
+ {"host": ["*"], "config": {"crazy": "something dumb"}},
+ {
+ "host": ["spoo.example.com"],
+ "config": {"crazy": "something else"},
+ },
+ ],
+ )
def test_host_config(self):
global test_config_file
@@ -121,44 +138,57 @@ class UtilTest(unittest.TestCase):
config = paramiko.util.parse_ssh_config(f)
for host, values in {
- 'irc.danger.com': {'crazy': 'something dumb',
- 'hostname': 'irc.danger.com',
- 'user': 'robey'},
- 'irc.example.com': {'crazy': 'something dumb',
- 'hostname': 'irc.example.com',
- 'user': 'robey',
- 'port': '3333'},
- 'spoo.example.com': {'crazy': 'something dumb',
- 'hostname': 'spoo.example.com',
- 'user': 'robey',
- 'port': '3333'}
+ "irc.danger.com": {
+ "crazy": "something dumb",
+ "hostname": "irc.danger.com",
+ "user": "robey",
+ },
+ "irc.example.com": {
+ "crazy": "something dumb",
+ "hostname": "irc.example.com",
+ "user": "robey",
+ "port": "3333",
+ },
+ "spoo.example.com": {
+ "crazy": "something dumb",
+ "hostname": "spoo.example.com",
+ "user": "robey",
+ "port": "3333",
+ },
}.items():
- values = dict(values,
+ values = dict(
+ values,
hostname=host,
- identityfile=[os.path.expanduser("~/.ssh/id_rsa")]
+ identityfile=[os.path.expanduser("~/.ssh/id_rsa")],
)
self.assertEqual(
- paramiko.util.lookup_ssh_host_config(host, config),
- values
+ paramiko.util.lookup_ssh_host_config(host, config), values
)
def test_generate_key_bytes(self):
- x = paramiko.util.generate_key_bytes(sha1, b'ABCDEFGH', 'This is my secret passphrase.', 64)
- hex = ''.join(['%02x' % byte_ord(c) for c in x])
- self.assertEqual(hex, '9110e2f6793b69363e58173e9436b13a5a4b339005741d5c680e505f57d871347b4239f14fb5c46e857d5e100424873ba849ac699cea98d729e57b3e84378e8b')
+ x = paramiko.util.generate_key_bytes(
+ sha1, b"ABCDEFGH", "This is my secret passphrase.", 64
+ )
+ hex = "".join(["%02x" % byte_ord(c) for c in x])
+ self.assertEqual(
+ hex,
+ "9110e2f6793b69363e58173e9436b13a5a4b339005741d5c680e505f57d871347b4239f14fb5c46e857d5e100424873ba849ac699cea98d729e57b3e84378e8b",
+ )
def test_host_keys(self):
- with open('hostfile.temp', 'w') as f:
+ with open("hostfile.temp", "w") as f:
f.write(test_hosts_file)
try:
- hostdict = paramiko.util.load_host_keys('hostfile.temp')
+ hostdict = paramiko.util.load_host_keys("hostfile.temp")
self.assertEqual(2, len(hostdict))
self.assertEqual(1, len(list(hostdict.values())[0]))
self.assertEqual(1, len(list(hostdict.values())[1]))
- fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper()
- self.assertEqual(b'E6684DB30E109B67B70FF1DC5C7F1363', fp)
+ fp = hexlify(
+ hostdict["secure.example.com"]["ssh-rsa"].get_fingerprint()
+ ).upper()
+ self.assertEqual(b"E6684DB30E109B67B70FF1DC5C7F1363", fp)
finally:
- os.unlink('hostfile.temp')
+ os.unlink("hostfile.temp")
def test_host_config_expose_issue_33(self):
test_config_file = """
@@ -173,36 +203,44 @@ Host *
"""
f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
- host = 'www13.example.com'
+ host = "www13.example.com"
self.assertEqual(
paramiko.util.lookup_ssh_host_config(host, config),
- {'hostname': host, 'port': '22'}
+ {"hostname": host, "port": "22"},
)
def test_eintr_retry(self):
- self.assertEqual('foo', paramiko.util.retry_on_signal(lambda: 'foo'))
+ self.assertEqual("foo", paramiko.util.retry_on_signal(lambda: "foo"))
# Variables that are set by raises_intr
intr_errors_remaining = [3]
call_count = [0]
+
def raises_intr():
call_count[0] += 1
if intr_errors_remaining[0] > 0:
intr_errors_remaining[0] -= 1
- raise IOError(errno.EINTR, 'file', 'interrupted system call')
+ raise IOError(errno.EINTR, "file", "interrupted system call")
+
self.assertTrue(paramiko.util.retry_on_signal(raises_intr) is None)
self.assertEqual(0, intr_errors_remaining[0])
self.assertEqual(4, call_count[0])
def raises_ioerror_not_eintr():
- raise IOError(errno.ENOENT, 'file', 'file not found')
- self.assertRaises(IOError,
- lambda: paramiko.util.retry_on_signal(raises_ioerror_not_eintr))
+ raise IOError(errno.ENOENT, "file", "file not found")
+
+ self.assertRaises(
+ IOError,
+ lambda: paramiko.util.retry_on_signal(raises_ioerror_not_eintr),
+ )
def raises_other_exception():
- raise AssertionError('foo')
- self.assertRaises(AssertionError,
- lambda: paramiko.util.retry_on_signal(raises_other_exception))
+ raise AssertionError("foo")
+
+ self.assertRaises(
+ AssertionError,
+ lambda: paramiko.util.retry_on_signal(raises_other_exception),
+ )
def test_proxycommand_config_equals_parsing(self):
"""
@@ -217,17 +255,18 @@ Host equals-delimited
"""
f = StringIO(conf)
config = paramiko.util.parse_ssh_config(f)
- for host in ('space-delimited', 'equals-delimited'):
+ for host in ("space-delimited", "equals-delimited"):
self.assertEqual(
- host_config(host, config)['proxycommand'],
- 'foo bar=biz baz'
+ host_config(host, config)["proxycommand"], "foo bar=biz baz"
)
def test_proxycommand_interpolation(self):
"""
ProxyCommand should perform interpolation on the value
"""
- config = paramiko.util.parse_ssh_config(StringIO("""
+ config = paramiko.util.parse_ssh_config(
+ StringIO(
+ """
Host specific
Port 37
ProxyCommand host %h port %p lol
@@ -238,28 +277,32 @@ Host portonly
Host *
Port 25
ProxyCommand host %h port %p
-"""))
+"""
+ )
+ )
for host, val in (
- ('foo.com', "host foo.com port 25"),
- ('specific', "host specific port 37 lol"),
- ('portonly', "host portonly port 155"),
+ ("foo.com", "host foo.com port 25"),
+ ("specific", "host specific port 37 lol"),
+ ("portonly", "host portonly port 155"),
):
- self.assertEqual(
- host_config(host, config)['proxycommand'],
- val
- )
+ self.assertEqual(host_config(host, config)["proxycommand"], val)
def test_proxycommand_tilde_expansion(self):
"""
Tilde (~) should be expanded inside ProxyCommand
"""
- config = paramiko.util.parse_ssh_config(StringIO("""
+ config = paramiko.util.parse_ssh_config(
+ StringIO(
+ """
Host test
ProxyCommand ssh -F ~/.ssh/test_config bastion nc %h %p
-"""))
+"""
+ )
+ )
self.assertEqual(
- 'ssh -F %s/.ssh/test_config bastion nc test 22' % os.path.expanduser('~'),
- host_config('test', config)['proxycommand']
+ "ssh -F %s/.ssh/test_config bastion nc test 22"
+ % os.path.expanduser("~"),
+ host_config("test", config)["proxycommand"],
)
def test_host_config_test_negation(self):
@@ -278,10 +321,10 @@ Host *
"""
f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
- host = 'www13.example.com'
+ host = "www13.example.com"
self.assertEqual(
paramiko.util.lookup_ssh_host_config(host, config),
- {'hostname': host, 'port': '8080'}
+ {"hostname": host, "port": "8080"},
)
def test_host_config_test_proxycommand(self):
@@ -296,20 +339,24 @@ Host proxy-without-equal-divisor
ProxyCommand foo=bar:%h-%p
"""
for host, values in {
- 'proxy-with-equal-divisor-and-space' :{'hostname': 'proxy-with-equal-divisor-and-space',
- 'proxycommand': 'foo=bar'},
- 'proxy-with-equal-divisor-and-no-space':{'hostname': 'proxy-with-equal-divisor-and-no-space',
- 'proxycommand': 'foo=bar'},
- 'proxy-without-equal-divisor' :{'hostname': 'proxy-without-equal-divisor',
- 'proxycommand':
- 'foo=bar:proxy-without-equal-divisor-22'}
+ "proxy-with-equal-divisor-and-space": {
+ "hostname": "proxy-with-equal-divisor-and-space",
+ "proxycommand": "foo=bar",
+ },
+ "proxy-with-equal-divisor-and-no-space": {
+ "hostname": "proxy-with-equal-divisor-and-no-space",
+ "proxycommand": "foo=bar",
+ },
+ "proxy-without-equal-divisor": {
+ "hostname": "proxy-without-equal-divisor",
+ "proxycommand": "foo=bar:proxy-without-equal-divisor-22",
+ },
}.items():
f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
self.assertEqual(
- paramiko.util.lookup_ssh_host_config(host, config),
- values
+ paramiko.util.lookup_ssh_host_config(host, config), values
)
def test_host_config_test_identityfile(self):
@@ -327,19 +374,21 @@ Host dsa2*
IdentityFile id_dsa22
"""
for host, values in {
- 'foo' :{'hostname': 'foo',
- 'identityfile': ['id_dsa0', 'id_dsa1']},
- 'dsa2' :{'hostname': 'dsa2',
- 'identityfile': ['id_dsa0', 'id_dsa1', 'id_dsa2', 'id_dsa22']},
- 'dsa22' :{'hostname': 'dsa22',
- 'identityfile': ['id_dsa0', 'id_dsa1', 'id_dsa22']}
+ "foo": {"hostname": "foo", "identityfile": ["id_dsa0", "id_dsa1"]},
+ "dsa2": {
+ "hostname": "dsa2",
+ "identityfile": ["id_dsa0", "id_dsa1", "id_dsa2", "id_dsa22"],
+ },
+ "dsa22": {
+ "hostname": "dsa22",
+ "identityfile": ["id_dsa0", "id_dsa1", "id_dsa22"],
+ },
}.items():
f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
self.assertEqual(
- paramiko.util.lookup_ssh_host_config(host, config),
- values
+ paramiko.util.lookup_ssh_host_config(host, config), values
)
def test_config_addressfamily_and_lazy_fqdn(self):
@@ -351,7 +400,9 @@ AddressFamily inet
IdentityFile something_%l_using_fqdn
"""
config = paramiko.util.parse_ssh_config(StringIO(test_config))
- assert config.lookup('meh') # will die during lookup() if bug regresses
+ assert config.lookup(
+ "meh"
+ ) # will die during lookup() if bug regresses
def test_clamp_value(self):
self.assertEqual(32768, paramiko.util.clamp_value(32767, 32768, 32769))
@@ -367,7 +418,9 @@ IdentityFile something_%l_using_fqdn
def test_get_hostnames(self):
f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
- self.assertEqual(config.get_hostnames(), {'*', '*.example.com', 'spoo.example.com'})
+ self.assertEqual(
+ config.get_hostnames(), {"*", "*.example.com", "spoo.example.com"}
+ )
def test_quoted_host_names(self):
test_config_file = """\
@@ -384,27 +437,23 @@ Host param4 "p a r" "p" "par" para
Port 4444
"""
res = {
- 'param pam': {'hostname': 'param pam', 'port': '1111'},
- 'param': {'hostname': 'param', 'port': '1111'},
- 'pam': {'hostname': 'pam', 'port': '1111'},
-
- 'param2': {'hostname': 'param2', 'port': '2222'},
-
- 'param3': {'hostname': 'param3', 'port': '3333'},
- 'parara': {'hostname': 'parara', 'port': '3333'},
-
- 'param4': {'hostname': 'param4', 'port': '4444'},
- 'p a r': {'hostname': 'p a r', 'port': '4444'},
- 'p': {'hostname': 'p', 'port': '4444'},
- 'par': {'hostname': 'par', 'port': '4444'},
- 'para': {'hostname': 'para', 'port': '4444'},
+ "param pam": {"hostname": "param pam", "port": "1111"},
+ "param": {"hostname": "param", "port": "1111"},
+ "pam": {"hostname": "pam", "port": "1111"},
+ "param2": {"hostname": "param2", "port": "2222"},
+ "param3": {"hostname": "param3", "port": "3333"},
+ "parara": {"hostname": "parara", "port": "3333"},
+ "param4": {"hostname": "param4", "port": "4444"},
+ "p a r": {"hostname": "p a r", "port": "4444"},
+ "p": {"hostname": "p", "port": "4444"},
+ "par": {"hostname": "par", "port": "4444"},
+ "para": {"hostname": "para", "port": "4444"},
}
f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
for host, values in res.items():
self.assertEquals(
- paramiko.util.lookup_ssh_host_config(host, config),
- values
+ paramiko.util.lookup_ssh_host_config(host, config), values
)
def test_quoted_params_in_config(self):
@@ -420,52 +469,44 @@ Host param3 parara
IdentityFile "test rsa key"
"""
res = {
- 'param pam': {'hostname': 'param pam', 'identityfile': ['id_rsa']},
- 'param': {'hostname': 'param', 'identityfile': ['id_rsa']},
- 'pam': {'hostname': 'pam', 'identityfile': ['id_rsa']},
-
- 'param2': {'hostname': 'param2', 'identityfile': ['test rsa key']},
-
- 'param3': {'hostname': 'param3', 'identityfile': ['id_rsa', 'test rsa key']},
- 'parara': {'hostname': 'parara', 'identityfile': ['id_rsa', 'test rsa key']},
+ "param pam": {"hostname": "param pam", "identityfile": ["id_rsa"]},
+ "param": {"hostname": "param", "identityfile": ["id_rsa"]},
+ "pam": {"hostname": "pam", "identityfile": ["id_rsa"]},
+ "param2": {"hostname": "param2", "identityfile": ["test rsa key"]},
+ "param3": {
+ "hostname": "param3",
+ "identityfile": ["id_rsa", "test rsa key"],
+ },
+ "parara": {
+ "hostname": "parara",
+ "identityfile": ["id_rsa", "test rsa key"],
+ },
}
f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
for host, values in res.items():
self.assertEquals(
- paramiko.util.lookup_ssh_host_config(host, config),
- values
+ paramiko.util.lookup_ssh_host_config(host, config), values
)
def test_quoted_host_in_config(self):
conf = SSHConfig()
correct_data = {
- 'param': ['param'],
- '"param"': ['param'],
-
- 'param pam': ['param', 'pam'],
- '"param" "pam"': ['param', 'pam'],
- '"param" pam': ['param', 'pam'],
- 'param "pam"': ['param', 'pam'],
-
- 'param "pam" p': ['param', 'pam', 'p'],
- '"param" pam "p"': ['param', 'pam', 'p'],
-
- '"pa ram"': ['pa ram'],
- '"pa ram" pam': ['pa ram', 'pam'],
- 'param "p a m"': ['param', 'p a m'],
+ "param": ["param"],
+ '"param"': ["param"],
+ "param pam": ["param", "pam"],
+ '"param" "pam"': ["param", "pam"],
+ '"param" pam': ["param", "pam"],
+ 'param "pam"': ["param", "pam"],
+ 'param "pam" p': ["param", "pam", "p"],
+ '"param" pam "p"': ["param", "pam", "p"],
+ '"pa ram"': ["pa ram"],
+ '"pa ram" pam': ["pa ram", "pam"],
+ 'param "p a m"': ["param", "p a m"],
}
- incorrect_data = [
- 'param"',
- '"param',
- 'param "pam',
- 'param "pam" "p a',
- ]
+ incorrect_data = ['param"', '"param', 'param "pam', 'param "pam" "p a']
for host, values in correct_data.items():
- self.assertEquals(
- conf._get_hosts(host),
- values
- )
+ self.assertEquals(conf._get_hosts(host), values)
for host in incorrect_data:
self.assertRaises(Exception, conf._get_hosts, host)
@@ -490,15 +531,18 @@ Host proxycommand-with-equals-none
ProxyCommand=None
"""
for host, values in {
- 'proxycommand-standard-none': {'hostname': 'proxycommand-standard-none'},
- 'proxycommand-with-equals-none': {'hostname': 'proxycommand-with-equals-none'}
+ "proxycommand-standard-none": {
+ "hostname": "proxycommand-standard-none"
+ },
+ "proxycommand-with-equals-none": {
+ "hostname": "proxycommand-with-equals-none"
+ },
}.items():
f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
self.assertEqual(
- paramiko.util.lookup_ssh_host_config(host, config),
- values
+ paramiko.util.lookup_ssh_host_config(host, config), values
)
def test_proxycommand_none_masking(self):
@@ -521,12 +565,10 @@ Host *
# backwards compatibility reasons in 1.x/2.x) appear completely blank,
# as if the host had no ProxyCommand whatsoever.
# Threw another unrelated host in there just for sanity reasons.
- self.assertFalse('proxycommand' in config.lookup('specific-host'))
+ self.assertFalse("proxycommand" in config.lookup("specific-host"))
self.assertEqual(
- config.lookup('other-host')['proxycommand'],
- 'other-proxy'
+ config.lookup("other-host")["proxycommand"], "other-proxy"
)
self.assertEqual(
- config.lookup('some-random-host')['proxycommand'],
- 'default-proxy'
+ config.lookup("some-random-host")["proxycommand"], "default-proxy"
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment