Skip to content

Instantly share code, notes, and snippets.

@cstockton
Created October 15, 2017 21:02
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save cstockton/5f2fea4ace0f099b4908a115409f2da3 to your computer and use it in GitHub Desktop.
Save cstockton/5f2fea4ace0f099b4908a115409f2da3 to your computer and use it in GitHub Desktop.
issue-22197
new file mode 100644
index 0000000..365941e
--- /dev/null
+++ b/src/cmd/compile/internal/ssa/chansendn.go
@@ -0,0 +1,282 @@
+package ssa
+
+import (
+ "bytes"
+ "cmd/compile/internal/types"
+ "cmd/internal/obj"
+ "fmt"
+ "strings"
+)
+
+func (s indVar) String() string {
+ return fmt.Sprintf("indVar in %v: %v", s.entry, s.ind)
+}
+
+func (s indVar) LongString() string {
+ var b bytes.Buffer
+ for _, val := range s.entry.Values {
+ b.WriteString(" " + val.LongString() + "\n")
+ }
+ vals := b.String()
+ return fmt.Sprintf("indVar in %v:\n for %v;\n %v < %v;\n %v INC(%v) {\n%v}",
+ s.entry.LongString(), s.ind.LongString(), s.min.LongString(), s.max.LongString(),
+ s.nxt.LongString(), s.inc.LongString(), vals)
+}
+
+// chansendn takes calls to runtime.chansend1 and converts them to runtime
+// chansendn1, which will perform multiple non-blocking channel sends in a single
+// call under the same mutex. If not all values may be sent it returns the number
+// remaining, allow to transform a loop body such as:
+//
+// loop:
+// ind = (Phi min nxt)
+// sendVal = v
+// if ind < max
+// then goto enter_loop
+// else goto exit_loop
+//
+// enter_loop:
+// CallStatic chansend1 [sendVal]
+// nxt = inc + ind
+// goto loop
+// exit_loop:
+//
+// Into:
+//
+// loop:
+// ind = (Phi min nxt),
+// sendVal = v
+// if ind < max
+// then goto enter_loop
+// else goto exit_loop
+//
+// enter_loop:
+// n = CallStatic chansendn1 [sendVal]
+// ind = n
+// goto loop
+// exit_loop:
+//
+// loop:
+// n = chansendn1(...)
+// if n > 0:
+// goto loop
+//
+// Conditions:
+//
+// 1. phi name must have enabledChansendn prefix [POC constraint]
+// 2. phi must have a loop body
+// 3. loop body must have a single [POC constraint] integer induction var
+// with a maximum value > 1
+// 4. loop condition increment must be +1 [POC constraint]
+// 5. loop must contain a single blocking channel send [POC constraint] on a
+// channel with a capacity > 1.
+// 6. channel send must be the same value each call
+// (prove it does not change within the loop body)
+//
+// i.e.:
+//
+// Eligible
+//
+// for i := 0; i < count; i++ {
+// ch <- struct{}{}
+// }
+//
+// Ineligble, I don't think it _has_ to be, but currently I'm avoiding it because
+// the generated code would be more complex
+//
+// for i := 0; i < count; i++ {
+// select {
+// case ch <- struct{}{}:
+// default:
+// }
+// }
+//
+// Ineligble: (could be allowed, but I imagine this rarely occurs)
+//
+// ch <- struct{}{}
+// ch <- struct{}{}
+// ch <- struct{}{}
+//
+// Note: this is just a poc and is far from correct, served as a learning
+// experience and introduction to this code. I make a lot of assumptions based
+// on observation of a small sampling, not on correct API usage.
+func chansendn(f *Func) {
+
+ // 1. phi name must have enabledChansendn prefix [POC constraint]
+ if !strings.HasPrefix(f.Name, `enableChansendn`) {
+ return
+ }
+
+ fmt.Println(`[chansendn] name enabled chansendn:`, f.Name)
+ printFunc(f)
+
+ // 2. phi must have a loop body
+ ivList := findIndVar(f)
+ if len(ivList) == 0 {
+ fmt.Println(` (ineligible) has no induction variables`)
+ return
+ }
+
+ for _, iv := range ivList {
+ if chansendnCheckIndVar(f, iv) {
+ return
+ }
+ }
+}
+
+func chansendnCheckIndVar(f *Func, iv indVar) (rewrote bool) {
+ // 3. loop body must have a single integer induction variable
+ if !iv.ind.Type.IsInteger() {
+ fmt.Println(` (ineligible) induction var must be integer`, iv.ind.Type)
+ return
+ }
+
+ // ... with a maximum value > 1
+ if iv.max.AuxInt <= 1 {
+ fmt.Println(` (ineligible) induction var max must be +1`, iv.max.AuxInt)
+ return
+ }
+
+ // 4. loop condition increment must be +1
+ if iv.inc.AuxInt != 1 {
+ fmt.Println(` (ineligible) induction var inc must be +1`, iv.inc.AuxInt)
+ return
+ }
+
+ // 5. loop must contain a single blocking channel send [POC constraint]
+ var call, callArg, chanVal, chanEl *Value
+ for i, v := range iv.entry.Values {
+
+ // must be a call static op to runtime.chansend1
+ if !isChansendCall(v) {
+ fmt.Println(` not a chansend call:`, v.LongString())
+ continue
+ }
+
+ // must contain a single send
+ if call != nil {
+ fmt.Println(` (ineligible) multiple channel sends are not supported`)
+ return
+ }
+
+ // ... must not have an operation that could change the control flow of the
+ // I naively prove this now by just allowing the operations needed for a
+ // single call static to chansend1. This helps with 6. too.
+ switch v.Op {
+ case
+ OpStaticCall, OpVarDef, OpStore, OpAddr:
+ default:
+ fmt.Println(` (ineligible) contains an unsupported op:`, v.Op)
+ return
+ }
+
+ // can have at most 1 more op, which must be VarKill for our static call.
+ if len(iv.entry.Values) <= i+1 || iv.entry.Values[i+1].Op != OpVarKill {
+ fmt.Println(` (ineligible) contains operations after chan send`)
+ }
+ call, callArg = v, v.Args[0]
+
+ // 6. channel send must be the same value each call
+ // (prove it does not change within the loop body)
+ var ok bool
+ if chanVal, chanEl, ok = chansendProveAndFind(callArg, iv.entry.Values, i); !ok {
+ fmt.Println(` (ineligible) unable to prove call argument`, callArg)
+ return
+ }
+ break // all conditions met
+ }
+
+ if call == nil || callArg == nil || chanVal == nil || chanEl == nil {
+ fmt.Println(` (ineligible) one or more conditions not satisfied`)
+ return
+ }
+ fmt.Println(` (eligible) rewriting to chansendn`)
+
+ // Rewrite:
+ // CallStatic chansend1 [sendVal]
+ // nxt = inc + ind
+ //
+ // Into:
+ // n = CallStatic chansendn1 [sendVal]
+ // ind = n
+ //
+ // This is difficult so I'm not doing it for just a POC, the layout for the
+ // call is different and I need to learn how to properly setup a return value.
+ rewrote = true
+
+ // chsendn := f.fe.Syslook(`chansendn1`)
+ // newCallArg = ...
+ // newCall = chsendn newCallArg
+ // ??
+ f.invalidateCFG()
+ return
+}
+
+// 6. channel send must be the same value each call
+// (prove it does not change within the loop body)
+func chansendProveAndFind(v *Value, blockVals []*Value, idx int) (chanVal, chanEl *Value, ok bool) {
+ if v.Op != OpStore {
+ fmt.Println(` (ineligible) arg did not originate from a Store`)
+ return
+ }
+ if len(v.Args) != 3 {
+ fmt.Println(` (ineligible) store op did not have 3 args`)
+ return
+ }
+ if v.Args[1].Op != OpAddr || v.Args[2].Op != OpStore {
+ fmt.Println(` (ineligible) setup for CallStatic may not be safe value for multiple sends`)
+ for i, arg := range v.Args {
+ fmt.Println(` arg:`, i, arg.LongString())
+ }
+ return
+ }
+
+ // ch is the 2nd arg in the Store
+ ch := v.Args[2].Args[1]
+ typ := ch.Type
+
+ // I don't know how to access the capacity of the channel through this API yet.
+ ok = typ.IsChan() && (typ.ChanDir() == types.Csend || typ.ChanDir() == types.Cboth)
+ if !ok {
+ fmt.Println(` (ineligible) chan type is not supported:`, ch)
+ return
+ }
+
+ chanEl, chanVal = v.Args[1], ch
+ ok = true
+ return
+}
+
+func isChansendEl(v *Value) bool {
+ if v.Aux == nil || v.Op != OpStore {
+ return false
+ }
+
+ ok := v.Type.IsChan() && (v.Type.ChanDir() == types.Csend || v.Type.ChanDir() == types.Cboth)
+ // ch := v.Type.ChanType()
+
+ lsym, ok := v.Aux.(*obj.LSym)
+ return ok && lsym.Name == `runtime.chansend1`
+}
+
+func isChansendCall(v *Value) bool {
+ switch v.Op {
+ case
+ OpStaticCall:
+
+ // These are after rewrite passes, I'm not sure this code could support
+ // being run at this phase.
+ //
+ // Op386CALLstatic, OpAMD64CALLstatic, OpARMCALLstatic, OpARM64CALLstatic,
+ // OpMIPSCALLstatic, OpMIPS64CALLstatic, OpPPC64CALLstatic, OpS390XCALLstatic:
+ return true
+ }
+
+ if v.Aux == nil || len(v.Args) != 1 {
+ // not a valid chansend call
+ return false
+ }
+
+ lsym, ok := v.Aux.(*obj.LSym)
+ return ok && lsym.Name == `runtime.chansend1`
+}
diff --git a/src/cmd/compile/internal/gc/builtin.go b/src/cmd/compile/internal/gc/builtin.go
index f21a4da..77efd85 100644
--- a/src/cmd/compile/internal/gc/builtin.go
+++ b/src/cmd/compile/internal/gc/builtin.go
@@ -98,57 +98,58 @@ var runtimeDecls = [...]struct {
{"chanrecv1", funcTag, 74},
{"chanrecv2", funcTag, 75},
{"chansend1", funcTag, 77},
+ {"chansendn1", funcTag, 78},
{"closechan", funcTag, 23},
- {"writeBarrier", varTag, 79},
- {"writebarrierptr", funcTag, 80},
- {"typedmemmove", funcTag, 81},
- {"typedmemclr", funcTag, 82},
- {"typedslicecopy", funcTag, 83},
- {"selectnbsend", funcTag, 84},
- {"selectnbrecv", funcTag, 85},
- {"selectnbrecv2", funcTag, 87},
- {"newselect", funcTag, 88},
- {"selectsend", funcTag, 89},
- {"selectrecv", funcTag, 90},
+ {"writeBarrier", varTag, 80},
+ {"writebarrierptr", funcTag, 81},
+ {"typedmemmove", funcTag, 82},
+ {"typedmemclr", funcTag, 83},
+ {"typedslicecopy", funcTag, 84},
+ {"selectnbsend", funcTag, 85},
+ {"selectnbrecv", funcTag, 86},
+ {"selectnbrecv2", funcTag, 88},
+ {"newselect", funcTag, 89},
+ {"selectsend", funcTag, 90},
+ {"selectrecv", funcTag, 91},
{"selectdefault", funcTag, 56},
- {"selectgo", funcTag, 91},
+ {"selectgo", funcTag, 92},
{"block", funcTag, 5},
- {"makeslice", funcTag, 93},
- {"makeslice64", funcTag, 94},
- {"growslice", funcTag, 95},
- {"memmove", funcTag, 96},
- {"memclrNoHeapPointers", funcTag, 97},
- {"memclrHasPointers", funcTag, 97},
- {"memequal", funcTag, 98},
- {"memequal8", funcTag, 99},
- {"memequal16", funcTag, 99},
- {"memequal32", funcTag, 99},
- {"memequal64", funcTag, 99},
- {"memequal128", funcTag, 99},
- {"int64div", funcTag, 100},
- {"uint64div", funcTag, 101},
- {"int64mod", funcTag, 100},
- {"uint64mod", funcTag, 101},
- {"float64toint64", funcTag, 102},
- {"float64touint64", funcTag, 103},
- {"float64touint32", funcTag, 105},
- {"int64tofloat64", funcTag, 106},
- {"uint64tofloat64", funcTag, 107},
- {"uint32tofloat64", funcTag, 108},
- {"complex128div", funcTag, 109},
- {"racefuncenter", funcTag, 110},
+ {"makeslice", funcTag, 94},
+ {"makeslice64", funcTag, 95},
+ {"growslice", funcTag, 96},
+ {"memmove", funcTag, 97},
+ {"memclrNoHeapPointers", funcTag, 98},
+ {"memclrHasPointers", funcTag, 98},
+ {"memequal", funcTag, 99},
+ {"memequal8", funcTag, 100},
+ {"memequal16", funcTag, 100},
+ {"memequal32", funcTag, 100},
+ {"memequal64", funcTag, 100},
+ {"memequal128", funcTag, 100},
+ {"int64div", funcTag, 101},
+ {"uint64div", funcTag, 102},
+ {"int64mod", funcTag, 101},
+ {"uint64mod", funcTag, 102},
+ {"float64toint64", funcTag, 103},
+ {"float64touint64", funcTag, 104},
+ {"float64touint32", funcTag, 106},
+ {"int64tofloat64", funcTag, 107},
+ {"uint64tofloat64", funcTag, 108},
+ {"uint32tofloat64", funcTag, 109},
+ {"complex128div", funcTag, 110},
+ {"racefuncenter", funcTag, 111},
{"racefuncexit", funcTag, 5},
- {"raceread", funcTag, 110},
- {"racewrite", funcTag, 110},
- {"racereadrange", funcTag, 111},
- {"racewriterange", funcTag, 111},
- {"msanread", funcTag, 111},
- {"msanwrite", funcTag, 111},
+ {"raceread", funcTag, 111},
+ {"racewrite", funcTag, 111},
+ {"racereadrange", funcTag, 112},
+ {"racewriterange", funcTag, 112},
+ {"msanread", funcTag, 112},
+ {"msanwrite", funcTag, 112},
{"support_popcnt", varTag, 11},
}
func runtimeTypes() []*types.Type {
- var typs [112]*types.Type
+ var typs [113]*types.Type
typs[0] = types.Bytetype
typs[1] = types.NewPtr(typs[0])
typs[2] = types.Types[TANY]
@@ -227,39 +228,40 @@ func runtimeTypes() []*types.Type {
typs[75] = functype(nil, []*Node{anonfield(typs[73]), anonfield(typs[3])}, []*Node{anonfield(typs[11])})
typs[76] = types.NewChan(typs[2], types.Csend)
typs[77] = functype(nil, []*Node{anonfield(typs[76]), anonfield(typs[3])}, nil)
- typs[78] = types.NewArray(typs[0], 3)
- typs[79] = tostruct([]*Node{namedfield("enabled", typs[11]), namedfield("pad", typs[78]), namedfield("needed", typs[11]), namedfield("cgo", typs[11]), namedfield("alignme", typs[17])})
- typs[80] = functype(nil, []*Node{anonfield(typs[3]), anonfield(typs[2])}, nil)
- typs[81] = functype(nil, []*Node{anonfield(typs[1]), anonfield(typs[3]), anonfield(typs[3])}, nil)
- typs[82] = functype(nil, []*Node{anonfield(typs[1]), anonfield(typs[3])}, nil)
- typs[83] = functype(nil, []*Node{anonfield(typs[1]), anonfield(typs[2]), anonfield(typs[2])}, []*Node{anonfield(typs[32])})
- typs[84] = functype(nil, []*Node{anonfield(typs[76]), anonfield(typs[3])}, []*Node{anonfield(typs[11])})
- typs[85] = functype(nil, []*Node{anonfield(typs[3]), anonfield(typs[73])}, []*Node{anonfield(typs[11])})
- typs[86] = types.NewPtr(typs[11])
- typs[87] = functype(nil, []*Node{anonfield(typs[3]), anonfield(typs[86]), anonfield(typs[73])}, []*Node{anonfield(typs[11])})
- typs[88] = functype(nil, []*Node{anonfield(typs[1]), anonfield(typs[15]), anonfield(typs[8])}, nil)
- typs[89] = functype(nil, []*Node{anonfield(typs[1]), anonfield(typs[76]), anonfield(typs[3])}, nil)
- typs[90] = functype(nil, []*Node{anonfield(typs[1]), anonfield(typs[73]), anonfield(typs[3]), anonfield(typs[86])}, nil)
- typs[91] = functype(nil, []*Node{anonfield(typs[1])}, []*Node{anonfield(typs[32])})
- typs[92] = types.NewSlice(typs[2])
- typs[93] = functype(nil, []*Node{anonfield(typs[1]), anonfield(typs[32]), anonfield(typs[32])}, []*Node{anonfield(typs[92])})
- typs[94] = functype(nil, []*Node{anonfield(typs[1]), anonfield(typs[15]), anonfield(typs[15])}, []*Node{anonfield(typs[92])})
- typs[95] = functype(nil, []*Node{anonfield(typs[1]), anonfield(typs[92]), anonfield(typs[32])}, []*Node{anonfield(typs[92])})
- typs[96] = functype(nil, []*Node{anonfield(typs[3]), anonfield(typs[3]), anonfield(typs[49])}, nil)
- typs[97] = functype(nil, []*Node{anonfield(typs[58]), anonfield(typs[49])}, nil)
- typs[98] = functype(nil, []*Node{anonfield(typs[3]), anonfield(typs[3]), anonfield(typs[49])}, []*Node{anonfield(typs[11])})
- typs[99] = functype(nil, []*Node{anonfield(typs[3]), anonfield(typs[3])}, []*Node{anonfield(typs[11])})
- typs[100] = functype(nil, []*Node{anonfield(typs[15]), anonfield(typs[15])}, []*Node{anonfield(typs[15])})
- typs[101] = functype(nil, []*Node{anonfield(typs[17]), anonfield(typs[17])}, []*Node{anonfield(typs[17])})
- typs[102] = functype(nil, []*Node{anonfield(typs[13])}, []*Node{anonfield(typs[15])})
- typs[103] = functype(nil, []*Node{anonfield(typs[13])}, []*Node{anonfield(typs[17])})
- typs[104] = types.Types[TUINT32]
- typs[105] = functype(nil, []*Node{anonfield(typs[13])}, []*Node{anonfield(typs[104])})
- typs[106] = functype(nil, []*Node{anonfield(typs[15])}, []*Node{anonfield(typs[13])})
- typs[107] = functype(nil, []*Node{anonfield(typs[17])}, []*Node{anonfield(typs[13])})
- typs[108] = functype(nil, []*Node{anonfield(typs[104])}, []*Node{anonfield(typs[13])})
- typs[109] = functype(nil, []*Node{anonfield(typs[19]), anonfield(typs[19])}, []*Node{anonfield(typs[19])})
- typs[110] = functype(nil, []*Node{anonfield(typs[49])}, nil)
- typs[111] = functype(nil, []*Node{anonfield(typs[49]), anonfield(typs[49])}, nil)
+ typs[78] = functype(nil, []*Node{anonfield(typs[76]), anonfield(typs[3])}, []*Node{anonfield(typs[32])})
+ typs[79] = types.NewArray(typs[0], 3)
+ typs[80] = tostruct([]*Node{namedfield("enabled", typs[11]), namedfield("pad", typs[79]), namedfield("needed", typs[11]), namedfield("cgo", typs[11]), namedfield("alignme", typs[17])})
+ typs[81] = functype(nil, []*Node{anonfield(typs[3]), anonfield(typs[2])}, nil)
+ typs[82] = functype(nil, []*Node{anonfield(typs[1]), anonfield(typs[3]), anonfield(typs[3])}, nil)
+ typs[83] = functype(nil, []*Node{anonfield(typs[1]), anonfield(typs[3])}, nil)
+ typs[84] = functype(nil, []*Node{anonfield(typs[1]), anonfield(typs[2]), anonfield(typs[2])}, []*Node{anonfield(typs[32])})
+ typs[85] = functype(nil, []*Node{anonfield(typs[76]), anonfield(typs[3])}, []*Node{anonfield(typs[11])})
+ typs[86] = functype(nil, []*Node{anonfield(typs[3]), anonfield(typs[73])}, []*Node{anonfield(typs[11])})
+ typs[87] = types.NewPtr(typs[11])
+ typs[88] = functype(nil, []*Node{anonfield(typs[3]), anonfield(typs[87]), anonfield(typs[73])}, []*Node{anonfield(typs[11])})
+ typs[89] = functype(nil, []*Node{anonfield(typs[1]), anonfield(typs[15]), anonfield(typs[8])}, nil)
+ typs[90] = functype(nil, []*Node{anonfield(typs[1]), anonfield(typs[76]), anonfield(typs[3])}, nil)
+ typs[91] = functype(nil, []*Node{anonfield(typs[1]), anonfield(typs[73]), anonfield(typs[3]), anonfield(typs[87])}, nil)
+ typs[92] = functype(nil, []*Node{anonfield(typs[1])}, []*Node{anonfield(typs[32])})
+ typs[93] = types.NewSlice(typs[2])
+ typs[94] = functype(nil, []*Node{anonfield(typs[1]), anonfield(typs[32]), anonfield(typs[32])}, []*Node{anonfield(typs[93])})
+ typs[95] = functype(nil, []*Node{anonfield(typs[1]), anonfield(typs[15]), anonfield(typs[15])}, []*Node{anonfield(typs[93])})
+ typs[96] = functype(nil, []*Node{anonfield(typs[1]), anonfield(typs[93]), anonfield(typs[32])}, []*Node{anonfield(typs[93])})
+ typs[97] = functype(nil, []*Node{anonfield(typs[3]), anonfield(typs[3]), anonfield(typs[49])}, nil)
+ typs[98] = functype(nil, []*Node{anonfield(typs[58]), anonfield(typs[49])}, nil)
+ typs[99] = functype(nil, []*Node{anonfield(typs[3]), anonfield(typs[3]), anonfield(typs[49])}, []*Node{anonfield(typs[11])})
+ typs[100] = functype(nil, []*Node{anonfield(typs[3]), anonfield(typs[3])}, []*Node{anonfield(typs[11])})
+ typs[101] = functype(nil, []*Node{anonfield(typs[15]), anonfield(typs[15])}, []*Node{anonfield(typs[15])})
+ typs[102] = functype(nil, []*Node{anonfield(typs[17]), anonfield(typs[17])}, []*Node{anonfield(typs[17])})
+ typs[103] = functype(nil, []*Node{anonfield(typs[13])}, []*Node{anonfield(typs[15])})
+ typs[104] = functype(nil, []*Node{anonfield(typs[13])}, []*Node{anonfield(typs[17])})
+ typs[105] = types.Types[TUINT32]
+ typs[106] = functype(nil, []*Node{anonfield(typs[13])}, []*Node{anonfield(typs[105])})
+ typs[107] = functype(nil, []*Node{anonfield(typs[15])}, []*Node{anonfield(typs[13])})
+ typs[108] = functype(nil, []*Node{anonfield(typs[17])}, []*Node{anonfield(typs[13])})
+ typs[109] = functype(nil, []*Node{anonfield(typs[105])}, []*Node{anonfield(typs[13])})
+ typs[110] = functype(nil, []*Node{anonfield(typs[19]), anonfield(typs[19])}, []*Node{anonfield(typs[19])})
+ typs[111] = functype(nil, []*Node{anonfield(typs[49])}, nil)
+ typs[112] = functype(nil, []*Node{anonfield(typs[49]), anonfield(typs[49])}, nil)
return typs[:]
}
diff --git a/src/cmd/compile/internal/gc/builtin/runtime.go b/src/cmd/compile/internal/gc/builtin/runtime.go
index 7f4846d..14efe35 100644
--- a/src/cmd/compile/internal/gc/builtin/runtime.go
+++ b/src/cmd/compile/internal/gc/builtin/runtime.go
@@ -120,6 +120,7 @@ func makechan(chanType *byte, hint int64) (hchan chan any)
func chanrecv1(hchan <-chan any, elem *any)
func chanrecv2(hchan <-chan any, elem *any) bool
func chansend1(hchan chan<- any, elem *any)
+func chansendn1(hchan chan<- any, elem *any) int
func closechan(hchan any)
var writeBarrier struct {
diff --git a/src/cmd/compile/internal/gc/go.go b/src/cmd/compile/internal/gc/go.go
index b1ead93..f0a133a 100644
--- a/src/cmd/compile/internal/gc/go.go
+++ b/src/cmd/compile/internal/gc/go.go
@@ -287,6 +287,7 @@ var (
assertE2I2,
assertI2I,
assertI2I2,
+ chansendn1,
goschedguarded,
writeBarrier,
writebarrierptr,
diff --git a/src/cmd/compile/internal/gc/ssa.go b/src/cmd/compile/internal/gc/ssa.go
index 9c1b3ca..de3768f 100644
--- a/src/cmd/compile/internal/gc/ssa.go
+++ b/src/cmd/compile/internal/gc/ssa.go
@@ -87,6 +87,7 @@ func initssaconfig() {
assertI2I = Sysfunc("assertI2I")
assertI2I2 = Sysfunc("assertI2I2")
goschedguarded = Sysfunc("goschedguarded")
+ chansendn1 = Sysfunc("chansendn1")
writeBarrier = Sysfunc("writeBarrier")
writebarrierptr = Sysfunc("writebarrierptr")
typedmemmove = Sysfunc("typedmemmove")
@@ -5066,6 +5067,8 @@ func (e *ssafn) Syslook(name string) *obj.LSym {
switch name {
case "goschedguarded":
return goschedguarded
+ case "chansendn1":
+ return chansendn1
case "writeBarrier":
return writeBarrier
case "writebarrierptr":
diff --git a/src/cmd/compile/internal/ssa/compile.go b/src/cmd/compile/internal/ssa/compile.go
index 315416b..4195ef9 100644
--- a/src/cmd/compile/internal/ssa/compile.go
+++ b/src/cmd/compile/internal/ssa/compile.go
@@ -343,6 +343,7 @@ var passes = [...]pass{
{name: "nilcheckelim", fn: nilcheckelim},
{name: "prove", fn: prove},
{name: "loopbce", fn: loopbce},
+ {name: "chansendn", fn: chansendn, required: true},
{name: "decompose builtin", fn: decomposeBuiltIn, required: true},
{name: "dec", fn: dec, required: true},
{name: "late opt", fn: opt, required: true}, // TODO: split required rules and optimizing rules
diff --git a/src/runtime/chan.go b/src/runtime/chan.go
index 6294678..4e6693e 100644
--- a/src/runtime/chan.go
+++ b/src/runtime/chan.go
@@ -113,6 +113,115 @@ func chansend1(c *hchan, elem unsafe.Pointer) {
chansend(c, elem, true, getcallerpc(unsafe.Pointer(&c)))
}
+// entry point for a blocking chan send of same value:
+//
+// FOR { c <- x }
+//
+// implemented as:
+//
+// FOR {
+// for n := chansendn1(...); n > 0; n-- {
+// chansend1(...)
+// }
+// }
+//
+//go:nosplit
+func chansendn1(c *hchan, elem unsafe.Pointer, n int) int {
+ return chansendn(c, elem, n, getcallerpc(unsafe.Pointer(&c)))
+}
+
+// chansendn will perform at least 1 possibly blocking send, with a maximum of
+// n total non-blocking sends under the same mutex.
+func chansendn(c *hchan, ep unsafe.Pointer, n int, callerpc uintptr) int {
+ if c == nil {
+ gopark(nil, nil, "chan send (nil chan)", traceEvGoStop, 2)
+ throw("unreachable")
+ }
+ if debugChan {
+ print("chansendn: chan=", c, "\n")
+ }
+ if raceenabled {
+ racereadpc(unsafe.Pointer(c), callerpc, funcPC(chansendn))
+ }
+ if c.closed == 0 && ((c.dataqsiz == 0 && c.recvq.first == nil) ||
+ (c.dataqsiz > 0 && c.qcount == c.dataqsiz)) {
+ return n
+ }
+
+ var t0 int64
+ if blockprofilerate > 0 {
+ t0 = cputicks()
+ }
+
+ y := n
+ lock(&c.lock)
+ if c.closed != 0 {
+ unlock(&c.lock)
+ panic(plainError("send on closed channel"))
+ }
+ for sg := c.recvq.dequeue(); sg != nil; sg = c.recvq.dequeue() {
+ send(c, sg, ep, func() {}, 3)
+ y--
+ }
+ for c.qcount < c.dataqsiz && y > 0 {
+ qp := chanbuf(c, c.sendx)
+ if raceenabled {
+ raceacquire(qp)
+ racerelease(qp)
+ }
+ typedmemmove(c.elemtype, qp, ep)
+ c.sendx++
+ if c.sendx == c.dataqsiz {
+ c.sendx = 0
+ }
+ c.qcount++
+ y--
+ }
+
+ // We made no progress, so we will fall back to a blocking send.
+ if n == y {
+ // Block on the channel. Some receiver will complete our operation for us.
+ gp := getg()
+ mysg := acquireSudog()
+ mysg.releasetime = 0
+ if t0 != 0 {
+ mysg.releasetime = -1
+ }
+ // No stack splits between assigning elem and enqueuing mysg
+ // on gp.waiting where copystack can find it.
+ mysg.elem = ep
+ mysg.waitlink = nil
+ mysg.g = gp
+ mysg.selectdone = nil
+ mysg.c = c
+ gp.waiting = mysg
+ gp.param = nil
+ c.sendq.enqueue(mysg)
+ goparkunlock(&c.lock, "chan send", traceEvGoBlockSend, 3)
+
+ // someone woke us up.
+ if mysg != gp.waiting {
+ throw("G waiting list is corrupted")
+ }
+ gp.waiting = nil
+ if gp.param == nil {
+ if c.closed == 0 {
+ throw("chansend: spurious wakeup")
+ }
+ panic(plainError("send on closed channel"))
+ }
+ gp.param = nil
+ if mysg.releasetime > 0 {
+ blockevent(mysg.releasetime-t0, 2)
+ }
+ mysg.c = nil
+ releaseSudog(mysg)
+ }
+
+ unlock(&c.lock)
+ return n
+}
+
/*
* generic single channel send/recv
* If block is not nil,
package ssa
import (
"bytes"
"cmd/compile/internal/types"
"cmd/internal/obj"
"fmt"
"strings"
)
func (s indVar) String() string {
return fmt.Sprintf("indVar in %v: %v", s.entry, s.ind)
}
func (s indVar) LongString() string {
var b bytes.Buffer
for _, val := range s.entry.Values {
b.WriteString(" " + val.LongString() + "\n")
}
vals := b.String()
return fmt.Sprintf("indVar in %v:\n for %v;\n %v < %v;\n %v INC(%v) {\n%v}",
s.entry.LongString(), s.ind.LongString(), s.min.LongString(), s.max.LongString(),
s.nxt.LongString(), s.inc.LongString(), vals)
}
// chansendn takes calls to runtime.chansend1 and converts them to runtime
// chansendn1, which will perform multiple non-blocking channel sends in a single
// call under the same mutex. If not all values may be sent it returns the number
// remaining, allow to transform a loop body such as:
//
// loop:
// ind = (Phi min nxt)
// sendVal = v
// if ind < max
// then goto enter_loop
// else goto exit_loop
//
// enter_loop:
// CallStatic chansend1 [sendVal]
// nxt = inc + ind
// goto loop
// exit_loop:
//
// Into:
//
// loop:
// ind = (Phi min nxt),
// sendVal = v
// if ind < max
// then goto enter_loop
// else goto exit_loop
//
// enter_loop:
// n = CallStatic chansendn1 [sendVal]
// ind = n
// goto loop
// exit_loop:
//
// loop:
// n = chansendn1(...)
// if n > 0:
// goto loop
//
// Conditions:
//
// 1. phi name must have enabledChansendn prefix [POC constraint]
// 2. phi must have a loop body
// 3. loop body must have a single [POC constraint] integer induction var
// with a maximum value > 1
// 4. loop condition increment must be +1 [POC constraint]
// 5. loop must contain a single blocking channel send [POC constraint] on a
// channel with a capacity > 1.
// 6. channel send must be the same value each call
// (prove it does not change within the loop body)
//
// i.e.:
//
// Eligible
//
// for i := 0; i < count; i++ {
// ch <- struct{}{}
// }
//
// Ineligble, I don't think it _has_ to be, but currently I'm avoiding it because
// the generated code would be more complex
//
// for i := 0; i < count; i++ {
// select {
// case ch <- struct{}{}:
// default:
// }
// }
//
// Ineligble: (could be allowed, but I imagine this rarely occurs)
//
// ch <- struct{}{}
// ch <- struct{}{}
// ch <- struct{}{}
//
// Note: this is just a poc and is far from correct, served as a learning
// experience and introduction to this code. I make a lot of assumptions based
// on observation of a small sampling, not on correct API usage.
func chansendn(f *Func) {
// 1. phi name must have enabledChansendn prefix [POC constraint]
if !strings.HasPrefix(f.Name, `enableChansendn`) {
return
}
fmt.Println(`[chansendn] name enabled chansendn:`, f.Name)
printFunc(f)
// 2. phi must have a loop body
ivList := findIndVar(f)
if len(ivList) == 0 {
fmt.Println(` (ineligible) has no induction variables`)
return
}
for _, iv := range ivList {
if chansendnCheckIndVar(f, iv) {
return
}
}
}
func chansendnCheckIndVar(f *Func, iv indVar) (rewrote bool) {
// 3. loop body must have a single integer induction variable
if !iv.ind.Type.IsInteger() {
fmt.Println(` (ineligible) induction var must be integer`, iv.ind.Type)
return
}
// ... with a maximum value > 1
if iv.max.AuxInt <= 1 {
fmt.Println(` (ineligible) induction var max must be +1`, iv.max.AuxInt)
return
}
// 4. loop condition increment must be +1
if iv.inc.AuxInt != 1 {
fmt.Println(` (ineligible) induction var inc must be +1`, iv.inc.AuxInt)
return
}
// 5. loop must contain a single blocking channel send [POC constraint]
var call, callArg, chanVal, chanEl *Value
for i, v := range iv.entry.Values {
// must be a call static op to runtime.chansend1
if !isChansendCall(v) {
fmt.Println(` not a chansend call:`, v.LongString())
continue
}
// must contain a single send
if call != nil {
fmt.Println(` (ineligible) multiple channel sends are not supported`)
return
}
// ... must not have an operation that could change the control flow of the
// I naively prove this now by just allowing the operations needed for a
// single call static to chansend1. This helps with 6. too.
switch v.Op {
case
OpStaticCall, OpVarDef, OpStore, OpAddr:
default:
fmt.Println(` (ineligible) contains an unsupported op:`, v.Op)
return
}
// can have at most 1 more op, which must be VarKill for our static call.
if len(iv.entry.Values) <= i+1 || iv.entry.Values[i+1].Op != OpVarKill {
fmt.Println(` (ineligible) contains operations after chan send`)
}
call, callArg = v, v.Args[0]
// 6. channel send must be the same value each call
// (prove it does not change within the loop body)
var ok bool
if chanVal, chanEl, ok = chansendProveAndFind(callArg, iv.entry.Values, i); !ok {
fmt.Println(` (ineligible) unable to prove call argument`, callArg)
return
}
break // all conditions met
}
if call == nil || callArg == nil || chanVal == nil || chanEl == nil {
fmt.Println(` (ineligible) one or more conditions not satisfied`)
return
}
fmt.Println(` (eligible) rewriting to chansendn`)
// Rewrite:
// CallStatic chansend1 [sendVal]
// nxt = inc + ind
//
// Into:
// n = CallStatic chansendn1 [sendVal]
// ind = n
//
// This is difficult so I'm not doing it for just a POC, the layout for the
// call is different and I need to learn how to properly setup a return value.
rewrote = true
// chsendn := f.fe.Syslook(`chansendn1`)
// newCallArg = ...
// newCall = chsendn newCallArg
// ??
f.invalidateCFG()
return
}
// 6. channel send must be the same value each call
// (prove it does not change within the loop body)
func chansendProveAndFind(v *Value, blockVals []*Value, idx int) (chanVal, chanEl *Value, ok bool) {
if v.Op != OpStore {
fmt.Println(` (ineligible) arg did not originate from a Store`)
return
}
if len(v.Args) != 3 {
fmt.Println(` (ineligible) store op did not have 3 args`)
return
}
if v.Args[1].Op != OpAddr || v.Args[2].Op != OpStore {
fmt.Println(` (ineligible) setup for CallStatic may not be safe value for multiple sends`)
for i, arg := range v.Args {
fmt.Println(` arg:`, i, arg.LongString())
}
return
}
// ch is the 2nd arg in the Store
ch := v.Args[2].Args[1]
typ := ch.Type
// I don't know how to access the capacity of the channel through this API yet.
ok = typ.IsChan() && (typ.ChanDir() == types.Csend || typ.ChanDir() == types.Cboth)
if !ok {
fmt.Println(` (ineligible) chan type is not supported:`, ch)
return
}
chanEl, chanVal = v.Args[1], ch
ok = true
return
}
func isChansendEl(v *Value) bool {
if v.Aux == nil || v.Op != OpStore {
return false
}
ok := v.Type.IsChan() && (v.Type.ChanDir() == types.Csend || v.Type.ChanDir() == types.Cboth)
// ch := v.Type.ChanType()
lsym, ok := v.Aux.(*obj.LSym)
return ok && lsym.Name == `runtime.chansend1`
}
func isChansendCall(v *Value) bool {
switch v.Op {
case
OpStaticCall:
// These are after rewrite passes, I'm not sure this code could support
// being run at this phase.
//
// Op386CALLstatic, OpAMD64CALLstatic, OpARMCALLstatic, OpARM64CALLstatic,
// OpMIPSCALLstatic, OpMIPS64CALLstatic, OpPPC64CALLstatic, OpS390XCALLstatic:
return true
}
if v.Aux == nil || len(v.Args) != 1 {
// not a valid chansend call
return false
}
lsym, ok := v.Aux.(*obj.LSym)
return ok && lsym.Name == `runtime.chansend1`
}
Some very low quality work that served as a learning ecersize to better understand the effort
involved with my request in https://github.com/golang/go/issues/22197 - makes me really appreciate
the engineering effort in the Go compiler, it would take a serious time investment for me to
create a patch with enough quality that it would not waste the reviewers time.
package main
import "fmt"
const count = 8
/*
Notes:
cmd/compile/internal/gc/ssa.go
buildssa(...) -> ssa.Compile(s.f)
cmd/compile/internal/ssa/compile.go
-> ssa.Compile(s.f)
GOSSAFUNC=sendN go run main.go 2>&1 | grep pass
cd $GOROOT/src/cmd/compile
go install -a cmd/compile/internal/ssa && go install cmd/compile \
&& go run main.go -o tmp /ws/cwrk/src/github.com/cstockton/go22197/go22197/main.go
*/
func enableChansendnTestingSuccess1(ch chan struct{}) {
for i := 0; i < count; i++ {
ch <- struct{}{}
}
}
func enableChansendnTestingSuccess2(ch chan int, val int) {
for i := 0; i < count; i++ {
ch <- val
}
}
// // not increment of 1 each iteration
// func enableChansendnTestingFailureInc(ch chan int, val int) {
// for i := 0; i < count; i += 2 {
// ch <- val
// }
// }
//
// // Not the same value each iteration
// func enableChansendnTestingFailure1(ch chan int) {
// for i := 0; i < count; i++ {
// ch <- i
// }
// }
//
// // Not the same value each iteration
// func enableChansendnTestingFailure2(ch chan int) {
// for i := 0; i < count; i++ {
// ch <- i
// }
// }
//
// func enableChansendnTestingFailure3(ch chan struct{}) {
// for i := 0; i < count; i++ {
// select {
// case ch <- struct{}{}:
// default:
// }
// }
// }
//
// func enableChansendnTestingFailure4(ch chan struct{}) {
// ch <- struct{}{}
// ch <- struct{}{}
// ch <- struct{}{}
// }
func main() {
// ch := make(chan struct{}, count)
// enableChansendnTestingSuccess1(ch)
ch := make(chan int, count)
enableChansendnTestingSuccess2(ch, 6)
close(ch)
var n int
for range ch {
n++
fmt.Printf(" got value %v of %v\n", n, count)
}
println(`done`)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment