Skip to content

Instantly share code, notes, and snippets.

@SuperQ
Last active February 28, 2022 17:28
Show Gist options
  • Save SuperQ/6a74834ff4c4855f1446c2ffeb117c84 to your computer and use it in GitHub Desktop.
Save SuperQ/6a74834ff4c4855f1446c2ffeb117c84 to your computer and use it in GitHub Desktop.
A first pass at making a packet tracking object.
package ping
import (
"sync"
"time"
"github.com/google/uuid"
)
type PacketTracker struct {
currentUUID uuid.UUID
packets map[uuid.UUID]PacketSequence
sequence int
nextSequence int
timeout time.Duration
timeoutCh chan *inFlightPacket
mutex sync.RWMutex
}
type PacketSequence struct {
packets map[uint]inFlightPacket
}
type inFlightPacket struct {
timeoutTimer *time.Timer
}
func NewPacketTracker(t time.Duration) *PacketTracker {
firstUUID := uuid.New()
var firstSequence = map[uuid.UUID]map[int]struct{}{}
firstSequence[firstUUID] = make(map[int]struct{})
return &PacketTracker{
packets: map[uuid.UUID]PacketSequence{},
sequence: 0,
timeout: t,
}
}
func (t *PacketTracker) AddPacket() int {
t.mutex.Lock()
defer t.mutex.Unlock()
if t.nextSequence > 65535 {
newUUID := uuid.New()
t.packets[newUUID] = PacketSequence{}
t.currentUUID = newUUID
t.nextSequence = 0
}
t.sequence = t.nextSequence
t.packets[t.currentUUID][t.sequence] = inFlightPacket{}
// if t.timeout > 0 {
// t.packets[t.currentUUID][t.sequence].timeoutTimer = time.Timer(t.timeout)
// }
t.nextSequence++
return t.sequence
}
// DeletePacket removes a packet from the tracker.
func (t *PacketTracker) DeletePacket(u uuid.UUID, seq int) {
t.mutex.Lock()
defer t.mutex.Unlock()
if t.hasPacket(u, seq) {
if t.packets[u][seq] != nil {
t.packets[u][seq].timeoutTimer.Stop()
}
delete(t.packets[u], seq)
}
}
func (t *PacketTracker) hasPacket(u uuid.UUID, seq int) bool {
_, inflight := t.packets[u][seq]
return inflight
}
// HasPacket checks the tracker to see if it's currently tracking a packet.
func (t *PacketTracker) HasPacket(u uuid.UUID, seq int) bool {
t.mutex.RLock()
defer t.mutex.Unlock()
return t.hasPacket(u, seq)
}
func (t *PacketTracker) HasUUID(u uuid.UUID) bool {
_, hasUUID := t.packets[u]
return hasUUID
}
func (t *PacketTracker) CurrentUUID() uuid.UUID {
t.mutex.RLock()
defer t.mutex.Unlock()
return t.currentUUID
}
diff --git a/ping.go b/ping.go
index e1c09ee..0abb789 100644
--- a/ping.go
+++ b/ping.go
@@ -87,9 +87,6 @@ var (
// New returns a new Pinger struct pointer.
func New(addr string) *Pinger {
r := rand.New(rand.NewSource(getSeed()))
- firstUUID := uuid.New()
- var firstSequence = map[uuid.UUID]map[int]struct{}{}
- firstSequence[firstUUID] = make(map[int]struct{})
return &Pinger{
Count: -1,
Interval: time.Second,
@@ -97,17 +94,15 @@ func New(addr string) *Pinger {
Size: timeSliceLength + trackerLength,
Timeout: time.Duration(math.MaxInt64),
- addr: addr,
- done: make(chan interface{}),
- id: r.Intn(math.MaxUint16),
- trackerUUIDs: []uuid.UUID{firstUUID},
- ipaddr: nil,
- ipv4: false,
- network: "ip",
- protocol: "udp",
- awaitingSequences: firstSequence,
- TTL: 64,
- logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())},
+ addr: addr,
+ done: make(chan interface{}),
+ id: r.Intn(math.MaxUint16),
+ ipaddr: nil,
+ ipv4: false,
+ network: "ip",
+ protocol: "udp",
+ TTL: 64,
+ logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())},
}
}
@@ -143,6 +138,9 @@ type Pinger struct {
// Number of duplicate packets received
PacketsRecvDuplicates int
+ // Per-packet timeout
+ PacketTimeout time.Duration
+
// Round trip time statistics
minRtt time.Duration
maxRtt time.Duration
@@ -189,14 +187,11 @@ type Pinger struct {
ipaddr *net.IPAddr
addr string
- // trackerUUIDs is the list of UUIDs being used for sending packets.
- trackerUUIDs []uuid.UUID
-
ipv4 bool
id int
sequence int
- // awaitingSequences are in-flight sequence numbers we keep track of to help remove duplicate receipts
- awaitingSequences map[uuid.UUID]map[int]struct{}
+ // tracker is a PacketTrackrer of UUIDs and sequence numbers.
+ tracker *PacketTracker
// network is one of "ip", "ip4", or "ip6".
network string
// protocol is "icmp" or "udp".
@@ -413,6 +408,9 @@ func (p *Pinger) Run() error {
if err != nil {
return err
}
+
+ p.tracker = NewPacketTracker(p.PacketTimeout)
+
if conn, err = p.listen(); err != nil {
return err
}
@@ -616,19 +614,12 @@ func (p *Pinger) getPacketUUID(pkt []byte) (*uuid.UUID, error) {
return nil, fmt.Errorf("error decoding tracking UUID: %w", err)
}
- for _, item := range p.trackerUUIDs {
- if item == packetUUID {
- return &packetUUID, nil
- }
+ if p.tracker.HasUUID(packetUUID) {
+ return &packetUUID, nil
}
return nil, nil
}
-// getCurrentTrackerUUID grabs the latest tracker UUID.
-func (p *Pinger) getCurrentTrackerUUID() uuid.UUID {
- return p.trackerUUIDs[len(p.trackerUUIDs)-1]
-}
-
func (p *Pinger) processPacket(recv *packet) error {
receivedAt := time.Now()
var proto int
@@ -677,15 +668,15 @@ func (p *Pinger) processPacket(recv *packet) error {
inPkt.Rtt = receivedAt.Sub(timestamp)
inPkt.Seq = pkt.Seq
// If we've already received this sequence, ignore it.
- if _, inflight := p.awaitingSequences[*pktUUID][pkt.Seq]; !inflight {
+ if !p.tracker.HasPacket(*pktUUID, pkt.Seq) {
p.PacketsRecvDuplicates++
if p.OnDuplicateRecv != nil {
p.OnDuplicateRecv(inPkt)
}
return nil
}
- // remove it from the list of sequences we're waiting for so we don't get duplicates.
- delete(p.awaitingSequences[*pktUUID], pkt.Seq)
+ // Remove it from the list of sequences we're waiting for so we don't get duplicates.
+ p.tracker.DeletePacket(*pktUUID, pkt.Seq)
p.updateStatistics(inPkt)
default:
// Very bad, not sure how this can happen
@@ -706,7 +697,7 @@ func (p *Pinger) sendICMP(conn packetConn) error {
dst = &net.UDPAddr{IP: p.ipaddr.IP, Zone: p.ipaddr.Zone}
}
- currentUUID := p.getCurrentTrackerUUID()
+ currentUUID := p.tracker.CurrentUUID()
uuidEncoded, err := currentUUID.MarshalBinary()
if err != nil {
return fmt.Errorf("unable to marshal UUID binary: %w", err)
@@ -754,15 +745,8 @@ func (p *Pinger) sendICMP(conn packetConn) error {
handler(outPkt)
}
// mark this sequence as in-flight
- p.awaitingSequences[currentUUID][p.sequence] = struct{}{}
+ p.sequence = p.tracker.AddPacket()
p.PacketsSent++
- p.sequence++
- if p.sequence > 65535 {
- newUUID := uuid.New()
- p.trackerUUIDs = append(p.trackerUUIDs, newUUID)
- p.awaitingSequences[newUUID] = make(map[int]struct{})
- p.sequence = 0
- }
break
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment