Skip to content

Instantly share code, notes, and snippets.

@ajalab
Created September 10, 2018 13:33
Show Gist options
  • Save ajalab/bba351bd0fdc47053f2248ebc850f920 to your computer and use it in GitHub Desktop.
Save ajalab/bba351bd0fdc47053f2248ebc850f920 to your computer and use it in GitHub Desktop.
stub io
package stubio
import (
"errors"
"io"
"os"
)
type StubIO struct {
In io.Reader
Out io.Writer
Err io.Writer
stdin *os.File
stdout *os.File
stderr *os.File
ow *os.File
ew *os.File
iw *os.File
errChan chan error
}
func (stub *StubIO) Stub() error {
if stub == nil {
return errors.New("receiver is nil")
}
ir, iw, err := os.Pipe()
if err != nil {
return err
}
or, ow, err := os.Pipe()
if err != nil {
return err
}
er, ew, err := os.Pipe()
if err != nil {
return err
}
stub.iw, stub.ow, stub.ew = iw, ow, ew
stub.stdin, stub.stdout, stub.stderr = os.Stdin, os.Stdout, os.Stderr
os.Stdin, os.Stdout, os.Stderr = ir, ow, ew
stub.errChan = make(chan error, 1)
go func() {
if stub.In != nil {
if _, err := io.Copy(iw, stub.In); err != nil {
stub.errChan <- err
return
}
}
stub.iw.Close()
if stub.Out != nil {
if _, err := io.Copy(stub.Out, or); err != nil {
stub.errChan <- err
return
}
}
if stub.Err != nil {
if _, err := io.Copy(stub.Err, er); err != nil {
stub.errChan <- err
return
}
}
close(stub.errChan)
}()
return nil
}
func (stub *StubIO) Unstub() error {
if stub == nil {
return errors.New("receiver is nil")
}
os.Stdin, os.Stdout, os.Stderr = stub.stdin, stub.stdout, stub.stderr
stub.ow.Close()
stub.ew.Close()
return <-stub.errChan
}
package stubio
import (
"bytes"
"fmt"
"io/ioutil"
"os"
"testing"
)
func TestStubIO(t *testing.T) {
stdin := bytes.NewBufferString("hoge")
stdout := new(bytes.Buffer)
stderr := new(bytes.Buffer)
stub := &StubIO{
In: stdin,
Out: stdout,
Err: stderr,
}
if err := stub.Stub(); err != nil {
t.Error("stub.Stub: ", err)
}
input, _ := ioutil.ReadAll(os.Stdin)
fmt.Print("fuga")
fmt.Fprint(os.Stderr, "piyo")
if err := stub.Unstub(); err != nil {
t.Error("stub.Unstub: ", err)
}
if s := string(input); s != "hoge" {
t.Errorf("failed to capture stdin: actual %s", s)
}
if s := stdout.String(); s != "fuga" {
t.Errorf("failed to capture stdout: actual %s", s)
}
if s := stderr.String(); s != "piyo" {
t.Errorf("failed to capture stderr: actual %s", s)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment