Skip to content

Instantly share code, notes, and snippets.

@ks888
Last active March 8, 2018 10:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ks888/33aecb94b84873743dfcc5446755272c to your computer and use it in GitHub Desktop.
Save ks888/33aecb94b84873743dfcc5446755272c to your computer and use it in GitHub Desktop.
ping command in golang, just for fun
package main
import (
"bytes"
"fmt"
"log"
"net"
"os"
"time"
)
// walker is an interface to describe how to traverse struct fields.
// Adopt a visitor design pattern as it enables clients to visit fields for various purposes,
// including unpacking, packing and printing.
type walker interface {
walk(func(field interface{}))
}
// TODO: error handling
func unpackStruct(data walker, rawMsg []byte) int {
offset := 0
data.walk(func(field interface{}) {
switch v := field.(type) {
case *uint8:
*v = rawMsg[offset]
offset++
case *uint16:
*v = uint16(rawMsg[offset])<<8 | uint16(rawMsg[offset+1])
offset += 2
case *uint32:
*v = uint32(rawMsg[offset])<<24 | uint32(rawMsg[offset+1])<<16 | uint32(rawMsg[offset+2])<<8 | uint32(rawMsg[offset+3])
offset += 4
case []byte:
offset += copy(v, rawMsg[offset:])
default:
log.Printf("unknown field type: %T", v)
}
})
return offset
}
// TODO: error handling
func packStruct(data walker) []byte {
buff := &bytes.Buffer{}
data.walk(func(field interface{}) {
switch v := field.(type) {
case *uint8:
buff.WriteByte(*v)
case *uint16:
buff.WriteByte(byte(*v >> 8))
buff.WriteByte(byte(*v))
case *uint32:
buff.WriteByte(byte(*v >> 24))
buff.WriteByte(byte(*v >> 16))
buff.WriteByte(byte(*v >> 8))
buff.WriteByte(byte(*v))
case []byte:
buff.Write(v)
default:
log.Printf("unknown field type: %T", v)
}
})
return buff.Bytes()
}
type rawIPv4Header struct {
versionAndIHL uint8
dscpAndECN uint8
totalLength uint16
identification uint16
fragmentSettings uint16
timeToLive uint8
protocol uint8
headerChecksum uint16
srcAddress uint32
dstAddress uint32
options []byte
}
func (hdr *rawIPv4Header) unpack(rawMsg []byte) int {
return unpackStruct(hdr, rawMsg)
}
func (hdr *rawIPv4Header) walk(f func(field interface{})) {
f(&hdr.versionAndIHL)
f(&hdr.dscpAndECN)
f(&hdr.totalLength)
f(&hdr.identification)
f(&hdr.fragmentSettings)
f(&hdr.timeToLive)
f(&hdr.protocol)
f(&hdr.headerChecksum)
f(&hdr.srcAddress)
f(&hdr.dstAddress)
ihl := hdr.versionAndIHL & 0xf
optionsLen := (ihl * 4) - 20
hdr.options = make([]byte, optionsLen)
f(hdr.options)
}
// IPv4Header represents the IPv4 header
type IPv4Header struct {
version uint8
headerLenInBytes uint8
totalLen uint16
identification uint16
dontFragmentFlag bool
moreFragmentFlag bool
fragmentOffset uint16
timeToLive uint8
protocol uint8
headerChecksum uint16
srcAddress uint32
dstAddress uint32
options []byte
}
// Unpack unpacks the given message to the header structure
func (hdr *IPv4Header) Unpack(rawMsg []byte) int {
rawHdr := &rawIPv4Header{}
offset := rawHdr.unpack(rawMsg)
hdr.version = (rawHdr.versionAndIHL >> 4) & 0xf
hdr.headerLenInBytes = (rawHdr.versionAndIHL) & 0xf
hdr.totalLen = rawHdr.totalLength
hdr.identification = rawHdr.identification
hdr.dontFragmentFlag = rawHdr.fragmentSettings&(1<<14) != 0
hdr.moreFragmentFlag = rawHdr.fragmentSettings&(1<<13) != 0
hdr.fragmentOffset = rawHdr.fragmentSettings & 0x1FFF
hdr.timeToLive = rawHdr.timeToLive
hdr.protocol = rawHdr.protocol
hdr.headerChecksum = rawHdr.headerChecksum
hdr.srcAddress = rawHdr.srcAddress
hdr.dstAddress = rawHdr.dstAddress
hdr.options = rawHdr.options
return offset
}
// SourceIPAddress returns string format of the src ip address.
func (hdr *IPv4Header) SourceIPAddress() string {
buff := make([]byte, 4)
buff[0] = uint8(hdr.srcAddress >> 24)
buff[1] = uint8(hdr.srcAddress >> 16)
buff[2] = uint8(hdr.srcAddress >> 8)
buff[3] = uint8(hdr.srcAddress)
return fmt.Sprintf("%d.%d.%d.%d", buff[0], buff[1], buff[2], buff[3])
}
// ICMPMessage represents the ICMP message
type ICMPMessage struct {
icmpType uint8
icmpCode uint8
checksum uint16
restOfHeader uint32 // depends on icmp type and code
data []byte
}
func (msg *ICMPMessage) walk(f func(field interface{})) {
f(&msg.icmpType)
f(&msg.icmpCode)
f(&msg.checksum)
f(&msg.restOfHeader)
f(msg.data)
}
// Unpack unpacks the given raw message to the message structure
func (msg *ICMPMessage) Unpack(rawMsg []byte) int {
const headerLen = 8
msg.data = make([]byte, len(rawMsg)-headerLen)
return unpackStruct(msg, rawMsg)
}
// Pack packs the message to the byte array
func (msg *ICMPMessage) Pack() []byte {
msg.setChecksum()
return packStruct(msg)
}
func (msg *ICMPMessage) setChecksum() {
sum := uint32(msg.icmpType) << 8
sum += uint32(msg.icmpCode)
sum += uint32(msg.restOfHeader) >> 16 & 0xFFFF
sum += uint32(msg.restOfHeader) & 0xFFFF
for i, chunk := range msg.data {
if i%2 == 0 {
sum += uint32(chunk) << 8 & 0xFF00
} else {
sum += uint32(chunk) & 0xFF
}
}
msg.checksum = ^(uint16(sum>>16) + uint16(sum))
}
// ToEchoMessage changes the ICMPMessage struct to ICMPEchoMessage struct
func (msg *ICMPMessage) ToEchoMessage() *ICMPEchoMessage {
return &ICMPEchoMessage{
ICMPMessage: *msg,
icmpID: uint16(msg.restOfHeader >> 16),
icmpSeq: uint16(msg.restOfHeader),
}
}
// ICMPEchoMessage represents the ICMP Echo or Echo Reply message
type ICMPEchoMessage struct {
ICMPMessage
icmpID uint16
icmpSeq uint16
}
// Unpack unpacks the given raw message to the message structure
func (msg *ICMPEchoMessage) Unpack(rawMsg []byte) int {
offset := msg.ICMPMessage.Unpack(rawMsg)
msg.icmpID = uint16(msg.restOfHeader >> 16)
msg.icmpSeq = uint16(msg.restOfHeader)
return offset
}
// Pack packs the message to the byte array
func (msg *ICMPEchoMessage) Pack() []byte {
msg.restOfHeader = uint32(msg.icmpID)<<16&0xFFFF0000 | uint32(msg.icmpSeq)&0xFFFF
return msg.ICMPMessage.Pack()
}
// Ping is a struct to send and receive ping message.
type Ping struct {
icmpID uint16
icmpSeq uint16
}
// NewPing returns a new Ping.
func NewPing(icmpID uint16) *Ping {
return &Ping{icmpID: icmpID, icmpSeq: uint16(1)}
}
// Receive receives the icmp echo reply message with the specified icmp ID.
func (ping *Ping) Receive() error {
ln, err := net.ListenIP("ip4:icmp", nil)
if err != nil {
return err
}
rawMsg := make([]byte, 1500)
n, err := ln.Read(rawMsg)
if err != nil {
return err
}
return ping.procEchoReplyMessage(rawMsg[:n])
}
func (ping *Ping) procEchoReplyMessage(rawMsg []byte) error {
ipv4Hdr := &IPv4Header{}
offset := ipv4Hdr.Unpack(rawMsg)
rawICMPMsg := rawMsg[offset:]
msg := &ICMPMessage{}
msg.Unpack(rawICMPMsg)
if msg.icmpCode != 0 || msg.icmpType != 0 {
return fmt.Errorf("not echo reply message: %d %d", msg.icmpType, msg.icmpCode)
}
echoMsg := msg.ToEchoMessage()
if echoMsg.icmpID != ping.icmpID {
return fmt.Errorf("message from unknown icmp ID (%d)", echoMsg.icmpID)
}
timeInMsg := time.Time{}
if err := (&timeInMsg).GobDecode(echoMsg.data); err != nil {
return fmt.Errorf("failed to decode the data: %v", err)
}
curr := time.Now()
fmt.Printf("%d bytes from %s: seq=%d, ttl=%d, rtt=%s \n", len(rawICMPMsg), ipv4Hdr.SourceIPAddress(), echoMsg.icmpSeq, ipv4Hdr.timeToLive, curr.Sub(timeInMsg))
return nil
}
// Send sends the icmp echo message with the specified icmp ID.
// Do not call thie method from multiple go routines (not thread safe).
func (ping *Ping) Send(hostname string) error {
conn, err := net.Dial("ip4:icmp", hostname)
if err != nil {
return err
}
rawMsg, err := ping.buildEchoMessage()
if err != nil {
return err
}
_, err = conn.Write(rawMsg)
return err
}
func (ping *Ping) buildEchoMessage() ([]byte, error) {
curr := time.Now()
data, err := curr.GobEncode()
if err != nil {
return nil, err
}
msg := ICMPEchoMessage{
ICMPMessage: ICMPMessage{
icmpType: 0x8,
icmpCode: 0x0,
checksum: 0x0,
data: data,
},
icmpID: ping.icmpID,
icmpSeq: ping.icmpSeq,
}
ping.icmpSeq++
return msg.Pack(), nil
}
func main() {
hostname := "golang.org"
if len(os.Args) >= 2 {
hostname = os.Args[1]
}
pid := uint16(os.Getpid())
ping := NewPing(pid)
go func() {
for {
if err := ping.Send(hostname); err != nil {
log.Printf("failed to send icmp message: %v", err)
}
time.Sleep(1 * time.Second)
}
}()
for {
if err := ping.Receive(); err != nil {
log.Printf("failed to receive icmp message: %v", err)
}
}
}
package main
import (
"reflect"
"testing"
)
func TestRawIPV4Unpack(t *testing.T) {
rawMsg := []byte{0x45, 0x0, 0x0, 0x2b, 0x35, 0x5, 0x40, 0x0, 0x3d, 0x1, 0x95, 0x72, 0xff, 0xd9, 0x1a, 0xff, 0xff, 0xff, 0x0, 0x2}
hdr := &rawIPv4Header{}
offset := hdr.unpack(rawMsg)
if offset != 20 {
t.Errorf("invalid offset: %v", offset)
}
if hdr.versionAndIHL != 0x45 {
t.Errorf("invalid versionAndIHL: %v", hdr.versionAndIHL)
}
if hdr.dscpAndECN != 0 {
t.Errorf("invalid dscpAndECN: %v", hdr.dscpAndECN)
}
if hdr.totalLength != 0x2b {
t.Errorf("invalid total length: %v", hdr.totalLength)
}
if hdr.identification != 0x3505 {
t.Errorf("invalid identification: %v", hdr.identification)
}
if hdr.fragmentSettings != 0x4000 {
t.Errorf("invalid fragmentSettings: %v", hdr.fragmentSettings)
}
if hdr.timeToLive != 0x3d {
t.Errorf("invalid timeToLive: %v", hdr.timeToLive)
}
if hdr.protocol != 0x1 {
t.Errorf("invalid protocol: %v", hdr.protocol)
}
if hdr.headerChecksum != 0x9572 {
t.Errorf("invalid headerChecksum: %v", hdr.headerChecksum)
}
if hdr.srcAddress != 0xffd91aff {
t.Errorf("invalid srcAddress: %v", hdr.srcAddress)
}
if hdr.dstAddress != 0xffff0002 {
t.Errorf("invalid dstAddress: %v", hdr.dstAddress)
}
}
func TestRawIPV4Unpack_WithOptions(t *testing.T) {
rawMsg := []byte{0x46, 0x0, 0x0, 0x2b, 0x35, 0x5, 0x40, 0x0, 0x3d, 0x1, 0x95, 0x72, 0xff, 0xd9, 0x1a, 0xff, 0xff, 0xff, 0x0, 0x2, 0x12, 0x34, 0x56, 0x78, 0x9a}
hdr := &rawIPv4Header{}
offset := hdr.unpack(rawMsg)
if offset != 24 {
t.Errorf("invalid offset: %v", offset)
}
if !reflect.DeepEqual([]byte{0x12, 0x34, 0x56, 0x78}, hdr.options) {
t.Errorf("invalid options: %v", hdr.options)
}
}
func TestIPV4Unpack(t *testing.T) {
rawMsg := []byte{0x45, 0x0, 0x0, 0x2b, 0x35, 0x5, 0x40, 0x0, 0x3d, 0x1, 0x95, 0x72, 0xff, 0xd9, 0x1a, 0xff, 0xff, 0xff, 0x0, 0x2, 0x0, 0x0, 0xe4, 0xd0, 0x3, 0x80, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0, 0xe, 0xd2, 0x2d, 0xa9, 0x8b, 0x1d, 0x11, 0x15, 0x3e, 0x0, 0x0}
hdr := &IPv4Header{}
offset := hdr.Unpack(rawMsg)
if offset != 20 {
t.Errorf("invalid offset: %v", offset)
}
if hdr.version != 0x4 {
t.Errorf("invalid version: %v", hdr.version)
}
if hdr.headerLenInBytes != 0x5 {
t.Errorf("invalid header length: %v", hdr.headerLenInBytes)
}
if hdr.totalLen != 0x2b {
t.Errorf("invalid total length: %v", hdr.totalLen)
}
if hdr.identification != 0x3505 {
t.Errorf("invalid identification: %v", hdr.identification)
}
if !hdr.dontFragmentFlag {
t.Errorf("invalid dontFragmentFlag: %v", hdr.dontFragmentFlag)
}
if hdr.moreFragmentFlag {
t.Errorf("invalid moreFragmentFlag: %v", hdr.moreFragmentFlag)
}
if hdr.fragmentOffset != 0 {
t.Errorf("invalid fragmentOffset: %v", hdr.fragmentOffset)
}
if hdr.timeToLive != 0x3d {
t.Errorf("invalid timeToLive: %v", hdr.timeToLive)
}
if hdr.protocol != 0x1 {
t.Errorf("invalid protocol: %v", hdr.protocol)
}
if hdr.headerChecksum != 0x9572 {
t.Errorf("invalid headerChecksum: %v", hdr.headerChecksum)
}
if hdr.srcAddress != 0xffd91aff {
t.Errorf("invalid srcAddress: %v", hdr.srcAddress)
}
if hdr.dstAddress != 0xffff0002 {
t.Errorf("invalid dstAddress: %v", hdr.dstAddress)
}
if len(hdr.options) != 0 {
t.Errorf("invalid options: %v", hdr.options)
}
}
func TestIPV4UnpackWithOptions(t *testing.T) {
rawMsg := []byte{0x46, 0x0, 0x0, 0x2b, 0x35, 0x5, 0x40, 0x0, 0x3d, 0x1, 0x95, 0x72, 0xff, 0xd9, 0x1a, 0xff, 0xff, 0xff, 0x0, 0x2, 0x12, 0x34, 0x56, 0x78, 0x9a}
hdr := &IPv4Header{}
offset := hdr.Unpack(rawMsg)
if offset != 24 {
t.Errorf("invalid offset: %v", offset)
}
if !reflect.DeepEqual(hdr.options, []byte{0x12, 0x34, 0x56, 0x78}) {
t.Errorf("invalid options: %v", hdr.options)
}
}
func TestICMPMessageUnpack(t *testing.T) {
rawMsg := []byte{0x0, 0x0, 0xe4, 0xd0, 0x3, 0x80, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0, 0xe, 0xd2, 0x2d, 0xa9, 0x8b, 0x1d, 0x11, 0x15, 0x3e, 0x0, 0x0}
msg := &ICMPMessage{}
offset := msg.Unpack(rawMsg)
if offset != 23 {
t.Errorf("invalid offset: %v", offset)
}
if msg.icmpType != 0 {
t.Errorf("invalid icmpType: %v", msg.icmpType)
}
if msg.icmpCode != 0 {
t.Errorf("invalid icmpCode: %v", msg.icmpCode)
}
if msg.checksum != 0xe4d0 {
t.Errorf("invalid checksum: %v", msg.checksum)
}
if msg.restOfHeader != 0x03800001 {
t.Errorf("invalid restOfHeader: %v", msg.restOfHeader)
}
if len(msg.data) != 15 {
t.Errorf("invalid data: %v", msg.data)
}
}
func TestICMPMessagePack(t *testing.T) {
icmpMsg := &ICMPMessage{
icmpType: 0,
icmpCode: 0,
restOfHeader: 0x03800001,
data: []byte{0x1, 0x0, 0x0, 0x0, 0xe, 0xd2, 0x2d, 0xa9, 0x8b, 0x1d, 0x11, 0x15, 0x3e, 0x0, 0x0},
}
rawMsg := icmpMsg.Pack()
expect := []byte{0x0, 0x0, 0xe4, 0xd0, 0x3, 0x80, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0, 0xe, 0xd2, 0x2d, 0xa9, 0x8b, 0x1d, 0x11, 0x15, 0x3e, 0x0, 0x0}
if !reflect.DeepEqual(expect, rawMsg) {
t.Errorf("invalid message: %v", rawMsg)
}
}
func TestICMPMessageToEchoMessage(t *testing.T) {
rawMsg := []byte{0x0, 0x0, 0xe4, 0xd0, 0x3, 0x80, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0, 0xe, 0xd2, 0x2d, 0xa9, 0x8b, 0x1d, 0x11, 0x15, 0x3e, 0x0, 0x0}
msg := &ICMPMessage{}
_ = msg.Unpack(rawMsg)
if msg.icmpType != 0 || msg.icmpCode != 0 {
t.Errorf("invalid icmpType or icmpCode: %v, %v", msg.icmpType, msg.icmpCode)
}
echoMsg := msg.ToEchoMessage()
if echoMsg.icmpID != 0x0380 {
t.Errorf("invalid icmpID: %v", echoMsg.icmpID)
}
if echoMsg.icmpSeq != 0x0001 {
t.Errorf("invalid icmpSeq: %v", echoMsg.icmpSeq)
}
}
func TestICMPEchoMessageUnpack(t *testing.T) {
rawMsg := []byte{0x0, 0x0, 0xe4, 0xd0, 0x3, 0x80, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0, 0xe, 0xd2, 0x2d, 0xa9, 0x8b, 0x1d, 0x11, 0x15, 0x3e, 0x0, 0x0}
msg := &ICMPEchoMessage{}
_ = msg.Unpack(rawMsg)
if msg.icmpID != 0x0380 {
t.Errorf("invalid icmpID: %v", msg.icmpID)
}
if msg.icmpSeq != 0x0001 {
t.Errorf("invalid icmpSeq: %v", msg.icmpSeq)
}
}
func TestICMPEchoMessagePack(t *testing.T) {
icmpMsg := &ICMPEchoMessage{
ICMPMessage: ICMPMessage{
icmpType: 0,
icmpCode: 0,
checksum: 0xed40,
data: []byte{0x1, 0x0, 0x0, 0x0, 0xe, 0xd2, 0x2d, 0xa9, 0x8b, 0x1d, 0x11, 0x15, 0x3e, 0x0, 0x0},
},
icmpID: 0x0380,
icmpSeq: 0x0001,
}
rawMsg := icmpMsg.Pack()
expect := []byte{0x0, 0x0, 0xe4, 0xd0, 0x3, 0x80, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0, 0xe, 0xd2, 0x2d, 0xa9, 0x8b, 0x1d, 0x11, 0x15, 0x3e, 0x0, 0x0}
if !reflect.DeepEqual(expect, rawMsg) {
t.Errorf("invalid message: %v", rawMsg)
}
}
func TestProcEchoReplyMessage(t *testing.T) {
rawMsg := []byte{0x45, 0x0, 0x0, 0x2b, 0x35, 0x5, 0x40, 0x0, 0x3d, 0x1, 0x95, 0x72, 0xff, 0xd9, 0x1a, 0xff, 0xff, 0xff, 0x0, 0x2, 0x0, 0x0, 0xe4, 0xd0, 0x3, 0x80, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0, 0xe, 0xd2, 0x2d, 0xa9, 0x8b, 0x1d, 0x11, 0x15, 0x3e, 0x0, 0x0}
ping := &Ping{icmpID: uint16(0x0380)}
err := ping.procEchoReplyMessage(rawMsg)
if err != nil {
t.Errorf("failed to proc echo reply: %v", err)
}
}
func TestProcEchoReplyMessage_NotEchoReply(t *testing.T) {
rawMsg := []byte{0x45, 0x0, 0x0, 0x2b, 0x35, 0x5, 0x40, 0x0, 0x3d, 0x1, 0x95, 0x72, 0xff, 0xd9, 0x1a, 0xff, 0xff, 0xff, 0x0, 0x2, 0x8, 0x0, 0xe4, 0xd0, 0x3, 0x80, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0, 0xe, 0xd2, 0x2d, 0xa9, 0x8b, 0x1d, 0x11, 0x15, 0x3e, 0x0, 0x0}
ping := &Ping{icmpID: uint16(0x0380)}
err := ping.procEchoReplyMessage(rawMsg)
if err == nil {
t.Error("should not process non echo reply message")
}
}
func TestProcEchoReplyMessage_NotExpectedICMPID(t *testing.T) {
rawMsg := []byte{0x45, 0x0, 0x0, 0x2b, 0x35, 0x5, 0x40, 0x0, 0x3d, 0x1, 0x95, 0x72, 0xff, 0xd9, 0x1a, 0xff, 0xff, 0xff, 0x0, 0x2, 0x0, 0x0, 0xe4, 0xd0, 0x3, 0x80, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0, 0xe, 0xd2, 0x2d, 0xa9, 0x8b, 0x1d, 0x11, 0x15, 0x3e, 0x0, 0x0}
ping := &Ping{}
err := ping.procEchoReplyMessage(rawMsg)
if err == nil {
t.Error("should not process unexpected icmp id")
}
}
func TestProcEchoReplyMessage_InvalidData(t *testing.T) {
rawMsg := []byte{0x45, 0x0, 0x0, 0x2b, 0x35, 0x5, 0x40, 0x0, 0x3d, 0x1, 0x95, 0x72, 0xff, 0xd9, 0x1a, 0xff, 0xff, 0xff, 0x0, 0x2, 0x0, 0x0, 0xe4, 0xd0, 0x3, 0x80, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0, 0xe, 0xd2, 0x2d, 0xa9, 0x8b, 0x1d, 0x11, 0x15, 0x3e, 0x0}
ping := &Ping{icmpID: uint16(0x0380)}
err := ping.procEchoReplyMessage(rawMsg)
if err == nil {
t.Errorf("should return error if data is invalid: %v", err)
}
}
func TestBuildEchoMessage(t *testing.T) {
ping := &Ping{icmpID: uint16(0x0380)}
_, err := ping.buildEchoMessage()
if err != nil {
t.Errorf("failed to build echo message: %v", err)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment