mirror of https://github.com/fluffle/goirc
Use a Context to kill internal goroutines.
This commit is contained in:
parent
1bb2dff298
commit
27cc39787d
|
@ -45,8 +45,8 @@ type Conn struct {
|
||||||
out chan string
|
out chan string
|
||||||
connected bool
|
connected bool
|
||||||
|
|
||||||
// Control channel and WaitGroup for goroutines
|
// CancelFunc and WaitGroup for goroutines
|
||||||
die chan struct{}
|
die context.CancelFunc
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
|
|
||||||
// Internal counters for flood protection
|
// Internal counters for flood protection
|
||||||
|
@ -295,7 +295,7 @@ func (conn *Conn) initialise() {
|
||||||
conn.sock = nil
|
conn.sock = nil
|
||||||
conn.in = make(chan *Line, 32)
|
conn.in = make(chan *Line, 32)
|
||||||
conn.out = make(chan string, 32)
|
conn.out = make(chan string, 32)
|
||||||
conn.die = make(chan struct{})
|
conn.die = nil
|
||||||
if conn.st != nil {
|
if conn.st != nil {
|
||||||
conn.st.Wipe()
|
conn.st.Wipe()
|
||||||
}
|
}
|
||||||
|
@ -404,24 +404,25 @@ func (conn *Conn) internalConnect(ctx context.Context) error {
|
||||||
conn.sock = s
|
conn.sock = s
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.postConnect(true)
|
conn.postConnect(ctx, true)
|
||||||
conn.connected = true
|
conn.connected = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// postConnect performs post-connection setup, for ease of testing.
|
// postConnect performs post-connection setup, for ease of testing.
|
||||||
func (conn *Conn) postConnect(start bool) {
|
func (conn *Conn) postConnect(ctx context.Context, start bool) {
|
||||||
conn.io = bufio.NewReadWriter(
|
conn.io = bufio.NewReadWriter(
|
||||||
bufio.NewReader(conn.sock),
|
bufio.NewReader(conn.sock),
|
||||||
bufio.NewWriter(conn.sock))
|
bufio.NewWriter(conn.sock))
|
||||||
if start {
|
if start {
|
||||||
|
ctx, conn.die = context.WithCancel(ctx)
|
||||||
conn.wg.Add(3)
|
conn.wg.Add(3)
|
||||||
go conn.send()
|
go conn.send(ctx)
|
||||||
go conn.recv()
|
go conn.recv()
|
||||||
go conn.runLoop()
|
go conn.runLoop(ctx)
|
||||||
if conn.cfg.PingFreq > 0 {
|
if conn.cfg.PingFreq > 0 {
|
||||||
conn.wg.Add(1)
|
conn.wg.Add(1)
|
||||||
go conn.ping()
|
go conn.ping(ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -434,8 +435,8 @@ func hasPort(s string) bool {
|
||||||
|
|
||||||
// send is started as a goroutine after a connection is established.
|
// send is started as a goroutine after a connection is established.
|
||||||
// It shuttles data from the output channel to write(), and is killed
|
// It shuttles data from the output channel to write(), and is killed
|
||||||
// when Conn.die is closed.
|
// when the context is cancelled.
|
||||||
func (conn *Conn) send() {
|
func (conn *Conn) send(ctx context.Context) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case line := <-conn.out:
|
case line := <-conn.out:
|
||||||
|
@ -446,7 +447,7 @@ func (conn *Conn) send() {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case <-conn.die:
|
case <-ctx.Done():
|
||||||
// control channel closed, bail out
|
// control channel closed, bail out
|
||||||
conn.wg.Done()
|
conn.wg.Done()
|
||||||
return
|
return
|
||||||
|
@ -483,14 +484,14 @@ func (conn *Conn) recv() {
|
||||||
|
|
||||||
// ping is started as a goroutine after a connection is established, as
|
// ping is started as a goroutine after a connection is established, as
|
||||||
// long as Config.PingFreq >0. It pings the server every PingFreq seconds.
|
// long as Config.PingFreq >0. It pings the server every PingFreq seconds.
|
||||||
func (conn *Conn) ping() {
|
func (conn *Conn) ping(ctx context.Context) {
|
||||||
defer conn.wg.Done()
|
defer conn.wg.Done()
|
||||||
tick := time.NewTicker(conn.cfg.PingFreq)
|
tick := time.NewTicker(conn.cfg.PingFreq)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-tick.C:
|
case <-tick.C:
|
||||||
conn.Ping(fmt.Sprintf("%d", time.Now().UnixNano()))
|
conn.Ping(fmt.Sprintf("%d", time.Now().UnixNano()))
|
||||||
case <-conn.die:
|
case <-ctx.Done():
|
||||||
// control channel closed, bail out
|
// control channel closed, bail out
|
||||||
tick.Stop()
|
tick.Stop()
|
||||||
return
|
return
|
||||||
|
@ -501,13 +502,13 @@ func (conn *Conn) ping() {
|
||||||
// runLoop is started as a goroutine after a connection is established.
|
// runLoop is started as a goroutine after a connection is established.
|
||||||
// It pulls Lines from the input channel and dispatches them to any
|
// It pulls Lines from the input channel and dispatches them to any
|
||||||
// handlers that have been registered for that IRC verb.
|
// handlers that have been registered for that IRC verb.
|
||||||
func (conn *Conn) runLoop() {
|
func (conn *Conn) runLoop(ctx context.Context) {
|
||||||
defer conn.wg.Done()
|
defer conn.wg.Done()
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case line := <-conn.in:
|
case line := <-conn.in:
|
||||||
conn.dispatch(line)
|
conn.dispatch(line)
|
||||||
case <-conn.die:
|
case <-ctx.Done():
|
||||||
// control channel closed, bail out
|
// control channel closed, bail out
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -572,7 +573,9 @@ func (conn *Conn) Close() error {
|
||||||
logging.Info("irc.Close(): Disconnected from server.")
|
logging.Info("irc.Close(): Disconnected from server.")
|
||||||
conn.connected = false
|
conn.connected = false
|
||||||
err := conn.sock.Close()
|
err := conn.sock.Close()
|
||||||
close(conn.die)
|
if conn.die != nil {
|
||||||
|
conn.die()
|
||||||
|
}
|
||||||
// Drain both in and out channels to avoid a deadlock if the buffers
|
// Drain both in and out channels to avoid a deadlock if the buffers
|
||||||
// have filled. See TestSendDeadlockOnFullBuffer in connection_test.go.
|
// have filled. See TestSendDeadlockOnFullBuffer in connection_test.go.
|
||||||
conn.drainIn()
|
conn.drainIn()
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -54,6 +55,7 @@ func setUp(t *testing.T, start ...bool) (*Conn, *testState) {
|
||||||
nc := MockNetConn(t)
|
nc := MockNetConn(t)
|
||||||
c := SimpleClient("test", "test", "Testing IRC")
|
c := SimpleClient("test", "test", "Testing IRC")
|
||||||
c.initialise()
|
c.initialise()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
c.st = st
|
c.st = st
|
||||||
c.sock = nc
|
c.sock = nc
|
||||||
|
@ -61,7 +63,7 @@ func setUp(t *testing.T, start ...bool) (*Conn, *testState) {
|
||||||
c.connected = true
|
c.connected = true
|
||||||
// If a second argument is passed to setUp, we tell postConnect not to
|
// If a second argument is passed to setUp, we tell postConnect not to
|
||||||
// start the various goroutines that shuttle data around.
|
// start the various goroutines that shuttle data around.
|
||||||
c.postConnect(len(start) == 0)
|
c.postConnect(ctx, len(start) == 0)
|
||||||
// Sleep 1ms to allow background routines to start.
|
// Sleep 1ms to allow background routines to start.
|
||||||
<-time.After(time.Millisecond)
|
<-time.After(time.Millisecond)
|
||||||
|
|
||||||
|
@ -166,7 +168,7 @@ func TestClientAndStateTracking(t *testing.T) {
|
||||||
ctrl.Finish()
|
ctrl.Finish()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSendExitsOnDie(t *testing.T) {
|
func TestSendExitsOnCancel(t *testing.T) {
|
||||||
// Passing a second value to setUp stops goroutines from starting
|
// Passing a second value to setUp stops goroutines from starting
|
||||||
c, s := setUp(t, false)
|
c, s := setUp(t, false)
|
||||||
defer s.tearDown()
|
defer s.tearDown()
|
||||||
|
@ -178,10 +180,11 @@ func TestSendExitsOnDie(t *testing.T) {
|
||||||
|
|
||||||
// We want to test that the a goroutine calling send will exit correctly.
|
// We want to test that the a goroutine calling send will exit correctly.
|
||||||
exited := callCheck(t)
|
exited := callCheck(t)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
// send() will decrement the WaitGroup, so we must increment it.
|
// send() will decrement the WaitGroup, so we must increment it.
|
||||||
c.wg.Add(1)
|
c.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
c.send()
|
c.send(ctx)
|
||||||
exited.call()
|
exited.call()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -194,13 +197,9 @@ func TestSendExitsOnDie(t *testing.T) {
|
||||||
c.out <- "SENT AFTER START"
|
c.out <- "SENT AFTER START"
|
||||||
s.nc.Expect("SENT AFTER START")
|
s.nc.Expect("SENT AFTER START")
|
||||||
|
|
||||||
// Now, use the control channel to exit send and kill the goroutine.
|
// Now, cancel the context to exit send and kill the goroutine.
|
||||||
// This sneakily uses the fact that the other two goroutines that would
|
|
||||||
// normally be waiting for die to close are not running, so we only send
|
|
||||||
// to the goroutine started above. Normally Close() closes c.die and
|
|
||||||
// signals to all three goroutines (send, ping, runLoop) to exit.
|
|
||||||
exited.assertNotCalled("Exited before signal sent.")
|
exited.assertNotCalled("Exited before signal sent.")
|
||||||
c.die <- struct{}{}
|
cancel()
|
||||||
exited.assertWasCalled("Didn't exit after signal.")
|
exited.assertWasCalled("Didn't exit after signal.")
|
||||||
s.nc.ExpectNothing()
|
s.nc.ExpectNothing()
|
||||||
|
|
||||||
|
@ -221,7 +220,7 @@ func TestSendExitsOnWriteError(t *testing.T) {
|
||||||
// send() will decrement the WaitGroup, so we must increment it.
|
// send() will decrement the WaitGroup, so we must increment it.
|
||||||
c.wg.Add(1)
|
c.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
c.send()
|
c.send(context.Background())
|
||||||
exited.call()
|
exited.call()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -251,6 +250,7 @@ func TestSendDeadlockOnFullBuffer(t *testing.T) {
|
||||||
// We want to test that the a goroutine calling send will exit correctly.
|
// We want to test that the a goroutine calling send will exit correctly.
|
||||||
loopExit := callCheck(t)
|
loopExit := callCheck(t)
|
||||||
sendExit := callCheck(t)
|
sendExit := callCheck(t)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
// send() and runLoop() will decrement the WaitGroup, so we must increment it.
|
// send() and runLoop() will decrement the WaitGroup, so we must increment it.
|
||||||
c.wg.Add(2)
|
c.wg.Add(2)
|
||||||
|
|
||||||
|
@ -274,13 +274,13 @@ func TestSendDeadlockOnFullBuffer(t *testing.T) {
|
||||||
})
|
})
|
||||||
// And trigger it by starting runLoop and inserting a line into conn.in:
|
// And trigger it by starting runLoop and inserting a line into conn.in:
|
||||||
go func() {
|
go func() {
|
||||||
c.runLoop()
|
c.runLoop(ctx)
|
||||||
loopExit.call()
|
loopExit.call()
|
||||||
}()
|
}()
|
||||||
c.in <- &Line{Cmd: PRIVMSG, Raw: "WRITE THAT CAUSES DEADLOCK"}
|
c.in <- &Line{Cmd: PRIVMSG, Raw: "WRITE THAT CAUSES DEADLOCK"}
|
||||||
|
|
||||||
// At this point the handler should be blocked on a write to conn.out,
|
// At this point the handler should be blocked on a write to conn.out,
|
||||||
// preventing runLoop from looping and thus noticing conn.die is closed.
|
// preventing runLoop from looping and thus noticng the cancelled context.
|
||||||
//
|
//
|
||||||
// The next part is to force send() to call conn.Close(), which can
|
// The next part is to force send() to call conn.Close(), which can
|
||||||
// be done by closing the fake net.Conn so that it returns an error on
|
// be done by closing the fake net.Conn so that it returns an error on
|
||||||
|
@ -292,8 +292,9 @@ func TestSendDeadlockOnFullBuffer(t *testing.T) {
|
||||||
// to write it to the socket. It should immediately receive an error and
|
// to write it to the socket. It should immediately receive an error and
|
||||||
// call conn.Close(), triggering the deadlock as it waits forever for
|
// call conn.Close(), triggering the deadlock as it waits forever for
|
||||||
// runLoop to call conn.wg.Done.
|
// runLoop to call conn.wg.Done.
|
||||||
|
c.die = cancel // Close needs to cancel the context for us.
|
||||||
go func() {
|
go func() {
|
||||||
c.send()
|
c.send(ctx)
|
||||||
sendExit.call()
|
sendExit.call()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -407,10 +408,11 @@ func TestPing(t *testing.T) {
|
||||||
|
|
||||||
// Start ping loop.
|
// Start ping loop.
|
||||||
exited := callCheck(t)
|
exited := callCheck(t)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
// ping() will decrement the WaitGroup, so we must increment it.
|
// ping() will decrement the WaitGroup, so we must increment it.
|
||||||
c.wg.Add(1)
|
c.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
c.ping()
|
c.ping(ctx)
|
||||||
exited.call()
|
exited.call()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -438,13 +440,9 @@ func TestPing(t *testing.T) {
|
||||||
t.Errorf("Line not output after another %s.", 2*res)
|
t.Errorf("Line not output after another %s.", 2*res)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now kill the ping loop.
|
// Now kill the ping loop by cancelling the context.
|
||||||
// This sneakily uses the fact that the other two goroutines that would
|
|
||||||
// normally be waiting for die to close are not running, so we only send
|
|
||||||
// to the goroutine started above. Normally Close() closes c.die and
|
|
||||||
// signals to all three goroutines (send, ping, runLoop) to exit.
|
|
||||||
exited.assertNotCalled("Exited before signal sent.")
|
exited.assertNotCalled("Exited before signal sent.")
|
||||||
c.die <- struct{}{}
|
cancel()
|
||||||
exited.assertWasCalled("Didn't exit after signal.")
|
exited.assertWasCalled("Didn't exit after signal.")
|
||||||
// Make sure we're no longer pinging by waiting >2x PingFreq
|
// Make sure we're no longer pinging by waiting >2x PingFreq
|
||||||
<-time.After(2*c.cfg.PingFreq + res)
|
<-time.After(2*c.cfg.PingFreq + res)
|
||||||
|
@ -478,10 +476,11 @@ func TestRunLoop(t *testing.T) {
|
||||||
// We want to test that the a goroutine calling runLoop will exit correctly.
|
// We want to test that the a goroutine calling runLoop will exit correctly.
|
||||||
// Now, we can expect the call to Dispatch to take place as runLoop starts.
|
// Now, we can expect the call to Dispatch to take place as runLoop starts.
|
||||||
exited := callCheck(t)
|
exited := callCheck(t)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
// runLoop() will decrement the WaitGroup, so we must increment it.
|
// runLoop() will decrement the WaitGroup, so we must increment it.
|
||||||
c.wg.Add(1)
|
c.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
c.runLoop()
|
c.runLoop(ctx)
|
||||||
exited.call()
|
exited.call()
|
||||||
}()
|
}()
|
||||||
h002.assertWasCalled("002 handler not called after runLoop started.")
|
h002.assertWasCalled("002 handler not called after runLoop started.")
|
||||||
|
@ -492,13 +491,9 @@ func TestRunLoop(t *testing.T) {
|
||||||
c.in <- l3
|
c.in <- l3
|
||||||
h003.assertWasCalled("003 handler not called while runLoop started.")
|
h003.assertWasCalled("003 handler not called while runLoop started.")
|
||||||
|
|
||||||
// Now, use the control channel to exit send and kill the goroutine.
|
// Now, cancel the context to exit runLoop and kill the goroutine.
|
||||||
// This sneakily uses the fact that the other two goroutines that would
|
|
||||||
// normally be waiting for die to close are not running, so we only send
|
|
||||||
// to the goroutine started above. Normally Close() closes c.die and
|
|
||||||
// signals to all three goroutines (send, ping, runLoop) to exit.
|
|
||||||
exited.assertNotCalled("Exited before signal sent.")
|
exited.assertNotCalled("Exited before signal sent.")
|
||||||
c.die <- struct{}{}
|
cancel()
|
||||||
exited.assertWasCalled("Didn't exit after signal.")
|
exited.assertWasCalled("Didn't exit after signal.")
|
||||||
|
|
||||||
// Sending more on c.in shouldn't dispatch any further events
|
// Sending more on c.in shouldn't dispatch any further events
|
||||||
|
|
Loading…
Reference in New Issue