Last active
July 10, 2024 10:57
-
-
Save fintanmm/025b8546637590e7bedf9ed1d0a48701 to your computer and use it in GitHub Desktop.
Installs and configures AWS SSM Agent for On-Premise Linux Servers using SSH and JBang
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
///usr/bin/env jbang "$0" "$@" ; exit $? | |
//DEPS ch.qos.reload4j:reload4j:1.2.25 | |
//DEPS info.picocli:picocli:4.7.6 | |
//DEPS info.picocli:picocli-codegen:4.7.6 | |
//DEPS com.github.mwiede:jsch:0.2.18 | |
//DEPS com.google.code.gson:gson:2.8.9 | |
import com.google.gson.JsonArray; | |
import com.google.gson.JsonElement; | |
import com.google.gson.JsonObject; | |
import com.google.gson.JsonParser; | |
import com.jcraft.jsch.ChannelExec; | |
import com.jcraft.jsch.JSch; | |
import com.jcraft.jsch.JSchException; | |
import com.jcraft.jsch.Session; | |
import org.apache.log4j.*; | |
import org.apache.log4j.spi.LoggingEvent; | |
import picocli.CommandLine; | |
import picocli.CommandLine.ArgGroup; | |
import picocli.CommandLine.Command; | |
import picocli.CommandLine.Option; | |
import java.io.IOException; | |
import java.io.InputStream; | |
import java.net.InetSocketAddress; | |
import java.net.Socket; | |
import java.nio.file.Path; | |
import java.util.HashSet; | |
import java.util.Optional; | |
import java.util.Scanner; | |
import java.util.Set; | |
import java.util.concurrent.ExecutorService; | |
import java.util.concurrent.Executors; | |
@Command(name = "SsmAgentInstall", mixinStandardHelpOptions = true, version = "SsmAgentInstall 0.5", description = "Installs and configures AWS SSM Agent for On-Premise Linux Servers using SSH.") | |
public class SsmAgentInstall implements Runnable { | |
public static final String ON_PORT = " on port "; | |
public static final int TIMEOUT = 30000; | |
public static final String UBUNTU = "Ubuntu"; | |
public static final String DEBIAN = "Debian"; | |
public static final String CONNECTIVITY_CHECK_FAILED_FOR = "Connectivity check failed for "; | |
public static final String FAILED_TO_INSTALL_THE_SSM_AGENT = "Failed to install the SSM Agent"; | |
private static final String FAILED_TO_CONNECT_TO_SERVER = "Failed to connect to server: %s"; | |
private static final String FAILED_TO_ENABLE_THE_SSM_AGENT = "Failed to enable the SSM Agent"; | |
private static final String FAILED_TO_RESTART_THE_SSM_AGENT = "Failed to restart the SSM Agent"; | |
@Option(names = {"-l", "--log-file"}, description = "The log file to write the output to") | |
private static Path logFile; | |
private final Set<String> failedServers = new HashSet<>(); | |
private final Logger logger = Logger.getLogger(SsmAgentInstall.class); | |
@ArgGroup(exclusive = false, multiplicity = "1") | |
AuthenticationGroup authenticationGroup; | |
@ArgGroup(exclusive = false, multiplicity = "0..1") | |
AwsOptions awsOptions; | |
@Option(names = {"-s", "--servers"}, description = "A list seperated by commas of On-Premise Servers that the SSM Agent will be installed and configured on", split = ",", required = true) | |
private String[] servers; | |
@Option(names = {"-P", "--port"}, description = "The port to use to connect to the server", defaultValue = "22") | |
private int port; | |
@Option(names = {"-v", "--verbose"}, description = "Be verbose.", defaultValue = "false") | |
private boolean verbose; | |
@Option(names = {"-t", "--test-connectivity"}, description = "Test connectivity to each server but do not install", defaultValue = "false") | |
private boolean testConnectivity; | |
@Option(names = {"-u", "--username"}, description = "The username to use to connect to the server") | |
private String username; | |
@Option(names = {"-S", "--strict"}, description = "Don't use strict host key checking", defaultValue = "true") | |
private boolean strict; | |
@Option(names = {"-D", "--sudo"}, description = "Use sudo to install the SSM Agent", defaultValue = "false") | |
private boolean sudo; | |
@Option(names = {"-T", "--threads"}, description = "The number of threads to use to install the SSM Agent, defaults to 10", defaultValue = "10") | |
private int threads; | |
@Option(names = {"-f", "--force"}, description = "Force the re-registration of a server", defaultValue = "false") | |
private boolean force; | |
public static void main(String... args) { | |
int exitCode = new CommandLine(new SsmAgentInstall()).execute(args); | |
System.exit(exitCode); | |
} | |
/** | |
* Sleep the current thread for the specified delay | |
* | |
* @param delay the delay in milliseconds | |
*/ | |
private static void thread(int delay) { | |
try { | |
Thread.sleep(delay); | |
} catch (InterruptedException ie) { | |
Thread.currentThread().interrupt(); | |
} | |
} | |
@Override | |
public void run() { | |
// Create a new color console appender | |
ColorConsoleAppender colorAppender = new ColorConsoleAppender(); | |
colorAppender.setLayout(new PatternLayout("%d{yyyy-MM-dd HH:mm:ss} %-5p [%c{1}] %m%n")); | |
colorAppender.activateOptions(); | |
// Get the root logger and add your appender to it | |
Logger rootLogger = Logger.getRootLogger(); | |
rootLogger.addAppender(colorAppender); | |
BasicConfigurator.configure(colorAppender); | |
if (logFile != null) { | |
try { | |
FileAppender fileAppender = new FileAppender(new PatternLayout("%d{yyyy-MM-dd HH:mm:ss} %-5p [%c{1}] %m%n"), logFile.toAbsolutePath().toString()); | |
BasicConfigurator.configure(fileAppender); | |
} catch (IOException e) { | |
logger.error("Failed to create log file", e); | |
System.exit(1); | |
} | |
} | |
if (verbose) { | |
logger.setLevel(Level.DEBUG); | |
logger.info("Verbose output enabled"); | |
logger.info("Installing and configuring the SSM Agent on On-Premise Servers"); | |
logger.info("AWS Region: %s".formatted(awsOptions.region)); | |
logger.debug("Code: %s".formatted(awsOptions.code)); | |
logger.debug("ID: %s".formatted(awsOptions.id)); | |
} | |
checkAndInstallSSMAgent(servers); | |
} | |
/** | |
* Check connectivity to the servers and install the SSM Agent | |
* | |
* @param servers the servers to connect to | |
*/ | |
private void checkAndInstallSSMAgent(String[] servers) { | |
try (ExecutorService executor = Executors.newFixedThreadPool(threads)) { | |
for (String server : servers) { | |
executor.submit(() -> { | |
String s = server.trim(); | |
if (testConnectivity) { | |
if (testConnectivity(s)) { | |
testConnectivityJCraft(s); | |
} | |
return; | |
} | |
Session session = createSession(s); | |
if (session != null) { | |
if (!testConnectivity) { | |
installSSMAgent(session); | |
} | |
session.disconnect(); | |
} else { | |
failedServers.add(s); | |
} | |
}); | |
} | |
} catch (Exception e) { | |
logger.error(FAILED_TO_INSTALL_THE_SSM_AGENT, e); | |
} | |
if (!failedServers.isEmpty()) { | |
logger.error("Failed to process the following servers: %s".formatted(String.join(", ", failedServers))); | |
} | |
} | |
/** | |
* Download the SSM Agent from the AWS S3 bucket | |
* | |
* @param withSession the withSession to the server | |
*/ | |
private void installSSMAgent(Session withSession) { | |
String forDistro = checkLinuxVersion(withSession); | |
String andArch = checkLinuxArch(withSession); | |
String url = ""; | |
if (forDistro == null || andArch == null) { | |
logger.error("Failed to determine the Linux distribution or Architecture: %s".formatted(withSession.getHost())); | |
failedServers.add(withSession.getHost()); | |
Thread.currentThread().interrupt(); | |
return; | |
} | |
url = getDownloadUrl(forDistro, andArch); | |
if (url == null) { | |
logger.error("Failed to determine the Linux distribution or Architecture: %s %s".formatted(forDistro, andArch)); | |
failedServers.add(withSession.getHost()); | |
return; | |
} | |
String fileName = forDistro.contains(UBUNTU) || forDistro.contains(DEBIAN) ? "amazon-ssm-agent.deb" | |
: "amazon-ssm-agent.rpm"; | |
if (verbose) { | |
logger.info("Downloading %s from %s".formatted(fileName, url)); | |
} | |
downloadFile(withSession, url, fileName); | |
install(fileName, withSession, forDistro); | |
} | |
/** | |
* Get the download URL for the SSM Agent | |
* | |
* @param distro the Linux distribution | |
* @param arch the Linux architecture | |
* @return String the download URL | |
*/ | |
private String getDownloadUrl(String distro, String arch) { | |
arch = arch.equals("x86_64") ? "amd64" : "arm64"; | |
String url = "https://s3.amazonaws.com/ec2-downloads-windows/SSMAgent/latest/"; | |
if (distro.contains(UBUNTU) || distro.contains(DEBIAN)) { | |
return url + "debian_" + arch + "/amazon-ssm-agent.deb"; | |
} else if (distro.contains("CentOS") || distro.contains("RHEL") || distro.contains("Fedora") | |
|| distro.contains("SUSE")) { | |
return url + "linux_" + arch + "/amazon-ssm-agent.deb"; | |
} | |
// handle other distributions | |
return null; | |
} | |
/** | |
* Download the SSM Agent from the specified URL using curl | |
* | |
* @param session the session to the server | |
* @param url the URL to download the SSM Agent from | |
* @param fileName the name of the file to download | |
*/ | |
private void downloadFile(Session session, String url, String fileName) { | |
try { | |
// Check if the file already exists | |
String checkFileExistsCommand = String.format("test -f /tmp/%s", fileName); | |
ChannelExec checkFileChannel = (ChannelExec) session.openChannel("exec"); | |
checkFileChannel.setCommand(checkFileExistsCommand); | |
checkFileChannel.connect(); | |
// Wait for the command to finish | |
while (!checkFileChannel.isClosed()) { | |
Thread.sleep(100); | |
} | |
// If the file exists, skip downloading | |
if (checkFileChannel.getExitStatus() == 0) { | |
logger.info(String.format("File %s already exists. Skipping download.", fileName)); | |
checkFileChannel.disconnect(); | |
return; | |
} | |
checkFileChannel.disconnect(); | |
// File does not exist, proceed with download | |
logger.info("Downloading the file: " + fileName); | |
String command = String.format("curl -o /tmp/%s %s", fileName, url); | |
ChannelExec channelExec = (ChannelExec) session.openChannel("exec"); | |
channelExec.setCommand(command); | |
channelExec.connect(); | |
while (!channelExec.isClosed()) { | |
Thread.sleep(100); | |
} | |
if (channelExec.getExitStatus() != 0) { | |
logger.error("Failed to download the file: " + channelExec.getExitStatus()); | |
failedServers.add(session.getHost()); | |
return; | |
} | |
logger.info("File downloaded successfully: " + fileName); | |
channelExec.disconnect(); | |
} catch (JSchException | InterruptedException e) { | |
logger.error("Failed to download the file", e); | |
failedServers.add(session.getHost()); | |
Thread.currentThread().interrupt(); | |
} | |
} | |
/** | |
* Check the Linux architecture of the server | |
* | |
* @param session the session to the server | |
* @return String the Linux architecture | |
*/ | |
private String checkLinuxArch(Session session) { | |
StringBuilder arch = new StringBuilder(); | |
try { | |
ChannelExec channel = (ChannelExec) session.openChannel("exec"); | |
channel.setCommand("uname -m"); | |
// Get the input stream before connecting | |
InputStream in = channel.getInputStream(); | |
channel.connect(); | |
// Read the command output | |
arch.append(new String(in.readAllBytes()).trim()); | |
} catch (JSchException | IOException e) { | |
logger.error("Failed to check the Linux architecture for: %s".formatted(session.getHost())); | |
failedServers.add(session.getHost()); | |
Thread.currentThread().interrupt(); | |
return null; | |
} | |
logger.info("Checking the Linux architecture for %s: %s".formatted(session.getHost(), arch.toString().trim())); | |
return arch.toString().trim(); | |
} | |
/** | |
* Check the Linux distribution of the server | |
* | |
* @param session the session to the server | |
* @return String the Linux distribution | |
*/ | |
private String checkLinuxVersion(Session session) { | |
try { | |
ChannelExec channel = (ChannelExec) session.openChannel("exec"); | |
channel.setCommand("cat /etc/os-release"); | |
// Get the input stream before connecting | |
InputStream in = channel.getInputStream(); | |
channel.connect(); | |
// Use a Scanner to read the command output | |
try (Scanner scanner = new Scanner(in)) { | |
while (scanner.hasNextLine()) { | |
String line = scanner.nextLine(); | |
if (line.startsWith("NAME=")) { | |
String distro = line.split("=")[1].replace("\"", "").trim(); | |
logger.info("Checking the Linux distribution for %s: %s".formatted(session.getHost(), distro)); | |
return distro; | |
} | |
} | |
} | |
failedServers.add(session.getHost()); | |
return null; | |
} catch (JSchException | IOException e) { | |
logger.error("Failed to check the Linux version for: %s".formatted(session.getHost())); | |
failedServers.add(session.getHost()); | |
return null; | |
} | |
} | |
/** | |
* Install the SSM Agent on the server | |
* | |
* @param fileName the name of the file to install | |
* @param session the session to the server | |
* @param distro the Linux distribution | |
*/ | |
private void install(String fileName, Session session, String distro) { | |
// if the password is null use the picocli library to prompt for the password | |
// if (authenticationGroup.password == null) { | |
// CommandLine cmd = new CommandLine(new SsmAgentInstall()); | |
// cmd.setExecutionStrategy(new CommandLine.RunLast()); | |
// cmd.execute("password"); | |
// if (authenticationGroup.password == null) { | |
// logger.error("\u001B[31mFailed to install the SSM Agent: No password provided\u001B[0m"); | |
// return; | |
// } | |
// } | |
String password = new String(authenticationGroup.password); // Assuming authenticationGroup is accessible here | |
try { | |
ChannelExec channel = (ChannelExec) session.openChannel("exec"); | |
if (distro.contains(UBUNTU) || distro.contains(DEBIAN)) { | |
channel.setCommand("echo \"%s\" | sudo -S dpkg -i /tmp/%s".formatted(password, fileName)); | |
} else if (distro.contains("CentOS") || distro.contains("RHEL") || distro.contains("Fedora") | |
|| distro.contains("SUSE")) { | |
channel.setCommand("echo \"%s\" | sudo -S rpm -i /tmp/%s".formatted(password, fileName)); | |
} | |
channel.connect(); | |
// Wait for the command to finish | |
while (!channel.isClosed()) { | |
Thread.sleep(100); | |
} | |
// Check the exit status | |
if (channel.getExitStatus() != 0) { | |
logger.error((FAILED_TO_INSTALL_THE_SSM_AGENT + ": %d").formatted(channel.getExitStatus())); | |
failedServers.add(session.getHost()); | |
throw new JSchException(FAILED_TO_INSTALL_THE_SSM_AGENT); | |
} | |
channel.disconnect(); | |
enableAndRestartSSMAgentService(session); | |
configureSSMAgent(session); | |
} catch (JSchException e) { | |
logger.error(FAILED_TO_INSTALL_THE_SSM_AGENT, e); | |
failedServers.add(session.getHost()); | |
session.disconnect(); | |
} catch (InterruptedException e) { | |
logger.error(FAILED_TO_INSTALL_THE_SSM_AGENT, e); | |
failedServers.add(session.getHost()); | |
session.disconnect(); | |
Thread.currentThread().interrupt(); | |
} finally { | |
session.disconnect(); | |
} | |
} | |
/** | |
* Restart the SSM Agent on the server | |
* | |
* @param session the session to the server | |
*/ | |
private void enableAndRestartSSMAgentService(Session session) { | |
enableSsmAgentService(session); | |
restartSsmAgentService(session); | |
} | |
/** | |
* Enable the SSM Agent on the server | |
* | |
* @param session the session to the server | |
*/ | |
private void enableSsmAgentService(Session session) { | |
String password = new String(authenticationGroup.password); | |
try { | |
ChannelExec channel = (ChannelExec) session.openChannel("exec"); | |
channel.setCommand("echo \"%s\" | sudo -S systemctl enable amazon-ssm-agent.service".formatted(password)); | |
channel.connect(); | |
// Wait for the command to finish | |
while (!channel.isClosed()) { | |
Thread.sleep(100); | |
} | |
if (channel.getExitStatus() != 0) { | |
logger.error((FAILED_TO_ENABLE_THE_SSM_AGENT + ": %d").formatted(channel.getExitStatus())); | |
failedServers.add(session.getHost()); | |
throw new JSchException(FAILED_TO_ENABLE_THE_SSM_AGENT); | |
} | |
channel.disconnect(); | |
logger.info("SSM Agent enabled successfully"); | |
} catch (JSchException | InterruptedException e) { | |
logger.error(FAILED_TO_ENABLE_THE_SSM_AGENT, e); | |
failedServers.add(session.getHost()); | |
return; | |
} | |
} | |
/** | |
* Restart the SSM Agent on the server | |
* | |
* @param session the session to the server | |
*/ | |
private void restartSsmAgentService(Session session) { | |
String password = new String(authenticationGroup.password); | |
try { | |
ChannelExec channel = (ChannelExec) session.openChannel("exec"); | |
channel.setCommand("echo \"%s\" | sudo -S systemctl restart amazon-ssm-agent.service".formatted(password)); | |
channel.connect(); | |
// Wait for the command to finish | |
while (!channel.isClosed()) { | |
Thread.sleep(100); | |
} | |
if (channel.getExitStatus() != 0) { | |
logger.error((FAILED_TO_RESTART_THE_SSM_AGENT + ": %d").formatted(channel.getExitStatus())); | |
failedServers.add(session.getHost()); | |
throw new JSchException(FAILED_TO_RESTART_THE_SSM_AGENT); | |
} | |
channel.disconnect(); | |
logger.info("SSM Agent restarted successfully"); | |
} catch (JSchException | InterruptedException e) { | |
logger.error(FAILED_TO_RESTART_THE_SSM_AGENT, e); | |
failedServers.add(session.getHost()); | |
} | |
} | |
/** | |
* Configure the SSM Agent to communicate with the AWS Systems Manager Service | |
* | |
* @param session the session to the server | |
*/ | |
private void configureSSMAgent(Session session) { | |
String password = new String(authenticationGroup.password); | |
try { | |
ChannelExec channel = (ChannelExec) session.openChannel("exec"); | |
// Construct the base command for SSM Agent registration | |
String baseCommand = String.format("sudo -S amazon-ssm-agent -register -code \"%s\" -id \"%s\" -region \"%s\"", awsOptions.code, awsOptions.id, awsOptions.region); | |
// If force is true, prepend the command to automatically answer "Yes" | |
String agentCommand = force ? "echo \"Yes\" | %s".formatted(baseCommand) : baseCommand; | |
// Include password input for sudo | |
String fullCommand = String.format("echo \"%s\" | %s", password, agentCommand); | |
channel.setCommand(fullCommand); | |
// Get the input and output streams | |
InputStream in = channel.getInputStream(); | |
channel.connect(); | |
// Use a Scanner to read the command output | |
try (Scanner scanner = new Scanner(in)) { | |
while (scanner.hasNextLine()) { | |
String line = scanner.nextLine().trim(); | |
if (verbose) { | |
logger.info("Command output: %s".formatted(line)); | |
} | |
if (logSsmAgentStatus(session, line)) return; | |
if (channel.isClosed()) { | |
if (scanner.hasNextLine()) continue; | |
logger.info("exit-status: " + channel.getExitStatus()); | |
break; | |
} | |
} | |
} | |
channel.disconnect(); | |
verifyDiagnosticsFromSSMAgent(session); | |
} catch (JSchException | IOException e) { | |
logger.error("Failed to configure the SSM Agent on %s".formatted(session.getHost()), e); | |
failedServers.add(session.getHost()); | |
session.disconnect(); | |
Thread.currentThread().interrupt(); | |
} | |
} | |
/** | |
* Log the SSM Agent status | |
* | |
* @param session the session to the server | |
* @param line the command output line | |
* @return boolean true if the command output contains an error message | |
*/ | |
private boolean logSsmAgentStatus(Session session, String line) { | |
// Check if the command line contains the error message | |
if (line.contains("ActivationExpired:")) { | |
// Log an error message and add the server to the failedServers list | |
logger.error("Registration failed due to expired activation code for server: %s".formatted(session.getHost())); | |
failedServers.add(session.getHost()); | |
return true; | |
} | |
if (line.contains("ValidationException:")) { | |
// Log an error message and add the server to the failedServers list | |
logger.error("Registration failed due to invalid activation code for server: %s".formatted(session.getHost())); | |
failedServers.add(session.getHost()); | |
return true; | |
} | |
if (line.contains("EOF")) { | |
// Log an error message and add the server to the failedServers list | |
logger.error("Registration failed due to EOF: %s".formatted(session.getHost())); | |
failedServers.add(session.getHost()); | |
return true; | |
} | |
if (line.contains("Successfully")) { | |
logger.info("Registration successful for server: %s".formatted(session.getHost())); | |
logger.info(line); | |
} | |
return false; | |
} | |
/** | |
* Verify the diagnostics from the SSM Agent | |
* | |
* @param session the session to the server | |
*/ | |
private void verifyDiagnosticsFromSSMAgent(Session session) { | |
String password = new String(authenticationGroup.password); | |
try { | |
ChannelExec channel = (ChannelExec) session.openChannel("exec"); | |
String diagnosticsCommand = "echo \"%s\" | sudo -S ssm-cli get-diagnostics --output json".formatted(password); | |
channel.setCommand(diagnosticsCommand); | |
// Get the input stream | |
try (InputStream in = channel.getInputStream()) { | |
channel.connect(); | |
// Read the command output | |
String output = new String(in.readAllBytes()).trim(); | |
handleDiagnosticsOutput(output); | |
} | |
channel.disconnect(); | |
} catch (JSchException | IOException e) { | |
logger.error("Failed to run diagnostics command on %s: %s".formatted(session.getHost(), e.getMessage()), e); | |
failedServers.add(session.getHost()); | |
session.disconnect(); | |
Thread.currentThread().interrupt(); | |
} | |
} | |
/** | |
* Handle the diagnostics output from the SSM Agent | |
* | |
* @param output the diagnostics output | |
*/ | |
private void handleDiagnosticsOutput(String output) { | |
// Check if the output is not empty and is valid JSON | |
if (output != null && !output.isEmpty()) { | |
// Parse the JSON output | |
JsonObject jsonOutput = JsonParser.parseString(output).getAsJsonObject(); | |
// Check if the JSON output contains the 'DiagnosticsOutput' field | |
if (jsonOutput.has("DiagnosticsOutput")) { | |
// Get the 'DiagnosticsOutput' array | |
JsonArray diagnosticsOutput = jsonOutput.getAsJsonArray("DiagnosticsOutput"); | |
// Iterate over each object in the array | |
diagnosticsOutput.forEach(this::handleDiagnosticElement); | |
} else { | |
logger.error("Diagnostics output does not contain 'DiagnosticsOutput' field"); | |
} | |
} else { | |
logger.error("Diagnostics command did not return any output"); | |
} | |
} | |
/** | |
* Handle a diagnostic element | |
* | |
* @param element the diagnostic element | |
*/ | |
private void handleDiagnosticElement(JsonElement element) { | |
JsonObject diagnostic = element.getAsJsonObject(); | |
// Check if the object contains the 'Check' and 'Status' fields | |
if (diagnostic.has("Check") && diagnostic.has("Status") && verbose) { | |
String check = diagnostic.get("Check").getAsString(); | |
String status = diagnostic.get("Status").getAsString(); | |
if (status.equals("Failed")) { | |
logger.error("Check: %s, Status: %s".formatted(check, status)); | |
} else { | |
logger.info("Check: %s, Status: %s".formatted(check, status)); | |
} | |
} else { | |
String check = diagnostic.get("Check").getAsString(); | |
String status = diagnostic.get("Status").getAsString(); | |
if (!status.equals("Failed") && check.contains("Hybrid")) { | |
logger.info("Check: %s, Status: %s".formatted(check, status)); | |
} else if (status.equals("Failed") && check.contains("Hybrid")) { | |
logger.error("Check: %s, Status: %s".formatted(check, status)); | |
} | |
} | |
} | |
/** | |
* Create a session to the server | |
* | |
* @param server the server to connect to | |
* @return Session the session to the server | |
*/ | |
private Session createSession(String server) { | |
Session session = createJSchSession(server); | |
if (session != null && session.isConnected()) { | |
logger.info("Connected to server: %s".formatted(server)); | |
} else { | |
logger.error(FAILED_TO_CONNECT_TO_SERVER.formatted(server)); | |
Thread.currentThread().interrupt(); | |
return null; | |
} | |
return session; | |
} | |
/** | |
* Create a JSch session to the server | |
* | |
* @param server the server to connect to | |
* @return Session the JSch session | |
*/ | |
private Session createJSchSession(String server) { | |
JSch jsch = new JSch(); | |
Optional<Session> session = Optional.empty(); | |
int maxRetries = 3; // Maximum number of retries | |
int retries = 0; // Current retry count | |
int delay = 3000; // Delay in milliseconds | |
while (retries < maxRetries) { | |
try { | |
if (verbose) { | |
configureJschLogging(jsch); | |
} | |
session = Optional.ofNullable(configureSession(jsch, server)); | |
session.get().connect(30000); | |
return session.orElse(null); | |
} catch (JSchException e) { | |
retries++; | |
session = handleJSchException(e, server, session.orElse(null), retries, maxRetries, delay); | |
if (session.isEmpty() && retries < maxRetries) { | |
thread(delay); | |
} | |
} | |
} | |
return session.orElse(null); | |
} | |
private Session configureSession(JSch jsch, String server) throws JSchException { | |
Session session; | |
if (!strict) { | |
JSch.setConfig("StrictHostKeyChecking", "no"); | |
} | |
session = jsch.getSession(username, server, port); | |
if (authenticationGroup.password != null) { | |
session.setConfig("PreferredAuthentications", "password,keyboard-interactive"); | |
MyUserInfo userInfo = new MyUserInfo(authenticationGroup.password); | |
session.setUserInfo(userInfo); | |
} else if (authenticationGroup.privateKey != null) { | |
session.setConfig("PreferredAuthentications", "publickey"); | |
jsch.addIdentity(authenticationGroup.privateKey.toAbsolutePath().toString()); | |
} else { | |
logger.error("No authentication method provided"); | |
} | |
return session; | |
} | |
/** | |
* Handle a JSch exception | |
* | |
* @param e the JSch exception | |
* @param server the server | |
* @param session the session | |
* @param retries the current retry count | |
* @param maxRetries the maximum retry count | |
* @param delay the delay in milliseconds | |
* @return Session the session | |
*/ | |
private Optional<Session> handleJSchException(JSchException e, String server, Session session, int retries, int maxRetries, int delay) { | |
String message = e.getMessage(); | |
if (message.contains("Auth fail")) { | |
logger.error("Authentication failed for server: %s%n ".formatted(server)); | |
} else if (message.contains("timeout")) { | |
logger.error("Connection timed out for server: %s%n ".formatted(server)); | |
} else if (message.contains("Too many authentication failures")) { | |
logger.error("Too many authentication failures for user: %s on server: %s".formatted(username, server)); | |
logger.info("Retrying connection to server: %s, attempt number: %d".formatted(server, retries)); | |
retries++; | |
delay *= 2; | |
if (logFailures(server, retries, maxRetries, delay)) return Optional.of(session); | |
} else if (message.contains("Algorithm negotiation fail")) { | |
logger.error("Algorithm negotiation failed for server: %s".formatted(server)); | |
logger.info("Retrying connection to server: %s, attempt number: %d".formatted(server, retries)); | |
// Extract the server's proposed algorithm from the error message | |
Optional<String> proposedAlgorithm = extractProposedAlgorithm(message); | |
// Add the proposed algorithm to the preferred list | |
session.setConfig("PreferredAuthentications", proposedAlgorithm.orElse("")); | |
// Retry the connection | |
retries = 2; | |
if (logFailures(server, retries, maxRetries, delay)) return Optional.of(session); | |
return Optional.of(session); | |
} else { | |
retries = 3; | |
if (retries == maxRetries && session == null) { | |
logger.error(FAILED_TO_CONNECT_TO_SERVER.formatted(server)); | |
failedServers.add(server); | |
return Optional.ofNullable(session); | |
} | |
} | |
return Optional.ofNullable(session); | |
} | |
/** | |
* Log the failures and add the server to the failed servers list | |
* | |
* @param server the server | |
* @param retries the current retry count | |
* @param maxRetries the maximum retry count | |
* @param delay the delay in milliseconds | |
* @return boolean true if the connection failed | |
*/ | |
private boolean logFailures(String server, int retries, int maxRetries, int delay) { | |
if (retries == maxRetries) { | |
logger.error(FAILED_TO_CONNECT_TO_SERVER.formatted(server)); | |
logger.info("Adding server to the failed servers list: %s".formatted(server)); | |
failedServers.add(server); | |
Thread.currentThread().interrupt(); | |
return true; | |
} | |
thread(delay); | |
return false; | |
} | |
/** | |
* Extract the server's proposed algorithm from the error message | |
* | |
* @param message the error message | |
* @return the server's proposed algorithm | |
*/ | |
private Optional<String> extractProposedAlgorithm(String message) { | |
String[] parts = message.split("="); | |
if (parts.length > 1) { | |
return Optional.of(parts[1]); | |
} | |
return Optional.empty(); | |
} | |
/** | |
* Set the logging level for JCraft | |
* | |
* @param jsch the JSch instance | |
*/ | |
private void configureJschLogging(JSch jsch) { | |
jsch.getInstanceLogger().isEnabled(Logger.getRootLogger().getLevel().toInt()); | |
JSch.setLogger(new com.jcraft.jsch.Logger() { | |
public boolean isEnabled(int level) { | |
return verbose; | |
} | |
public void log(int level, String message) { | |
if (level == com.jcraft.jsch.Logger.DEBUG) { | |
logger.debug("%s".formatted(message)); | |
} else { | |
logger.info("%s".formatted(message)); | |
} | |
} | |
}); | |
} | |
/** | |
* Test connectivity to the server on the specified port | |
* | |
* @param server the server to connect to | |
*/ | |
private boolean testConnectivity(String server) { | |
try (Socket socket = new Socket()) { | |
socket.connect(new InetSocketAddress(server, port), TIMEOUT); | |
logger.info("Connectivity check successful for %s on port %d".formatted(server, port)); | |
return true; | |
} catch (IOException e) { | |
logger.error("%s%s%s%d ".formatted(CONNECTIVITY_CHECK_FAILED_FOR, server, ON_PORT, port)); | |
failedServers.add(server); | |
} | |
return false; | |
} | |
/** | |
* Test connectivity to the server using JCraft | |
* | |
* @param server the server to connect to | |
*/ | |
private void testConnectivityJCraft(String server) { | |
logger.info("Attempting to log into server: %s".formatted(server)); | |
Session session = createJSchSession(server); | |
if (session != null) { | |
logger.info("Successfully logged into server: %s".formatted(server)); | |
session.disconnect(); | |
failedServers.remove(server); | |
} else { | |
logger.error("Failed to log into server: %s".formatted(server)); | |
} | |
} | |
/** | |
* A custom implementation of the UserInfo interface | |
*/ | |
static class MyUserInfo implements com.jcraft.jsch.UserInfo, com.jcraft.jsch.UIKeyboardInteractive { | |
private final char[] password; | |
private final Logger logger = Logger.getLogger(MyUserInfo.class); | |
public MyUserInfo(char[] password) { | |
this.password = password; | |
// Remove the console appender so it duplicates the logging output | |
Logger.getRootLogger().removeAppender("console"); | |
} | |
@Override | |
public String getPassword() { | |
return new String(password); | |
} | |
@Override | |
public boolean promptYesNo(String str) { | |
return true; | |
} | |
@Override | |
public String getPassphrase() { | |
return null; | |
} | |
@Override | |
public boolean promptPassphrase(String message) { | |
return true; | |
} | |
@Override | |
public boolean promptPassword(String message) { | |
return true; | |
} | |
@Override | |
public void showMessage(String message) { | |
logger.info("Message: %s".formatted(message)); | |
} | |
@Override | |
public String[] promptKeyboardInteractive(String destination, String name, String instruction, String[] prompt, boolean[] echo) { | |
String[] response = new String[prompt.length]; | |
for (int i = 0; i < prompt.length; i++) { | |
logger.info("Prompt: %s".formatted(prompt[i])); | |
if (prompt[i].toLowerCase().contains("password")) { | |
response[i] = new String(password); | |
} | |
} | |
return response; | |
} | |
} | |
public static class AuthenticationGroup { | |
@Option(names = {"-p", "--password"}, interactive = true, description = "The password to use to connect to the server", arity = "0..1") | |
char[] password; | |
@Option(names = {"-k", "--private-key"}, description = "The private key to use to connect to the server") | |
Path privateKey; | |
} | |
public static class AwsOptions { | |
@Option(names = {"-r", "--region"}, description = "The AWS Region where the server is located") | |
private String region; | |
@Option(names = {"-c", "--code"}, description = "The IAM Role ARN that the SSM Agent will assume to communicate with the AWS Systems Manager Service") | |
private String code; | |
@Option(names = {"-i", "--id"}, description = "The IAM Role ARN that the SSM Agent will assume to communicate with the AWS Systems Manager Service") | |
private String id; | |
} | |
/** | |
* A custom appender to colorize the log output | |
*/ | |
public class ColorConsoleAppender extends ConsoleAppender { | |
private static final String ESC_START = "\u001B["; | |
private static final String ESC_END = "\u001B[0m"; | |
private static final String RED = "31m"; | |
private static final String GREEN = "32m"; | |
private static final String BLUE = "34m"; | |
@Override | |
protected void subAppend(LoggingEvent event) { | |
switch (event.getLevel().toInt()) { | |
case Level.ERROR_INT: | |
this.qw.write(ESC_START + RED); | |
break; | |
case Priority.INFO_INT: | |
this.qw.write(ESC_START + GREEN); | |
break; | |
case Level.DEBUG_INT: | |
this.qw.write(ESC_START + BLUE); | |
break; | |
default: | |
break; | |
} | |
super.subAppend(event); | |
this.qw.write(ESC_END); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment