Skip to content

Instantly share code, notes, and snippets.

@dalbothek
Last active December 19, 2015 23:28
Show Gist options
  • Save dalbothek/6034575 to your computer and use it in GitHub Desktop.
Save dalbothek/6034575 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
# This source file is part of mc4p,
# the Minecraft Portable Protocol-Parsing Proxy.
#
# Copyright (C) 2011 Matthew J. McGill, Simon Marti
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License v2 as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
import re
import struct
import inspect
import collections
class Protocol(dict):
"""Collection of multiple protocol versions"""
def version(self, version):
return ProtocolVersion(version, self)
def __getitem__(self, version):
assert isinstance(version, int)
while version not in self and version > 0:
version -= 1
return super(Protocol, self).__getitem__(version)
def _register_version(self, protocol_version):
self[protocol_version.version] = protocol_version
def __str__(self):
return "\n\n".join(str(version) for version in self.itervalues())
class ProtocolVersion(object):
def __init__(self, version, protocol=None):
self.version = version
self.messages = [None] * 256
self.protocol = protocol
def parse_message(self, stream):
message_id = UnsignedByte.parse(stream)
if self.messages[message_id] is None:
raise self.UnsupportedPacketException(message_id)
return self.messages[message_id](stream)
def __enter__(self):
pass
def __exit__(self, *args):
"""Captures all defined messages"""
potential_messages = inspect.currentframe().f_back.f_locals
# ^ This is the part where you give up your firstborn son to Satan
for message in potential_messages.itervalues():
if (inspect.isclass(message) and
issubclass(message, Message) and
message not in (Message, ServerMessage, ClientMessage)):
message._do_magic()
self.messages[message.id] = message
if self.protocol is not None:
self.protocol._register_version(self)
def __str__(self):
return "\n".join((
"Protocol version %s - Client" % self.version,
"-------------------",
"\n\n".join(msg._str() for msg in self.client_messages if msg),
"",
"Protocol version %s - Server" % self.version,
"-------------------",
"\n\n".join(msg._str() for msg in self.server_messages if msg)
))
class UnsupportedPacketException(Exception):
def __init__(self, message_id):
super(ProtocolVersion.UnsupportedPacketException, self).__init__(
"Unsupported packet id 0x%x" % message_id
)
self.message_id = message_id
class Message(object):
_NAME_PATTERN = re.compile("(.)([A-Z])")
id = None
def __init__(self, stream=None, **kwargs):
if stream and kwargs:
raise TypeError("Unexpected argument combination")
for name, field in self._fields.iteritems():
if stream:
setattr(self, name, field.parse(stream, self))
else:
setattr(self, name, kwargs.get(name))
def emit(self):
for name, field in self._fields.iteritems():
field.prepare(getattr(self, name), self)
return (struct.pack(">B", self.id) +
"".join(field.emit(getattr(self, name), self)
for name, field in self._fields.iteritems()))
@classmethod
def _do_magic(cls):
cls._name = cls._NAME_PATTERN.sub(
lambda g: "%s %s" % (g.group(1), g.group(2)), cls.__name__
)
cls._fields = collections.OrderedDict(sorted(
((name, field) for name, field in cls.__dict__.iteritems()
if isinstance(field, MessageField)),
key=lambda i: i[1]._order_id
))
@classmethod
def _str(cls):
if cls._fields:
fields = "\n".join(
" %s (%s)" % (name, field)
for name, field in cls._fields.iteritems()
)
else:
fields = " -- empty --"
return "\n".join((
"0x%02x %s" % (cls.id, cls._name),
fields
))
class ClientMessage(Message):
"""Message sent from client to server"""
pass
class ServerMessage(Message):
"""Message sent from server to client"""
pass
class MessageField(object):
_NEXT_ID = 1
def __init__(self):
self._order_id = MessageField._NEXT_ID
MessageField._NEXT_ID += 1
@classmethod
def parse(cls, stream, message):
return None
@classmethod
def prepare(cls, value, message):
"""Used to set stray length fields"""
pass
@classmethod
def emit(self, value, message):
return ""
@classmethod
def _parse_subfield(cls, field, stream, message):
if isinstance(field, MessageField):
return field.parse(stream, message)
elif isinstance(field, basestring):
return getattr(message, field)
elif isinstance(field, dict):
return collections.OrderedDict(
(key, cls._parse_subfield(cls, subfield, stream, message))
for key, subfield in field
)
else:
raise NotImplementedError
@classmethod
def _emit_subfield(cls, field, value, message):
if isinstance(field, MessageField):
return field.emit(value, message)
elif isinstance(field, basestring):
return ""
elif isinstance(field, dict):
return "".join(
cls._emit_subfield(cls, subfield, value[name], message)
for name, subfield in field.iteritems()
)
else:
raise NotImplementedError
@classmethod
def _set_subfield(cls, field, value, message):
if isinstance(field, basestring):
setattr(message, field, value)
def __repr__(self):
return self.__class__.__name__
def simple_type_field(name, format):
format = ">" + format
length = struct.calcsize(format)
class SimpleType(MessageField):
@classmethod
def parse(cls, stream, message=None):
return struct.unpack(format, stream.read(length))[0]
@classmethod
def emit(cls, value, message=None):
return struct.pack(format, value)
SimpleType.__name__ = name
return SimpleType
class Conditional(MessageField):
def __init__(self, field, condition):
self._field = field
self.condition = condition
super(Conditional, self).__init__()
def parse(self, stream, message):
if not self.condition(message):
return None
return self._parse_subfield(self._field, stream, message)
def emit(self, value, message):
if not self.condition(message):
return ""
return self._emit_subfield(self._field, value, message)
class List(MessageField):
def __init__(self, field, size):
self._size = size
self._field = field
super(List, self).__init__()
def parse(self, stream, message=None):
return [
self._parse_subfield(self._field, stream, message)
for i in range(self._parse_subfield(self._size, stream, message))
]
def emit(self, value, message=None):
return (self._emit_subfield(self._size, len(value), message) +
"".join(self._emit_subfield(self._field, entry, message)
for entry in value))
class Dict(MessageField):
def __init__(self, *args, **kwargs):
super(Dict, self).__init__()
Byte = simple_type_field("Byte", "b")
UnsignedByte = simple_type_field("UnsignedByte", "B")
Short = simple_type_field("Short", "h")
Int = simple_type_field("Int", "i")
Float = simple_type_field("Float", "f")
Double = simple_type_field("Double", "d")
Long = simple_type_field("Long", "q")
class Bool(Byte):
@classmethod
def parse(cls, stream, message):
return super(Bool, cls).parse(stream) == 1
class String(MessageField):
@classmethod
def parse(cls, stream, message=None):
return unicode(stream.read(2 * Short.parse(stream)),
encoding="utf-16-be")
@classmethod
def emit(cls, value, message=None):
return Short.emit(len(value)) + value.encode("utf-16-be")
class Data(MessageField):
def __init__(self, size):
self._size = size
super(Data, self).__init__()
def parse(self, stream, message=None):
return stream.read(self._parse_subfield(self._size, stream, message))
def emit(self, value, message=None):
return self._emit_subfield(self._size, len(value), message) + value
class ItemStack(MessageField):
@classmethod
def parse(cls, stream, message=None):
item_id = Short.parse(stream)
if item_id == -1:
return None
return {
"item_id": item_id,
"count": Byte.parse(stream),
"uses": Short.parse(stream),
'nbt_data': stream.read(Short.parse(stream))
}
@classmethod
def emit(cls, value, message=None):
if value is None:
return Short.emit(-1)
return "".join((
Short.emit(value['item_id']),
Byte.emit(value['count']),
Short.emit(value['uses']),
Short.emit(len(value['nb_data'])),
value['nb_data']
))
class Metadata(MessageField):
FIELD_TYPES = [
Byte,
Short,
Int,
Float,
String,
ItemStack
]
@classmethod
def _key_generator(cls, stream):
while True:
key = UnsignedByte.parse(stream)
if key == 127:
return
yield key
@classmethod
def parse(cls, stream, message=None):
return [{
'index': key & 0x1f,
'type': key >> 5,
'value': cls.FIELD_TYPES[key >> 5].parse(stream)
} for key in cls._key_generator(stream)]
@classmethod
def emit(cls, value, message=None):
return "".join(
UnsignedByte.emit(item['index'] | item['type'] << 5) +
cls.FIELD_TYPES[item['type']].emit(item['value'])
for item in value
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment