Skip to content

Instantly share code, notes, and snippets.

@mnstrspeed
Created March 1, 2014 18:54
Show Gist options
  • Save mnstrspeed/9295199 to your computer and use it in GitHub Desktop.
Save mnstrspeed/9295199 to your computer and use it in GitHub Desktop.
WebSocket Client
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.net.Socket;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import javax.xml.bind.DatatypeConverter;
public class WebSocket {
private static final int FLAG_TEXT_FRAME = 0b00000001;
private static final int FLAG_FINAL_MESSAGE = 0b10000000;
private static final int FLAG_MASKED = 0b10000000;
private URI uri;
private HashMap<String, String> headers;
private Socket socket;
private InputStream inputStream;
private OutputStream outputStream;
public WebSocket(URI uri) {
this.uri = uri;
this.headers = new HashMap<String, String>();
this.setHeader("Host", this.uri.getHost() + ":" + this.uri.getPort());
this.setHeader("Upgrade", "websocket");
this.setHeader("Connection", "Upgrade");
this.setHeader("Sec-WebSocket-Key", getRandomKey());
this.setHeader("Sec-WebSocket-Version", "13");
}
public URI getURI() {
return this.uri;
}
public void setHeader(String field, String value) {
this.headers.put(field, value);
}
private static String getRandomKey() {
return DatatypeConverter.printBase64Binary(getRandomBytes(16)); //random, base64 encoded
}
private static byte[] getRandomBytes(int length) {
byte bytes[] = new byte[length];
new Random().nextBytes(bytes);
return bytes;
}
public List<String> connect(InputStream inputStream, OutputStream outputStream) throws IOException {
this.inputStream = inputStream;
this.outputStream = outputStream;
StringBuilder requestBuilder = new StringBuilder();
requestBuilder.append("GET " + this.uri.getPath() + " HTTP/1.1\r\n");
for (Map.Entry<String, String> entry : this.headers.entrySet()) {
requestBuilder.append(entry.getKey() + ": " + entry.getValue() + "\r\n");
}
requestBuilder.append("\r\n");
this.outputStream.write(requestBuilder.toString().getBytes());
this.outputStream.flush();
BufferedReader reader = new BufferedReader(
new InputStreamReader(this.inputStream));
ArrayList<String> headers = new ArrayList<String>();
String header;
do {
header = reader.readLine();
headers.add(header);
}
while (!header.equals(""));
return headers;
}
public List<String> connect() throws IOException {
this.socket = new Socket(this.uri.getHost(), this.uri.getPort());
return this.connect(
this.socket.getInputStream(),
this.socket.getOutputStream());
}
public void close() throws IOException {
this.inputStream.close();
this.outputStream.close();
this.socket.close();
}
public void send(String content) throws IOException {
this.send(content, false);
}
public void send(String content, boolean delay) throws IOException {
byte[] message = content.getBytes("UTF-8");
if (message.length > 125) {
throw new RuntimeException("Message too large");
}
int field1 = FLAG_TEXT_FRAME | FLAG_FINAL_MESSAGE;
int field2 = FLAG_MASKED | message.length;
byte[] mask = getRandomBytes(4);
byte[] header = { (byte)field1, (byte)field2,
mask[0], mask[1], mask[2], mask[3] };
byte[] body = mask(message, mask);
this.outputStream.write(concat(header, body));
if (!delay) {
this.outputStream.flush();
}
}
public String read() throws IOException {
String buffer = "";
boolean fin = false;
while (!fin) {
int field1 = this.inputStream.read();
int field2 = this.inputStream.read();
fin = (byte)(field1 >>> 7) == 1; // FIN bit
boolean masked = (field2 >>> 7) == 1; // MASK bit
long length = field2 & 0b01111111;
if (length == 126) {
byte[] extendedLength = new byte[2];
this.inputStream.read(extendedLength);
length = ByteBuffer.wrap(extendedLength).getShort();
} else if (length == 127) {
byte[] extendedLength = new byte[8];
this.inputStream.read(extendedLength);
length = ByteBuffer.wrap(extendedLength).getLong();
}
if (masked) {
byte[] mask = new byte[4];
this.inputStream.read(mask);
}
int read = 0;
byte[] data = new byte[(int)length];
while (read < data.length) {
read += this.inputStream.read(data, read, data.length - read);
}
buffer += new String(data, "UTF-8");
}
return buffer;
}
private static byte[] concat(byte[] A, byte[] B) {
byte[] C= new byte[A.length + B.length];
System.arraycopy(A, 0, C, 0, A.length);
System.arraycopy(B, 0, C, A.length, B.length);
return C;
}
private static byte[] mask(byte[] data, byte[] mask) {
byte[] result = new byte[data.length];
for (int i = 0; i < data.length; i++) {
result[i] = (byte)(data[i] ^ mask[i % 4]);
}
return result;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment