Use a Context in mockNetConn too.

This commit is contained in:
Alex Bramley 2021-03-26 12:36:19 +00:00
parent 27cc39787d
commit 1a10eba91a
1 changed files with 15 additions and 19 deletions

View File

@ -1,6 +1,7 @@
package client
import (
"context"
"io"
"net"
"os"
@ -14,7 +15,7 @@ type mockNetConn struct {
In, Out chan string
in, out chan []byte
die chan struct{}
die context.CancelFunc
closed bool
rt, wt time.Time
@ -22,7 +23,8 @@ type mockNetConn struct {
func MockNetConn(t *testing.T) *mockNetConn {
// Our mock connection is a testing object
m := &mockNetConn{T: t, die: make(chan struct{})}
ctx, cancel := context.WithCancel(context.Background())
m := &mockNetConn{T: t, die: cancel}
// buffer input
m.In = make(chan string, 20)
@ -30,7 +32,7 @@ func MockNetConn(t *testing.T) *mockNetConn {
go func() {
for {
select {
case <-m.die:
case <-ctx.Done():
return
case s := <-m.In:
m.in <- []byte(s)
@ -44,7 +46,7 @@ func MockNetConn(t *testing.T) *mockNetConn {
go func() {
for {
select {
case <-m.die:
case <-ctx.Done():
return
case b := <-m.out:
m.Out <- string(b)
@ -89,15 +91,12 @@ func (m *mockNetConn) Read(b []byte) (int, error) {
if m.Closed() {
return 0, os.ErrInvalid
}
l := 0
select {
case s := <-m.in:
l = len(s)
copy(b, s)
case <-m.die:
return 0, io.EOF
s, ok := <-m.in
copy(b, s)
if !ok {
return len(s), io.EOF
}
return l, nil
return len(s), nil
}
func (m *mockNetConn) Write(s []byte) (int, error) {
@ -114,19 +113,16 @@ func (m *mockNetConn) Close() error {
if m.Closed() {
return os.ErrInvalid
}
m.closed = true
// Shut down *ALL* the goroutines!
// This will trigger an EOF event in Read() too
close(m.die)
m.die()
close(m.in)
return nil
}
func (m *mockNetConn) Closed() bool {
select {
case <-m.die:
return true
default:
return false
}
return m.closed
}
func (m *mockNetConn) LocalAddr() net.Addr {