diff --git a/client/connection_test.go b/client/connection_test.go index acf4713..017d664 100644 --- a/client/connection_test.go +++ b/client/connection_test.go @@ -153,13 +153,6 @@ func TestClientAndStateTracking(t *testing.T) { t.Errorf("State tracker not disabled correctly.") } - // Finally, check state tracking handlers were all removed correctly - for k, _ := range stHandlers { - if _, ok := c.intHandlers.set[strings.ToLower(k)]; ok && k != "NICK" { - // A bit leaky, because intHandlers adds a NICK handler. - t.Errorf("State handler for '%s' not removed correctly.", k) - } - } if len(c.stRemovers) != 0 { t.Errorf("stRemovers not zeroed correctly when removing state handlers.") } diff --git a/client/dispatch.go b/client/dispatch.go index dc38a33..3aeb000 100644 --- a/client/dispatch.go +++ b/client/dispatch.go @@ -39,110 +39,133 @@ func (hf HandlerFunc) Handle(conn *Conn, line *Line) { hf(conn, line) } -// Handlers are organised using a map of linked-lists, with each map -// key representing an IRC verb or numeric, and the linked list values -// being handlers that are executed in parallel when a Line from the +// Handlers are organised using a map of lockless singly linked lists, with +// each map key representing an IRC verb or numeric, and the linked list +// values being handlers that are executed in parallel when a Line from the // server with that verb or numeric arrives. type hSet struct { set map[string]*hList sync.RWMutex } -type hList struct { - start, end *hNode +func (hs *hSet) getList(ev string) (hl *hList, ok bool) { + ev = strings.ToLower(ev) + hs.RLock() + defer hs.RUnlock() + hl, ok = hs.set[ev] + return } -// Storing the forward and backward links in the node allows O(1) removal. -// This probably isn't strictly necessary but I think it's kinda nice. -type hNode struct { - next, prev *hNode - set *hSet - event string - handler Handler +func (hs *hSet) getOrMakeList(ev string) (hl *hList) { + ev = strings.ToLower(ev) + hs.Lock() + defer hs.Unlock() + hl, ok := hs.set[ev] + if !ok { + hl = makeHList() + hs.set[ev] = hl + } + return hl } +// Lists are lockless thanks to atomic pointers. (which hNodePtr wraps) +type hList struct { + first, last hNodePtr +} + +// In order for the whole thing to be goroutine-safe, each list must contain a +// zero-valued node at any given time as its last element. You'll see why +// later down. +func makeHList() (hl *hList) { + hl, hn0 := &hList{}, &hNode{} + hl.first.store(hn0) + hl.last.store(hn0) + return +} + +// (hNodeState is also an atomic wrapper.) +type hNode struct { + next hNodePtr + state hNodeState + handler Handler +} + +// Nodes progress through these three stages in order as the program runs. +const ( + unready hNodeState = iota + active + unlinkable +) + // A hNode implements both Handler (with configurable panic recovery)... func (hn *hNode) Handle(conn *Conn, line *Line) { defer conn.cfg.Recover(conn, line) hn.handler.Handle(conn, line) } -// ... and Remover. +// ... and Remover, which works by flagging the node so the goroutines running +// hSet.dispatch know to ignore its handler and to dispose of it. func (hn *hNode) Remove() { - hn.set.remove(hn) + hn.state.store(unlinkable) } func handlerSet() *hSet { return &hSet{set: make(map[string]*hList)} } -// When a new Handler is added for an event, it is wrapped in a hNode and -// returned as a Remover so the caller can remove it at a later time. +// When a new Handler is added for an event, it is assigned into a hNode, +// which is returned as a Remover so the caller can remove it at a later time. +// +// Concerning goroutine-safety, the point is that the atomic swap there +// reserves the previous last node for this handler and puts up a new one. +// The former node has the desirable property that the rest of the list points +// to it, and the latter inherits this property once the former becomes part +// of the list. It's also the case that handler should't be read by +// hSet.dispatch before the node is marked as ready via state. func (hs *hSet) add(ev string, h Handler) Remover { - hs.Lock() - defer hs.Unlock() - ev = strings.ToLower(ev) - l, ok := hs.set[ev] - if !ok { - l = &hList{} - } - hn := &hNode{ - set: hs, - event: ev, - handler: h, - } - if !ok { - l.start = hn - } else { - hn.prev = l.end - l.end.next = hn - } - l.end = hn - hs.set[ev] = l + hl := hs.getOrMakeList(ev) + hn0 := &hNode{} + hn := hl.last.swap(hn0) + hn.next.store(hn0) + hn.handler = h + hn.state.compareAndSwap(unready, active) return hn } -func (hs *hSet) remove(hn *hNode) { - hs.Lock() - defer hs.Unlock() - l, ok := hs.set[hn.event] - if !ok { - logging.Error("Removing node for unknown event '%s'", hn.event) - return - } - if hn.next == nil { - l.end = hn.prev - } else { - hn.next.prev = hn.prev - } - if hn.prev == nil { - l.start = hn.next - } else { - hn.prev.next = hn.next - } - hn.next = nil - hn.prev = nil - hn.set = nil - if l.start == nil || l.end == nil { - delete(hs.set, hn.event) - } -} - +// And finally, dispatch works like so: it goes through the whole list while +// remembering the adress of the pointer that led it to the current node, +// which allows it to unlink it if it must be. Since the pointers are atomic, +// if many goroutine enter the same unlinkable node at the same time, they +// will all end up writing the same value to the pointer anyway. Even in +// cases where consecutive nodes are flagged and unlinking node n revives node +// n+1 which had been unlinked by making node n point to n+2 without the +// unlinker of n+1 noticing, all dead nodes are unmistakable and will +// eventually be definitely unlinked and garbage-collected. Also note that +// the fact that the last node is always a zero node, as well as letting the +// list grow concurrently, allows the next-to-last node to be unlinked safely. func (hs *hSet) dispatch(conn *Conn, line *Line) { - hs.RLock() - defer hs.RUnlock() - ev := strings.ToLower(line.Cmd) - list, ok := hs.set[ev] + hl, ok := hs.getList(line.Cmd) if !ok { - return + return // nothing to do } wg := &sync.WaitGroup{} - for hn := list.start; hn != nil; hn = hn.next { - wg.Add(1) - go func(hn *hNode) { - hn.Handle(conn, line.Copy()) - wg.Done() - }(hn) + hn, hnptr := hl.first.load(), &hl.first + for hn != nil { + switch hn.state.load() { + case active: + wg.Add(1) + go func(hn *hNode) { + hn.Handle(conn, line.Copy()) + wg.Done() + }(hn) + fallthrough + case unready: + hnptr = &hn.next + hn = hnptr.load() + case unlinkable: + hn = hn.next.load() + hnptr.store(hn) + } } wg.Wait() } diff --git a/client/dispatch_test.go b/client/dispatch_test.go index b79df64..64c1f3c 100644 --- a/client/dispatch_test.go +++ b/client/dispatch_test.go @@ -24,66 +24,28 @@ func TestHandlerSet(t *testing.T) { // Add one hn1 := hs.add("ONE", HandlerFunc(f)).(*hNode) - hl, ok := hs.set["one"] + _, ok := hs.set["one"] if len(hs.set) != 1 || !ok { t.Errorf("Set doesn't contain 'one' list after add().") } - if hn1.set != hs || hn1.event != "one" || hn1.prev != nil || hn1.next != nil { - t.Errorf("First node for 'one' not created correctly") - } - if hl.start != hn1 || hl.end != hn1 { - t.Errorf("Node not added to empty 'one' list correctly.") - } // Add another one... hn2 := hs.add("one", HandlerFunc(f)).(*hNode) if len(hs.set) != 1 { t.Errorf("Set contains more than 'one' list after add().") } - if hn2.set != hs || hn2.event != "one" { - t.Errorf("Second node for 'one' not created correctly") - } - if hn1.prev != nil || hn1.next != hn2 || hn2.prev != hn1 || hn2.next != nil { - t.Errorf("Nodes for 'one' not linked correctly.") - } - if hl.start != hn1 || hl.end != hn2 { - t.Errorf("Node not appended to 'one' list correctly.") - } // Add a third one! hn3 := hs.add("one", HandlerFunc(f)).(*hNode) if len(hs.set) != 1 { t.Errorf("Set contains more than 'one' list after add().") } - if hn3.set != hs || hn3.event != "one" { - t.Errorf("Third node for 'one' not created correctly") - } - if hn1.prev != nil || hn1.next != hn2 || - hn2.prev != hn1 || hn2.next != hn3 || - hn3.prev != hn2 || hn3.next != nil { - t.Errorf("Nodes for 'one' not linked correctly.") - } - if hl.start != hn1 || hl.end != hn3 { - t.Errorf("Node not appended to 'one' list correctly.") - } // And finally a fourth one! hn4 := hs.add("one", HandlerFunc(f)).(*hNode) if len(hs.set) != 1 { t.Errorf("Set contains more than 'one' list after add().") } - if hn4.set != hs || hn4.event != "one" { - t.Errorf("Fourth node for 'one' not created correctly.") - } - if hn1.prev != nil || hn1.next != hn2 || - hn2.prev != hn1 || hn2.next != hn3 || - hn3.prev != hn2 || hn3.next != hn4 || - hn4.prev != hn3 || hn4.next != nil { - t.Errorf("Nodes for 'one' not linked correctly.") - } - if hl.start != hn1 || hl.end != hn4 { - t.Errorf("Node not appended to 'one' list correctly.") - } // Dispatch should result in 4 additions. if atomic.LoadInt32(callcount) != 0 { @@ -100,17 +62,6 @@ func TestHandlerSet(t *testing.T) { if len(hs.set) != 1 { t.Errorf("Set list count changed after remove().") } - if hn3.set != nil || hn3.prev != nil || hn3.next != nil { - t.Errorf("Third node for 'one' not removed correctly.") - } - if hn1.prev != nil || hn1.next != hn2 || - hn2.prev != hn1 || hn2.next != hn4 || - hn4.prev != hn2 || hn4.next != nil { - t.Errorf("Third node for 'one' not unlinked correctly.") - } - if hl.start != hn1 || hl.end != hn4 { - t.Errorf("Third node for 'one' changed list pointers.") - } // Dispatch should result in 3 additions. hs.dispatch(c, &Line{Cmd: "One"}) @@ -120,19 +71,10 @@ func TestHandlerSet(t *testing.T) { } // Remove node 1. - hs.remove(hn1) + hn1.Remove() if len(hs.set) != 1 { t.Errorf("Set list count changed after remove().") } - if hn1.set != nil || hn1.prev != nil || hn1.next != nil { - t.Errorf("First node for 'one' not removed correctly.") - } - if hn2.prev != nil || hn2.next != hn4 || hn4.prev != hn2 || hn4.next != nil { - t.Errorf("First node for 'one' not unlinked correctly.") - } - if hl.start != hn2 || hl.end != hn4 { - t.Errorf("First node for 'one' didn't change list pointers.") - } // Dispatch should result in 2 additions. hs.dispatch(c, &Line{Cmd: "One"}) @@ -146,15 +88,6 @@ func TestHandlerSet(t *testing.T) { if len(hs.set) != 1 { t.Errorf("Set list count changed after remove().") } - if hn4.set != nil || hn4.prev != nil || hn4.next != nil { - t.Errorf("Fourth node for 'one' not removed correctly.") - } - if hn2.prev != nil || hn2.next != nil { - t.Errorf("Fourth node for 'one' not unlinked correctly.") - } - if hl.start != hn2 || hl.end != hn2 { - t.Errorf("Fourth node for 'one' didn't change list pointers.") - } // Dispatch should result in 1 addition. hs.dispatch(c, &Line{Cmd: "One"}) @@ -164,16 +97,7 @@ func TestHandlerSet(t *testing.T) { } // Remove node 2. - hs.remove(hn2) - if len(hs.set) != 0 { - t.Errorf("Removing last node in 'one' didn't remove list.") - } - if hn2.set != nil || hn2.prev != nil || hn2.next != nil { - t.Errorf("Second node for 'one' not removed correctly.") - } - if hl.start != nil || hl.end != nil { - t.Errorf("Second node for 'one' didn't change list pointers.") - } + hn2.Remove() // Dispatch should result in NO additions. hs.dispatch(c, &Line{Cmd: "One"}) diff --git a/client/dispatch_unsafe.go b/client/dispatch_unsafe.go new file mode 100644 index 0000000..2074d23 --- /dev/null +++ b/client/dispatch_unsafe.go @@ -0,0 +1,38 @@ +package client + +import ( + "sync/atomic" + "unsafe" +) + +type hNodePtr struct { + ptr unsafe.Pointer +} + +func (p *hNodePtr) load() *hNode { + return (*hNode)(atomic.LoadPointer(&p.ptr)) +} +func (p *hNodePtr) store(new *hNode) { + atomic.StorePointer(&p.ptr, unsafe.Pointer(new)) +} +func (p *hNodePtr) swap(new *hNode) (old *hNode) { + return (*hNode)(atomic.SwapPointer(&p.ptr, unsafe.Pointer(new))) +} +func (p *hNodePtr) compareAndSwap(old, new *hNode) (swapped bool) { + return atomic.CompareAndSwapPointer(&p.ptr, unsafe.Pointer(old), unsafe.Pointer(new)) +} + +type hNodeState uintptr + +func (s *hNodeState) load() hNodeState { + return hNodeState(atomic.LoadUintptr((*uintptr)(s))) +} +func (s *hNodeState) store(new hNodeState) { + atomic.StoreUintptr((*uintptr)(s), uintptr(new)) +} +func (s *hNodeState) swap(new hNodeState) (old hNodeState) { + return hNodeState(atomic.SwapUintptr((*uintptr)(s), uintptr(new))) +} +func (s *hNodeState) compareAndSwap(old, new hNodeState) (swapped bool) { + return atomic.CompareAndSwapUintptr((*uintptr)(s), uintptr(old), uintptr(new)) +}