Skip to content

Instantly share code, notes, and snippets.

@spikeekips
Created August 26, 2021 17:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save spikeekips/85470cfa3bef72bdfcefe491b1d0d3ff to your computer and use it in GitHub Desktop.
Save spikeekips/85470cfa3bef72bdfcefe491b1d0d3ff to your computer and use it in GitHub Desktop.
diff --git a/base/node/errors.go b/base/node/errors.go
new file mode 100644
index 0000000..c66df3b
--- /dev/null
+++ b/base/node/errors.go
@@ -0,0 +1,23 @@
+package node
+
+import "github.com/spikeekips/mitum/base"
+
+type NodeError struct {
+ err error
+ node base.Address
+}
+
+func NewNodeError(no base.Address, err error) NodeError {
+ return NodeError{
+ node: no,
+ err: err,
+ }
+}
+
+func (er NodeError) Error() string {
+ return er.err.Error()
+}
+
+func (er NodeError) Unwrap() error {
+ return er.err
+}
diff --git a/launch/deploy/context.go b/launch/deploy/context.go
index f5173d0..c775cb7 100644
--- a/launch/deploy/context.go
+++ b/launch/deploy/context.go
@@ -9,6 +9,7 @@ import (
var (
ContextValueDeployKeyStorage util.ContextKey = "deploy_key_storage"
ContextValueBlockDataCleaner util.ContextKey = "blockdata_cleaner"
+ ContextValueDeployHandler util.ContextKey = "deploy_handler"
)
func LoadDeployKeyStorageContextValue(ctx context.Context, l **DeployKeyStorage) error {
@@ -18,3 +19,7 @@ func LoadDeployKeyStorageContextValue(ctx context.Context, l **DeployKeyStorage)
func LoadBlockDataCleanerContextValue(ctx context.Context, l **BlockDataCleaner) error {
return util.LoadFromContextValue(ctx, ContextValueBlockDataCleaner, l)
}
+
+func LoadDeployHandler(ctx context.Context, l **DeployHandlers) error {
+ return util.LoadFromContextValue(ctx, ContextValueDeployHandler, l)
+}
diff --git a/launch/deploy/deploy_handlers.go b/launch/deploy/deploy_handlers.go
index 08db598..f2b55b4 100644
--- a/launch/deploy/deploy_handlers.go
+++ b/launch/deploy/deploy_handlers.go
@@ -9,7 +9,6 @@ import (
"github.com/rs/zerolog"
"github.com/spikeekips/mitum/launch/config"
"github.com/spikeekips/mitum/launch/process"
- "github.com/spikeekips/mitum/storage"
"github.com/spikeekips/mitum/util"
"github.com/spikeekips/mitum/util/encoder"
jsonenc "github.com/spikeekips/mitum/util/encoder/json"
@@ -21,7 +20,7 @@ var QuicHandlerPathSetBlockDataMaps = "/_deploy/blockdatamaps"
var RateLimitHandlerNameSetBlockDataMaps = "set-blockdatamaps"
-type baseDeployHandler struct {
+type BaseDeployHandler struct {
*logging.Logging
handler func(string) *mux.Route
handlerMap map[string][]process.RateLimitRule
@@ -30,11 +29,11 @@ type baseDeployHandler struct {
enc encoder.Encoder
}
-func newBaseDeployHandler(
+func NewBaseDeployHandler(
ctx context.Context,
name string,
handler func(string) *mux.Route,
-) (*baseDeployHandler, error) {
+) (*BaseDeployHandler, error) {
var log *logging.Logging
if err := config.LoadLogContextValue(ctx, &log); err != nil {
return nil, err
@@ -60,7 +59,7 @@ func newBaseDeployHandler(
return nil, err
}
- dh := &baseDeployHandler{
+ dh := &BaseDeployHandler{
Logging: logging.NewLogging(func(c zerolog.Context) zerolog.Context {
return c.Str("module", name)
}),
@@ -76,7 +75,7 @@ func newBaseDeployHandler(
return dh, nil
}
-func (dh *baseDeployHandler) rateLimit(name string, handler http.Handler) http.Handler {
+func (dh *BaseDeployHandler) RateLimit(name string, handler http.Handler) http.Handler {
i, found := dh.handlerMap[name]
if !found {
return handler
@@ -87,62 +86,27 @@ func (dh *baseDeployHandler) rateLimit(name string, handler http.Handler) http.H
).Middleware(handler)
}
-type deployHandlers struct {
- *baseDeployHandler
- db storage.Database
- bc *BlockDataCleaner
+type DeployHandlers struct {
+ *BaseDeployHandler
mw *DeployByKeyMiddleware
}
-func newDeployHandlers(ctx context.Context, handler func(string) *mux.Route) (*deployHandlers, error) {
- base, err := newBaseDeployHandler(ctx, "deploy-handlers", handler)
+func NewDeployHandlers(ctx context.Context, handler func(string) *mux.Route) (*DeployHandlers, error) {
+ base, err := NewBaseDeployHandler(ctx, "deploy-handlers", handler)
if err != nil {
return nil, err
}
- var db storage.Database
- if err := process.LoadDatabaseContextValue(ctx, &db); err != nil {
- return nil, err
- }
-
- var bc *BlockDataCleaner
- if err := LoadBlockDataCleanerContextValue(ctx, &bc); err != nil {
- return nil, err
- }
-
mw := NewDeployByKeyMiddleware(base.ks)
- dh := &deployHandlers{
- baseDeployHandler: base,
- db: db,
- bc: bc,
+ dh := &DeployHandlers{
+ BaseDeployHandler: base,
mw: mw,
}
return dh, nil
}
-func (dh *deployHandlers) setHandlers() error {
- setter := []func() error{
- dh.setSetBlockDataMaps,
- }
-
- for i := range setter {
- if err := setter[i](); err != nil {
- return err
- }
- }
-
- return nil
-}
-
-func (dh *deployHandlers) setSetBlockDataMaps() error {
- handler := dh.rateLimit(
- RateLimitHandlerNameSetBlockDataMaps,
- http.HandlerFunc(NewSetBlockDataMapsHandler(dh.enc, dh.db, dh.bc)),
- )
-
- _ = dh.handler(QuicHandlerPathSetBlockDataMaps).Handler(handler)
-
- return nil
+func (dh *DeployHandlers) SetHandler(prefix string, handler http.Handler) *mux.Route {
+ return dh.handler(prefix).Handler(dh.mw.Middleware(handler))
}
diff --git a/launch/deploy/deploy_key_handlers.go b/launch/deploy/deploy_key_handlers.go
index 5bf7100..f9b23d2 100644
--- a/launch/deploy/deploy_key_handlers.go
+++ b/launch/deploy/deploy_key_handlers.go
@@ -33,13 +33,13 @@ var (
)
type deployKeyHandlers struct {
- *baseDeployHandler
+ *BaseDeployHandler
cache cache.Cache
mw *DeployKeyByTokenMiddleware
}
func newDeployKeyHandlers(ctx context.Context, handler func(string) *mux.Route) (*deployKeyHandlers, error) {
- base, err := newBaseDeployHandler(ctx, "deploy-key-handlers", handler)
+ base, err := NewBaseDeployHandler(ctx, "deploy-key-handlers", handler)
if err != nil {
return nil, err
}
@@ -62,7 +62,7 @@ func newDeployKeyHandlers(ctx context.Context, handler func(string) *mux.Route)
mw := NewDeployKeyByTokenMiddleware(c, local.Privatekey().Publickey(), policy.NetworkID())
dh := &deployKeyHandlers{
- baseDeployHandler: base,
+ BaseDeployHandler: base,
cache: c,
mw: mw,
}
@@ -88,7 +88,7 @@ func (dh *deployKeyHandlers) setHandlers() error {
}
func (dh *deployKeyHandlers) setTokenHandler() error {
- handler := dh.rateLimit(
+ handler := dh.RateLimit(
RateLimitHandlerNameDeployKeyToken,
NewDeployKeyTokenHandler(dh.cache, DefaultDeployKeyTokenExpired),
)
@@ -119,7 +119,7 @@ func (dh *deployKeyHandlers) setKeyHandler() error {
}
func (dh *deployKeyHandlers) setKeysHandler() error {
- handler := dh.rateLimit(RateLimitHandlerNameDeployKeyKeys, http.HandlerFunc(NewDeployKeyKeysHandler(dh.ks, dh.enc)))
+ handler := dh.RateLimit(RateLimitHandlerNameDeployKeyKeys, http.HandlerFunc(NewDeployKeyKeysHandler(dh.ks, dh.enc)))
_ = dh.handler(QuicHandlerPathDeployKeyKeys).Handler(dh.mw.Middleware(handler))
@@ -127,7 +127,7 @@ func (dh *deployKeyHandlers) setKeysHandler() error {
}
func (dh *deployKeyHandlers) setKeyNewHandler() error {
- handler := dh.rateLimit(RateLimitHandlerNameDeployKeyNew, http.HandlerFunc(NewDeployKeyNewHandler(dh.ks, dh.enc)))
+ handler := dh.RateLimit(RateLimitHandlerNameDeployKeyNew, http.HandlerFunc(NewDeployKeyNewHandler(dh.ks, dh.enc)))
_ = dh.handler(QuicHandlerPathDeployKeyNew).Handler(dh.mw.Middleware(handler))
@@ -136,10 +136,10 @@ func (dh *deployKeyHandlers) setKeyNewHandler() error {
func (dh *deployKeyHandlers) keyHandler() network.HTTPHandlerFunc {
handler := NewDeployKeyKeyHandler(dh.ks, dh.enc)
- return dh.rateLimit(RateLimitHandlerNameDeployKeyKey, http.HandlerFunc(handler)).ServeHTTP
+ return dh.RateLimit(RateLimitHandlerNameDeployKeyKey, http.HandlerFunc(handler)).ServeHTTP
}
func (dh *deployKeyHandlers) keyRevokeHandler() network.HTTPHandlerFunc {
handler := NewDeployKeyRevokeHandler(dh.ks)
- return dh.rateLimit(RateLimitHandlerNameDeployKeyRevoke, http.HandlerFunc(handler)).ServeHTTP
+ return dh.RateLimit(RateLimitHandlerNameDeployKeyRevoke, http.HandlerFunc(handler)).ServeHTTP
}
diff --git a/launch/deploy/hook_deploy_handlers.go b/launch/deploy/hook_deploy_handlers.go
index dc645d2..e70e35b 100644
--- a/launch/deploy/hook_deploy_handlers.go
+++ b/launch/deploy/hook_deploy_handlers.go
@@ -3,11 +3,13 @@ package deploy
import (
"context"
"fmt"
+ "net/http"
"github.com/spikeekips/mitum/launch/config"
"github.com/spikeekips/mitum/launch/process"
"github.com/spikeekips/mitum/network"
quicnetwork "github.com/spikeekips/mitum/network/quic"
+ "github.com/spikeekips/mitum/storage"
"github.com/spikeekips/mitum/util/logging"
)
@@ -38,11 +40,30 @@ func HookDeployHandlers(ctx context.Context) (context.Context, error) {
return ctx, err
}
- if i, err := newDeployHandlers(ctx, qnt.Handler); err != nil {
+ return hookDefaultDeployHandlers(ctx, qnt)
+}
+
+func hookDefaultDeployHandlers(ctx context.Context, qnt *quicnetwork.Server) (context.Context, error) {
+ dh, err := NewDeployHandlers(ctx, qnt.Handler)
+ if err != nil {
+ return ctx, err
+ }
+
+ var db storage.Database
+ if err := process.LoadDatabaseContextValue(ctx, &db); err != nil {
return ctx, err
- } else if err := i.setHandlers(); err != nil {
+ }
+
+ var bc *BlockDataCleaner
+ if err := LoadBlockDataCleanerContextValue(ctx, &bc); err != nil {
return ctx, err
}
- return ctx, nil
+ setBlockDataMapsHandler := http.HandlerFunc(NewSetBlockDataMapsHandler(qnt.Encoder(), db, bc))
+ _ = dh.SetHandler(
+ QuicHandlerPathSetBlockDataMaps,
+ dh.RateLimit(RateLimitHandlerNameSetBlockDataMaps, setBlockDataMapsHandler),
+ )
+
+ return context.WithValue(ctx, ContextValueDeployHandler, dh), nil
}
diff --git a/launch/process/process_network.go b/launch/process/process_network.go
index f1c7c28..1d2ba52 100644
--- a/launch/process/process_network.go
+++ b/launch/process/process_network.go
@@ -48,6 +48,11 @@ func ProcessQuicNetwork(ctx context.Context) (context.Context, error) {
return ctx, err
}
+ var nodepool *network.Nodepool
+ if err := LoadNodepoolContextValue(ctx, &nodepool); err != nil {
+ return ctx, err
+ }
+
var l *logging.Logging
if err := config.LoadLogContextValue(ctx, &l); err != nil {
return ctx, err
@@ -63,7 +68,7 @@ func ProcessQuicNetwork(ctx context.Context) (context.Context, error) {
return ctx, err
}
- nt, err := NewNetworkServer(conf.Bind().Host, conf.Certs(), encs, ca, httpLog)
+ nt, err := NewNetworkServer(conf.Bind().Host, conf.Certs(), encs, ca, conf.ConnInfo(), nodepool, httpLog)
if err != nil {
return ctx, err
}
@@ -81,6 +86,8 @@ func NewNetworkServer(
certs []tls.Certificate,
encs *encoder.Encoders,
ca cache.Cache,
+ connInfo network.ConnInfo,
+ nodepool *network.Nodepool,
httpLog *logging.Logging,
) (network.Server, error) {
je, err := encs.Encoder(jsonenc.JSONEncoderType, "")
@@ -90,7 +97,7 @@ func NewNetworkServer(
if qs, err := quicnetwork.NewPrimitiveQuicServer(bind, certs, httpLog); err != nil {
return nil, err
- } else if nqs, err := quicnetwork.NewServer(qs, encs, je, ca); err != nil {
+ } else if nqs, err := quicnetwork.NewServer(qs, encs, je, ca, connInfo, nodepool.Passthroughs); err != nil {
return nil, err
} else if err := nqs.Initialize(); err != nil {
return nil, err
diff --git a/network/dummy_channel.go b/network/dummy_channel.go
index 1f77393..6787e6a 100644
--- a/network/dummy_channel.go
+++ b/network/dummy_channel.go
@@ -34,7 +34,7 @@ func (lc *DummyChannel) ConnInfo() ConnInfo {
return lc.connInfo
}
-func (lc *DummyChannel) SendSeal(_ context.Context, sl seal.Seal) error {
+func (lc *DummyChannel) SendSeal(_ context.Context, _ ConnInfo, sl seal.Seal) error {
if lc.newSealHandler == nil {
return lc.notSupported()
}
diff --git a/network/gochan/channel.go b/network/gochan/channel.go
index 23ba68c..478f4f9 100644
--- a/network/gochan/channel.go
+++ b/network/gochan/channel.go
@@ -18,7 +18,7 @@ import (
type Channel struct {
*logging.Logging
connInfo network.ConnInfo
- recvChan chan seal.Seal
+ recvChan chan network.PassthroughedSeal
getSealHandler network.GetSealsHandler
getState network.GetStateHandler
nodeInfo network.NodeInfoHandler
@@ -32,7 +32,7 @@ func NewChannel(bufsize uint, connInfo network.ConnInfo) *Channel {
return c.Str("module", "chan-network")
}),
connInfo: connInfo,
- recvChan: make(chan seal.Seal, bufsize),
+ recvChan: make(chan network.PassthroughedSeal, bufsize),
}
}
@@ -52,13 +52,13 @@ func (ch *Channel) Seals(_ context.Context, h []valuehash.Hash) ([]seal.Seal, er
return ch.getSealHandler(h)
}
-func (ch *Channel) SendSeal(_ context.Context, sl seal.Seal) error {
- ch.recvChan <- sl
+func (ch *Channel) SendSeal(_ context.Context, ci network.ConnInfo, sl seal.Seal) error {
+ ch.recvChan <- network.NewPassthroughedSealFromConnInfo(sl, ci)
return nil
}
-func (ch *Channel) ReceiveSeal() <-chan seal.Seal {
+func (ch *Channel) ReceiveSeal() <-chan network.PassthroughedSeal {
return ch.recvChan
}
diff --git a/network/gochan/channel_test.go b/network/gochan/channel_test.go
index 266bf4c..4ec76bc 100644
--- a/network/gochan/channel_test.go
+++ b/network/gochan/channel_test.go
@@ -27,7 +27,7 @@ func (t *testNetworkChanChannel) TestSendReceive() {
sl := seal.NewDummySeal(t.pk)
go func() {
- _ = gs.SendSeal(context.TODO(), sl)
+ _ = gs.SendSeal(context.TODO(), nil, sl)
}()
rsl := <-gs.ReceiveSeal()
diff --git a/network/gochan/server.go b/network/gochan/server.go
index 182620d..b4f60e1 100644
--- a/network/gochan/server.go
+++ b/network/gochan/server.go
@@ -2,6 +2,7 @@ package channetwork
import (
"context"
+ "time"
"github.com/rs/zerolog"
"github.com/spikeekips/mitum/base/seal"
@@ -15,14 +16,19 @@ type Server struct {
*util.ContextDaemon
newSealHandler network.NewSealHandler
ch *Channel
+ passthroughs func(context.Context, network.PassthroughedSeal, func(seal.Seal, network.Channel)) error
}
-func NewServer(ch *Channel) *Server {
+func NewServer(
+ ch *Channel,
+ passthroughs func(context.Context, network.PassthroughedSeal, func(seal.Seal, network.Channel)) error,
+) *Server {
sv := &Server{
Logging: logging.NewLogging(func(c zerolog.Context) zerolog.Context {
return c.Str("module", "network-chan-server")
}),
ch: ch,
+ passthroughs: passthroughs,
}
sv.ContextDaemon = util.NewContextDaemon("network-chan-server", sv.run)
@@ -59,7 +65,13 @@ end:
case <-ctx.Done():
break end
case sl := <-sv.ch.ReceiveSeal():
- go func(sl seal.Seal) {
+ go func(sl network.PassthroughedSeal) {
+ go func() {
+ if err := sv.doPassthroughs(sl); err != nil {
+ sv.Log().Error().Err(err).Msg("failed to passthroughs")
+ }
+ }()
+
if sv.newSealHandler == nil {
sv.Log().Error().Msg("no NewSealHandler")
return
@@ -77,3 +89,22 @@ end:
return nil
}
+
+func (sv *Server) doPassthroughs(sl network.PassthroughedSeal) error {
+ if sv.passthroughs == nil {
+ return nil
+ }
+
+ return sv.passthroughs(
+ context.Background(),
+ sl,
+ func(sl seal.Seal, ch network.Channel) {
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
+ defer cancel()
+
+ if err := ch.SendSeal(ctx, sv.ch.ConnInfo(), sl); err != nil {
+ sv.Log().Error().Err(err).Stringer("remote", ch.ConnInfo()).Msg("failed to send seal")
+ }
+ },
+ )
+}
diff --git a/network/gochan/server_test.go b/network/gochan/server_test.go
index 509e64f..6e110b7 100644
--- a/network/gochan/server_test.go
+++ b/network/gochan/server_test.go
@@ -12,7 +12,7 @@ type testChanServer struct {
}
func (t *testChanServer) TestNew() {
- s := NewServer(nil)
+ s := NewServer(nil, nil)
t.Implements((*network.Server)(nil), s)
}
diff --git a/network/network.go b/network/network.go
index 4f67564..b6f8874 100644
--- a/network/network.go
+++ b/network/network.go
@@ -52,7 +52,7 @@ type Channel interface {
util.Initializer
ConnInfo() ConnInfo
Seals(context.Context, []valuehash.Hash) ([]seal.Seal, error)
- SendSeal(context.Context, seal.Seal) error
+ SendSeal(context.Context, ConnInfo /* from ConnInfo */, seal.Seal) error
NodeInfo(context.Context) (NodeInfo, error)
BlockDataMaps(context.Context, []base.Height) ([]block.BlockDataMap, error)
BlockData(context.Context, block.BlockDataMapItem) (io.ReadCloser, error)
diff --git a/network/nodepool.go b/network/nodepool.go
index 8acf212..ec3075c 100644
--- a/network/nodepool.go
+++ b/network/nodepool.go
@@ -1,12 +1,17 @@
package network
import (
+ "context"
+ "fmt"
"sync"
+ "time"
"github.com/pkg/errors"
"github.com/spikeekips/mitum/base"
"github.com/spikeekips/mitum/base/node"
+ "github.com/spikeekips/mitum/base/seal"
"github.com/spikeekips/mitum/util"
+ "golang.org/x/sync/semaphore"
)
// Nodepool contains all the known nodes including local node.
@@ -16,6 +21,8 @@ type Nodepool struct {
localch Channel
nodes map[string]base.Node
chs map[string]Channel
+ pts map[string]Channel // NOTE passthrough
+ ptsfilter map[string]func(seal.Seal) bool
}
func NewNodepool(local *node.Local, ch Channel) *Nodepool {
@@ -29,6 +36,8 @@ func NewNodepool(local *node.Local, ch Channel) *Nodepool {
chs: map[string]Channel{
addr: ch,
},
+ pts: map[string]Channel{},
+ ptsfilter: map[string]func(seal.Seal) bool{},
}
}
@@ -173,18 +182,190 @@ func (np *Nodepool) TraverseAliveRemotes(callback func(base.Node, Channel) bool)
}
}
-func (np *Nodepool) Addresses() []base.Address {
- nodes := make([]base.Address, np.Len())
+func (np *Nodepool) Broadcast(
+ ctx context.Context,
+ sl seal.Seal,
+ filter func(base.Node) bool,
+) []error {
+ var localci ConnInfo
+ if ch := np.LocalChannel(); ch != nil {
+ if ci := ch.ConnInfo(); ci != nil {
+ localci = ci
+ }
+ }
- var i int
- np.Traverse(func(n base.Node, _ Channel) bool {
- nodes[i] = n.Address()
- i++
+ sem := semaphore.NewWeighted(100)
+
+ failedch := make(chan error)
+ send := func(ctx context.Context, no base.Node, ch Channel, sl seal.Seal) error {
+ nctx, cancel := context.WithTimeout(ctx, time.Second*5)
+ defer cancel()
+
+ err := ch.SendSeal(nctx, localci, sl)
+ sem.Release(1)
+
+ if err == nil {
+ return nil
+ }
+
+ if no == nil {
+ err = node.NewNodeError(no.Address(), err)
+ } else {
+ err = fmt.Errorf("failed to passthrough to %q: %w", ch.ConnInfo(), err)
+ }
+
+ return err
+ }
+
+ donech := make(chan bool)
+ var errs []error // nolint:prealloc
+ go func() {
+ for err := range failedch {
+ errs = append(errs, err)
+ }
+
+ donech <- true
+ }()
+
+ np.TraverseAliveRemotes(func(no base.Node, ch Channel) bool {
+ if !filter(no) {
+ return true
+ }
+
+ if err := sem.Acquire(ctx, 1); err != nil {
+ go func() {
+ failedch <- err
+ }()
+
+ return false
+ }
+
+ if err := send(ctx, no, ch, sl); err != nil {
+ go func() {
+ failedch <- err
+ }()
+ }
return true
})
- return nodes
+ // NOTE passthrough
+ if err := np.Passthroughs(
+ ctx,
+ NewPassthroughedSealFromConnInfo(sl, localci),
+ func(sl seal.Seal, ch Channel) {
+ if err := send(ctx, nil, ch, sl); err != nil {
+ go func() {
+ failedch <- err
+ }()
+ }
+ },
+ ); err != nil {
+ failedch <- err
+ }
+
+ if err := sem.Acquire(ctx, 100); err != nil {
+ if !errors.Is(err, context.Canceled) {
+ return []error{err}
+ }
+ }
+
+ close(failedch)
+
+ <-donech
+
+ return errs
+}
+
+func (np *Nodepool) SetPassthrough(ch Channel, filter func(seal.Seal) bool) error {
+ np.Lock()
+ defer np.Unlock()
+
+ ci := ch.ConnInfo()
+ if ci == nil {
+ return fmt.Errorf("nil ConnInfo")
+ }
+
+ if _, found := np.pts[ci.String()]; found {
+ // NOTE update filter
+ np.ptsfilter[ci.String()] = filter
+
+ return nil
+ }
+
+ for i := range np.chs {
+ if np.chs[i] == nil {
+ continue
+ }
+
+ eci := np.chs[i].ConnInfo()
+ if eci == nil {
+ continue
+ }
+
+ if eci.Equal(ci) {
+ return util.FoundError.Errorf("already in ndoes")
+ }
+ }
+
+ if filter == nil {
+ filter = func(seal.Seal) bool { return true }
+ }
+
+ np.pts[ci.String()] = ch
+ np.ptsfilter[ci.String()] = filter
+
+ return nil
+}
+
+func (np *Nodepool) RemovePassthrough(s string) error {
+ np.Lock()
+ defer np.Unlock()
+
+ if _, found := np.pts[s]; !found {
+ return util.NotFoundError.Call()
+ }
+
+ delete(np.pts, s)
+ delete(np.ptsfilter, s)
+
+ return nil
+}
+
+func (np *Nodepool) Passthroughs(ctx context.Context, sl PassthroughedSeal, callback func(seal.Seal, Channel)) error {
+ np.RLock()
+ defer np.RUnlock()
+
+ sem := semaphore.NewWeighted(100)
+
+ for i := range np.pts {
+ ch := np.pts[i]
+ if from := sl.FromConnInfo(); ch.ConnInfo().String() == from {
+ continue
+ }
+
+ if !np.ptsfilter[i](sl) {
+ continue
+ }
+
+ if err := sem.Acquire(ctx, 1); err != nil {
+ return err
+ }
+
+ go func() {
+ defer sem.Release(1)
+
+ callback(sl, ch)
+ }()
+ }
+
+ if err := sem.Acquire(ctx, 100); err != nil {
+ if !errors.Is(err, context.Canceled) {
+ return err
+ }
+ }
+
+ return nil
}
func (np *Nodepool) exists(address base.Address) bool {
@@ -224,3 +405,28 @@ func (np *Nodepool) nc(filterLocal bool) ([]base.Node, []Channel) {
return nodes, channels
}
+
+type PassthroughedSeal struct {
+ seal.Seal
+ fromconnInfo string
+}
+
+func NewPassthroughedSealFromConnInfo(sl seal.Seal, ci ConnInfo) PassthroughedSeal {
+ var s string
+ if ci != nil {
+ s = ci.String()
+ }
+
+ return NewPassthroughedSeal(sl, s)
+}
+
+func NewPassthroughedSeal(sl seal.Seal, ci string) PassthroughedSeal {
+ return PassthroughedSeal{
+ Seal: sl,
+ fromconnInfo: ci,
+ }
+}
+
+func (sl PassthroughedSeal) FromConnInfo() string {
+ return sl.fromconnInfo
+}
diff --git a/network/nodepool_test.go b/network/nodepool_test.go
index d1d67d7..3e51c0b 100644
--- a/network/nodepool_test.go
+++ b/network/nodepool_test.go
@@ -1,11 +1,15 @@
package network
import (
+ "context"
+ "fmt"
"testing"
"github.com/pkg/errors"
"github.com/spikeekips/mitum/base"
+ "github.com/spikeekips/mitum/base/key"
"github.com/spikeekips/mitum/base/node"
+ "github.com/spikeekips/mitum/base/seal"
"github.com/spikeekips/mitum/util"
"github.com/stretchr/testify/suite"
)
@@ -140,6 +144,119 @@ func (t *testNodepool) TestTraverse() {
}
}
+func (t *testNodepool) TestAddPassthrough() {
+ ns := NewNodepool(t.local, nil)
+
+ for i := 0; i < 10; i++ {
+ ch := NilConnInfoChannel(fmt.Sprintf("ch%d", i))
+ t.NoError(ns.SetPassthrough(ch, nil))
+ }
+}
+
+func (t *testNodepool) TestAddPassthroughButInNodes() {
+ ns := NewNodepool(t.local, nil)
+ n0 := node.RandomLocal("n0")
+ ch0 := NilConnInfoChannel("n0")
+
+ t.NoError(ns.Add(n0, ch0))
+
+ err := ns.SetPassthrough(ch0, nil)
+ t.True(errors.Is(err, util.FoundError))
+}
+
+func (t *testNodepool) TestAddPassthroughButNoConnInfo() {
+ ns := NewNodepool(t.local, nil)
+ ch0 := NewDummyChannel(nil)
+
+ err := ns.SetPassthrough(ch0, nil)
+ t.Contains(err.Error(), "nil ConnInfo")
+}
+
+func (t *testNodepool) TestRemovePassthrough() {
+ ns := NewNodepool(t.local, nil)
+
+ chs := make([]Channel, 10)
+ for i := 0; i < 10; i++ {
+ ch := NilConnInfoChannel(fmt.Sprintf("ch%d", i))
+ t.NoError(ns.SetPassthrough(ch, nil))
+
+ chs[i] = ch
+ }
+
+ t.NoError(ns.RemovePassthrough(chs[3].ConnInfo().String()))
+ err := ns.RemovePassthrough(chs[3].ConnInfo().String())
+ t.True(errors.Is(err, util.NotFoundError))
+
+ t.Equal(9, len(ns.pts))
+}
+
+func (t *testNodepool) TestPassthroughs() {
+ ns := NewNodepool(t.local, nil)
+
+ ch0 := NilConnInfoChannel("n0")
+ t.NoError(ns.SetPassthrough(ch0, nil))
+
+ ch1 := NilConnInfoChannel("n1")
+ t.NoError(ns.SetPassthrough(ch1, nil))
+
+ pk := key.MustNewBTCPrivatekey()
+ sl := NewPassthroughedSeal(seal.NewDummySeal(pk), "")
+
+ passedch := make(chan [2]interface{}, 2)
+ ns.Passthroughs(context.Background(), sl, func(sl seal.Seal, ch Channel) {
+ passedch <- [2]interface{}{ch.ConnInfo().String(), sl}
+ })
+
+ close(passedch)
+
+ passed := map[string]seal.Seal{}
+ for i := range passedch {
+ pci := i[0].(string)
+ psl := i[1].(seal.Seal)
+
+ passed[pci] = psl
+ }
+
+ t.Equal(2, len(passed))
+ t.True(passed[ch0.ConnInfo().String()].Hash().Equal(sl.Hash()))
+ t.True(passed[ch1.ConnInfo().String()].Hash().Equal(sl.Hash()))
+}
+
+func (t *testNodepool) TestPassthroughsFilter() {
+ ns := NewNodepool(t.local, nil)
+
+ ch0 := NilConnInfoChannel("n0")
+ t.NoError(ns.SetPassthrough(ch0, nil))
+
+ ch1 := NilConnInfoChannel("n1")
+ t.NoError(ns.SetPassthrough(ch1, func(sl seal.Seal) bool {
+ return false
+ }))
+
+ pk := key.MustNewBTCPrivatekey()
+ sl := NewPassthroughedSeal(seal.NewDummySeal(pk), "")
+
+ passedch := make(chan [2]interface{}, 2)
+ ns.Passthroughs(context.Background(), sl, func(sl seal.Seal, ch Channel) {
+ passedch <- [2]interface{}{ch.ConnInfo().String(), sl}
+ })
+
+ close(passedch)
+
+ passed := map[string]seal.Seal{}
+ for i := range passedch {
+ pci := i[0].(string)
+ psl := i[1].(seal.Seal)
+
+ passed[pci] = psl
+ }
+
+ t.Equal(1, len(passed))
+ t.True(passed[ch0.ConnInfo().String()].Hash().Equal(sl.Hash()))
+ _, found := passed[ch1.ConnInfo().String()]
+ t.False(found)
+}
+
func TestNodepool(t *testing.T) {
suite.Run(t, new(testNodepool))
}
diff --git a/network/quic/channel.go b/network/quic/channel.go
index 9225196..f50a6a6 100644
--- a/network/quic/channel.go
+++ b/network/quic/channel.go
@@ -120,7 +120,7 @@ func (ch *Channel) Seals(ctx context.Context, hs []valuehash.Hash) ([]seal.Seal,
return seals, nil
}
-func (ch *Channel) SendSeal(ctx context.Context, sl seal.Seal) error {
+func (ch *Channel) SendSeal(ctx context.Context, ci network.ConnInfo, sl seal.Seal) error {
timeout := network.ChannelTimeoutSendSeal
ctx, cancel := ch.timeoutContext(ctx, timeout)
defer cancel()
@@ -134,6 +134,9 @@ func (ch *Channel) SendSeal(ctx context.Context, sl seal.Seal) error {
headers := http.Header{}
headers.Set(QuicEncoderHintHeader, ch.enc.Hint().String())
+ if ci != nil {
+ headers.Set(SendSealFromConnInfoHeader, ci.String())
+ }
res, err := ch.client.Send(ctx, timeout*2, ch.sendSealURL, b, headers)
if err != nil {
diff --git a/network/quic/primitive_server.go b/network/quic/primitive_server.go
index 1eef89d..6cfde32 100644
--- a/network/quic/primitive_server.go
+++ b/network/quic/primitive_server.go
@@ -17,7 +17,7 @@ import (
"github.com/spikeekips/mitum/util/logging"
)
-const QuicEncoderHintHeader string = "x-mitum-encoder-hint"
+const QuicEncoderHintHeader string = "X-MITUM-ENCODER-HINT"
type PrimitiveQuicServer struct {
*logging.Logging
diff --git a/network/quic/server.go b/network/quic/server.go
index 6465cd1..283528a 100644
--- a/network/quic/server.go
+++ b/network/quic/server.go
@@ -2,6 +2,7 @@ package quicnetwork
import (
"bytes"
+ "context"
"io"
"net/http"
"net/url"
@@ -41,6 +42,10 @@ var LimitRequestByHeights = 20 // max number of reqeust heights
var cacheKeyNodeInfo = [2]byte{0x00, 0x00}
+const (
+ SendSealFromConnInfoHeader string = "X-MITUM-FROM-CONNINFO"
+)
+
type Server struct {
*logging.Logging
*PrimitiveQuicServer
@@ -54,12 +59,16 @@ type Server struct {
blockDataHandler network.BlockDataHandler
cache cache.Cache
rg *singleflight.Group
+ connInfo network.ConnInfo
+ passthroughs func(context.Context, network.PassthroughedSeal, func(seal.Seal, network.Channel)) error
}
func NewServer(
prim *PrimitiveQuicServer,
encs *encoder.Encoders, enc encoder.Encoder,
ca cache.Cache,
+ connInfo network.ConnInfo,
+ passthroughs func(context.Context, network.PassthroughedSeal, func(seal.Seal, network.Channel)) error,
) (*Server, error) {
if ca == nil {
ca = cache.Dummy{}
@@ -74,6 +83,8 @@ func NewServer(
enc: enc,
cache: ca,
rg: &singleflight.Group{},
+ connInfo: connInfo,
+ passthroughs: passthroughs,
}
nqs.setHandlers()
@@ -192,11 +203,6 @@ func (sv *Server) handleGetSeals(w http.ResponseWriter, r *http.Request) {
}
func (sv *Server) handleNewSeal(w http.ResponseWriter, r *http.Request) {
- if sv.newSealHandler == nil {
- network.HTTPError(w, http.StatusInternalServerError)
- return
- }
-
body := &bytes.Buffer{}
if _, err := io.Copy(body, r.Body); err != nil {
sv.Log().Error().Err(err).Msg("failed to read post body")
@@ -220,6 +226,17 @@ func (sv *Server) handleNewSeal(w http.ResponseWriter, r *http.Request) {
return
}
+ go func() {
+ if err := sv.doPassthroughs(r, sl); err != nil {
+ sv.Log().Error().Err(err).Msg("failed to passthroughs")
+ }
+ }()
+
+ if sv.newSealHandler == nil {
+ network.HTTPError(w, http.StatusInternalServerError)
+ return
+ }
+
// NOTE if already received, returns 200
if sv.hasSealHandler != nil {
if found, err := sv.hasSealHandler(sl.Hash()); err != nil {
@@ -424,6 +441,25 @@ func (sv *Server) logNilHanders() {
sv.Log().Debug().Strs("enabled", enables).Strs("disabled", disables).Msg("check handler")
}
+func (sv *Server) doPassthroughs(r *http.Request, sl seal.Seal) error {
+ if sv.passthroughs == nil {
+ return nil
+ }
+
+ return sv.passthroughs(
+ context.Background(),
+ network.NewPassthroughedSeal(sl, strings.TrimSpace(r.Header.Get(SendSealFromConnInfoHeader))),
+ func(sl seal.Seal, ch network.Channel) {
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
+ defer cancel()
+
+ if err := ch.SendSeal(ctx, sv.connInfo, sl); err != nil {
+ sv.Log().Error().Err(err).Stringer("remote", ch.ConnInfo()).Msg("failed to passthrough seal")
+ }
+ },
+ )
+}
+
func mustQuicURL(u, p string) (string, *url.URL) {
uu, err := network.ParseURL(u, false)
if err != nil {
diff --git a/network/quic/server_test.go b/network/quic/server_test.go
index b1795db..6a3173e 100644
--- a/network/quic/server_test.go
+++ b/network/quic/server_test.go
@@ -76,7 +76,7 @@ func (t *testQuicServer) readyServer() *Server {
ca, err := cache.NewGCache("lru", 100, time.Second*3)
t.NoError(err)
- qn, err := NewServer(qs, t.encs, t.enc, ca)
+ qn, err := NewServer(qs, t.encs, t.enc, ca, t.connInfo, nil)
t.NoError(err)
t.NoError(qn.Start())
@@ -106,7 +106,7 @@ func (t *testQuicServer) TestNew() {
qs, err := NewPrimitiveQuicServer(t.bind, t.certs, nil)
t.NoError(err)
- qn, err := NewServer(qs, t.encs, t.enc, nil)
+ qn, err := NewServer(qs, t.encs, t.enc, nil, t.connInfo, nil)
t.NoError(err)
t.Implements((*network.Server)(nil), qn)
@@ -129,7 +129,7 @@ func (t *testQuicServer) TestSendSeal() {
sl := seal.NewDummySeal(key.MustNewBTCPrivatekey())
- t.NoError(qc.SendSeal(context.TODO(), sl))
+ t.NoError(qc.SendSeal(context.TODO(), nil, sl))
select {
case <-time.After(time.Second):
@@ -148,7 +148,7 @@ func (t *testQuicServer) TestSendSeal() {
return true, nil
})
- t.NoError(qc.SendSeal(context.TODO(), sl))
+ t.NoError(qc.SendSeal(context.TODO(), nil, sl))
}
func (t *testQuicServer) TestGetSeals() {
@@ -390,6 +390,80 @@ func (t *testQuicServer) TestGetBlockData() {
t.Equal(data, b)
}
+func (t *testQuicServer) TestPassthroughs() {
+ qn := t.readyServer()
+ defer qn.Stop()
+
+ qc, err := NewChannel(t.connInfo, 2, nil, t.encs, t.enc)
+ t.NoError(err)
+ t.Implements((*network.Channel)(nil), qc)
+
+ // attach Nodepool
+ local := node.RandomLocal("local")
+ ns := network.NewNodepool(local, qc)
+
+ ch0 := network.NilConnInfoChannel("n0")
+ t.NoError(ns.SetPassthrough(ch0, nil))
+
+ passedch := make(chan seal.Seal, 10)
+ ch0.SetNewSealHandler(func(sl seal.Seal) error {
+ passedch <- sl
+ return nil
+ })
+
+ qn.passthroughs = ns.Passthroughs
+
+ sl := seal.NewDummySeal(key.MustNewBTCPrivatekey())
+
+ t.NoError(qc.SendSeal(context.TODO(), t.connInfo, sl))
+
+ select {
+ case <-time.After(time.Second):
+ t.NoError(errors.Errorf("failed to receive respond"))
+ case r := <-passedch:
+ t.Equal(sl.Hint(), r.Hint())
+ t.True(sl.Hash().Equal(r.Hash()))
+ t.True(sl.BodyHash().Equal(r.BodyHash()))
+ t.True(sl.Signer().Equal(r.Signer()))
+ t.Equal(sl.Signature(), r.Signature())
+ t.True(localtime.Equal(sl.SignedAt(), r.SignedAt()))
+ }
+}
+
+func (t *testQuicServer) TestPassthroughsFilterFrom() {
+ qn := t.readyServer()
+ defer qn.Stop()
+
+ qc, err := NewChannel(t.connInfo, 2, nil, t.encs, t.enc)
+ t.NoError(err)
+ t.Implements((*network.Channel)(nil), qc)
+
+ // attach Nodepool
+ local := node.RandomLocal("local")
+ ns := network.NewNodepool(local, qc)
+
+ ch0 := network.NilConnInfoChannel("n0")
+ t.NoError(ns.SetPassthrough(ch0, nil))
+
+ passedch := make(chan seal.Seal, 10)
+ ch0.SetNewSealHandler(func(sl seal.Seal) error {
+ passedch <- sl
+ return nil
+ })
+
+ qn.passthroughs = ns.Passthroughs
+
+ sl := seal.NewDummySeal(key.MustNewBTCPrivatekey())
+
+ t.NoError(qc.SendSeal(context.TODO(), ch0.ConnInfo(), sl))
+
+ select {
+ case <-time.After(time.Second):
+ case <-passedch:
+ t.NoError(errors.Errorf("seal should be filtered"))
+ }
+}
+
func TestQuicServer(t *testing.T) {
suite.Run(t, new(testQuicServer))
}
diff --git a/network/tests.go b/network/tests.go
index cd60fcd..fe63b64 100644
--- a/network/tests.go
+++ b/network/tests.go
@@ -44,3 +44,7 @@ func CompareNodeInfo(t *testing.T, a, b NodeInfo) {
assert.Equal(t, as[i].Insecure, bs[i].Insecure)
}
}
+
+func NilConnInfoChannel(s string) *DummyChannel {
+ return NewDummyChannel(NewNilConnInfo(s))
+}
diff --git a/states/basic/state.go b/states/basic/state.go
index e3aa31d..efe7d24 100644
--- a/states/basic/state.go
+++ b/states/basic/state.go
@@ -131,7 +131,9 @@ func (st *BaseState) BroadcastBallot(blt ballot.Ballot, toLocal bool) error {
return st.broadcastSealsFunc(blt, toLocal)
}
- return st.States.BroadcastBallot(blt, toLocal)
+ st.States.BroadcastBallot(blt, toLocal)
+
+ return nil
}
func (st *BaseState) Timers() *localtime.Timers {
diff --git a/states/basic/states.go b/states/basic/states.go
index 5566ddd..6d36907 100644
--- a/states/basic/states.go
+++ b/states/basic/states.go
@@ -255,22 +255,12 @@ func (ss *States) NewProposal(proposal ballot.Proposal) {
// BroadcastBallot broadcast seal to the known nodes,
// - suffrage nodes
// - if toLocal is true, sends to local
-func (ss *States) BroadcastBallot(blt ballot.Ballot, toLocal bool) error {
- return ss.broadcast(blt, toLocal, func(node base.Node) bool {
+func (ss *States) BroadcastBallot(blt ballot.Ballot, toLocal bool) {
+ go ss.broadcast(blt, toLocal, func(node base.Node) bool {
return ss.suffrage.IsInside(node.Address())
})
}
-// BroadcastSeals broadcast seal to the known nodes,
-// - suffrage nodes
-// - and other nodes
-// - if toLocal is true, sends to local
-func (ss *States) BroadcastSeals(sl seal.Seal, toLocal bool) error {
- return ss.broadcast(sl, toLocal, func(base.Node) bool {
- return true
- })
-}
-
func (ss *States) BlockSavedHook() *pm.Hooks {
return ss.blockSavedHook
}
@@ -823,18 +813,16 @@ func (ss *States) broadcastOperationSealToSuffrageNodes(sl operation.Seal) {
return
}
- if err := ss.broadcast(sl, false, func(node base.Node) bool {
+ go ss.broadcast(sl, false, func(node base.Node) bool {
return ss.suffrage.IsInside(node.Address())
- }); err != nil {
- ss.Log().Error().Err(err).Msg("problem to broadcast operation.Seal to suffrage nodes")
- }
+ })
}
func (ss *States) broadcast(
sl seal.Seal,
toLocal bool,
filter func(node base.Node) bool,
-) error {
+) {
l := ss.Log().With().Stringer("seal_hash", sl.Hash()).Logger()
l.Debug().Dict("seal", LogSeal(sl)).Bool("to_local", toLocal).Msg("broadcasting seal")
@@ -848,32 +836,11 @@ func (ss *States) broadcast(
}()
}
- if ss.nodepool.LenRemoteAlives() < 1 {
- return nil
- }
-
// NOTE broadcast nodes of Nodepool, including suffrage nodes
- var targets int
- ss.nodepool.TraverseAliveRemotes(func(no base.Node, ch network.Channel) bool {
- if !filter(no) {
- return true
- }
-
- targets++
-
- go func(no base.Node, ch network.Channel) {
- ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
- defer cancel()
-
- if err := ch.SendSeal(ctx, sl); err != nil {
- l.Error().Err(err).Stringer("target_node", no.Address()).Msg("failed to broadcast")
- }
- }(no, ch)
-
- return true
- })
-
+ switch failed := ss.nodepool.Broadcast(context.Background(), sl, filter); {
+ case len(failed) > 0:
+ l.Error().Errs("failed_nodes", failed).Msg("something wrong to broadcast seal")
+ default:
l.Debug().Msg("seal broadcasted")
-
- return nil
+ }
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment