Skip to content

Instantly share code, notes, and snippets.

@fintanmm
Last active July 10, 2024 10:57
Show Gist options
  • Save fintanmm/025b8546637590e7bedf9ed1d0a48701 to your computer and use it in GitHub Desktop.
Save fintanmm/025b8546637590e7bedf9ed1d0a48701 to your computer and use it in GitHub Desktop.
Installs and configures AWS SSM Agent for On-Premise Linux Servers using SSH and JBang
///usr/bin/env jbang "$0" "$@" ; exit $?
//DEPS ch.qos.reload4j:reload4j:1.2.25
//DEPS info.picocli:picocli:4.7.6
//DEPS info.picocli:picocli-codegen:4.7.6
//DEPS com.github.mwiede:jsch:0.2.18
//DEPS com.google.code.gson:gson:2.8.9
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.jcraft.jsch.ChannelExec;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.JSchException;
import com.jcraft.jsch.Session;
import org.apache.log4j.*;
import org.apache.log4j.spi.LoggingEvent;
import picocli.CommandLine;
import picocli.CommandLine.ArgGroup;
import picocli.CommandLine.Command;
import picocli.CommandLine.Option;
import java.io.IOException;
import java.io.InputStream;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.nio.file.Path;
import java.util.HashSet;
import java.util.Optional;
import java.util.Scanner;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@Command(name = "SsmAgentInstall", mixinStandardHelpOptions = true, version = "SsmAgentInstall 0.5", description = "Installs and configures AWS SSM Agent for On-Premise Linux Servers using SSH.")
public class SsmAgentInstall implements Runnable {
public static final String ON_PORT = " on port ";
public static final int TIMEOUT = 30000;
public static final String UBUNTU = "Ubuntu";
public static final String DEBIAN = "Debian";
public static final String CONNECTIVITY_CHECK_FAILED_FOR = "Connectivity check failed for ";
public static final String FAILED_TO_INSTALL_THE_SSM_AGENT = "Failed to install the SSM Agent";
private static final String FAILED_TO_CONNECT_TO_SERVER = "Failed to connect to server: %s";
private static final String FAILED_TO_ENABLE_THE_SSM_AGENT = "Failed to enable the SSM Agent";
private static final String FAILED_TO_RESTART_THE_SSM_AGENT = "Failed to restart the SSM Agent";
@Option(names = {"-l", "--log-file"}, description = "The log file to write the output to")
private static Path logFile;
private final Set<String> failedServers = new HashSet<>();
private final Logger logger = Logger.getLogger(SsmAgentInstall.class);
@ArgGroup(exclusive = false, multiplicity = "1")
AuthenticationGroup authenticationGroup;
@ArgGroup(exclusive = false, multiplicity = "0..1")
AwsOptions awsOptions;
@Option(names = {"-s", "--servers"}, description = "A list seperated by commas of On-Premise Servers that the SSM Agent will be installed and configured on", split = ",", required = true)
private String[] servers;
@Option(names = {"-P", "--port"}, description = "The port to use to connect to the server", defaultValue = "22")
private int port;
@Option(names = {"-v", "--verbose"}, description = "Be verbose.", defaultValue = "false")
private boolean verbose;
@Option(names = {"-t", "--test-connectivity"}, description = "Test connectivity to each server but do not install", defaultValue = "false")
private boolean testConnectivity;
@Option(names = {"-u", "--username"}, description = "The username to use to connect to the server")
private String username;
@Option(names = {"-S", "--strict"}, description = "Don't use strict host key checking", defaultValue = "true")
private boolean strict;
@Option(names = {"-D", "--sudo"}, description = "Use sudo to install the SSM Agent", defaultValue = "false")
private boolean sudo;
@Option(names = {"-T", "--threads"}, description = "The number of threads to use to install the SSM Agent, defaults to 10", defaultValue = "10")
private int threads;
@Option(names = {"-f", "--force"}, description = "Force the re-registration of a server", defaultValue = "false")
private boolean force;
public static void main(String... args) {
int exitCode = new CommandLine(new SsmAgentInstall()).execute(args);
System.exit(exitCode);
}
/**
* Sleep the current thread for the specified delay
*
* @param delay the delay in milliseconds
*/
private static void thread(int delay) {
try {
Thread.sleep(delay);
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
}
}
@Override
public void run() {
// Create a new color console appender
ColorConsoleAppender colorAppender = new ColorConsoleAppender();
colorAppender.setLayout(new PatternLayout("%d{yyyy-MM-dd HH:mm:ss} %-5p [%c{1}] %m%n"));
colorAppender.activateOptions();
// Get the root logger and add your appender to it
Logger rootLogger = Logger.getRootLogger();
rootLogger.addAppender(colorAppender);
BasicConfigurator.configure(colorAppender);
if (logFile != null) {
try {
FileAppender fileAppender = new FileAppender(new PatternLayout("%d{yyyy-MM-dd HH:mm:ss} %-5p [%c{1}] %m%n"), logFile.toAbsolutePath().toString());
BasicConfigurator.configure(fileAppender);
} catch (IOException e) {
logger.error("Failed to create log file", e);
System.exit(1);
}
}
if (verbose) {
logger.setLevel(Level.DEBUG);
logger.info("Verbose output enabled");
logger.info("Installing and configuring the SSM Agent on On-Premise Servers");
logger.info("AWS Region: %s".formatted(awsOptions.region));
logger.debug("Code: %s".formatted(awsOptions.code));
logger.debug("ID: %s".formatted(awsOptions.id));
}
checkAndInstallSSMAgent(servers);
}
/**
* Check connectivity to the servers and install the SSM Agent
*
* @param servers the servers to connect to
*/
private void checkAndInstallSSMAgent(String[] servers) {
try (ExecutorService executor = Executors.newFixedThreadPool(threads)) {
for (String server : servers) {
executor.submit(() -> {
String s = server.trim();
if (testConnectivity) {
if (testConnectivity(s)) {
testConnectivityJCraft(s);
}
return;
}
Session session = createSession(s);
if (session != null) {
if (!testConnectivity) {
installSSMAgent(session);
}
session.disconnect();
} else {
failedServers.add(s);
}
});
}
} catch (Exception e) {
logger.error(FAILED_TO_INSTALL_THE_SSM_AGENT, e);
}
if (!failedServers.isEmpty()) {
logger.error("Failed to process the following servers: %s".formatted(String.join(", ", failedServers)));
}
}
/**
* Download the SSM Agent from the AWS S3 bucket
*
* @param withSession the withSession to the server
*/
private void installSSMAgent(Session withSession) {
String forDistro = checkLinuxVersion(withSession);
String andArch = checkLinuxArch(withSession);
String url = "";
if (forDistro == null || andArch == null) {
logger.error("Failed to determine the Linux distribution or Architecture: %s".formatted(withSession.getHost()));
failedServers.add(withSession.getHost());
Thread.currentThread().interrupt();
return;
}
url = getDownloadUrl(forDistro, andArch);
if (url == null) {
logger.error("Failed to determine the Linux distribution or Architecture: %s %s".formatted(forDistro, andArch));
failedServers.add(withSession.getHost());
return;
}
String fileName = forDistro.contains(UBUNTU) || forDistro.contains(DEBIAN) ? "amazon-ssm-agent.deb"
: "amazon-ssm-agent.rpm";
if (verbose) {
logger.info("Downloading %s from %s".formatted(fileName, url));
}
downloadFile(withSession, url, fileName);
install(fileName, withSession, forDistro);
}
/**
* Get the download URL for the SSM Agent
*
* @param distro the Linux distribution
* @param arch the Linux architecture
* @return String the download URL
*/
private String getDownloadUrl(String distro, String arch) {
arch = arch.equals("x86_64") ? "amd64" : "arm64";
String url = "https://s3.amazonaws.com/ec2-downloads-windows/SSMAgent/latest/";
if (distro.contains(UBUNTU) || distro.contains(DEBIAN)) {
return url + "debian_" + arch + "/amazon-ssm-agent.deb";
} else if (distro.contains("CentOS") || distro.contains("RHEL") || distro.contains("Fedora")
|| distro.contains("SUSE")) {
return url + "linux_" + arch + "/amazon-ssm-agent.deb";
}
// handle other distributions
return null;
}
/**
* Download the SSM Agent from the specified URL using curl
*
* @param session the session to the server
* @param url the URL to download the SSM Agent from
* @param fileName the name of the file to download
*/
private void downloadFile(Session session, String url, String fileName) {
try {
// Check if the file already exists
String checkFileExistsCommand = String.format("test -f /tmp/%s", fileName);
ChannelExec checkFileChannel = (ChannelExec) session.openChannel("exec");
checkFileChannel.setCommand(checkFileExistsCommand);
checkFileChannel.connect();
// Wait for the command to finish
while (!checkFileChannel.isClosed()) {
Thread.sleep(100);
}
// If the file exists, skip downloading
if (checkFileChannel.getExitStatus() == 0) {
logger.info(String.format("File %s already exists. Skipping download.", fileName));
checkFileChannel.disconnect();
return;
}
checkFileChannel.disconnect();
// File does not exist, proceed with download
logger.info("Downloading the file: " + fileName);
String command = String.format("curl -o /tmp/%s %s", fileName, url);
ChannelExec channelExec = (ChannelExec) session.openChannel("exec");
channelExec.setCommand(command);
channelExec.connect();
while (!channelExec.isClosed()) {
Thread.sleep(100);
}
if (channelExec.getExitStatus() != 0) {
logger.error("Failed to download the file: " + channelExec.getExitStatus());
failedServers.add(session.getHost());
return;
}
logger.info("File downloaded successfully: " + fileName);
channelExec.disconnect();
} catch (JSchException | InterruptedException e) {
logger.error("Failed to download the file", e);
failedServers.add(session.getHost());
Thread.currentThread().interrupt();
}
}
/**
* Check the Linux architecture of the server
*
* @param session the session to the server
* @return String the Linux architecture
*/
private String checkLinuxArch(Session session) {
StringBuilder arch = new StringBuilder();
try {
ChannelExec channel = (ChannelExec) session.openChannel("exec");
channel.setCommand("uname -m");
// Get the input stream before connecting
InputStream in = channel.getInputStream();
channel.connect();
// Read the command output
arch.append(new String(in.readAllBytes()).trim());
} catch (JSchException | IOException e) {
logger.error("Failed to check the Linux architecture for: %s".formatted(session.getHost()));
failedServers.add(session.getHost());
Thread.currentThread().interrupt();
return null;
}
logger.info("Checking the Linux architecture for %s: %s".formatted(session.getHost(), arch.toString().trim()));
return arch.toString().trim();
}
/**
* Check the Linux distribution of the server
*
* @param session the session to the server
* @return String the Linux distribution
*/
private String checkLinuxVersion(Session session) {
try {
ChannelExec channel = (ChannelExec) session.openChannel("exec");
channel.setCommand("cat /etc/os-release");
// Get the input stream before connecting
InputStream in = channel.getInputStream();
channel.connect();
// Use a Scanner to read the command output
try (Scanner scanner = new Scanner(in)) {
while (scanner.hasNextLine()) {
String line = scanner.nextLine();
if (line.startsWith("NAME=")) {
String distro = line.split("=")[1].replace("\"", "").trim();
logger.info("Checking the Linux distribution for %s: %s".formatted(session.getHost(), distro));
return distro;
}
}
}
failedServers.add(session.getHost());
return null;
} catch (JSchException | IOException e) {
logger.error("Failed to check the Linux version for: %s".formatted(session.getHost()));
failedServers.add(session.getHost());
return null;
}
}
/**
* Install the SSM Agent on the server
*
* @param fileName the name of the file to install
* @param session the session to the server
* @param distro the Linux distribution
*/
private void install(String fileName, Session session, String distro) {
// if the password is null use the picocli library to prompt for the password
// if (authenticationGroup.password == null) {
// CommandLine cmd = new CommandLine(new SsmAgentInstall());
// cmd.setExecutionStrategy(new CommandLine.RunLast());
// cmd.execute("password");
// if (authenticationGroup.password == null) {
// logger.error("\u001B[31mFailed to install the SSM Agent: No password provided\u001B[0m");
// return;
// }
// }
String password = new String(authenticationGroup.password); // Assuming authenticationGroup is accessible here
try {
ChannelExec channel = (ChannelExec) session.openChannel("exec");
if (distro.contains(UBUNTU) || distro.contains(DEBIAN)) {
channel.setCommand("echo \"%s\" | sudo -S dpkg -i /tmp/%s".formatted(password, fileName));
} else if (distro.contains("CentOS") || distro.contains("RHEL") || distro.contains("Fedora")
|| distro.contains("SUSE")) {
channel.setCommand("echo \"%s\" | sudo -S rpm -i /tmp/%s".formatted(password, fileName));
}
channel.connect();
// Wait for the command to finish
while (!channel.isClosed()) {
Thread.sleep(100);
}
// Check the exit status
if (channel.getExitStatus() != 0) {
logger.error((FAILED_TO_INSTALL_THE_SSM_AGENT + ": %d").formatted(channel.getExitStatus()));
failedServers.add(session.getHost());
throw new JSchException(FAILED_TO_INSTALL_THE_SSM_AGENT);
}
channel.disconnect();
enableAndRestartSSMAgentService(session);
configureSSMAgent(session);
} catch (JSchException e) {
logger.error(FAILED_TO_INSTALL_THE_SSM_AGENT, e);
failedServers.add(session.getHost());
session.disconnect();
} catch (InterruptedException e) {
logger.error(FAILED_TO_INSTALL_THE_SSM_AGENT, e);
failedServers.add(session.getHost());
session.disconnect();
Thread.currentThread().interrupt();
} finally {
session.disconnect();
}
}
/**
* Restart the SSM Agent on the server
*
* @param session the session to the server
*/
private void enableAndRestartSSMAgentService(Session session) {
enableSsmAgentService(session);
restartSsmAgentService(session);
}
/**
* Enable the SSM Agent on the server
*
* @param session the session to the server
*/
private void enableSsmAgentService(Session session) {
String password = new String(authenticationGroup.password);
try {
ChannelExec channel = (ChannelExec) session.openChannel("exec");
channel.setCommand("echo \"%s\" | sudo -S systemctl enable amazon-ssm-agent.service".formatted(password));
channel.connect();
// Wait for the command to finish
while (!channel.isClosed()) {
Thread.sleep(100);
}
if (channel.getExitStatus() != 0) {
logger.error((FAILED_TO_ENABLE_THE_SSM_AGENT + ": %d").formatted(channel.getExitStatus()));
failedServers.add(session.getHost());
throw new JSchException(FAILED_TO_ENABLE_THE_SSM_AGENT);
}
channel.disconnect();
logger.info("SSM Agent enabled successfully");
} catch (JSchException | InterruptedException e) {
logger.error(FAILED_TO_ENABLE_THE_SSM_AGENT, e);
failedServers.add(session.getHost());
return;
}
}
/**
* Restart the SSM Agent on the server
*
* @param session the session to the server
*/
private void restartSsmAgentService(Session session) {
String password = new String(authenticationGroup.password);
try {
ChannelExec channel = (ChannelExec) session.openChannel("exec");
channel.setCommand("echo \"%s\" | sudo -S systemctl restart amazon-ssm-agent.service".formatted(password));
channel.connect();
// Wait for the command to finish
while (!channel.isClosed()) {
Thread.sleep(100);
}
if (channel.getExitStatus() != 0) {
logger.error((FAILED_TO_RESTART_THE_SSM_AGENT + ": %d").formatted(channel.getExitStatus()));
failedServers.add(session.getHost());
throw new JSchException(FAILED_TO_RESTART_THE_SSM_AGENT);
}
channel.disconnect();
logger.info("SSM Agent restarted successfully");
} catch (JSchException | InterruptedException e) {
logger.error(FAILED_TO_RESTART_THE_SSM_AGENT, e);
failedServers.add(session.getHost());
}
}
/**
* Configure the SSM Agent to communicate with the AWS Systems Manager Service
*
* @param session the session to the server
*/
private void configureSSMAgent(Session session) {
String password = new String(authenticationGroup.password);
try {
ChannelExec channel = (ChannelExec) session.openChannel("exec");
// Construct the base command for SSM Agent registration
String baseCommand = String.format("sudo -S amazon-ssm-agent -register -code \"%s\" -id \"%s\" -region \"%s\"", awsOptions.code, awsOptions.id, awsOptions.region);
// If force is true, prepend the command to automatically answer "Yes"
String agentCommand = force ? "echo \"Yes\" | %s".formatted(baseCommand) : baseCommand;
// Include password input for sudo
String fullCommand = String.format("echo \"%s\" | %s", password, agentCommand);
channel.setCommand(fullCommand);
// Get the input and output streams
InputStream in = channel.getInputStream();
channel.connect();
// Use a Scanner to read the command output
try (Scanner scanner = new Scanner(in)) {
while (scanner.hasNextLine()) {
String line = scanner.nextLine().trim();
if (verbose) {
logger.info("Command output: %s".formatted(line));
}
if (logSsmAgentStatus(session, line)) return;
if (channel.isClosed()) {
if (scanner.hasNextLine()) continue;
logger.info("exit-status: " + channel.getExitStatus());
break;
}
}
}
channel.disconnect();
verifyDiagnosticsFromSSMAgent(session);
} catch (JSchException | IOException e) {
logger.error("Failed to configure the SSM Agent on %s".formatted(session.getHost()), e);
failedServers.add(session.getHost());
session.disconnect();
Thread.currentThread().interrupt();
}
}
/**
* Log the SSM Agent status
*
* @param session the session to the server
* @param line the command output line
* @return boolean true if the command output contains an error message
*/
private boolean logSsmAgentStatus(Session session, String line) {
// Check if the command line contains the error message
if (line.contains("ActivationExpired:")) {
// Log an error message and add the server to the failedServers list
logger.error("Registration failed due to expired activation code for server: %s".formatted(session.getHost()));
failedServers.add(session.getHost());
return true;
}
if (line.contains("ValidationException:")) {
// Log an error message and add the server to the failedServers list
logger.error("Registration failed due to invalid activation code for server: %s".formatted(session.getHost()));
failedServers.add(session.getHost());
return true;
}
if (line.contains("EOF")) {
// Log an error message and add the server to the failedServers list
logger.error("Registration failed due to EOF: %s".formatted(session.getHost()));
failedServers.add(session.getHost());
return true;
}
if (line.contains("Successfully")) {
logger.info("Registration successful for server: %s".formatted(session.getHost()));
logger.info(line);
}
return false;
}
/**
* Verify the diagnostics from the SSM Agent
*
* @param session the session to the server
*/
private void verifyDiagnosticsFromSSMAgent(Session session) {
String password = new String(authenticationGroup.password);
try {
ChannelExec channel = (ChannelExec) session.openChannel("exec");
String diagnosticsCommand = "echo \"%s\" | sudo -S ssm-cli get-diagnostics --output json".formatted(password);
channel.setCommand(diagnosticsCommand);
// Get the input stream
try (InputStream in = channel.getInputStream()) {
channel.connect();
// Read the command output
String output = new String(in.readAllBytes()).trim();
handleDiagnosticsOutput(output);
}
channel.disconnect();
} catch (JSchException | IOException e) {
logger.error("Failed to run diagnostics command on %s: %s".formatted(session.getHost(), e.getMessage()), e);
failedServers.add(session.getHost());
session.disconnect();
Thread.currentThread().interrupt();
}
}
/**
* Handle the diagnostics output from the SSM Agent
*
* @param output the diagnostics output
*/
private void handleDiagnosticsOutput(String output) {
// Check if the output is not empty and is valid JSON
if (output != null && !output.isEmpty()) {
// Parse the JSON output
JsonObject jsonOutput = JsonParser.parseString(output).getAsJsonObject();
// Check if the JSON output contains the 'DiagnosticsOutput' field
if (jsonOutput.has("DiagnosticsOutput")) {
// Get the 'DiagnosticsOutput' array
JsonArray diagnosticsOutput = jsonOutput.getAsJsonArray("DiagnosticsOutput");
// Iterate over each object in the array
diagnosticsOutput.forEach(this::handleDiagnosticElement);
} else {
logger.error("Diagnostics output does not contain 'DiagnosticsOutput' field");
}
} else {
logger.error("Diagnostics command did not return any output");
}
}
/**
* Handle a diagnostic element
*
* @param element the diagnostic element
*/
private void handleDiagnosticElement(JsonElement element) {
JsonObject diagnostic = element.getAsJsonObject();
// Check if the object contains the 'Check' and 'Status' fields
if (diagnostic.has("Check") && diagnostic.has("Status") && verbose) {
String check = diagnostic.get("Check").getAsString();
String status = diagnostic.get("Status").getAsString();
if (status.equals("Failed")) {
logger.error("Check: %s, Status: %s".formatted(check, status));
} else {
logger.info("Check: %s, Status: %s".formatted(check, status));
}
} else {
String check = diagnostic.get("Check").getAsString();
String status = diagnostic.get("Status").getAsString();
if (!status.equals("Failed") && check.contains("Hybrid")) {
logger.info("Check: %s, Status: %s".formatted(check, status));
} else if (status.equals("Failed") && check.contains("Hybrid")) {
logger.error("Check: %s, Status: %s".formatted(check, status));
}
}
}
/**
* Create a session to the server
*
* @param server the server to connect to
* @return Session the session to the server
*/
private Session createSession(String server) {
Session session = createJSchSession(server);
if (session != null && session.isConnected()) {
logger.info("Connected to server: %s".formatted(server));
} else {
logger.error(FAILED_TO_CONNECT_TO_SERVER.formatted(server));
Thread.currentThread().interrupt();
return null;
}
return session;
}
/**
* Create a JSch session to the server
*
* @param server the server to connect to
* @return Session the JSch session
*/
private Session createJSchSession(String server) {
JSch jsch = new JSch();
Optional<Session> session = Optional.empty();
int maxRetries = 3; // Maximum number of retries
int retries = 0; // Current retry count
int delay = 3000; // Delay in milliseconds
while (retries < maxRetries) {
try {
if (verbose) {
configureJschLogging(jsch);
}
session = Optional.ofNullable(configureSession(jsch, server));
session.get().connect(30000);
return session.orElse(null);
} catch (JSchException e) {
retries++;
session = handleJSchException(e, server, session.orElse(null), retries, maxRetries, delay);
if (session.isEmpty() && retries < maxRetries) {
thread(delay);
}
}
}
return session.orElse(null);
}
private Session configureSession(JSch jsch, String server) throws JSchException {
Session session;
if (!strict) {
JSch.setConfig("StrictHostKeyChecking", "no");
}
session = jsch.getSession(username, server, port);
if (authenticationGroup.password != null) {
session.setConfig("PreferredAuthentications", "password,keyboard-interactive");
MyUserInfo userInfo = new MyUserInfo(authenticationGroup.password);
session.setUserInfo(userInfo);
} else if (authenticationGroup.privateKey != null) {
session.setConfig("PreferredAuthentications", "publickey");
jsch.addIdentity(authenticationGroup.privateKey.toAbsolutePath().toString());
} else {
logger.error("No authentication method provided");
}
return session;
}
/**
* Handle a JSch exception
*
* @param e the JSch exception
* @param server the server
* @param session the session
* @param retries the current retry count
* @param maxRetries the maximum retry count
* @param delay the delay in milliseconds
* @return Session the session
*/
private Optional<Session> handleJSchException(JSchException e, String server, Session session, int retries, int maxRetries, int delay) {
String message = e.getMessage();
if (message.contains("Auth fail")) {
logger.error("Authentication failed for server: %s%n ".formatted(server));
} else if (message.contains("timeout")) {
logger.error("Connection timed out for server: %s%n ".formatted(server));
} else if (message.contains("Too many authentication failures")) {
logger.error("Too many authentication failures for user: %s on server: %s".formatted(username, server));
logger.info("Retrying connection to server: %s, attempt number: %d".formatted(server, retries));
retries++;
delay *= 2;
if (logFailures(server, retries, maxRetries, delay)) return Optional.of(session);
} else if (message.contains("Algorithm negotiation fail")) {
logger.error("Algorithm negotiation failed for server: %s".formatted(server));
logger.info("Retrying connection to server: %s, attempt number: %d".formatted(server, retries));
// Extract the server's proposed algorithm from the error message
Optional<String> proposedAlgorithm = extractProposedAlgorithm(message);
// Add the proposed algorithm to the preferred list
session.setConfig("PreferredAuthentications", proposedAlgorithm.orElse(""));
// Retry the connection
retries = 2;
if (logFailures(server, retries, maxRetries, delay)) return Optional.of(session);
return Optional.of(session);
} else {
retries = 3;
if (retries == maxRetries && session == null) {
logger.error(FAILED_TO_CONNECT_TO_SERVER.formatted(server));
failedServers.add(server);
return Optional.ofNullable(session);
}
}
return Optional.ofNullable(session);
}
/**
* Log the failures and add the server to the failed servers list
*
* @param server the server
* @param retries the current retry count
* @param maxRetries the maximum retry count
* @param delay the delay in milliseconds
* @return boolean true if the connection failed
*/
private boolean logFailures(String server, int retries, int maxRetries, int delay) {
if (retries == maxRetries) {
logger.error(FAILED_TO_CONNECT_TO_SERVER.formatted(server));
logger.info("Adding server to the failed servers list: %s".formatted(server));
failedServers.add(server);
Thread.currentThread().interrupt();
return true;
}
thread(delay);
return false;
}
/**
* Extract the server's proposed algorithm from the error message
*
* @param message the error message
* @return the server's proposed algorithm
*/
private Optional<String> extractProposedAlgorithm(String message) {
String[] parts = message.split("=");
if (parts.length > 1) {
return Optional.of(parts[1]);
}
return Optional.empty();
}
/**
* Set the logging level for JCraft
*
* @param jsch the JSch instance
*/
private void configureJschLogging(JSch jsch) {
jsch.getInstanceLogger().isEnabled(Logger.getRootLogger().getLevel().toInt());
JSch.setLogger(new com.jcraft.jsch.Logger() {
public boolean isEnabled(int level) {
return verbose;
}
public void log(int level, String message) {
if (level == com.jcraft.jsch.Logger.DEBUG) {
logger.debug("%s".formatted(message));
} else {
logger.info("%s".formatted(message));
}
}
});
}
/**
* Test connectivity to the server on the specified port
*
* @param server the server to connect to
*/
private boolean testConnectivity(String server) {
try (Socket socket = new Socket()) {
socket.connect(new InetSocketAddress(server, port), TIMEOUT);
logger.info("Connectivity check successful for %s on port %d".formatted(server, port));
return true;
} catch (IOException e) {
logger.error("%s%s%s%d ".formatted(CONNECTIVITY_CHECK_FAILED_FOR, server, ON_PORT, port));
failedServers.add(server);
}
return false;
}
/**
* Test connectivity to the server using JCraft
*
* @param server the server to connect to
*/
private void testConnectivityJCraft(String server) {
logger.info("Attempting to log into server: %s".formatted(server));
Session session = createJSchSession(server);
if (session != null) {
logger.info("Successfully logged into server: %s".formatted(server));
session.disconnect();
failedServers.remove(server);
} else {
logger.error("Failed to log into server: %s".formatted(server));
}
}
/**
* A custom implementation of the UserInfo interface
*/
static class MyUserInfo implements com.jcraft.jsch.UserInfo, com.jcraft.jsch.UIKeyboardInteractive {
private final char[] password;
private final Logger logger = Logger.getLogger(MyUserInfo.class);
public MyUserInfo(char[] password) {
this.password = password;
// Remove the console appender so it duplicates the logging output
Logger.getRootLogger().removeAppender("console");
}
@Override
public String getPassword() {
return new String(password);
}
@Override
public boolean promptYesNo(String str) {
return true;
}
@Override
public String getPassphrase() {
return null;
}
@Override
public boolean promptPassphrase(String message) {
return true;
}
@Override
public boolean promptPassword(String message) {
return true;
}
@Override
public void showMessage(String message) {
logger.info("Message: %s".formatted(message));
}
@Override
public String[] promptKeyboardInteractive(String destination, String name, String instruction, String[] prompt, boolean[] echo) {
String[] response = new String[prompt.length];
for (int i = 0; i < prompt.length; i++) {
logger.info("Prompt: %s".formatted(prompt[i]));
if (prompt[i].toLowerCase().contains("password")) {
response[i] = new String(password);
}
}
return response;
}
}
public static class AuthenticationGroup {
@Option(names = {"-p", "--password"}, interactive = true, description = "The password to use to connect to the server", arity = "0..1")
char[] password;
@Option(names = {"-k", "--private-key"}, description = "The private key to use to connect to the server")
Path privateKey;
}
public static class AwsOptions {
@Option(names = {"-r", "--region"}, description = "The AWS Region where the server is located")
private String region;
@Option(names = {"-c", "--code"}, description = "The IAM Role ARN that the SSM Agent will assume to communicate with the AWS Systems Manager Service")
private String code;
@Option(names = {"-i", "--id"}, description = "The IAM Role ARN that the SSM Agent will assume to communicate with the AWS Systems Manager Service")
private String id;
}
/**
* A custom appender to colorize the log output
*/
public class ColorConsoleAppender extends ConsoleAppender {
private static final String ESC_START = "\u001B[";
private static final String ESC_END = "\u001B[0m";
private static final String RED = "31m";
private static final String GREEN = "32m";
private static final String BLUE = "34m";
@Override
protected void subAppend(LoggingEvent event) {
switch (event.getLevel().toInt()) {
case Level.ERROR_INT:
this.qw.write(ESC_START + RED);
break;
case Priority.INFO_INT:
this.qw.write(ESC_START + GREEN);
break;
case Level.DEBUG_INT:
this.qw.write(ESC_START + BLUE);
break;
default:
break;
}
super.subAppend(event);
this.qw.write(ESC_END);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment