Skip to content

Instantly share code, notes, and snippets.

@bitprophet
Created August 20, 2011 00:25
Show Gist options
  • Save bitprophet/34a60a4fbe86e6866eb5 to your computer and use it in GitHub Desktop.
Save bitprophet/34a60a4fbe86e6866eb5 to your computer and use it in GitHub Desktop.
diff --git a/.gitignore b/.gitignore
index 3af1e5b..a78c1c0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -12,3 +12,5 @@ dist
build/
tags
TAGS
+.project
+.pydevproject
diff --git a/fabric/forward_ssh.py b/fabric/forward_ssh.py
new file mode 100644
index 0000000..8b25113
--- /dev/null
+++ b/fabric/forward_ssh.py
@@ -0,0 +1,44 @@
+import getpass
+import paramiko as ssh
+from paramiko.resource import ResourceManager
+
+class ForwardSSHClient(ssh.SSHClient):
+ """
+ Override the default ssh.SSHClient to make it accept a socket as an extra argument,
+ instead of creating one of its own.
+ """
+ def connect(self, hostname, sock, port=22, username=None, password=None, pkey=None,
+ key_filename=None, timeout=None, allow_agent=True, look_for_keys=True):
+ t = self._transport = ssh.Transport(sock)
+
+ if self._log_channel is not None:
+ t.set_log_channel(self._log_channel)
+
+ t.start_client()
+ ResourceManager.register(self, t)
+
+ server_key = t.get_remote_server_key()
+ keytype = server_key.get_name()
+
+ our_server_key = self._system_host_keys.get(hostname, {}).get(keytype, None)
+ if our_server_key is None:
+ our_server_key = self._host_keys.get(hostname, {}).get(keytype, None)
+ if our_server_key is None:
+ # will raise exception if the key is rejected; let that fall out
+ self._policy.missing_host_key(self, hostname, server_key)
+ # if the callback returns, assume the key is ok
+ our_server_key = server_key
+
+ if server_key != our_server_key:
+ raise ssh.BadHostKeyException(hostname, server_key, our_server_key)
+
+ if username is None:
+ username = getpass.getuser()
+
+ if key_filename is None:
+ key_filenames = []
+ elif isinstance(key_filename, (str, unicode)):
+ key_filenames = [ key_filename ]
+ else:
+ key_filenames = key_filename
+ self._auth(username, password, pkey, key_filenames, allow_agent, look_for_keys)
\ No newline at end of file
diff --git a/fabric/network.py b/fabric/network.py
index 139c75b..e321a68 100644
--- a/fabric/network.py
+++ b/fabric/network.py
@@ -29,7 +29,6 @@ Please make sure all dependencies are installed and importable.""" % e
host_pattern = r'((?P<user>.+)@)?(?P<host>[^:]+)(:(?P<port>\d+))?'
host_regex = re.compile(host_pattern)
-
class HostConnectionCache(dict):
"""
Dict subclass allowing for caching of host connections/clients.
@@ -61,14 +60,32 @@ class HostConnectionCache(dict):
two different connections to the same host being made. If no port is given,
22 is assumed, so ``example.com`` is equivalent to ``example.com:22``.
"""
+
+ def initialize_gateway(self):
+ """
+ Initializes the connection to the gateway, if a gateway is specified.
+ """
+ from fabric import state
+ gateway_key = state.env.get('gateway')
+ if not gateway_key is None and state.gateway_connection is None:
+ gateway_user, gateway_host, gateway_port = normalize(gateway_key)
+ # Normalize given key (i.e. obtain username and port, if not given)
+ state.gateway_connection = connect(gateway_user, gateway_host, gateway_port)
+
def __getitem__(self, key):
+ from fabric import state
+ self.initialize_gateway()
# Normalize given key (i.e. obtain username and port, if not given)
user, host, port = normalize(key)
# Recombine for use as a key.
real_key = join_host_strings(user, host, port)
# If not found, create new connection and store it
if real_key not in self:
- self[real_key] = connect(user, host, port)
+ if state.gateway_connection is None:
+ self[real_key] = connect(user, host, port)
+ else:
+ self[real_key] = connect_forward(state.gateway_connection, host, port, user)
+
# Return the value either way
return dict.__getitem__(self, real_key)
@@ -265,6 +282,56 @@ def connect(user, host, port):
host, e[1])
)
+def connect_forward(gw, host, port, user):
+ """
+ Create a different connect that works with a gateway. We really need to
+ create the socket and destroy it when the connection fails and then retry
+ the connect.
+ """
+ from state import env
+ from forward_ssh import ForwardSSHClient
+ client = ForwardSSHClient()
+ while True:
+ # Load known host keys (e.g. ~/.ssh/known_hosts) unless user says not to.
+ if not env.disable_known_hosts:
+ client.load_system_host_keys()
+ # Unless user specified not to, accept/add new, unknown host keys
+ if not env.reject_unknown_hosts:
+ client.set_missing_host_key_policy(ssh.AutoAddPolicy())
+
+ sock = gw.get_transport().open_channel('direct-tcpip', (host, int(port)), ('', 0))
+ try:
+ client.connect(host, sock, int(port), user, env.password,
+ key_filename=env.key_filename, timeout=10)
+ client._sock_ = sock
+ return client
+ except (
+ ssh.AuthenticationException,
+ ssh.PasswordRequiredException,
+ ssh.SSHException
+ ), e:
+ if e.__class__ is ssh.SSHException and env.password:
+ abort(str(e))
+
+ env.password = prompt_for_password(env.password)
+ sock.close()
+
+ except (EOFError, TypeError):
+ # Print a newline (in case user was sitting at prompt)
+ print('')
+ sys.exit(0)
+ # Handle timeouts
+ except socket.timeout:
+ abort('Timed out trying to connect to %s' % host)
+ # Handle DNS error / name lookup failure
+ except socket.gaierror:
+ abort('Name lookup failed for %s' % host)
+ # Handle generic network-related errors
+ # NOTE: In 2.6, socket.error subclasses IOError
+ except socket.error, e:
+ abort('Low level socket error connecting to host %s: %s' % (
+ host, e[1])
+ )
def prompt_for_password(prompt=None, no_colon=False, stream=None):
"""
@@ -358,7 +425,7 @@ def disconnect_all():
Used at the end of ``fab``'s main loop, and also intended for use by
library users.
"""
- from fabric.state import connections, output
+ from fabric.state import connections, output, gateway_connection
# Explicitly disconnect from all servers
for key in connections.keys():
if output.status:
@@ -367,3 +434,10 @@ def disconnect_all():
del connections[key]
if output.status:
print "done."
+ if not gateway_connection is None:
+ if output.status:
+ print "Disconnecting from gateway...",
+ gateway_connection.close()
+ gateway_connection = None
+ if output.status:
+ print "done."
diff --git a/fabric/state.py b/fabric/state.py
index c59f158..e25f6b6 100644
--- a/fabric/state.py
+++ b/fabric/state.py
@@ -139,6 +139,11 @@ env_options = [
help="comma-separated list of hosts to operate on"
),
+ make_option('-G', '--gateway',
+ default=None,
+ help="SSH gateway"
+ ),
+
make_option('-R', '--roles',
default=[],
help="comma-separated list of roles to operate on"
@@ -286,6 +291,11 @@ commands = {}
connections = HostConnectionCache()
+#
+# Connection to SSH gateway if so requested
+#
+
+gateway_connection = None
def default_channel():
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment