Skip to content

Instantly share code, notes, and snippets.

@zombiezen
Last active November 30, 2020 06:01
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zombiezen/85095566a5c025e35817a0b399fec235 to your computer and use it in GitHub Desktop.
Save zombiezen/85095566a5c025e35817a0b399fec235 to your computer and use it in GitHub Desktop.
Reading while respecting Context.Done
// ctxReader adds timeouts and cancellation to a reader.
type ctxReader struct {
r io.Reader
ctx context.Context // set to change Context
// internal state
result chan readResult
pos, n int
err error
buf [1024]byte
}
type readResult struct {
n int
err error
}
// Read reads into p. It makes a best effort to respect the Done signal
// in cr.ctx.
func (cr *ctxReader) Read(p []byte) (int, error) {
if cr.pos < cr.n {
// Buffered from previous read.
n := copy(p, cr.buf[cr.pos:cr.n])
cr.pos += n
if cr.pos == cr.n && cr.err != nil {
err := cr.err
cr.err = nil
return n, err
}
return n, nil
}
if cr.result != nil {
// Read in progress.
select {
case r := <-cr.result:
cr.result = nil
cr.n = r.n
cr.pos = copy(p, cr.buf[:cr.n])
if cr.pos == cr.n && r.err != nil {
return cr.pos, r.err
}
cr.err = r.err
return cr.pos, nil
case <-cr.ctx.Done():
return 0, cr.ctx.Err()
}
}
// Check for early cancel.
select {
case <-cr.ctx.Done():
return 0, cr.ctx.Err()
default:
}
// Check for timeout support.
rd, ok := cr.r.(interface {
SetReadDeadline(time.Time) error
})
if !ok {
return cr.leakyRead(p)
}
if err := rd.SetReadDeadline(time.Now()); err != nil {
return cr.leakyRead(p)
}
// Start separate goroutine to wait on Context.Done.
if d, ok := cr.ctx.Deadline(); ok {
rd.SetReadDeadline(d)
} else {
rd.SetReadDeadline(time.Time{})
}
readDone := make(chan struct{})
listenDone := make(chan struct{})
go func() {
defer close(listenDone)
select {
case <-cr.ctx.Done():
rd.SetReadDeadline(time.Now()) // interrupt read
case <-readDone:
}
}()
// Read from reader.
n, err := cr.r.Read(p)
close(readDone)
<-listenDone
return n, err
}
// leakyRead reads from the underlying reader in a separate goroutine.
// If the Context is Done before the read completes, then the goroutine
// will stay alive until cr.wait() is called. The result is written to
// cr.buf to avoid retaining p past the end of leakyRead.
func (cr *ctxReader) leakyRead(p []byte) (int, error) {
cr.result = make(chan readResult)
max := len(p)
if max > len(cr.buf) {
max = len(cr.buf)
}
go func() {
n, err := cr.r.Read(cr.buf[:max])
cr.result <- readResult{n, err}
}()
select {
case r := <-cr.result:
cr.result = nil
copy(p, cr.buf[:r.n])
return r.n, r.err
case <-cr.ctx.Done():
return 0, cr.ctx.Err()
}
}
// wait waits until any goroutine started by leakyRead finishes.
func (cr *ctxReader) wait() {
if cr.result == nil {
return
}
r := <-cr.result
cr.result = nil
cr.pos, cr.n = 0, r.n
cr.err = r.err
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment