Skip to content

Instantly share code, notes, and snippets.

@dovidkopel
Last active November 22, 2016 16:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save dovidkopel/7618e6ea5ccebbb0adfb to your computer and use it in GitHub Desktop.
Save dovidkopel/7618e6ea5ccebbb0adfb to your computer and use it in GitHub Desktop.
SCP Integration to Spring application
import org.apache.sshd.common.scp.ScpHelper;
import org.apache.sshd.common.scp.ScpTransferEventListener;
import org.apache.sshd.common.util.GenericUtils;
import org.apache.sshd.server.Command;
import org.apache.sshd.server.scp.ScpCommand;
import org.apache.sshd.server.session.ServerSession;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Scope;
import org.springframework.stereotype.Component;
import java.io.IOException;
import java.util.concurrent.ExecutorService;
@Component
@Scope("prototype")
public class ReceivingScpCommand extends ScpCommand implements Command {
public ReceivingScpCommand(String command, ExecutorService executorService, boolean shutdownOnExit, int sendSize, int receiveSize, ScpTransferEventListener eventListener) {
super(command, executorService, shutdownOnExit, sendSize, receiveSize, eventListener);
}
private BeanFactory beanFactory;
@Autowired
public ReceivingScpCommand setBeanFactory(BeanFactory beanFactory) {
this.beanFactory = beanFactory;
return this;
}
@Override
public void run() {
int exitValue = ScpHelper.OK;
String exitMessage = null;
ScpHelper helper = beanFactory.getBean(ReceivingScpHelper.class, getServerSession(), in, out, fileSystem, listener);
try {
if (optT) {
helper.receive(helper.resolveLocalPath(path), optR, optD, optP, receiveBufferSize);
} else {
throw new IOException("Unsupported mode");
}
} catch (IOException e) {
ServerSession session = getServerSession();
try {
exitValue = ScpHelper.ERROR;
exitMessage = GenericUtils.trimToEmpty(e.getMessage());
ScpHelper.sendResponseMessage(out, exitValue, exitMessage);
} catch (IOException e2) {
if (log.isDebugEnabled()) {
log.debug("run({})[{}] Failed ({}) to send error response: {}",
session, name, e.getClass().getSimpleName(), e.getMessage());
}
if (log.isTraceEnabled()) {
log.trace("run(" + session + ")[" + name + "] error response failure details", e2);
}
}
if (log.isDebugEnabled()) {
log.debug("run({})[{}] Failed ({}) to run command: {}",
session, name, e.getClass().getSimpleName(), e.getMessage());
}
if (log.isTraceEnabled()) {
log.trace("run(" + session + ")[" + name + "] command execution failure details", e);
}
} finally {
if (callback != null) {
callback.onExit(exitValue, GenericUtils.trimToEmpty(exitMessage));
}
}
}
}
import com.google.common.collect.Sets;
import org.apache.commons.io.FilenameUtils;
import org.apache.sshd.common.scp.ScpHelper;
import org.apache.sshd.common.scp.ScpTargetStreamResolver;
import org.apache.sshd.common.scp.ScpTimestamp;
import org.apache.sshd.common.scp.ScpTransferEventListener;
import org.apache.sshd.common.session.Session;
import org.apache.sshd.common.util.io.IoUtils;
import org.apache.sshd.common.util.io.LimitInputStream;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Scope;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.FileSystem;
import java.nio.file.Path;
import java.nio.file.attribute.PosixFilePermission;
import java.util.Set;
/**
* Created by dkopel on 3/9/16.
*/
@Component
@Scope("prototype")
public class ReceivingScpHelper extends ScpHelper{
public ReceivingScpHelper(Session session, InputStream in, OutputStream out, FileSystem fileSystem, ScpTransferEventListener eventListener) {
super(session, in, out, fileSystem, eventListener);
}
@Override
public void receive(Path local, boolean recursive, boolean shouldBeDir, boolean preserve, int bufferSize) throws IOException {
super.receive(local, recursive, shouldBeDir, preserve, bufferSize);
log.info("Local path is {}", local);
}
@Override
public void receiveStream(String header, ScpTargetStreamResolver resolver, ScpTimestamp time, boolean preserve, int bufferSize) throws IOException {
if (!header.startsWith("C")) {
throw new IOException("receiveStream(" + resolver + ") Expected a C message but got '" + header + "'");
}
if (bufferSize < MIN_RECEIVE_BUFFER_SIZE) {
throw new IOException("receiveStream(" + resolver + ") buffer size (" + bufferSize + ") below minimum (" + MIN_RECEIVE_BUFFER_SIZE + ")");
}
Set<PosixFilePermission> perms = parseOctalPermissions(header.substring(1, 5));
final long length = Long.parseLong(header.substring(6, header.indexOf(' ', 6)));
String name = header.substring(header.indexOf(' ', 6) + 1);
log.info("Header {}", header);
if (length < 0L) { // TODO consider throwing an exception...
log.warn("receiveStream({})[{}] bad length in header: {}", this, resolver, header);
}
// if file size is less than buffer size allocate only expected file size
int bufSize;
if (length == 0L) {
if (log.isDebugEnabled()) {
log.debug("receiveStream({})[{}] zero file size (perhaps special file) using copy buffer size={}",
this, resolver, MIN_RECEIVE_BUFFER_SIZE);
}
bufSize = MIN_RECEIVE_BUFFER_SIZE;
} else {
bufSize = (int) Math.min(length, bufferSize);
}
if (bufSize < 0) { // TODO consider throwing an exception
log.warn("receiveStream({})[{}] bad buffer size ({}) using default ({})",
this, resolver, bufSize, MIN_RECEIVE_BUFFER_SIZE);
bufSize = MIN_RECEIVE_BUFFER_SIZE;
}
try (
InputStream is = new LimitInputStream(this.in, length);
ByteArrayOutputStream os = new ByteArrayOutputStream();
) {
ack();
Path file = resolver.getEventListenerFilePath();
try {
listener.startFileEvent(ScpTransferEventListener.FileOperation.RECEIVE, file, length, perms);
IoUtils.copy(is, os, bufSize);
ingestData(os.toByteArray(), file, name);
listener.endFileEvent(ScpTransferEventListener.FileOperation.RECEIVE, file, length, perms, null);
} catch (IOException | RuntimeException e) {
listener.endFileEvent(ScpTransferEventListener.FileOperation.RECEIVE, file, length, perms, e);
throw e;
}
}
ack();
readAck(false);
}
private void ingestData(byte[] data, Path path, String file) {
log.info("Here is the file {} with the path {} from the user {} which is {} bytes", file, path.toAbsolutePath(), getSession().getUsername(), data.length);
if(getSession().getProperties().containsKey("authentication")) {
SecurityContextHolder.getContext().setAuthentication(
(Authentication) getSession().getProperties().get("authentication")
);
// Do stuff with your file here!
//
// service.process(data);
}
}
}
import org.apache.sshd.common.file.virtualfs.VirtualFileSystemFactory;
import org.apache.sshd.common.scp.ScpTransferEventListener;
import org.apache.sshd.common.util.EventListenerUtils;
import org.apache.sshd.common.util.threads.ExecutorServiceConfigurer;
import org.apache.sshd.server.Command;
import org.apache.sshd.server.CommandFactory;
import org.apache.sshd.server.SshServer;
import org.apache.sshd.server.auth.password.PasswordAuthenticator;
import org.apache.sshd.server.auth.password.PasswordChangeRequiredException;
import org.apache.sshd.server.keyprovider.SimpleGeneratorHostKeyProvider;
import org.apache.sshd.server.session.ServerSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Service;
import javax.annotation.PostConstruct;
import java.io.IOException;
import java.nio.file.Paths;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.util.Collection;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.ExecutorService;
@Service
public class ScpService {
private SshServer sshd;
private BeanFactory beanFactory;
private UserAuthenticationService userAuthenticationService;
@Value("${scp.port")
private Integer port;
private Logger logger = LoggerFactory.getLogger(getClass());
@Autowired
public ScpService setBeanFactory(BeanFactory beanFactory) {
this.beanFactory = beanFactory;
return this;
}
@Autowired
public ScpService setUserAuthenticationService(UserAuthenticationService userAuthenticationService) {
this.userAuthenticationService = userAuthenticationService;
return this;
}
@PostConstruct
public void init() throws NoSuchProviderException, NoSuchAlgorithmException {
ScpReceivingCommandFactory commandFactory = new ScpReceivingCommandFactory();
VirtualFileSystemFactory fileSystemFactory = new VirtualFileSystemFactory();
fileSystemFactory.setDefaultHomeDir(Paths.get("/tmp"));
sshd = SshServer.setUpDefaultServer();
sshd.setCommandFactory(commandFactory);
sshd.setFileSystemFactory(fileSystemFactory);
SimpleGeneratorHostKeyProvider keyProvider = new SimpleGeneratorHostKeyProvider(Paths.get("scp.key"));
keyProvider.setAlgorithm("RSA");
keyProvider.setKeySize(1024);
sshd.setKeyPairProvider(keyProvider);
sshd.setPasswordAuthenticator(passwordAuthenticator());
sshd.setPort(port);
try {
sshd.start();
} catch (IOException e) {
logger.error("Could not start sshd session.");
}
}
private PasswordAuthenticator passwordAuthenticator() {
return new PasswordAuthenticator() {
@Override
public boolean authenticate(String username, String password, final ServerSession serverSession) throws PasswordChangeRequiredException {
if(userAuthenticationService.authenticate(username, password)) {
User user = userAuthenticationService.getUser(username);
Authentication authentication = new UsernamePasswordAuthenticationToken(user, password, user.getAuthorities());
SecurityContextHolder.getContext().setAuthentication(authentication);
serverSession.getProperties().put("authentication", authentication);
logger.info("Authentication successful for username {}", username);
return true;
}
logger.info("Authentication failed for username {}", username);
return false;
}
};
}
class ScpReceivingCommandFactory implements CommandFactory, ExecutorServiceConfigurer {
private CommandFactory delegate;
private ExecutorService executors;
private boolean shutdownExecutor;
private int sendBufferSize = 127;
private int receiveBufferSize = 127;
private Collection<ScpTransferEventListener> listeners = new CopyOnWriteArraySet();
private ScpTransferEventListener listenerProxy;
public ScpReceivingCommandFactory() {
this.listenerProxy = (ScpTransferEventListener) EventListenerUtils.proxyWrapper(ScpTransferEventListener.class, this.getClass().getClassLoader(), this.listeners);
}
public CommandFactory getDelegateCommandFactory() {
return this.delegate;
}
public ExecutorService getExecutorService() {
return this.executors;
}
public void setExecutorService(ExecutorService service) {
this.executors = service;
}
public boolean isShutdownOnExit() {
return this.shutdownExecutor;
}
public void setShutdownOnExit(boolean shutdown) {
this.shutdownExecutor = shutdown;
}
public int getSendBufferSize() {
return this.sendBufferSize;
}
public int getReceiveBufferSize() {
return this.receiveBufferSize;
}
public Command createCommand(String command) {
return beanFactory.getBean(Command.class, command, this.getExecutorService(), this.isShutdownOnExit(), this.getSendBufferSize(), this.getReceiveBufferSize(), this.listenerProxy);
}
}
}
import org.springframework.security.core.userdetails.User;
public interface UserAuthenticationService {
boolean authenticate(String username, String password);
User getUser(String username);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment