diff --git a/client/connection.go b/client/connection.go index e0bbc47..f98bc68 100644 --- a/client/connection.go +++ b/client/connection.go @@ -45,8 +45,8 @@ type Conn struct { out chan string connected bool - // Control channel and WaitGroup for goroutines - die chan struct{} + // CancelFunc and WaitGroup for goroutines + die context.CancelFunc wg sync.WaitGroup // Internal counters for flood protection @@ -295,7 +295,7 @@ func (conn *Conn) initialise() { conn.sock = nil conn.in = make(chan *Line, 32) conn.out = make(chan string, 32) - conn.die = make(chan struct{}) + conn.die = nil if conn.st != nil { conn.st.Wipe() } @@ -404,24 +404,25 @@ func (conn *Conn) internalConnect(ctx context.Context) error { conn.sock = s } - conn.postConnect(true) + conn.postConnect(ctx, true) conn.connected = true return nil } // postConnect performs post-connection setup, for ease of testing. -func (conn *Conn) postConnect(start bool) { +func (conn *Conn) postConnect(ctx context.Context, start bool) { conn.io = bufio.NewReadWriter( bufio.NewReader(conn.sock), bufio.NewWriter(conn.sock)) if start { + ctx, conn.die = context.WithCancel(ctx) conn.wg.Add(3) - go conn.send() + go conn.send(ctx) go conn.recv() - go conn.runLoop() + go conn.runLoop(ctx) if conn.cfg.PingFreq > 0 { conn.wg.Add(1) - go conn.ping() + go conn.ping(ctx) } } } @@ -434,8 +435,8 @@ func hasPort(s string) bool { // send is started as a goroutine after a connection is established. // It shuttles data from the output channel to write(), and is killed -// when Conn.die is closed. -func (conn *Conn) send() { +// when the context is cancelled. +func (conn *Conn) send(ctx context.Context) { for { select { case line := <-conn.out: @@ -446,7 +447,7 @@ func (conn *Conn) send() { conn.Close() return } - case <-conn.die: + case <-ctx.Done(): // control channel closed, bail out conn.wg.Done() return @@ -483,14 +484,14 @@ func (conn *Conn) recv() { // ping is started as a goroutine after a connection is established, as // long as Config.PingFreq >0. It pings the server every PingFreq seconds. -func (conn *Conn) ping() { +func (conn *Conn) ping(ctx context.Context) { defer conn.wg.Done() tick := time.NewTicker(conn.cfg.PingFreq) for { select { case <-tick.C: conn.Ping(fmt.Sprintf("%d", time.Now().UnixNano())) - case <-conn.die: + case <-ctx.Done(): // control channel closed, bail out tick.Stop() return @@ -501,13 +502,13 @@ func (conn *Conn) ping() { // runLoop is started as a goroutine after a connection is established. // It pulls Lines from the input channel and dispatches them to any // handlers that have been registered for that IRC verb. -func (conn *Conn) runLoop() { +func (conn *Conn) runLoop(ctx context.Context) { defer conn.wg.Done() for { select { case line := <-conn.in: conn.dispatch(line) - case <-conn.die: + case <-ctx.Done(): // control channel closed, bail out return } @@ -572,7 +573,9 @@ func (conn *Conn) Close() error { logging.Info("irc.Close(): Disconnected from server.") conn.connected = false err := conn.sock.Close() - close(conn.die) + if conn.die != nil { + conn.die() + } // Drain both in and out channels to avoid a deadlock if the buffers // have filled. See TestSendDeadlockOnFullBuffer in connection_test.go. conn.drainIn() diff --git a/client/connection_test.go b/client/connection_test.go index e56df44..b4f3145 100644 --- a/client/connection_test.go +++ b/client/connection_test.go @@ -1,6 +1,7 @@ package client import ( + "context" "runtime" "strings" "testing" @@ -54,6 +55,7 @@ func setUp(t *testing.T, start ...bool) (*Conn, *testState) { nc := MockNetConn(t) c := SimpleClient("test", "test", "Testing IRC") c.initialise() + ctx := context.Background() c.st = st c.sock = nc @@ -61,7 +63,7 @@ func setUp(t *testing.T, start ...bool) (*Conn, *testState) { c.connected = true // If a second argument is passed to setUp, we tell postConnect not to // start the various goroutines that shuttle data around. - c.postConnect(len(start) == 0) + c.postConnect(ctx, len(start) == 0) // Sleep 1ms to allow background routines to start. <-time.After(time.Millisecond) @@ -166,7 +168,7 @@ func TestClientAndStateTracking(t *testing.T) { ctrl.Finish() } -func TestSendExitsOnDie(t *testing.T) { +func TestSendExitsOnCancel(t *testing.T) { // Passing a second value to setUp stops goroutines from starting c, s := setUp(t, false) defer s.tearDown() @@ -178,10 +180,11 @@ func TestSendExitsOnDie(t *testing.T) { // We want to test that the a goroutine calling send will exit correctly. exited := callCheck(t) + ctx, cancel := context.WithCancel(context.Background()) // send() will decrement the WaitGroup, so we must increment it. c.wg.Add(1) go func() { - c.send() + c.send(ctx) exited.call() }() @@ -194,13 +197,9 @@ func TestSendExitsOnDie(t *testing.T) { c.out <- "SENT AFTER START" s.nc.Expect("SENT AFTER START") - // Now, use the control channel to exit send and kill the goroutine. - // This sneakily uses the fact that the other two goroutines that would - // normally be waiting for die to close are not running, so we only send - // to the goroutine started above. Normally Close() closes c.die and - // signals to all three goroutines (send, ping, runLoop) to exit. + // Now, cancel the context to exit send and kill the goroutine. exited.assertNotCalled("Exited before signal sent.") - c.die <- struct{}{} + cancel() exited.assertWasCalled("Didn't exit after signal.") s.nc.ExpectNothing() @@ -221,7 +220,7 @@ func TestSendExitsOnWriteError(t *testing.T) { // send() will decrement the WaitGroup, so we must increment it. c.wg.Add(1) go func() { - c.send() + c.send(context.Background()) exited.call() }() @@ -251,6 +250,7 @@ func TestSendDeadlockOnFullBuffer(t *testing.T) { // We want to test that the a goroutine calling send will exit correctly. loopExit := callCheck(t) sendExit := callCheck(t) + ctx, cancel := context.WithCancel(context.Background()) // send() and runLoop() will decrement the WaitGroup, so we must increment it. c.wg.Add(2) @@ -274,13 +274,13 @@ func TestSendDeadlockOnFullBuffer(t *testing.T) { }) // And trigger it by starting runLoop and inserting a line into conn.in: go func() { - c.runLoop() + c.runLoop(ctx) loopExit.call() }() c.in <- &Line{Cmd: PRIVMSG, Raw: "WRITE THAT CAUSES DEADLOCK"} // At this point the handler should be blocked on a write to conn.out, - // preventing runLoop from looping and thus noticing conn.die is closed. + // preventing runLoop from looping and thus noticng the cancelled context. // // The next part is to force send() to call conn.Close(), which can // be done by closing the fake net.Conn so that it returns an error on @@ -292,8 +292,9 @@ func TestSendDeadlockOnFullBuffer(t *testing.T) { // to write it to the socket. It should immediately receive an error and // call conn.Close(), triggering the deadlock as it waits forever for // runLoop to call conn.wg.Done. + c.die = cancel // Close needs to cancel the context for us. go func() { - c.send() + c.send(ctx) sendExit.call() }() @@ -407,10 +408,11 @@ func TestPing(t *testing.T) { // Start ping loop. exited := callCheck(t) + ctx, cancel := context.WithCancel(context.Background()) // ping() will decrement the WaitGroup, so we must increment it. c.wg.Add(1) go func() { - c.ping() + c.ping(ctx) exited.call() }() @@ -438,13 +440,9 @@ func TestPing(t *testing.T) { t.Errorf("Line not output after another %s.", 2*res) } - // Now kill the ping loop. - // This sneakily uses the fact that the other two goroutines that would - // normally be waiting for die to close are not running, so we only send - // to the goroutine started above. Normally Close() closes c.die and - // signals to all three goroutines (send, ping, runLoop) to exit. + // Now kill the ping loop by cancelling the context. exited.assertNotCalled("Exited before signal sent.") - c.die <- struct{}{} + cancel() exited.assertWasCalled("Didn't exit after signal.") // Make sure we're no longer pinging by waiting >2x PingFreq <-time.After(2*c.cfg.PingFreq + res) @@ -478,10 +476,11 @@ func TestRunLoop(t *testing.T) { // We want to test that the a goroutine calling runLoop will exit correctly. // Now, we can expect the call to Dispatch to take place as runLoop starts. exited := callCheck(t) + ctx, cancel := context.WithCancel(context.Background()) // runLoop() will decrement the WaitGroup, so we must increment it. c.wg.Add(1) go func() { - c.runLoop() + c.runLoop(ctx) exited.call() }() h002.assertWasCalled("002 handler not called after runLoop started.") @@ -492,13 +491,9 @@ func TestRunLoop(t *testing.T) { c.in <- l3 h003.assertWasCalled("003 handler not called while runLoop started.") - // Now, use the control channel to exit send and kill the goroutine. - // This sneakily uses the fact that the other two goroutines that would - // normally be waiting for die to close are not running, so we only send - // to the goroutine started above. Normally Close() closes c.die and - // signals to all three goroutines (send, ping, runLoop) to exit. + // Now, cancel the context to exit runLoop and kill the goroutine. exited.assertNotCalled("Exited before signal sent.") - c.die <- struct{}{} + cancel() exited.assertWasCalled("Didn't exit after signal.") // Sending more on c.in shouldn't dispatch any further events