Skip to content

Instantly share code, notes, and snippets.

@iseki0
Last active March 21, 2023 17:38
Show Gist options
  • Save iseki0/ef9fd5e4f9529ce0f6bde859fc602cca to your computer and use it in GitHub Desktop.
Save iseki0/ef9fd5e4f9529ce0f6bde859fc602cca to your computer and use it in GitHub Desktop.
SOCKS5 SocketFactory simple implementation
import java.io.EOFException
import java.io.InputStream
import java.io.OutputStream
import java.lang.IllegalArgumentException
import java.net.Inet4Address
import java.net.Inet6Address
import java.net.InetAddress
import java.net.Socket
import javax.net.SocketFactory
internal interface Socks5AuthHandler {
fun handle(socket: Socket)
val id: Int
}
internal object NoAuthHandler : Socks5AuthHandler {
override fun handle(socket: Socket) {
// No-OP
}
override val id: Int
get() = 0
}
private const val RFC1929_EOF_MESSAGE = "Unexpected EOF during username/password authentication"
/**
* Socks5 Username/Password authentication handler
*
* Since the RFC doesn't specify the charset and encoding,
* the username and password string will be encoded in ISO-8859-1.
*
* The RFC specified the length of username and password should in 1..255.
* But for compatibility reason, allow zero length is suitable. (Such as cURL & wget allow it)
*
* Spec: [RFC1929](https://www.rfc-editor.org/rfc/rfc1929)
*
* @param[username] the username
* @param[password] the password
*/
internal class UsernamePasswordAuthHandler(
private val username: String,
private val password: String,
) : Socks5AuthHandler {
init {
require(username.length in 0..255)
require(password.length in 0..255)
}
override fun handle(socket: Socket) {
val output = socket.getOutputStream()!!
val input = socket.getInputStream()!!
output.write(1) // VER
output.write(username.length)
output.write(username.toByteArray(Charsets.ISO_8859_1))
output.write(password.length)
output.write(password.toByteArray(Charsets.ISO_8859_1))
output.flush()
// safe to ignore reply version
if (input.read() == -1) {
throw EOFException(RFC1929_EOF_MESSAGE)
}
when (val status = input.read()) {
-1 -> throw EOFException(RFC1929_EOF_MESSAGE)
0 -> return // authentication successful
else -> throw Socks5UsernamePasswordAuthenticationException(status)
}
}
override val id: Int
get() = 2
}
class Socks5UsernamePasswordAuthenticationException(val status: Int) :
Socks5AuthenticationException("Socks5 username/password authentication failed with status $status")
private const val EOF_NEGOTIATION_MESSAGE = "Unexpected EOF during SOCKS5 negotiation"
private const val EOF_CONNECT_MESSAGE = "Unexpected EOF during SOCKS5 connection"
private const val EOF_DECODE_BND_MESSAGE = "Unexpected EOF during SOCKS5 decode BND"
open class Socks5Exception(message: String? = null) : RuntimeException(message)
open class Socks5AuthenticationException(override val message: String) : Socks5Exception()
/**
* Reference: [RFC1928](https://www.rfc-editor.org/rfc/rfc1928)
*/
class Socks5SocketFactory(
private val underlyingFactory: SocketFactory,
private val socks5Host: String,
private val socks5Port: Int,
private val socks5Username: String,
private val socks5Password: String,
) : SocketFactory() {
companion object {
private val defaultCharset = Charsets.ISO_8859_1
}
private val authenticationHandlers = buildList {
if (socks5Username.isNotEmpty() || socks5Password.isNotEmpty()) {
add(UsernamePasswordAuthHandler(socks5Username, socks5Password))
}
add(NoAuthHandler)
}
override fun createSocket(host: String, port: Int): Socket = connectToServer(null, host, port)
override fun createSocket(host: String?, port: Int, localHost: InetAddress?, localPort: Int): Socket {
throw UnsupportedOperationException("create socket and bind local host and port is not supported")
}
override fun createSocket(host: InetAddress, port: Int): Socket = connectToServer(host, null, port)
override fun createSocket(address: InetAddress?, port: Int, localAddress: InetAddress?, localPort: Int): Socket {
throw UnsupportedOperationException("create socket and bind local host and port is not supported")
}
private fun writeHostConn(outputStream: OutputStream, host: String) {
require(host.length <= 255)
with(outputStream) {
write(3)
val d = host.toByteArray(defaultCharset)
write(d.size)
write(d)
}
}
private fun writeInetAddressConn(outputStream: OutputStream, address: InetAddress) {
with(outputStream) {
when (address) {
is Inet4Address -> {
write(1)
address.address.also { write(it) }
}
is Inet6Address -> {
write(4)
address.address.also { write(it) }
}
else -> IllegalArgumentException("Unsupported InetAddress")
}
}
}
private fun authNegotiation(socket: Socket) {
with(socket.getOutputStream()) {
write(5) // socks version: 5
write(authenticationHandlers.size)
authenticationHandlers.forEach {
write(it.id)
}
flush()
}
with(socket.getInputStream()) {
if (read() == -1) throw EOFException(EOF_NEGOTIATION_MESSAGE) // safe to ignore version code
val handler = when (val chosenMethod = read()) {
-1 -> throw EOFException(EOF_NEGOTIATION_MESSAGE)
0xff -> throw Socks5AuthenticationException("No accepted authentication method")
else -> authenticationHandlers.find { it.id == chosenMethod }
?: throw Socks5AuthenticationException("Server chosen a unknown authentication method $chosenMethod")
}
handler.handle(socket)
}
}
private fun decodeIP(inputStream: InputStream, len: Int) = try {
inputStream.readNBytes(len).let { InetAddress.getByAddress(it) }
} catch (ex: EOFException) {
throw EOFException(EOF_DECODE_BND_MESSAGE)
}
private fun decodeHost(inputStream: InputStream) = try {
val len = inputStream.read()
if (len == -1) throw EOFException()
inputStream.readNBytes(len).let { it.toString(defaultCharset) }
} catch (ex: EOFException) {
throw EOFException(EOF_DECODE_BND_MESSAGE)
}
private fun connectToServer(address: InetAddress? = null, host: String? = null, port: Int): Socket {
val socket = underlyingFactory.createSocket(socks5Host, socks5Port)!!
try {
authNegotiation(socket)
// send request
with(socket.getOutputStream()) {
write(5) // VER
write(1) // CONNECT
write(0)
// ADDRESS
if (address != null) {
writeInetAddressConn(this, address)
} else {
writeHostConn(this, host!!)
}
// PORT
write(port ushr 8)
write(port and 0xff)
}
with(socket.getInputStream()) {
if (read() == -1) throw EOFException(EOF_CONNECT_MESSAGE) // ignore version
val rep = read()
if (rep == -1) throw EOFException(EOF_CONNECT_MESSAGE)
if (rep > 0) throw Socks5Exception(errorMessage(rep))
if (read() == -1) throw EOFException(EOF_CONNECT_MESSAGE) // ignore RSV
// decode BND.ADDR
val atyp = read().also { if (it == -1) throw EOFException(EOF_CONNECT_MESSAGE) }
// just discard it
when (atyp) {
1, 4 -> decodeIP(this, atyp)
3 -> decodeHost(this)
}
// port, just ignore it
readNBytes(2)
}
return socket
} catch (th: Throwable) {
runCatching { socket.close() }.onFailure { th.addSuppressed(th) }
throw th
}
}
private fun errorMessage(code: Int) = when (code) {
1 -> "General SOCKS server failure"
2 -> "Connection not allowed by ruleset"
3 -> "Network unreachable"
4 -> "Host unreachable"
5 -> "Connection refused"
6 -> "TTL expired"
7 -> "Command not supported"
8 -> "Address type not supported"
else -> {
check(code in 9..255)
"Unknown SOCKS5 error: $code"
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment