-
-
Save SuperQ/6a74834ff4c4855f1446c2ffeb117c84 to your computer and use it in GitHub Desktop.
A first pass at making a packet tracking object.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package ping | |
import ( | |
"sync" | |
"time" | |
"github.com/google/uuid" | |
) | |
type PacketTracker struct { | |
currentUUID uuid.UUID | |
packets map[uuid.UUID]PacketSequence | |
sequence int | |
nextSequence int | |
timeout time.Duration | |
timeoutCh chan *inFlightPacket | |
mutex sync.RWMutex | |
} | |
type PacketSequence struct { | |
packets map[uint]inFlightPacket | |
} | |
type inFlightPacket struct { | |
timeoutTimer *time.Timer | |
} | |
func NewPacketTracker(t time.Duration) *PacketTracker { | |
firstUUID := uuid.New() | |
var firstSequence = map[uuid.UUID]map[int]struct{}{} | |
firstSequence[firstUUID] = make(map[int]struct{}) | |
return &PacketTracker{ | |
packets: map[uuid.UUID]PacketSequence{}, | |
sequence: 0, | |
timeout: t, | |
} | |
} | |
func (t *PacketTracker) AddPacket() int { | |
t.mutex.Lock() | |
defer t.mutex.Unlock() | |
if t.nextSequence > 65535 { | |
newUUID := uuid.New() | |
t.packets[newUUID] = PacketSequence{} | |
t.currentUUID = newUUID | |
t.nextSequence = 0 | |
} | |
t.sequence = t.nextSequence | |
t.packets[t.currentUUID][t.sequence] = inFlightPacket{} | |
// if t.timeout > 0 { | |
// t.packets[t.currentUUID][t.sequence].timeoutTimer = time.Timer(t.timeout) | |
// } | |
t.nextSequence++ | |
return t.sequence | |
} | |
// DeletePacket removes a packet from the tracker. | |
func (t *PacketTracker) DeletePacket(u uuid.UUID, seq int) { | |
t.mutex.Lock() | |
defer t.mutex.Unlock() | |
if t.hasPacket(u, seq) { | |
if t.packets[u][seq] != nil { | |
t.packets[u][seq].timeoutTimer.Stop() | |
} | |
delete(t.packets[u], seq) | |
} | |
} | |
func (t *PacketTracker) hasPacket(u uuid.UUID, seq int) bool { | |
_, inflight := t.packets[u][seq] | |
return inflight | |
} | |
// HasPacket checks the tracker to see if it's currently tracking a packet. | |
func (t *PacketTracker) HasPacket(u uuid.UUID, seq int) bool { | |
t.mutex.RLock() | |
defer t.mutex.Unlock() | |
return t.hasPacket(u, seq) | |
} | |
func (t *PacketTracker) HasUUID(u uuid.UUID) bool { | |
_, hasUUID := t.packets[u] | |
return hasUUID | |
} | |
func (t *PacketTracker) CurrentUUID() uuid.UUID { | |
t.mutex.RLock() | |
defer t.mutex.Unlock() | |
return t.currentUUID | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/ping.go b/ping.go | |
index e1c09ee..0abb789 100644 | |
--- a/ping.go | |
+++ b/ping.go | |
@@ -87,9 +87,6 @@ var ( | |
// New returns a new Pinger struct pointer. | |
func New(addr string) *Pinger { | |
r := rand.New(rand.NewSource(getSeed())) | |
- firstUUID := uuid.New() | |
- var firstSequence = map[uuid.UUID]map[int]struct{}{} | |
- firstSequence[firstUUID] = make(map[int]struct{}) | |
return &Pinger{ | |
Count: -1, | |
Interval: time.Second, | |
@@ -97,17 +94,15 @@ func New(addr string) *Pinger { | |
Size: timeSliceLength + trackerLength, | |
Timeout: time.Duration(math.MaxInt64), | |
- addr: addr, | |
- done: make(chan interface{}), | |
- id: r.Intn(math.MaxUint16), | |
- trackerUUIDs: []uuid.UUID{firstUUID}, | |
- ipaddr: nil, | |
- ipv4: false, | |
- network: "ip", | |
- protocol: "udp", | |
- awaitingSequences: firstSequence, | |
- TTL: 64, | |
- logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())}, | |
+ addr: addr, | |
+ done: make(chan interface{}), | |
+ id: r.Intn(math.MaxUint16), | |
+ ipaddr: nil, | |
+ ipv4: false, | |
+ network: "ip", | |
+ protocol: "udp", | |
+ TTL: 64, | |
+ logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())}, | |
} | |
} | |
@@ -143,6 +138,9 @@ type Pinger struct { | |
// Number of duplicate packets received | |
PacketsRecvDuplicates int | |
+ // Per-packet timeout | |
+ PacketTimeout time.Duration | |
+ | |
// Round trip time statistics | |
minRtt time.Duration | |
maxRtt time.Duration | |
@@ -189,14 +187,11 @@ type Pinger struct { | |
ipaddr *net.IPAddr | |
addr string | |
- // trackerUUIDs is the list of UUIDs being used for sending packets. | |
- trackerUUIDs []uuid.UUID | |
- | |
ipv4 bool | |
id int | |
sequence int | |
- // awaitingSequences are in-flight sequence numbers we keep track of to help remove duplicate receipts | |
- awaitingSequences map[uuid.UUID]map[int]struct{} | |
+ // tracker is a PacketTrackrer of UUIDs and sequence numbers. | |
+ tracker *PacketTracker | |
// network is one of "ip", "ip4", or "ip6". | |
network string | |
// protocol is "icmp" or "udp". | |
@@ -413,6 +408,9 @@ func (p *Pinger) Run() error { | |
if err != nil { | |
return err | |
} | |
+ | |
+ p.tracker = NewPacketTracker(p.PacketTimeout) | |
+ | |
if conn, err = p.listen(); err != nil { | |
return err | |
} | |
@@ -616,19 +614,12 @@ func (p *Pinger) getPacketUUID(pkt []byte) (*uuid.UUID, error) { | |
return nil, fmt.Errorf("error decoding tracking UUID: %w", err) | |
} | |
- for _, item := range p.trackerUUIDs { | |
- if item == packetUUID { | |
- return &packetUUID, nil | |
- } | |
+ if p.tracker.HasUUID(packetUUID) { | |
+ return &packetUUID, nil | |
} | |
return nil, nil | |
} | |
-// getCurrentTrackerUUID grabs the latest tracker UUID. | |
-func (p *Pinger) getCurrentTrackerUUID() uuid.UUID { | |
- return p.trackerUUIDs[len(p.trackerUUIDs)-1] | |
-} | |
- | |
func (p *Pinger) processPacket(recv *packet) error { | |
receivedAt := time.Now() | |
var proto int | |
@@ -677,15 +668,15 @@ func (p *Pinger) processPacket(recv *packet) error { | |
inPkt.Rtt = receivedAt.Sub(timestamp) | |
inPkt.Seq = pkt.Seq | |
// If we've already received this sequence, ignore it. | |
- if _, inflight := p.awaitingSequences[*pktUUID][pkt.Seq]; !inflight { | |
+ if !p.tracker.HasPacket(*pktUUID, pkt.Seq) { | |
p.PacketsRecvDuplicates++ | |
if p.OnDuplicateRecv != nil { | |
p.OnDuplicateRecv(inPkt) | |
} | |
return nil | |
} | |
- // remove it from the list of sequences we're waiting for so we don't get duplicates. | |
- delete(p.awaitingSequences[*pktUUID], pkt.Seq) | |
+ // Remove it from the list of sequences we're waiting for so we don't get duplicates. | |
+ p.tracker.DeletePacket(*pktUUID, pkt.Seq) | |
p.updateStatistics(inPkt) | |
default: | |
// Very bad, not sure how this can happen | |
@@ -706,7 +697,7 @@ func (p *Pinger) sendICMP(conn packetConn) error { | |
dst = &net.UDPAddr{IP: p.ipaddr.IP, Zone: p.ipaddr.Zone} | |
} | |
- currentUUID := p.getCurrentTrackerUUID() | |
+ currentUUID := p.tracker.CurrentUUID() | |
uuidEncoded, err := currentUUID.MarshalBinary() | |
if err != nil { | |
return fmt.Errorf("unable to marshal UUID binary: %w", err) | |
@@ -754,15 +745,8 @@ func (p *Pinger) sendICMP(conn packetConn) error { | |
handler(outPkt) | |
} | |
// mark this sequence as in-flight | |
- p.awaitingSequences[currentUUID][p.sequence] = struct{}{} | |
+ p.sequence = p.tracker.AddPacket() | |
p.PacketsSent++ | |
- p.sequence++ | |
- if p.sequence > 65535 { | |
- newUUID := uuid.New() | |
- p.trackerUUIDs = append(p.trackerUUIDs, newUUID) | |
- p.awaitingSequences[newUUID] = make(map[int]struct{}) | |
- p.sequence = 0 | |
- } | |
break | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment