-
-
Save bitprophet/34a60a4fbe86e6866eb5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/.gitignore b/.gitignore | |
index 3af1e5b..a78c1c0 100644 | |
--- a/.gitignore | |
+++ b/.gitignore | |
@@ -12,3 +12,5 @@ dist | |
build/ | |
tags | |
TAGS | |
+.project | |
+.pydevproject | |
diff --git a/fabric/forward_ssh.py b/fabric/forward_ssh.py | |
new file mode 100644 | |
index 0000000..8b25113 | |
--- /dev/null | |
+++ b/fabric/forward_ssh.py | |
@@ -0,0 +1,44 @@ | |
+import getpass | |
+import paramiko as ssh | |
+from paramiko.resource import ResourceManager | |
+ | |
+class ForwardSSHClient(ssh.SSHClient): | |
+ """ | |
+ Override the default ssh.SSHClient to make it accept a socket as an extra argument, | |
+ instead of creating one of its own. | |
+ """ | |
+ def connect(self, hostname, sock, port=22, username=None, password=None, pkey=None, | |
+ key_filename=None, timeout=None, allow_agent=True, look_for_keys=True): | |
+ t = self._transport = ssh.Transport(sock) | |
+ | |
+ if self._log_channel is not None: | |
+ t.set_log_channel(self._log_channel) | |
+ | |
+ t.start_client() | |
+ ResourceManager.register(self, t) | |
+ | |
+ server_key = t.get_remote_server_key() | |
+ keytype = server_key.get_name() | |
+ | |
+ our_server_key = self._system_host_keys.get(hostname, {}).get(keytype, None) | |
+ if our_server_key is None: | |
+ our_server_key = self._host_keys.get(hostname, {}).get(keytype, None) | |
+ if our_server_key is None: | |
+ # will raise exception if the key is rejected; let that fall out | |
+ self._policy.missing_host_key(self, hostname, server_key) | |
+ # if the callback returns, assume the key is ok | |
+ our_server_key = server_key | |
+ | |
+ if server_key != our_server_key: | |
+ raise ssh.BadHostKeyException(hostname, server_key, our_server_key) | |
+ | |
+ if username is None: | |
+ username = getpass.getuser() | |
+ | |
+ if key_filename is None: | |
+ key_filenames = [] | |
+ elif isinstance(key_filename, (str, unicode)): | |
+ key_filenames = [ key_filename ] | |
+ else: | |
+ key_filenames = key_filename | |
+ self._auth(username, password, pkey, key_filenames, allow_agent, look_for_keys) | |
\ No newline at end of file | |
diff --git a/fabric/network.py b/fabric/network.py | |
index 139c75b..e321a68 100644 | |
--- a/fabric/network.py | |
+++ b/fabric/network.py | |
@@ -29,7 +29,6 @@ Please make sure all dependencies are installed and importable.""" % e | |
host_pattern = r'((?P<user>.+)@)?(?P<host>[^:]+)(:(?P<port>\d+))?' | |
host_regex = re.compile(host_pattern) | |
- | |
class HostConnectionCache(dict): | |
""" | |
Dict subclass allowing for caching of host connections/clients. | |
@@ -61,14 +60,32 @@ class HostConnectionCache(dict): | |
two different connections to the same host being made. If no port is given, | |
22 is assumed, so ``example.com`` is equivalent to ``example.com:22``. | |
""" | |
+ | |
+ def initialize_gateway(self): | |
+ """ | |
+ Initializes the connection to the gateway, if a gateway is specified. | |
+ """ | |
+ from fabric import state | |
+ gateway_key = state.env.get('gateway') | |
+ if not gateway_key is None and state.gateway_connection is None: | |
+ gateway_user, gateway_host, gateway_port = normalize(gateway_key) | |
+ # Normalize given key (i.e. obtain username and port, if not given) | |
+ state.gateway_connection = connect(gateway_user, gateway_host, gateway_port) | |
+ | |
def __getitem__(self, key): | |
+ from fabric import state | |
+ self.initialize_gateway() | |
# Normalize given key (i.e. obtain username and port, if not given) | |
user, host, port = normalize(key) | |
# Recombine for use as a key. | |
real_key = join_host_strings(user, host, port) | |
# If not found, create new connection and store it | |
if real_key not in self: | |
- self[real_key] = connect(user, host, port) | |
+ if state.gateway_connection is None: | |
+ self[real_key] = connect(user, host, port) | |
+ else: | |
+ self[real_key] = connect_forward(state.gateway_connection, host, port, user) | |
+ | |
# Return the value either way | |
return dict.__getitem__(self, real_key) | |
@@ -265,6 +282,56 @@ def connect(user, host, port): | |
host, e[1]) | |
) | |
+def connect_forward(gw, host, port, user): | |
+ """ | |
+ Create a different connect that works with a gateway. We really need to | |
+ create the socket and destroy it when the connection fails and then retry | |
+ the connect. | |
+ """ | |
+ from state import env | |
+ from forward_ssh import ForwardSSHClient | |
+ client = ForwardSSHClient() | |
+ while True: | |
+ # Load known host keys (e.g. ~/.ssh/known_hosts) unless user says not to. | |
+ if not env.disable_known_hosts: | |
+ client.load_system_host_keys() | |
+ # Unless user specified not to, accept/add new, unknown host keys | |
+ if not env.reject_unknown_hosts: | |
+ client.set_missing_host_key_policy(ssh.AutoAddPolicy()) | |
+ | |
+ sock = gw.get_transport().open_channel('direct-tcpip', (host, int(port)), ('', 0)) | |
+ try: | |
+ client.connect(host, sock, int(port), user, env.password, | |
+ key_filename=env.key_filename, timeout=10) | |
+ client._sock_ = sock | |
+ return client | |
+ except ( | |
+ ssh.AuthenticationException, | |
+ ssh.PasswordRequiredException, | |
+ ssh.SSHException | |
+ ), e: | |
+ if e.__class__ is ssh.SSHException and env.password: | |
+ abort(str(e)) | |
+ | |
+ env.password = prompt_for_password(env.password) | |
+ sock.close() | |
+ | |
+ except (EOFError, TypeError): | |
+ # Print a newline (in case user was sitting at prompt) | |
+ print('') | |
+ sys.exit(0) | |
+ # Handle timeouts | |
+ except socket.timeout: | |
+ abort('Timed out trying to connect to %s' % host) | |
+ # Handle DNS error / name lookup failure | |
+ except socket.gaierror: | |
+ abort('Name lookup failed for %s' % host) | |
+ # Handle generic network-related errors | |
+ # NOTE: In 2.6, socket.error subclasses IOError | |
+ except socket.error, e: | |
+ abort('Low level socket error connecting to host %s: %s' % ( | |
+ host, e[1]) | |
+ ) | |
def prompt_for_password(prompt=None, no_colon=False, stream=None): | |
""" | |
@@ -358,7 +425,7 @@ def disconnect_all(): | |
Used at the end of ``fab``'s main loop, and also intended for use by | |
library users. | |
""" | |
- from fabric.state import connections, output | |
+ from fabric.state import connections, output, gateway_connection | |
# Explicitly disconnect from all servers | |
for key in connections.keys(): | |
if output.status: | |
@@ -367,3 +434,10 @@ def disconnect_all(): | |
del connections[key] | |
if output.status: | |
print "done." | |
+ if not gateway_connection is None: | |
+ if output.status: | |
+ print "Disconnecting from gateway...", | |
+ gateway_connection.close() | |
+ gateway_connection = None | |
+ if output.status: | |
+ print "done." | |
diff --git a/fabric/state.py b/fabric/state.py | |
index c59f158..e25f6b6 100644 | |
--- a/fabric/state.py | |
+++ b/fabric/state.py | |
@@ -139,6 +139,11 @@ env_options = [ | |
help="comma-separated list of hosts to operate on" | |
), | |
+ make_option('-G', '--gateway', | |
+ default=None, | |
+ help="SSH gateway" | |
+ ), | |
+ | |
make_option('-R', '--roles', | |
default=[], | |
help="comma-separated list of roles to operate on" | |
@@ -286,6 +291,11 @@ commands = {} | |
connections = HostConnectionCache() | |
+# | |
+# Connection to SSH gateway if so requested | |
+# | |
+ | |
+gateway_connection = None | |
def default_channel(): | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment