mirror of https://github.com/fluffle/goirc
Use a Context to kill internal goroutines.
This commit is contained in:
parent
1bb2dff298
commit
27cc39787d
|
@ -45,8 +45,8 @@ type Conn struct {
|
|||
out chan string
|
||||
connected bool
|
||||
|
||||
// Control channel and WaitGroup for goroutines
|
||||
die chan struct{}
|
||||
// CancelFunc and WaitGroup for goroutines
|
||||
die context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
|
||||
// Internal counters for flood protection
|
||||
|
@ -295,7 +295,7 @@ func (conn *Conn) initialise() {
|
|||
conn.sock = nil
|
||||
conn.in = make(chan *Line, 32)
|
||||
conn.out = make(chan string, 32)
|
||||
conn.die = make(chan struct{})
|
||||
conn.die = nil
|
||||
if conn.st != nil {
|
||||
conn.st.Wipe()
|
||||
}
|
||||
|
@ -404,24 +404,25 @@ func (conn *Conn) internalConnect(ctx context.Context) error {
|
|||
conn.sock = s
|
||||
}
|
||||
|
||||
conn.postConnect(true)
|
||||
conn.postConnect(ctx, true)
|
||||
conn.connected = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// postConnect performs post-connection setup, for ease of testing.
|
||||
func (conn *Conn) postConnect(start bool) {
|
||||
func (conn *Conn) postConnect(ctx context.Context, start bool) {
|
||||
conn.io = bufio.NewReadWriter(
|
||||
bufio.NewReader(conn.sock),
|
||||
bufio.NewWriter(conn.sock))
|
||||
if start {
|
||||
ctx, conn.die = context.WithCancel(ctx)
|
||||
conn.wg.Add(3)
|
||||
go conn.send()
|
||||
go conn.send(ctx)
|
||||
go conn.recv()
|
||||
go conn.runLoop()
|
||||
go conn.runLoop(ctx)
|
||||
if conn.cfg.PingFreq > 0 {
|
||||
conn.wg.Add(1)
|
||||
go conn.ping()
|
||||
go conn.ping(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -434,8 +435,8 @@ func hasPort(s string) bool {
|
|||
|
||||
// send is started as a goroutine after a connection is established.
|
||||
// It shuttles data from the output channel to write(), and is killed
|
||||
// when Conn.die is closed.
|
||||
func (conn *Conn) send() {
|
||||
// when the context is cancelled.
|
||||
func (conn *Conn) send(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case line := <-conn.out:
|
||||
|
@ -446,7 +447,7 @@ func (conn *Conn) send() {
|
|||
conn.Close()
|
||||
return
|
||||
}
|
||||
case <-conn.die:
|
||||
case <-ctx.Done():
|
||||
// control channel closed, bail out
|
||||
conn.wg.Done()
|
||||
return
|
||||
|
@ -483,14 +484,14 @@ func (conn *Conn) recv() {
|
|||
|
||||
// ping is started as a goroutine after a connection is established, as
|
||||
// long as Config.PingFreq >0. It pings the server every PingFreq seconds.
|
||||
func (conn *Conn) ping() {
|
||||
func (conn *Conn) ping(ctx context.Context) {
|
||||
defer conn.wg.Done()
|
||||
tick := time.NewTicker(conn.cfg.PingFreq)
|
||||
for {
|
||||
select {
|
||||
case <-tick.C:
|
||||
conn.Ping(fmt.Sprintf("%d", time.Now().UnixNano()))
|
||||
case <-conn.die:
|
||||
case <-ctx.Done():
|
||||
// control channel closed, bail out
|
||||
tick.Stop()
|
||||
return
|
||||
|
@ -501,13 +502,13 @@ func (conn *Conn) ping() {
|
|||
// runLoop is started as a goroutine after a connection is established.
|
||||
// It pulls Lines from the input channel and dispatches them to any
|
||||
// handlers that have been registered for that IRC verb.
|
||||
func (conn *Conn) runLoop() {
|
||||
func (conn *Conn) runLoop(ctx context.Context) {
|
||||
defer conn.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case line := <-conn.in:
|
||||
conn.dispatch(line)
|
||||
case <-conn.die:
|
||||
case <-ctx.Done():
|
||||
// control channel closed, bail out
|
||||
return
|
||||
}
|
||||
|
@ -572,7 +573,9 @@ func (conn *Conn) Close() error {
|
|||
logging.Info("irc.Close(): Disconnected from server.")
|
||||
conn.connected = false
|
||||
err := conn.sock.Close()
|
||||
close(conn.die)
|
||||
if conn.die != nil {
|
||||
conn.die()
|
||||
}
|
||||
// Drain both in and out channels to avoid a deadlock if the buffers
|
||||
// have filled. See TestSendDeadlockOnFullBuffer in connection_test.go.
|
||||
conn.drainIn()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
@ -54,6 +55,7 @@ func setUp(t *testing.T, start ...bool) (*Conn, *testState) {
|
|||
nc := MockNetConn(t)
|
||||
c := SimpleClient("test", "test", "Testing IRC")
|
||||
c.initialise()
|
||||
ctx := context.Background()
|
||||
|
||||
c.st = st
|
||||
c.sock = nc
|
||||
|
@ -61,7 +63,7 @@ func setUp(t *testing.T, start ...bool) (*Conn, *testState) {
|
|||
c.connected = true
|
||||
// If a second argument is passed to setUp, we tell postConnect not to
|
||||
// start the various goroutines that shuttle data around.
|
||||
c.postConnect(len(start) == 0)
|
||||
c.postConnect(ctx, len(start) == 0)
|
||||
// Sleep 1ms to allow background routines to start.
|
||||
<-time.After(time.Millisecond)
|
||||
|
||||
|
@ -166,7 +168,7 @@ func TestClientAndStateTracking(t *testing.T) {
|
|||
ctrl.Finish()
|
||||
}
|
||||
|
||||
func TestSendExitsOnDie(t *testing.T) {
|
||||
func TestSendExitsOnCancel(t *testing.T) {
|
||||
// Passing a second value to setUp stops goroutines from starting
|
||||
c, s := setUp(t, false)
|
||||
defer s.tearDown()
|
||||
|
@ -178,10 +180,11 @@ func TestSendExitsOnDie(t *testing.T) {
|
|||
|
||||
// We want to test that the a goroutine calling send will exit correctly.
|
||||
exited := callCheck(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// send() will decrement the WaitGroup, so we must increment it.
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
c.send()
|
||||
c.send(ctx)
|
||||
exited.call()
|
||||
}()
|
||||
|
||||
|
@ -194,13 +197,9 @@ func TestSendExitsOnDie(t *testing.T) {
|
|||
c.out <- "SENT AFTER START"
|
||||
s.nc.Expect("SENT AFTER START")
|
||||
|
||||
// Now, use the control channel to exit send and kill the goroutine.
|
||||
// This sneakily uses the fact that the other two goroutines that would
|
||||
// normally be waiting for die to close are not running, so we only send
|
||||
// to the goroutine started above. Normally Close() closes c.die and
|
||||
// signals to all three goroutines (send, ping, runLoop) to exit.
|
||||
// Now, cancel the context to exit send and kill the goroutine.
|
||||
exited.assertNotCalled("Exited before signal sent.")
|
||||
c.die <- struct{}{}
|
||||
cancel()
|
||||
exited.assertWasCalled("Didn't exit after signal.")
|
||||
s.nc.ExpectNothing()
|
||||
|
||||
|
@ -221,7 +220,7 @@ func TestSendExitsOnWriteError(t *testing.T) {
|
|||
// send() will decrement the WaitGroup, so we must increment it.
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
c.send()
|
||||
c.send(context.Background())
|
||||
exited.call()
|
||||
}()
|
||||
|
||||
|
@ -251,6 +250,7 @@ func TestSendDeadlockOnFullBuffer(t *testing.T) {
|
|||
// We want to test that the a goroutine calling send will exit correctly.
|
||||
loopExit := callCheck(t)
|
||||
sendExit := callCheck(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// send() and runLoop() will decrement the WaitGroup, so we must increment it.
|
||||
c.wg.Add(2)
|
||||
|
||||
|
@ -274,13 +274,13 @@ func TestSendDeadlockOnFullBuffer(t *testing.T) {
|
|||
})
|
||||
// And trigger it by starting runLoop and inserting a line into conn.in:
|
||||
go func() {
|
||||
c.runLoop()
|
||||
c.runLoop(ctx)
|
||||
loopExit.call()
|
||||
}()
|
||||
c.in <- &Line{Cmd: PRIVMSG, Raw: "WRITE THAT CAUSES DEADLOCK"}
|
||||
|
||||
// At this point the handler should be blocked on a write to conn.out,
|
||||
// preventing runLoop from looping and thus noticing conn.die is closed.
|
||||
// preventing runLoop from looping and thus noticng the cancelled context.
|
||||
//
|
||||
// The next part is to force send() to call conn.Close(), which can
|
||||
// be done by closing the fake net.Conn so that it returns an error on
|
||||
|
@ -292,8 +292,9 @@ func TestSendDeadlockOnFullBuffer(t *testing.T) {
|
|||
// to write it to the socket. It should immediately receive an error and
|
||||
// call conn.Close(), triggering the deadlock as it waits forever for
|
||||
// runLoop to call conn.wg.Done.
|
||||
c.die = cancel // Close needs to cancel the context for us.
|
||||
go func() {
|
||||
c.send()
|
||||
c.send(ctx)
|
||||
sendExit.call()
|
||||
}()
|
||||
|
||||
|
@ -407,10 +408,11 @@ func TestPing(t *testing.T) {
|
|||
|
||||
// Start ping loop.
|
||||
exited := callCheck(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// ping() will decrement the WaitGroup, so we must increment it.
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
c.ping()
|
||||
c.ping(ctx)
|
||||
exited.call()
|
||||
}()
|
||||
|
||||
|
@ -438,13 +440,9 @@ func TestPing(t *testing.T) {
|
|||
t.Errorf("Line not output after another %s.", 2*res)
|
||||
}
|
||||
|
||||
// Now kill the ping loop.
|
||||
// This sneakily uses the fact that the other two goroutines that would
|
||||
// normally be waiting for die to close are not running, so we only send
|
||||
// to the goroutine started above. Normally Close() closes c.die and
|
||||
// signals to all three goroutines (send, ping, runLoop) to exit.
|
||||
// Now kill the ping loop by cancelling the context.
|
||||
exited.assertNotCalled("Exited before signal sent.")
|
||||
c.die <- struct{}{}
|
||||
cancel()
|
||||
exited.assertWasCalled("Didn't exit after signal.")
|
||||
// Make sure we're no longer pinging by waiting >2x PingFreq
|
||||
<-time.After(2*c.cfg.PingFreq + res)
|
||||
|
@ -478,10 +476,11 @@ func TestRunLoop(t *testing.T) {
|
|||
// We want to test that the a goroutine calling runLoop will exit correctly.
|
||||
// Now, we can expect the call to Dispatch to take place as runLoop starts.
|
||||
exited := callCheck(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// runLoop() will decrement the WaitGroup, so we must increment it.
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
c.runLoop()
|
||||
c.runLoop(ctx)
|
||||
exited.call()
|
||||
}()
|
||||
h002.assertWasCalled("002 handler not called after runLoop started.")
|
||||
|
@ -492,13 +491,9 @@ func TestRunLoop(t *testing.T) {
|
|||
c.in <- l3
|
||||
h003.assertWasCalled("003 handler not called while runLoop started.")
|
||||
|
||||
// Now, use the control channel to exit send and kill the goroutine.
|
||||
// This sneakily uses the fact that the other two goroutines that would
|
||||
// normally be waiting for die to close are not running, so we only send
|
||||
// to the goroutine started above. Normally Close() closes c.die and
|
||||
// signals to all three goroutines (send, ping, runLoop) to exit.
|
||||
// Now, cancel the context to exit runLoop and kill the goroutine.
|
||||
exited.assertNotCalled("Exited before signal sent.")
|
||||
c.die <- struct{}{}
|
||||
cancel()
|
||||
exited.assertWasCalled("Didn't exit after signal.")
|
||||
|
||||
// Sending more on c.in shouldn't dispatch any further events
|
||||
|
|
Loading…
Reference in New Issue