Skip to content

Instantly share code, notes, and snippets.

@carterpeel
Created February 7, 2022 20:32
Show Gist options
  • Save carterpeel/8e27657541baf72d2989a8f6770adfbe to your computer and use it in GitHub Desktop.
Save carterpeel/8e27657541baf72d2989a8f6770adfbe to your computer and use it in GitHub Desktop.
Async IO in go
package asyncio
import (
"io"
"log"
"sync"
)
type AsyncMultiWriter struct {
mu *sync.Mutex
writers []io.Writer
indexMap map[string]int
asyncThreshold int
writeFn func(p []byte) (n int, err error)
wg *sync.WaitGroup
}
func NewAsyncMultiWriter() *AsyncMultiWriter {
nmw := &AsyncMultiWriter{
mu: &sync.Mutex{},
writers: make([]io.Writer, 0),
indexMap: make(map[string]int),
asyncThreshold: 2,
wg: &sync.WaitGroup{},
}
nmw.writeFn = nmw.writeSeq
return nmw
}
func (bw *AsyncMultiWriter) SetAsyncThreshold(threshold int) {
bw.mu.Lock()
defer bw.mu.Unlock()
bw.asyncThreshold = threshold
bw.checkAsyncThreshold()
}
func (bw *AsyncMultiWriter) checkAsyncThreshold() {
// We only check if it's equal since checking for >= would spam the log.
if len(bw.writers) == bw.asyncThreshold {
log.Logger.WithField("category", "Audio MultiWriter").Infof("Writer threshold reached! (Writers: %d || Threshold: %d)", len(bw.writers), bw.asyncThreshold)
log.Logger.WithField("category", "Audio MultiWriter").Infoln("Enabling asynchronous streaming to compensate for threshold delay...")
bw.writeFn = bw.writeAsync
} else {
bw.writeFn = bw.writeSeq
}
}
// AddWriter adds a writer and ties the writer index to the provided name.
func (bw *AsyncMultiWriter) AddWriter(writer io.Writer, name string) error {
if name == "" {
return NameCannotBeOmitted
}
bw.mu.Lock()
defer bw.mu.Unlock()
bw.writers = append(bw.writers, writer)
bw.indexMap[name] = len(bw.writers) - 1
bw.checkAsyncThreshold()
return nil
}
// RemoveWriter removes the writer corresponding with the provided name.
//
// Name cannot be omitted.
func (bw *AsyncMultiWriter) RemoveWriter(name string) error {
if name == "" {
return NameCannotBeOmitted
}
bw.mu.Lock()
defer bw.mu.Unlock()
index, ok := bw.indexMap[name]
if !ok {
return WriterNotFound
}
bw.writers = append(bw.writers[:index], bw.writers[index+1:]...)
delete(bw.indexMap, name)
for key, val := range bw.indexMap {
if val > index {
bw.indexMap[key]--
}
}
bw.checkAsyncThreshold()
return nil
}
// RemoveAll removes all writers referenced by (bw *AsyncMultiWriter).
func (bw *AsyncMultiWriter) RemoveAll() {
bw.mu.Lock()
defer bw.mu.Unlock()
bw.writers = bw.writers[:0]
bw.indexMap = make(map[string]int)
bw.checkAsyncThreshold()
}
func (bw *AsyncMultiWriter) Write(p []byte) (int, error) {
return bw.writeFn(p)
}
func (bw *AsyncMultiWriter) writeAsync(p []byte) (int, error) {
bw.mu.Lock()
defer bw.mu.Unlock()
bw.wg.Add(len(bw.writers))
for i := range bw.writers {
go func(i2 int) {
defer bw.wg.Done()
if _, err := bw.writers[i2].Write(p); err != nil {
log.Logger.WithField("category", "Named MultiWriter").Errorf("Error writing to writer with index %d: %v", i2, err)
}
}(i)
}
bw.wg.Wait()
return len(p), nil
}
func (bw *AsyncMultiWriter) writeSeq(p []byte) (int, error) {
bw.mu.Lock()
defer bw.mu.Unlock()
for i := range bw.writers {
if n, err := bw.writers[i].Write(p); err != nil {
return n, err
}
}
return len(p), nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment