From 92518943ff8e6f32bacce61a2c008fc0feb69c32 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Taavi=20V=C3=A4=C3=A4n=C3=A4nen?= <hi@taavi.wtf>
Date: Sun, 30 Oct 2022 12:01:33 +0200
Subject: [PATCH] Add SASL authentication support

This hacks together support for IRCv3.1 SASL. Currently only SASL PLAIN
is supported, but it's implemented in a way that adding support for
other types should not require too many changes to the current code.
---
 client/commands.go   |   6 +++
 client/connection.go |  38 ++++++++++------
 client/handlers.go   | 106 +++++++++++++++++++++++++++++++++++++++----
 client/sasl_test.go  | 103 +++++++++++++++++++++++++++++++++++++++++
 go.mod               |   1 +
 go.sum               |   3 +-
 6 files changed, 234 insertions(+), 23 deletions(-)
 create mode 100644 client/sasl_test.go

diff --git a/client/commands.go b/client/commands.go
index dac7059..f5d4c58 100644
--- a/client/commands.go
+++ b/client/commands.go
@@ -10,6 +10,7 @@ const (
 	CONNECTED    = "CONNECTED"
 	DISCONNECTED = "DISCONNECTED"
 	ACTION       = "ACTION"
+	AUTHENTICATE = "AUTHENTICATE"
 	AWAY         = "AWAY"
 	CAP          = "CAP"
 	CTCP         = "CTCP"
@@ -322,3 +323,8 @@ func (conn *Conn) Cap(subcommmand string, capabilities ...string) {
 		}
 	}
 }
