diff --git a/client/connection.go b/client/connection.go index e46b50d..9b4587e 100644 --- a/client/connection.go +++ b/client/connection.go @@ -176,8 +176,6 @@ func Client(cfg *Config) *Conn { conn := &Conn{ cfg: cfg, dialer: dialer, - in: make(chan *Line, 32), - out: make(chan string, 32), intHandlers: handlerSet(), fgHandlers: handlerSet(), bgHandlers: handlerSet(), @@ -265,6 +263,8 @@ func (conn *Conn) DisableStateTracking() { func (conn *Conn) initialise() { conn.io = nil conn.sock = nil + conn.in = make(chan *Line, 32) + conn.out = make(chan string, 32) conn.die = make(chan struct{}) if conn.st != nil { conn.st.Wipe() @@ -510,6 +510,10 @@ func (conn *Conn) shutdown() { conn.connected = false conn.sock.Close() close(conn.die) + // Drain both in and out channels to avoid a deadlock if the buffers + // have filled. See TestSendDeadlockOnFullBuffer in connection_test.go. + conn.drainIn() + conn.drainOut() conn.wg.Wait() conn.mu.Unlock() // Dispatch after closing connection but before reinit @@ -517,6 +521,28 @@ func (conn *Conn) shutdown() { conn.dispatch(&Line{Cmd: DISCONNECTED, Time: time.Now()}) } +// drainIn sends all data buffered in conn.in to /dev/null. +func (conn *Conn) drainIn() { + for { + select { + case <-conn.in: + default: + return + } + } +} + +// drainOut does the same for conn.out. Generics! +func (conn *Conn) drainOut() { + for { + select { + case <-conn.out: + default: + return + } + } +} + // Dumps a load of information about the current state of the connection to a // string for debugging state tracking and other such things. func (conn *Conn) String() string {