diff --git a/client/mocknetconn_test.go b/client/mocknetconn_test.go index 06f1393..b438a6c 100644 --- a/client/mocknetconn_test.go +++ b/client/mocknetconn_test.go @@ -9,19 +9,12 @@ import ( "time" ) -const ( - mockReadCloser = iota - mockInCloser - mockOutCloser -) - type mockNetConn struct { *testing.T In, Out chan string in, out chan []byte - closers [3]chan bool - rc chan bool + die chan struct{} closed bool rt, wt time.Time @@ -29,16 +22,15 @@ type mockNetConn struct { func MockNetConn(t *testing.T) *mockNetConn { // Our mock connection is a testing object - m := &mockNetConn{T: t} + m := &mockNetConn{T: t, die: make(chan struct{})} // buffer input m.In = make(chan string, 20) m.in = make(chan []byte) - m.closers[mockInCloser] = make(chan bool, 1) go func() { for { select { - case <-m.closers[mockInCloser]: + case <-m.die: return case s := <-m.In: m.in <- []byte(s) @@ -49,11 +41,10 @@ func MockNetConn(t *testing.T) *mockNetConn { // buffer output m.Out = make(chan string) m.out = make(chan []byte, 20) - m.closers[mockOutCloser] = make(chan bool, 1) go func() { for { select { - case <-m.closers[mockOutCloser]: + case <-m.die: return case b := <-m.out: m.Out <- string(b) @@ -61,9 +52,6 @@ func MockNetConn(t *testing.T) *mockNetConn { } }() - // Set up channel to force EOF to Read() on close. - m.closers[mockReadCloser] = make(chan bool, 1) - return m } @@ -98,7 +86,7 @@ func (m *mockNetConn) ExpectNothing() { // Implement net.Conn interface func (m *mockNetConn) Read(b []byte) (int, error) { - if m.closed { + if m.Closed() { return 0, os.ErrInvalid } l := 0 @@ -106,14 +94,14 @@ func (m *mockNetConn) Read(b []byte) (int, error) { case s := <-m.in: l = len(s) copy(b, s) - case <-m.closers[mockReadCloser]: + case <-m.die: return 0, io.EOF } return l, nil } func (m *mockNetConn) Write(s []byte) (int, error) { - if m.closed { + if m.Closed() { return 0, os.ErrInvalid } b := make([]byte, len(s)) @@ -123,18 +111,24 @@ func (m *mockNetConn) Write(s []byte) (int, error) { } func (m *mockNetConn) Close() error { - if m.closed { + if m.Closed() { return os.ErrInvalid } // Shut down *ALL* the goroutines! // This will trigger an EOF event in Read() too - for _, c := range m.closers { - c <- true - } - m.closed = true + close(m.die) return nil } +func (m *mockNetConn) Closed() bool { + select { + case <-m.die: + return true + default: + return false + } +} + func (m *mockNetConn) LocalAddr() net.Addr { return &net.IPAddr{net.IPv4(127, 0, 0, 1), ""} }