+
+// Authenticate send an AUTHENTICATE command to the server.
+func (conn *Conn) Authenticate(message string) {
+	conn.Raw(AUTHENTICATE + " " + message)
+}
diff --git a/client/connection.go b/client/connection.go
index 8ed5a84..aa9f2f5 100644
--- a/client/connection.go
+++ b/client/connection.go
@@ -13,6 +13,7 @@ import (
 	"sync"
 	"time"
 
+	"github.com/emersion/go-sasl"
 	"github.com/fluffle/goirc/logging"
 	"github.com/fluffle/goirc/state"
 	"golang.org/x/net/proxy"
@@ -51,6 +52,9 @@ type Conn struct {
 	// Capabilites currently enabled
 	currCaps *capSet
 
+	// SASL internals
+	saslRemainingData []byte
+
 	// CancelFunc and WaitGroup for goroutines
 	die context.CancelFunc
 	wg  sync.WaitGroup
@@ -101,6 +105,9 @@ type Config struct {
 	// A list of capabilities to request to the server during registration.
 	Capabilites []string
 
+	// SASL configuration to use to authenticate the connection.
+	Sasl sasl.Client
+
 	// Replaceable function to customise the 433 handler's new nick.
 	// By default an underscore "_" is appended to the current nick.
 	NewNick func(string) string
@@ -216,16 +223,22 @@ func Client(cfg *Config) *Conn {
 		}
 	}
 
+	if cfg.Sasl != nil && !cfg.EnableCapabilityNegotiation {
+		logging.Warn("Enabling capability negotiation as it's required for SASL")
+		cfg.EnableCapabilityNegotiation = true
+	}
+
 	conn := &Conn{
-		cfg:           cfg,
-		dialer:        dialer,
-		intHandlers:   handlerSet(),
-		fgHandlers:    handlerSet(),
-		bgHandlers:    handlerSet(),
-		stRemovers:    make([]Remover, 0, len(stHandlers)),
-		lastsent:      time.Now(),
-		supportedCaps: capabilitySet(),
-		currCaps:      capabilitySet(),
+		cfg:               cfg,
+		dialer:            dialer,
+		intHandlers:       handlerSet(),
+		fgHandlers:        handlerSet(),
+		bgHandlers:        handlerSet(),
+		stRemovers:        make([]Remover, 0, len(stHandlers)),
+		lastsent:          time.Now(),
+		supportedCaps:     capabilitySet(),
+		currCaps:          capabilitySet(),
+		saslRemainingData: nil,
 	}
 	conn.addIntHandlers()
 	return conn
@@ -245,10 +258,9 @@ func (conn *Conn) Connected() bool {
 // affect client behaviour. To disable flood protection temporarily,
 // for example, a handler could do:
 //
-//     conn.Config().Flood = true
-//     // Send many lines to the IRC server, risking "excess flood"
-//     conn.Config().Flood = false
-//
+//	conn.Config().Flood = true
+//	// Send many lines to the IRC server, risking "excess flood"
+//	conn.Config().Flood = false
 func (conn *Conn) Config() *Config {
 	return conn.cfg
 }
diff --git a/client/handlers.go b/client/handlers.go
index d2317df..3920c02 100644
--- a/client/handlers.go
+++ b/client/handlers.go
@@ -9,19 +9,27 @@ import (
 	"sync"
 	"time"
 
+	"encoding/base64"
 	"github.com/fluffle/goirc/logging"
 )
 
+// saslCap is the IRCv3 capability used for SASL authentication.
+const saslCap = "sasl"
+
 // sets up the internal event handlers to do essential IRC protocol things
 var intHandlers = map[string]HandlerFunc{
-	REGISTER: (*Conn).h_REGISTER,
-	"001":    (*Conn).h_001,
-	"433":    (*Conn).h_433,
-	CTCP:     (*Conn).h_CTCP,
-	NICK:     (*Conn).h_NICK,
-	PING:     (*Conn).h_PING,
-	CAP:      (*Conn).h_CAP,
-	"410":    (*Conn).h_410,
+	REGISTER:     (*Conn).h_REGISTER,
+	"001":        (*Conn).h_001,
+	"433":        (*Conn).h_433,
+	CTCP:         (*Conn).h_CTCP,
+	NICK:         (*Conn).h_NICK,
+	PING:         (*Conn).h_PING,
+	CAP:          (*Conn).h_CAP,
+	"410":        (*Conn).h_410,
+	AUTHENTICATE: (*Conn).h_AUTHENTICATE,
+	"903":        (*Conn).h_903,
+	"904":        (*Conn).h_904,
+	"908":        (*Conn).h_908,
 }
 
 // set up the ircv3 capabilities supported by this client which will be requested by default to the server.
@@ -59,6 +67,11 @@ func (conn *Conn) getRequestCapabilities() *capSet {
 	// add capabilites supported by the client
 	s.Add(defaultCaps...)
 
+	if conn.cfg.Sasl != nil {
+		// add the SASL cap if enabled
+		s.Add(saslCap)
+	}
+
 	// add capabilites requested by the user
 	s.Add(conn.cfg.Capabilites...)
 
@@ -79,10 +92,31 @@ func (conn *Conn) negotiateCapabilities(supportedCaps []string) {
 }
 
 func (conn *Conn) handleCapAck(caps []string) {
+	gotSasl := false
 	for _, cap := range caps {
 		conn.currCaps.Add(cap)
+
+		if conn.cfg.Sasl != nil && cap == saslCap {
+			mech, ir, err := conn.cfg.Sasl.Start()
+
+			if err != nil {
+				logging.Warn("SASL authentication failed: %v", err)
+				continue
+			}
+
+			// TODO: when IRC 3.2 capability negotiation is supported, ensure the
+			// capability value is used to match the chosen mechanism
+
+			gotSasl = true
+			conn.saslRemainingData = ir
+
+			conn.Authenticate(mech)
+		}
+	}
+
+	if !gotSasl {
+		conn.Cap(CAP_END)
 	}
-	conn.Cap(CAP_END)
 }
 
 func (conn *Conn) handleCapNak(caps []string) {
@@ -181,6 +215,60 @@ func (conn *Conn) h_CAP(line *Line) {
 	}
 }
 
+// Handler for SASL authentication
+func (conn *Conn) h_AUTHENTICATE(line *Line) {
+	if conn.cfg.Sasl == nil {
+		return
+	}
+
+	if conn.saslRemainingData != nil {
+		data := "+" // plus sign representing empty data
+		if len(conn.saslRemainingData) > 0 {
+			data = base64.StdEncoding.EncodeToString(conn.saslRemainingData)
+		}
+
+		// TODO: batch data into chunks of 400 bytes per the spec
+
+		conn.Authenticate(data)
+		conn.saslRemainingData = nil
+		return
+	}
+
+	// TODO: handle data over 400 bytes long (which will be chunked into multiple messages per the spec)
+	challenge, err := base64.StdEncoding.DecodeString(line.Args[0])
+	if err != nil {
+		logging.Error("Failed to decode SASL challenge: %v", err)
+		return
+	}
+
+	response, err := conn.cfg.Sasl.Next(challenge)
+	if err != nil {
+		logging.Error("Failed to generate response for SASL challenge: %v", err)
+		return
+	}
+
+	// TODO: batch data into chunks of 400 bytes per the spec
+	data := base64.StdEncoding.EncodeToString(response)
+	conn.Authenticate(data)
+}
+
+// Handler for RPL_SASLSUCCESS.
+func (conn *Conn) h_903(line *Line) {
+	conn.Cap(CAP_END)
+}
+
+// Handler for RPL_SASLFAILURE.
+func (conn *Conn) h_904(line *Line) {
+	logging.Warn("SASL authentication failed")
+	conn.Cap(CAP_END)
+}
+
+// Handler for RPL_SASLMECHS.
+func (conn *Conn) h_908(line *Line) {
+	logging.Warn("SASL mechanism not supported, supported mechanisms are: %v", line.Args[1])
+	conn.Cap(CAP_END)
+}
+
 // Handler to trigger a CONNECTED event on receipt of numeric 001
 // :<server> 001 <nick> :Welcome message <nick>!<user>@<host>
 func (conn *Conn) h_001(line *Line) {
diff --git a/client/sasl_test.go b/client/sasl_test.go
new file mode 100644
index 0000000..073a074
--- /dev/null
+++ b/client/sasl_test.go
@@ -0,0 +1,103 @@
+package client
+
+import (
+	"github.com/emersion/go-sasl"
+	"testing"
+)
+
+func TestSaslPlainSuccessWorkflow(t *testing.T) {
+	c, s := setUp(t)
+	defer s.tearDown()
+
+	c.Config().Sasl = sasl.NewPlainClient("", "example", "password")
+	c.Config().EnableCapabilityNegotiation = true
+
+	c.h_REGISTER(&Line{Cmd: REGISTER})
+	s.nc.Expect("CAP LS")
+	s.nc.Expect("NICK test")
+	s.nc.Expect("USER test 12 * :Testing IRC")
+	s.nc.Send("CAP * LS :sasl foobar")
+	s.nc.Expect("CAP REQ :sasl")
+	s.nc.Send("CAP * ACK :sasl")
+	s.nc.Expect("AUTHENTICATE PLAIN")
+	s.nc.Send("AUTHENTICATE +")
+	s.nc.Expect("AUTHENTICATE AGV4YW1wbGUAcGFzc3dvcmQ=")
+	s.nc.Send("904 test :SASL authentication successful")
+	s.nc.Expect("CAP END")
+}
+
+func TestSaslPlainWrongPassword(t *testing.T) {
+	c, s := setUp(t)
+	defer s.tearDown()
+
+	c.Config().Sasl = sasl.NewPlainClient("", "example", "password")
+	c.Config().EnableCapabilityNegotiation = true
+
+	c.h_REGISTER(&Line{Cmd: REGISTER})
+	s.nc.Expect("CAP LS")
+	s.nc.Expect("NICK test")
+	s.nc.Expect("USER test 12 * :Testing IRC")
+	s.nc.Send("CAP * LS :sasl foobar")
+	s.nc.Expect("CAP REQ :sasl")
+	s.nc.Send("CAP * ACK :sasl")
+	s.nc.Expect("AUTHENTICATE PLAIN")
+	s.nc.Send("AUTHENTICATE +")
+	s.nc.Expect("AUTHENTICATE AGV4YW1wbGUAcGFzc3dvcmQ=")
+	s.nc.Send("904 test :SASL authentication failed")
+	s.nc.Expect("CAP END")
+}
+
+func TestSaslExternalSuccessWorkflow(t *testing.T) {
+	c, s := setUp(t)
+	defer s.tearDown()
+
+	c.Config().Sasl = sasl.NewExternalClient("")
+	c.Config().EnableCapabilityNegotiation = true
+
+	c.h_REGISTER(&Line{Cmd: REGISTER})
+	s.nc.Expect("CAP LS")
+	s.nc.Expect("NICK test")
+	s.nc.Expect("USER test 12 * :Testing IRC")
+	s.nc.Send("CAP * LS :sasl foobar")
+	s.nc.Expect("CAP REQ :sasl")
+	s.nc.Send("CAP * ACK :sasl")
+	s.nc.Expect("AUTHENTICATE EXTERNAL")
+	s.nc.Send("AUTHENTICATE +")
+	s.nc.Expect("AUTHENTICATE +")
+	s.nc.Send("904 test :SASL authentication successful")
+	s.nc.Expect("CAP END")
+}
+
+func TestSaslNoSaslCap(t *testing.T) {
+	c, s := setUp(t)
+	defer s.tearDown()
+
+	c.Config().Sasl = sasl.NewPlainClient("", "example", "password")
+	c.Config().EnableCapabilityNegotiation = true
+
+	c.h_REGISTER(&Line{Cmd: REGISTER})
+	s.nc.Expect("CAP LS")
+	s.nc.Expect("NICK test")
+	s.nc.Expect("USER test 12 * :Testing IRC")
+	s.nc.Send("CAP * LS :foobar")
+	s.nc.Expect("CAP END")
+}
+
+func TestSaslUnsupportedMechanism(t *testing.T) {
+	c, s := setUp(t)
+	defer s.tearDown()
+
+	c.Config().Sasl = sasl.NewPlainClient("", "example", "password")
+	c.Config().EnableCapabilityNegotiation = true
+
+	c.h_REGISTER(&Line{Cmd: REGISTER})
+	s.nc.Expect("CAP LS")
+	s.nc.Expect("NICK test")
+	s.nc.Expect("USER test 12 * :Testing IRC")
+	s.nc.Send("CAP * LS :sasl foobar")
+	s.nc.Expect("CAP REQ :sasl")
+	s.nc.Send("CAP * ACK :sasl")
+	s.nc.Expect("AUTHENTICATE PLAIN")
+	s.nc.Send("908 test external :are available SASL mechanisms")
+	s.nc.Expect("CAP END")
+}
diff --git a/go.mod b/go.mod
index c9d6b80..5c388bb 100644
--- a/go.mod
+++ b/go.mod
@@ -1,6 +1,7 @@
 module github.com/fluffle/goirc
 
 require (
+	github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead
 	github.com/golang/mock v1.5.0
 	golang.org/x/net v0.0.0-20210119194325-5f4716e94777
 )
diff --git a/go.sum b/go.sum
index 4a29870..d8b7012 100644
--- a/go.sum
+++ b/go.sum
@@ -1,9 +1,10 @@
+github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead h1:fI1Jck0vUrXT8bnphprS1EoVRe2Q5CKCX8iDlpqjQ/Y=
+github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ=
 github.com/golang/mock v1.5.0 h1:jlYHihg//f7RRwuPfptm04yp4s7O6Kw8EZiVYIGcH0g=
 github.com/golang/mock v1.5.0/go.mod h1:CWnOUgYIOo4TcNZ0wHX3YZCqsaM1I1Jvs6v3mP3KVu8=
 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
 golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
 golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
-golang.org/x/net v0.0.0-20180926154720-4dfa2610cdf3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
 golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
 golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
 golang.org/x/net v0.0.0-20210119194325-5f4716e94777 h1:003p0dJM77cxMSyCPFphvZf/Y5/NXf5fzg6ufd1/Oew=