Use a single control chan to kill mockNetConn goroutines (77->54).

This commit is contained in:
Alex Bramley 2013-09-30 13:06:52 +01:00
parent 144001d109
commit adc6c2917d
1 changed files with 18 additions and 24 deletions

View File

@ -9,19 +9,12 @@ import (
"time" "time"
) )
const (
mockReadCloser = iota
mockInCloser
mockOutCloser
)
type mockNetConn struct { type mockNetConn struct {
*testing.T *testing.T
In, Out chan string In, Out chan string
in, out chan []byte in, out chan []byte
closers [3]chan bool die chan struct{}
rc chan bool
closed bool closed bool
rt, wt time.Time rt, wt time.Time
@ -29,16 +22,15 @@ type mockNetConn struct {
func MockNetConn(t *testing.T) *mockNetConn { func MockNetConn(t *testing.T) *mockNetConn {
// Our mock connection is a testing object // Our mock connection is a testing object
m := &mockNetConn{T: t} m := &mockNetConn{T: t, die: make(chan struct{})}
// buffer input // buffer input
m.In = make(chan string, 20) m.In = make(chan string, 20)
m.in = make(chan []byte) m.in = make(chan []byte)
m.closers[mockInCloser] = make(chan bool, 1)
go func() { go func() {
for { for {
select { select {
case <-m.closers[mockInCloser]: case <-m.die:
return return
case s := <-m.In: case s := <-m.In:
m.in <- []byte(s) m.in <- []byte(s)
@ -49,11 +41,10 @@ func MockNetConn(t *testing.T) *mockNetConn {
// buffer output // buffer output
m.Out = make(chan string) m.Out = make(chan string)
m.out = make(chan []byte, 20) m.out = make(chan []byte, 20)
m.closers[mockOutCloser] = make(chan bool, 1)
go func() { go func() {
for { for {
select { select {
case <-m.closers[mockOutCloser]: case <-m.die:
return return
case b := <-m.out: case b := <-m.out:
m.Out <- string(b) m.Out <- string(b)
@ -61,9 +52,6 @@ func MockNetConn(t *testing.T) *mockNetConn {
} }
}() }()
// Set up channel to force EOF to Read() on close.
m.closers[mockReadCloser] = make(chan bool, 1)
return m return m
} }
@ -98,7 +86,7 @@ func (m *mockNetConn) ExpectNothing() {
// Implement net.Conn interface // Implement net.Conn interface
func (m *mockNetConn) Read(b []byte) (int, error) { func (m *mockNetConn) Read(b []byte) (int, error) {
if m.closed { if m.Closed() {
return 0, os.ErrInvalid return 0, os.ErrInvalid
} }
l := 0 l := 0
@ -106,14 +94,14 @@ func (m *mockNetConn) Read(b []byte) (int, error) {
case s := <-m.in: case s := <-m.in:
l = len(s) l = len(s)
copy(b, s) copy(b, s)
case <-m.closers[mockReadCloser]: case <-m.die:
return 0, io.EOF return 0, io.EOF
} }
return l, nil return l, nil
} }
func (m *mockNetConn) Write(s []byte) (int, error) { func (m *mockNetConn) Write(s []byte) (int, error) {
if m.closed { if m.Closed() {
return 0, os.ErrInvalid return 0, os.ErrInvalid
} }
b := make([]byte, len(s)) b := make([]byte, len(s))
@ -123,18 +111,24 @@ func (m *mockNetConn) Write(s []byte) (int, error) {
} }
func (m *mockNetConn) Close() error { func (m *mockNetConn) Close() error {
if m.closed { if m.Closed() {
return os.ErrInvalid return os.ErrInvalid
} }
// Shut down *ALL* the goroutines! // Shut down *ALL* the goroutines!
// This will trigger an EOF event in Read() too // This will trigger an EOF event in Read() too
for _, c := range m.closers { close(m.die)
c <- true
}
m.closed = true
return nil return nil
} }
func (m *mockNetConn) Closed() bool {
select {
case <-m.die:
return true
default:
return false
}
}
func (m *mockNetConn) LocalAddr() net.Addr { func (m *mockNetConn) LocalAddr() net.Addr {
return &net.IPAddr{net.IPv4(127, 0, 0, 1), ""} return &net.IPAddr{net.IPv4(127, 0, 0, 1), ""}
} }