Skip to content

Instantly share code, notes, and snippets.

@bitprophet
Created August 19, 2011 01:57
Show Gist options
  • Save bitprophet/e3e96664765748151c05 to your computer and use it in GitHub Desktop.
Save bitprophet/e3e96664765748151c05 to your computer and use it in GitHub Desktop.
import sys
import socket
import getpass
import paramiko as ssh
from paramiko.resource import ResourceManager
from fabric import network
from fabric import state as s
def connect_forward(gw, host, port, user):
"""
Create a different connect that works with a gateway. We really need to
create the socket and destroy it when the connection fails and then retry
the connect.
"""
client = ForwardSSHClient()
while True:
# Load known host keys (e.g. ~/.ssh/known_hosts) unless user says not to.
if not s.env.disable_known_hosts:
client.load_system_host_keys()
# Unless user specified not to, accept/add new, unknown host keys
if not s.env.reject_unknown_hosts:
client.set_missing_host_key_policy(ssh.AutoAddPolicy())
sock = gw.get_transport().open_channel('direct-tcpip', (host, int(port)), ('', 0))
try:
client.connect(host, sock, int(port), user, s.env.password,
key_filename=s.env.key_filename, timeout=10)
client._sock_ = sock
return client
except (
ssh.AuthenticationException,
ssh.PasswordRequiredException,
ssh.SSHException
), e:
if e.__class__ is ssh.SSHException and password:
network.abort(str(e))
s.env.password = network.prompt_for_password(s.env.password)
sock.close()
except (EOFError, TypeError):
# Print a newline (in case user was sitting at prompt)
print('')
sys.exit(0)
# Handle timeouts
except socket.timeout:
network.abort('Timed out trying to connect to %s' % host)
# Handle DNS error / name lookup failure
except socket.gaierror:
network.abort('Name lookup failed for %s' % host)
# Handle generic network-related errors
# NOTE: In 2.6, socket.error subclasses IOError
except socket.error, e:
network.abort('Low level socket error connecting to host %s: %s' % (
host, e[1])
)
class ForwardSSHClient(ssh.SSHClient):
"""
Override the default ssh.SSHClient to make it accept a socket as an extra argument,
instead of creating one of its own.
"""
def connect(self, hostname, sock, port=22, username=None, password=None, pkey=None,
key_filename=None, timeout=None, allow_agent=True, look_for_keys=True):
t = self._transport = ssh.Transport(sock)
if self._log_channel is not None:
t.set_log_channel(self._log_channel)
t.start_client()
ResourceManager.register(self, t)
server_key = t.get_remote_server_key()
keytype = server_key.get_name()
our_server_key = self._system_host_keys.get(hostname, {}).get(keytype, None)
if our_server_key is None:
our_server_key = self._host_keys.get(hostname, {}).get(keytype, None)
if our_server_key is None:
# will raise exception if the key is rejected; let that fall out
self._policy.missing_host_key(self, hostname, server_key)
# if the callback returns, assume the key is ok
our_server_key = server_key
if server_key != our_server_key:
raise ssh.BadHostKeyException(hostname, server_key, our_server_key)
if username is None:
username = getpass.getuser()
if key_filename is None:
key_filenames = []
elif isinstance(key_filename, (str, unicode)):
key_filenames = [ key_filename ]
else:
key_filenames = key_filename
self._auth(username, password, pkey, key_filenames, allow_agent, look_for_keys)
class GatewayConnectionCache(network.HostConnectionCache):
_gw = None
def __getitem__(self, key):
gw = s.env.get('gateway')
if gw is None:
return super(GatewayConnectionCache, self).__getitem__(key)
gw_user, gw_host, gw_port = network.normalize(gw)
if self._gw is None:
# Normalize given key (i.e. obtain username and port, if not given)
self._gw = network.connect(gw_user, gw_host, gw_port)
# Normalize given key (i.e. obtain username and port, if not given)
user, host, port = network.normalize(key)
# Recombine for use as a key.
real_key = network.join_host_strings(user, host, port)
# If not found, create new connection and store it
if real_key not in self:
self[real_key] = connect_forward(self._gw, host, port, user)
# Return the value either way
return dict.__getitem__(self, real_key)
_c = s.connections = GatewayConnectionCache()
from fabric import operations
operations.connections = _c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment