Use a Context in mockNetConn too.

This commit is contained in:
Alex Bramley 2021-03-26 12:36:19 +00:00
parent 27cc39787d
commit 1a10eba91a
1 changed files with 15 additions and 19 deletions

View File

@ -1,6 +1,7 @@
package client package client
import ( import (
"context"
"io" "io"
"net" "net"
"os" "os"
@ -14,7 +15,7 @@ type mockNetConn struct {
In, Out chan string In, Out chan string
in, out chan []byte in, out chan []byte
die chan struct{} die context.CancelFunc
closed bool closed bool
rt, wt time.Time rt, wt time.Time
@ -22,7 +23,8 @@ type mockNetConn struct {
func MockNetConn(t *testing.T) *mockNetConn { func MockNetConn(t *testing.T) *mockNetConn {
// Our mock connection is a testing object // Our mock connection is a testing object
m := &mockNetConn{T: t, die: make(chan struct{})} ctx, cancel := context.WithCancel(context.Background())
m := &mockNetConn{T: t, die: cancel}
// buffer input // buffer input
m.In = make(chan string, 20) m.In = make(chan string, 20)
@ -30,7 +32,7 @@ func MockNetConn(t *testing.T) *mockNetConn {
go func() { go func() {
for { for {
select { select {
case <-m.die: case <-ctx.Done():
return return
case s := <-m.In: case s := <-m.In:
m.in <- []byte(s) m.in <- []byte(s)
@ -44,7 +46,7 @@ func MockNetConn(t *testing.T) *mockNetConn {
go func() { go func() {
for { for {
select { select {
case <-m.die: case <-ctx.Done():
return return
case b := <-m.out: case b := <-m.out:
m.Out <- string(b) m.Out <- string(b)
@ -89,15 +91,12 @@ func (m *mockNetConn) Read(b []byte) (int, error) {
if m.Closed() { if m.Closed() {
return 0, os.ErrInvalid return 0, os.ErrInvalid
} }
l := 0 s, ok := <-m.in
select {
case s := <-m.in:
l = len(s)
copy(b, s) copy(b, s)
case <-m.die: if !ok {
return 0, io.EOF return len(s), io.EOF
} }
return l, nil return len(s), nil
} }
func (m *mockNetConn) Write(s []byte) (int, error) { func (m *mockNetConn) Write(s []byte) (int, error) {
@ -114,19 +113,16 @@ func (m *mockNetConn) Close() error {
if m.Closed() { if m.Closed() {
return os.ErrInvalid return os.ErrInvalid
} }
m.closed = true
// Shut down *ALL* the goroutines! // Shut down *ALL* the goroutines!
// This will trigger an EOF event in Read() too // This will trigger an EOF event in Read() too
close(m.die) m.die()
close(m.in)
return nil return nil
} }
func (m *mockNetConn) Closed() bool { func (m *mockNetConn) Closed() bool {
select { return m.closed
case <-m.die:
return true
default:
return false
}
} }
func (m *mockNetConn) LocalAddr() net.Addr { func (m *mockNetConn) LocalAddr() net.Addr {