Use a Context to kill internal goroutines.

This commit is contained in:
Alex Bramley 2021-03-26 12:02:36 +00:00
parent 1bb2dff298
commit 27cc39787d
2 changed files with 41 additions and 43 deletions

View File

@ -45,8 +45,8 @@ type Conn struct {
out chan string out chan string
connected bool connected bool
// Control channel and WaitGroup for goroutines // CancelFunc and WaitGroup for goroutines
die chan struct{} die context.CancelFunc
wg sync.WaitGroup wg sync.WaitGroup
// Internal counters for flood protection // Internal counters for flood protection
@ -295,7 +295,7 @@ func (conn *Conn) initialise() {
conn.sock = nil conn.sock = nil
conn.in = make(chan *Line, 32) conn.in = make(chan *Line, 32)
conn.out = make(chan string, 32) conn.out = make(chan string, 32)
conn.die = make(chan struct{}) conn.die = nil
if conn.st != nil { if conn.st != nil {
conn.st.Wipe() conn.st.Wipe()
} }
@ -404,24 +404,25 @@ func (conn *Conn) internalConnect(ctx context.Context) error {
conn.sock = s conn.sock = s
} }
conn.postConnect(true) conn.postConnect(ctx, true)
conn.connected = true conn.connected = true
return nil return nil
} }
// postConnect performs post-connection setup, for ease of testing. // postConnect performs post-connection setup, for ease of testing.
func (conn *Conn) postConnect(start bool) { func (conn *Conn) postConnect(ctx context.Context, start bool) {
conn.io = bufio.NewReadWriter( conn.io = bufio.NewReadWriter(
bufio.NewReader(conn.sock), bufio.NewReader(conn.sock),
bufio.NewWriter(conn.sock)) bufio.NewWriter(conn.sock))
if start { if start {
ctx, conn.die = context.WithCancel(ctx)
conn.wg.Add(3) conn.wg.Add(3)
go conn.send() go conn.send(ctx)
go conn.recv() go conn.recv()
go conn.runLoop() go conn.runLoop(ctx)
if conn.cfg.PingFreq > 0 { if conn.cfg.PingFreq > 0 {
conn.wg.Add(1) conn.wg.Add(1)
go conn.ping() go conn.ping(ctx)
} }
} }
} }
@ -434,8 +435,8 @@ func hasPort(s string) bool {
// send is started as a goroutine after a connection is established. // send is started as a goroutine after a connection is established.
// It shuttles data from the output channel to write(), and is killed // It shuttles data from the output channel to write(), and is killed
// when Conn.die is closed. // when the context is cancelled.
func (conn *Conn) send() { func (conn *Conn) send(ctx context.Context) {
for { for {
select { select {
case line := <-conn.out: case line := <-conn.out:
@ -446,7 +447,7 @@ func (conn *Conn) send() {
conn.Close() conn.Close()
return return
} }
case <-conn.die: case <-ctx.Done():
// control channel closed, bail out // control channel closed, bail out
conn.wg.Done() conn.wg.Done()
return return
@ -483,14 +484,14 @@ func (conn *Conn) recv() {
// ping is started as a goroutine after a connection is established, as // ping is started as a goroutine after a connection is established, as
// long as Config.PingFreq >0. It pings the server every PingFreq seconds. // long as Config.PingFreq >0. It pings the server every PingFreq seconds.
func (conn *Conn) ping() { func (conn *Conn) ping(ctx context.Context) {
defer conn.wg.Done() defer conn.wg.Done()
tick := time.NewTicker(conn.cfg.PingFreq) tick := time.NewTicker(conn.cfg.PingFreq)
for { for {
select { select {
case <-tick.C: case <-tick.C:
conn.Ping(fmt.Sprintf("%d", time.Now().UnixNano())) conn.Ping(fmt.Sprintf("%d", time.Now().UnixNano()))
case <-conn.die: case <-ctx.Done():
// control channel closed, bail out // control channel closed, bail out
tick.Stop() tick.Stop()
return return
@ -501,13 +502,13 @@ func (conn *Conn) ping() {
// runLoop is started as a goroutine after a connection is established. // runLoop is started as a goroutine after a connection is established.
// It pulls Lines from the input channel and dispatches them to any // It pulls Lines from the input channel and dispatches them to any
// handlers that have been registered for that IRC verb. // handlers that have been registered for that IRC verb.
func (conn *Conn) runLoop() { func (conn *Conn) runLoop(ctx context.Context) {
defer conn.wg.Done() defer conn.wg.Done()
for { for {
select { select {
case line := <-conn.in: case line := <-conn.in:
conn.dispatch(line) conn.dispatch(line)
case <-conn.die: case <-ctx.Done():
// control channel closed, bail out // control channel closed, bail out
return return
} }
@ -572,7 +573,9 @@ func (conn *Conn) Close() error {
logging.Info("irc.Close(): Disconnected from server.") logging.Info("irc.Close(): Disconnected from server.")
conn.connected = false conn.connected = false
err := conn.sock.Close() err := conn.sock.Close()
close(conn.die) if conn.die != nil {
conn.die()
}
// Drain both in and out channels to avoid a deadlock if the buffers // Drain both in and out channels to avoid a deadlock if the buffers
// have filled. See TestSendDeadlockOnFullBuffer in connection_test.go. // have filled. See TestSendDeadlockOnFullBuffer in connection_test.go.
conn.drainIn() conn.drainIn()

View File

@ -1,6 +1,7 @@
package client package client
import ( import (
"context"
"runtime" "runtime"
"strings" "strings"
"testing" "testing"
@ -54,6 +55,7 @@ func setUp(t *testing.T, start ...bool) (*Conn, *testState) {
nc := MockNetConn(t) nc := MockNetConn(t)
c := SimpleClient("test", "test", "Testing IRC") c := SimpleClient("test", "test", "Testing IRC")
c.initialise() c.initialise()
ctx := context.Background()
c.st = st c.st = st
c.sock = nc c.sock = nc
@ -61,7 +63,7 @@ func setUp(t *testing.T, start ...bool) (*Conn, *testState) {
c.connected = true c.connected = true
// If a second argument is passed to setUp, we tell postConnect not to // If a second argument is passed to setUp, we tell postConnect not to
// start the various goroutines that shuttle data around. // start the various goroutines that shuttle data around.
c.postConnect(len(start) == 0) c.postConnect(ctx, len(start) == 0)
// Sleep 1ms to allow background routines to start. // Sleep 1ms to allow background routines to start.
<-time.After(time.Millisecond) <-time.After(time.Millisecond)
@ -166,7 +168,7 @@ func TestClientAndStateTracking(t *testing.T) {
ctrl.Finish() ctrl.Finish()
} }
func TestSendExitsOnDie(t *testing.T) { func TestSendExitsOnCancel(t *testing.T) {
// Passing a second value to setUp stops goroutines from starting // Passing a second value to setUp stops goroutines from starting
c, s := setUp(t, false) c, s := setUp(t, false)
defer s.tearDown() defer s.tearDown()
@ -178,10 +180,11 @@ func TestSendExitsOnDie(t *testing.T) {
// We want to test that the a goroutine calling send will exit correctly. // We want to test that the a goroutine calling send will exit correctly.
exited := callCheck(t) exited := callCheck(t)
ctx, cancel := context.WithCancel(context.Background())
// send() will decrement the WaitGroup, so we must increment it. // send() will decrement the WaitGroup, so we must increment it.
c.wg.Add(1) c.wg.Add(1)
go func() { go func() {
c.send() c.send(ctx)
exited.call() exited.call()
}() }()
@ -194,13 +197,9 @@ func TestSendExitsOnDie(t *testing.T) {
c.out <- "SENT AFTER START" c.out <- "SENT AFTER START"
s.nc.Expect("SENT AFTER START") s.nc.Expect("SENT AFTER START")
// Now, use the control channel to exit send and kill the goroutine. // Now, cancel the context to exit send and kill the goroutine.
// This sneakily uses the fact that the other two goroutines that would
// normally be waiting for die to close are not running, so we only send
// to the goroutine started above. Normally Close() closes c.die and
// signals to all three goroutines (send, ping, runLoop) to exit.
exited.assertNotCalled("Exited before signal sent.") exited.assertNotCalled("Exited before signal sent.")
c.die <- struct{}{} cancel()
exited.assertWasCalled("Didn't exit after signal.") exited.assertWasCalled("Didn't exit after signal.")
s.nc.ExpectNothing() s.nc.ExpectNothing()
@ -221,7 +220,7 @@ func TestSendExitsOnWriteError(t *testing.T) {
// send() will decrement the WaitGroup, so we must increment it. // send() will decrement the WaitGroup, so we must increment it.
c.wg.Add(1) c.wg.Add(1)
go func() { go func() {
c.send() c.send(context.Background())
exited.call() exited.call()
}() }()
@ -251,6 +250,7 @@ func TestSendDeadlockOnFullBuffer(t *testing.T) {
// We want to test that the a goroutine calling send will exit correctly. // We want to test that the a goroutine calling send will exit correctly.
loopExit := callCheck(t) loopExit := callCheck(t)
sendExit := callCheck(t) sendExit := callCheck(t)
ctx, cancel := context.WithCancel(context.Background())
// send() and runLoop() will decrement the WaitGroup, so we must increment it. // send() and runLoop() will decrement the WaitGroup, so we must increment it.
c.wg.Add(2) c.wg.Add(2)
@ -274,13 +274,13 @@ func TestSendDeadlockOnFullBuffer(t *testing.T) {
}) })
// And trigger it by starting runLoop and inserting a line into conn.in: // And trigger it by starting runLoop and inserting a line into conn.in:
go func() { go func() {
c.runLoop() c.runLoop(ctx)
loopExit.call() loopExit.call()
}() }()
c.in <- &Line{Cmd: PRIVMSG, Raw: "WRITE THAT CAUSES DEADLOCK"} c.in <- &Line{Cmd: PRIVMSG, Raw: "WRITE THAT CAUSES DEADLOCK"}
// At this point the handler should be blocked on a write to conn.out, // At this point the handler should be blocked on a write to conn.out,
// preventing runLoop from looping and thus noticing conn.die is closed. // preventing runLoop from looping and thus noticng the cancelled context.
// //
// The next part is to force send() to call conn.Close(), which can // The next part is to force send() to call conn.Close(), which can
// be done by closing the fake net.Conn so that it returns an error on // be done by closing the fake net.Conn so that it returns an error on
@ -292,8 +292,9 @@ func TestSendDeadlockOnFullBuffer(t *testing.T) {
// to write it to the socket. It should immediately receive an error and // to write it to the socket. It should immediately receive an error and
// call conn.Close(), triggering the deadlock as it waits forever for // call conn.Close(), triggering the deadlock as it waits forever for
// runLoop to call conn.wg.Done. // runLoop to call conn.wg.Done.
c.die = cancel // Close needs to cancel the context for us.
go func() { go func() {
c.send() c.send(ctx)
sendExit.call() sendExit.call()
}() }()
@ -407,10 +408,11 @@ func TestPing(t *testing.T) {
// Start ping loop. // Start ping loop.
exited := callCheck(t) exited := callCheck(t)
ctx, cancel := context.WithCancel(context.Background())
// ping() will decrement the WaitGroup, so we must increment it. // ping() will decrement the WaitGroup, so we must increment it.
c.wg.Add(1) c.wg.Add(1)
go func() { go func() {
c.ping() c.ping(ctx)
exited.call() exited.call()
}() }()
@ -438,13 +440,9 @@ func TestPing(t *testing.T) {
t.Errorf("Line not output after another %s.", 2*res) t.Errorf("Line not output after another %s.", 2*res)
} }
// Now kill the ping loop. // Now kill the ping loop by cancelling the context.
// This sneakily uses the fact that the other two goroutines that would
// normally be waiting for die to close are not running, so we only send
// to the goroutine started above. Normally Close() closes c.die and
// signals to all three goroutines (send, ping, runLoop) to exit.
exited.assertNotCalled("Exited before signal sent.") exited.assertNotCalled("Exited before signal sent.")
c.die <- struct{}{} cancel()
exited.assertWasCalled("Didn't exit after signal.") exited.assertWasCalled("Didn't exit after signal.")
// Make sure we're no longer pinging by waiting >2x PingFreq // Make sure we're no longer pinging by waiting >2x PingFreq
<-time.After(2*c.cfg.PingFreq + res) <-time.After(2*c.cfg.PingFreq + res)
@ -478,10 +476,11 @@ func TestRunLoop(t *testing.T) {
// We want to test that the a goroutine calling runLoop will exit correctly. // We want to test that the a goroutine calling runLoop will exit correctly.
// Now, we can expect the call to Dispatch to take place as runLoop starts. // Now, we can expect the call to Dispatch to take place as runLoop starts.
exited := callCheck(t) exited := callCheck(t)
ctx, cancel := context.WithCancel(context.Background())
// runLoop() will decrement the WaitGroup, so we must increment it. // runLoop() will decrement the WaitGroup, so we must increment it.
c.wg.Add(1) c.wg.Add(1)
go func() { go func() {
c.runLoop() c.runLoop(ctx)
exited.call() exited.call()
}() }()
h002.assertWasCalled("002 handler not called after runLoop started.") h002.assertWasCalled("002 handler not called after runLoop started.")
@ -492,13 +491,9 @@ func TestRunLoop(t *testing.T) {
c.in <- l3 c.in <- l3
h003.assertWasCalled("003 handler not called while runLoop started.") h003.assertWasCalled("003 handler not called while runLoop started.")
// Now, use the control channel to exit send and kill the goroutine. // Now, cancel the context to exit runLoop and kill the goroutine.
// This sneakily uses the fact that the other two goroutines that would
// normally be waiting for die to close are not running, so we only send
// to the goroutine started above. Normally Close() closes c.die and
// signals to all three goroutines (send, ping, runLoop) to exit.
exited.assertNotCalled("Exited before signal sent.") exited.assertNotCalled("Exited before signal sent.")
c.die <- struct{}{} cancel()
exited.assertWasCalled("Didn't exit after signal.") exited.assertWasCalled("Didn't exit after signal.")
// Sending more on c.in shouldn't dispatch any further events // Sending more on c.in shouldn't dispatch any further events