diff --git a/client/connection.go b/client/connection.go index 9186dee..2775933 100644 --- a/client/connection.go +++ b/client/connection.go @@ -236,10 +236,12 @@ func (conn *Conn) postConnect(start bool) { bufio.NewReader(conn.sock), bufio.NewWriter(conn.sock)) if start { + conn.wg.Add(3) go conn.send() go conn.recv() go conn.runLoop() if conn.cfg.PingFreq > 0 { + conn.wg.Add(1) go conn.ping() } } @@ -252,7 +254,6 @@ func hasPort(s string) bool { // goroutine to pass data from output channel to write() func (conn *Conn) send() { - conn.wg.Add(1) defer conn.wg.Done() for { select { @@ -267,7 +268,6 @@ func (conn *Conn) send() { // receive one \r\n terminated line from peer, parse and dispatch it func (conn *Conn) recv() { - conn.wg.Add(1) for { s, err := conn.io.ReadString('\n') if err != nil { @@ -293,7 +293,6 @@ func (conn *Conn) recv() { // Repeatedly pings the server every PingFreq seconds (no matter what) func (conn *Conn) ping() { - conn.wg.Add(1) defer conn.wg.Done() tick := time.NewTicker(conn.cfg.PingFreq) for { @@ -310,7 +309,6 @@ func (conn *Conn) ping() { // goroutine to dispatch events for lines received on input channel func (conn *Conn) runLoop() { - conn.wg.Add(1) defer conn.wg.Done() for { select { diff --git a/client/connection_test.go b/client/connection_test.go index a4bf26d..aaa11f7 100644 --- a/client/connection_test.go +++ b/client/connection_test.go @@ -170,6 +170,8 @@ func TestSend(t *testing.T) { // We want to test that the a goroutine calling send will exit correctly. exited := callCheck(t) + // send() will decrement the WaitGroup, so we must increment it. + c.wg.Add(1) go func() { c.send() exited.call() @@ -224,6 +226,8 @@ func TestRecv(t *testing.T) { // We want to test that the a goroutine calling recv will exit correctly. exited := callCheck(t) + // recv() will decrement the WaitGroup, so we must increment it. + c.wg.Add(1) go func() { c.recv() exited.call() @@ -282,6 +286,8 @@ func TestPing(t *testing.T) { // Start ping loop. exited := callCheck(t) + // ping() will decrement the WaitGroup, so we must increment it. + c.wg.Add(1) go func() { c.ping() exited.call() @@ -349,6 +355,8 @@ func TestRunLoop(t *testing.T) { // We want to test that the a goroutine calling runLoop will exit correctly. // Now, we can expect the call to Dispatch to take place as runLoop starts. exited := callCheck(t) + // runLoop() will decrement the WaitGroup, so we must increment it. + c.wg.Add(1) go func() { c.runLoop() exited.call()