Use a Context to kill internal goroutines.

This commit is contained in:
Alex Bramley 2021-03-26 12:02:36 +00:00
parent 1bb2dff298
commit 27cc39787d
2 changed files with 41 additions and 43 deletions

View File

@ -45,8 +45,8 @@ type Conn struct {
out chan string
connected bool
// Control channel and WaitGroup for goroutines
die chan struct{}
// CancelFunc and WaitGroup for goroutines
die context.CancelFunc
wg sync.WaitGroup
// Internal counters for flood protection
@ -295,7 +295,7 @@ func (conn *Conn) initialise() {
conn.sock = nil
conn.in = make(chan *Line, 32)
conn.out = make(chan string, 32)
conn.die = make(chan struct{})
conn.die = nil
if conn.st != nil {
conn.st.Wipe()
}
@ -404,24 +404,25 @@ func (conn *Conn) internalConnect(ctx context.Context) error {
conn.sock = s
}
conn.postConnect(true)
conn.postConnect(ctx, true)
conn.connected = true
return nil
}
// postConnect performs post-connection setup, for ease of testing.
func (conn *Conn) postConnect(start bool) {
func (conn *Conn) postConnect(ctx context.Context, start bool) {
conn.io = bufio.NewReadWriter(
bufio.NewReader(conn.sock),
bufio.NewWriter(conn.sock))
if start {
ctx, conn.die = context.WithCancel(ctx)
conn.wg.Add(3)
go conn.send()
go conn.send(ctx)
go conn.recv()
go conn.runLoop()
go conn.runLoop(ctx)
if conn.cfg.PingFreq > 0 {
conn.wg.Add(1)
go conn.ping()
go conn.ping(ctx)
}
}
}
@ -434,8 +435,8 @@ func hasPort(s string) bool {
// send is started as a goroutine after a connection is established.
// It shuttles data from the output channel to write(), and is killed
// when Conn.die is closed.
func (conn *Conn) send() {
// when the context is cancelled.
func (conn *Conn) send(ctx context.Context) {
for {
select {
case line := <-conn.out:
@ -446,7 +447,7 @@ func (conn *Conn) send() {
conn.Close()
return
}
case <-conn.die:
case <-ctx.Done():
// control channel closed, bail out
conn.wg.Done()
return
@ -483,14 +484,14 @@ func (conn *Conn) recv() {
// ping is started as a goroutine after a connection is established, as
// long as Config.PingFreq >0. It pings the server every PingFreq seconds.
func (conn *Conn) ping() {
func (conn *Conn) ping(ctx context.Context) {
defer conn.wg.Done()
tick := time.NewTicker(conn.cfg.PingFreq)
for {
select {
case <-tick.C:
conn.Ping(fmt.Sprintf("%d", time.Now().UnixNano()))
case <-conn.die:
case <-ctx.Done():
// control channel closed, bail out
tick.Stop()
return
@ -501,13 +502,13 @@ func (conn *Conn) ping() {
// runLoop is started as a goroutine after a connection is established.
// It pulls Lines from the input channel and dispatches them to any
// handlers that have been registered for that IRC verb.
func (conn *Conn) runLoop() {
func (conn *Conn) runLoop(ctx context.Context) {
defer conn.wg.Done()
for {
select {
case line := <-conn.in:
conn.dispatch(line)
case <-conn.die:
case <-ctx.Done():
// control channel closed, bail out
return
}
@ -572,7 +573,9 @@ func (conn *Conn) Close() error {
logging.Info("irc.Close(): Disconnected from server.")
conn.connected = false
err := conn.sock.Close()
close(conn.die)
if conn.die != nil {
conn.die()
}
// Drain both in and out channels to avoid a deadlock if the buffers
// have filled. See TestSendDeadlockOnFullBuffer in connection_test.go.
conn.drainIn()

View File

@ -1,6 +1,7 @@
package client
import (
"context"
"runtime"
"strings"
"testing"
@ -54,6 +55,7 @@ func setUp(t *testing.T, start ...bool) (*Conn, *testState) {
nc := MockNetConn(t)
c := SimpleClient("test", "test", "Testing IRC")
c.initialise()
ctx := context.Background()
c.st = st
c.sock = nc
@ -61,7 +63,7 @@ func setUp(t *testing.T, start ...bool) (*Conn, *testState) {
c.connected = true
// If a second argument is passed to setUp, we tell postConnect not to
// start the various goroutines that shuttle data around.
c.postConnect(len(start) == 0)
c.postConnect(ctx, len(start) == 0)
// Sleep 1ms to allow background routines to start.
<-time.After(time.Millisecond)
@ -166,7 +168,7 @@ func TestClientAndStateTracking(t *testing.T) {
ctrl.Finish()
}
func TestSendExitsOnDie(t *testing.T) {
func TestSendExitsOnCancel(t *testing.T) {
// Passing a second value to setUp stops goroutines from starting
c, s := setUp(t, false)
defer s.tearDown()
@ -178,10 +180,11 @@ func TestSendExitsOnDie(t *testing.T) {
// We want to test that the a goroutine calling send will exit correctly.
exited := callCheck(t)
ctx, cancel := context.WithCancel(context.Background())
// send() will decrement the WaitGroup, so we must increment it.
c.wg.Add(1)
go func() {
c.send()
c.send(ctx)
exited.call()
}()
@ -194,13 +197,9 @@ func TestSendExitsOnDie(t *testing.T) {
c.out <- "SENT AFTER START"
s.nc.Expect("SENT AFTER START")
// Now, use the control channel to exit send and kill the goroutine.
// This sneakily uses the fact that the other two goroutines that would
// normally be waiting for die to close are not running, so we only send
// to the goroutine started above. Normally Close() closes c.die and
// signals to all three goroutines (send, ping, runLoop) to exit.
// Now, cancel the context to exit send and kill the goroutine.
exited.assertNotCalled("Exited before signal sent.")
c.die <- struct{}{}
cancel()
exited.assertWasCalled("Didn't exit after signal.")
s.nc.ExpectNothing()
@ -221,7 +220,7 @@ func TestSendExitsOnWriteError(t *testing.T) {
// send() will decrement the WaitGroup, so we must increment it.
c.wg.Add(1)
go func() {
c.send()
c.send(context.Background())
exited.call()
}()
@ -251,6 +250,7 @@ func TestSendDeadlockOnFullBuffer(t *testing.T) {
// We want to test that the a goroutine calling send will exit correctly.
loopExit := callCheck(t)
sendExit := callCheck(t)
ctx, cancel := context.WithCancel(context.Background())
// send() and runLoop() will decrement the WaitGroup, so we must increment it.
c.wg.Add(2)
@ -274,13 +274,13 @@ func TestSendDeadlockOnFullBuffer(t *testing.T) {
})
// And trigger it by starting runLoop and inserting a line into conn.in:
go func() {
c.runLoop()
c.runLoop(ctx)
loopExit.call()
}()
c.in <- &Line{Cmd: PRIVMSG, Raw: "WRITE THAT CAUSES DEADLOCK"}
// At this point the handler should be blocked on a write to conn.out,
// preventing runLoop from looping and thus noticing conn.die is closed.
// preventing runLoop from looping and thus noticng the cancelled context.
//
// The next part is to force send() to call conn.Close(), which can
// be done by closing the fake net.Conn so that it returns an error on
@ -292,8 +292,9 @@ func TestSendDeadlockOnFullBuffer(t *testing.T) {
// to write it to the socket. It should immediately receive an error and
// call conn.Close(), triggering the deadlock as it waits forever for
// runLoop to call conn.wg.Done.
c.die = cancel // Close needs to cancel the context for us.
go func() {
c.send()
c.send(ctx)
sendExit.call()
}()
@ -407,10 +408,11 @@ func TestPing(t *testing.T) {
// Start ping loop.
exited := callCheck(t)
ctx, cancel := context.WithCancel(context.Background())
// ping() will decrement the WaitGroup, so we must increment it.
c.wg.Add(1)
go func() {
c.ping()
c.ping(ctx)
exited.call()
}()
@ -438,13 +440,9 @@ func TestPing(t *testing.T) {
t.Errorf("Line not output after another %s.", 2*res)
}
// Now kill the ping loop.
// This sneakily uses the fact that the other two goroutines that would
// normally be waiting for die to close are not running, so we only send
// to the goroutine started above. Normally Close() closes c.die and
// signals to all three goroutines (send, ping, runLoop) to exit.
// Now kill the ping loop by cancelling the context.
exited.assertNotCalled("Exited before signal sent.")
c.die <- struct{}{}
cancel()
exited.assertWasCalled("Didn't exit after signal.")
// Make sure we're no longer pinging by waiting >2x PingFreq
<-time.After(2*c.cfg.PingFreq + res)
@ -478,10 +476,11 @@ func TestRunLoop(t *testing.T) {
// We want to test that the a goroutine calling runLoop will exit correctly.
// Now, we can expect the call to Dispatch to take place as runLoop starts.
exited := callCheck(t)
ctx, cancel := context.WithCancel(context.Background())
// runLoop() will decrement the WaitGroup, so we must increment it.
c.wg.Add(1)
go func() {
c.runLoop()
c.runLoop(ctx)
exited.call()
}()
h002.assertWasCalled("002 handler not called after runLoop started.")
@ -492,13 +491,9 @@ func TestRunLoop(t *testing.T) {
c.in <- l3
h003.assertWasCalled("003 handler not called while runLoop started.")
// Now, use the control channel to exit send and kill the goroutine.
// This sneakily uses the fact that the other two goroutines that would
// normally be waiting for die to close are not running, so we only send
// to the goroutine started above. Normally Close() closes c.die and
// signals to all three goroutines (send, ping, runLoop) to exit.
// Now, cancel the context to exit runLoop and kill the goroutine.
exited.assertNotCalled("Exited before signal sent.")
c.die <- struct{}{}
cancel()
exited.assertWasCalled("Didn't exit after signal.")
// Sending more on c.in shouldn't dispatch any further events