-
-
Save Gitart/fea6734e9dc0b606c067f6cc91de2276 to your computer and use it in GitHub Desktop.
get open connection num from a *http.Transport
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"sync" | |
"net/http" | |
"time" | |
"io/ioutil" | |
"fmt" | |
"net" | |
"context" | |
) | |
func main(){ | |
// here is the test code. copy/paste/use NewTransportWithConnectNum. | |
handler:=http.HandlerFunc(func(w http.ResponseWriter, req *http.Request){ | |
time.Sleep(time.Second) | |
_,_ = w.Write([]byte("success")) | |
}) | |
go func(){ | |
_ = http.ListenAndServe("127.0.0.1:10098",handler) | |
}() | |
tran:= NewTransportWithConnectNum(&http.Transport{}) | |
client:=&http.Client{ | |
Transport:tran, | |
} | |
wg:=sync.WaitGroup{} | |
for i:=0;i<10;i++{ | |
wg.Add(1) | |
go func(){ | |
resp,err:=client.Get("http://127.0.0.1:10098") | |
if err!=nil{ | |
panic(err) | |
} | |
_,_ = ioutil.ReadAll(resp.Body) | |
_ = resp.Body.Close() | |
wg.Done() | |
}() | |
} | |
wg.Wait() | |
fmt.Println(tran.GetConnectionNum()) | |
} | |
// you can wrap a *http.Transport to get open connection number. | |
// do not support if DialTLS is set. | |
func NewTransportWithConnectNum(old *http.Transport) *Transport{ | |
connectCounter:=&connectCounter_t{} | |
tran:=&Transport{ | |
Transport: old, | |
connectCounter: connectCounter, | |
} | |
oldDialer:=tran.getOldDialer() | |
tran.Transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error){ | |
conn1,err:=oldDialer(ctx,network,addr) | |
if err!=nil{ | |
return nil,err | |
} | |
connectCounter.add(1) | |
conn2:=&connWithCounter{ | |
Conn: conn1, | |
connectCounter: connectCounter, | |
} | |
return conn2,nil | |
} | |
return tran | |
} | |
type Transport struct{ | |
*http.Transport | |
connectCounter *connectCounter_t | |
} | |
func (tran *Transport) GetConnectionNum() int{ | |
return tran.connectCounter.get() | |
} | |
var zeroDialer net.Dialer | |
func (tran *Transport) getOldDialer() func(ctx context.Context, network, addr string) (net.Conn, error){ | |
if tran.DialContext!=nil{ | |
return tran.DialContext | |
} | |
if tran.Dial!=nil{ | |
return func(ctx context.Context, network, addr string) (net.Conn, error){ | |
return tran.Dial(network,addr) | |
} | |
} | |
return zeroDialer.DialContext | |
} | |
type connectCounter_t struct{ | |
connectNum int | |
connectNumLocker sync.Mutex | |
} | |
func (c *connectCounter_t) add(num int){ | |
c.connectNumLocker.Lock() | |
c.connectNum+=num | |
c.connectNumLocker.Unlock() | |
} | |
func (c *connectCounter_t) get()(num int){ | |
c.connectNumLocker.Lock() | |
num = c.connectNum | |
c.connectNumLocker.Unlock() | |
return num | |
} | |
type connWithCounter struct{ | |
net.Conn | |
closeCounterSyncOnce sync.Once | |
connectCounter *connectCounter_t | |
} | |
func (conn *connWithCounter) Close() (err error){ | |
err = conn.Conn.Close() | |
conn.closeCounterSyncOnce.Do(func(){ | |
conn.connectCounter.add(-1) | |
}) | |
return err | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment