diff --git a/client/mocknetconn_test.go b/client/mocknetconn_test.go index e736c88..fc80de4 100644 --- a/client/mocknetconn_test.go +++ b/client/mocknetconn_test.go @@ -1,6 +1,7 @@ package client import ( + "context" "io" "net" "os" @@ -14,7 +15,7 @@ type mockNetConn struct { In, Out chan string in, out chan []byte - die chan struct{} + die context.CancelFunc closed bool rt, wt time.Time @@ -22,7 +23,8 @@ type mockNetConn struct { func MockNetConn(t *testing.T) *mockNetConn { // Our mock connection is a testing object - m := &mockNetConn{T: t, die: make(chan struct{})} + ctx, cancel := context.WithCancel(context.Background()) + m := &mockNetConn{T: t, die: cancel} // buffer input m.In = make(chan string, 20) @@ -30,7 +32,7 @@ func MockNetConn(t *testing.T) *mockNetConn { go func() { for { select { - case <-m.die: + case <-ctx.Done(): return case s := <-m.In: m.in <- []byte(s) @@ -44,7 +46,7 @@ func MockNetConn(t *testing.T) *mockNetConn { go func() { for { select { - case <-m.die: + case <-ctx.Done(): return case b := <-m.out: m.Out <- string(b) @@ -89,15 +91,12 @@ func (m *mockNetConn) Read(b []byte) (int, error) { if m.Closed() { return 0, os.ErrInvalid } - l := 0 - select { - case s := <-m.in: - l = len(s) - copy(b, s) - case <-m.die: - return 0, io.EOF + s, ok := <-m.in + copy(b, s) + if !ok { + return len(s), io.EOF } - return l, nil + return len(s), nil } func (m *mockNetConn) Write(s []byte) (int, error) { @@ -114,19 +113,16 @@ func (m *mockNetConn) Close() error { if m.Closed() { return os.ErrInvalid } + m.closed = true // Shut down *ALL* the goroutines! // This will trigger an EOF event in Read() too - close(m.die) + m.die() + close(m.in) return nil } func (m *mockNetConn) Closed() bool { - select { - case <-m.die: - return true - default: - return false - } + return m.closed } func (m *mockNetConn) LocalAddr() net.Addr {