diff --git a/client/commands.go b/client/commands.go index dac7059..f5d4c58 100644 --- a/client/commands.go +++ b/client/commands.go @@ -10,6 +10,7 @@ const ( CONNECTED = "CONNECTED" DISCONNECTED = "DISCONNECTED" ACTION = "ACTION" + AUTHENTICATE = "AUTHENTICATE" AWAY = "AWAY" CAP = "CAP" CTCP = "CTCP" @@ -322,3 +323,8 @@ func (conn *Conn) Cap(subcommmand string, capabilities ...string) { } } } + +// Authenticate send an AUTHENTICATE command to the server. +func (conn *Conn) Authenticate(message string) { + conn.Raw(AUTHENTICATE + " " + message) +} diff --git a/client/connection.go b/client/connection.go index 8ed5a84..722fb12 100644 --- a/client/connection.go +++ b/client/connection.go @@ -101,6 +101,9 @@ type Config struct { // A list of capabilities to request to the server during registration. Capabilites []string + // SASL configuration to use to authenticate the connection. + Sasl *SaslAuthenticator + // Replaceable function to customise the 433 handler's new nick. // By default an underscore "_" is appended to the current nick. NewNick func(string) string @@ -216,6 +219,11 @@ func Client(cfg *Config) *Conn { } } + if cfg.Sasl != nil && !cfg.EnableCapabilityNegotiation { + logging.Warn("Enabling capability negotiation as it's required for SASL") + cfg.EnableCapabilityNegotiation = true + } + conn := &Conn{ cfg: cfg, dialer: dialer, diff --git a/client/handlers.go b/client/handlers.go index d2317df..b253acb 100644 --- a/client/handlers.go +++ b/client/handlers.go @@ -14,14 +14,17 @@ import ( // sets up the internal event handlers to do essential IRC protocol things var intHandlers = map[string]HandlerFunc{ - REGISTER: (*Conn).h_REGISTER, - "001": (*Conn).h_001, - "433": (*Conn).h_433, - CTCP: (*Conn).h_CTCP, - NICK: (*Conn).h_NICK, - PING: (*Conn).h_PING, - CAP: (*Conn).h_CAP, - "410": (*Conn).h_410, + REGISTER: (*Conn).h_REGISTER, + "001": (*Conn).h_001, + "433": (*Conn).h_433, + CTCP: (*Conn).h_CTCP, + NICK: (*Conn).h_NICK, + PING: (*Conn).h_PING, + CAP: (*Conn).h_CAP, + "410": (*Conn).h_410, + AUTHENTICATE: (*Conn).h_AUTHENTICATE, + "903": (*Conn).h_903, + "904": (*Conn).h_904, } // set up the ircv3 capabilities supported by this client which will be requested by default to the server. @@ -59,6 +62,11 @@ func (conn *Conn) getRequestCapabilities() *capSet { // add capabilites supported by the client s.Add(defaultCaps...) + if conn.cfg.Sasl != nil { + // add the SASL cap if enabled + s.Add(saslCap) + } + // add capabilites requested by the user s.Add(conn.cfg.Capabilites...) @@ -79,10 +87,19 @@ func (conn *Conn) negotiateCapabilities(supportedCaps []string) { } func (conn *Conn) handleCapAck(caps []string) { + gotSasl := false for _, cap := range caps { conn.currCaps.Add(cap) + + if conn.cfg.Sasl != nil && cap == saslCap { + gotSasl = true + conn.Authenticate(string(conn.cfg.Sasl.mechanism)) + } + } + + if !gotSasl { + conn.Cap(CAP_END) } - conn.Cap(CAP_END) } func (conn *Conn) handleCapNak(caps []string) { @@ -181,6 +198,32 @@ func (conn *Conn) h_CAP(line *Line) { } } +// Handler for SASL authentication +func (conn *Conn) h_AUTHENTICATE(line *Line) { + if conn.cfg.Sasl == nil { + return + } + + if line.Args[0] != "+" { + return + } + + // start authentication + conn.Authenticate(conn.cfg.Sasl.authenticationRequest()) +} + +// Handler for RPL_SASLSUCCESS. +func (conn *Conn) h_903(line *Line) { + conn.Cap(CAP_END) +} + +// Handler for RPL_SASLFAILURE. +func (conn *Conn) h_904(line *Line) { + // TODO: do something about this? + logging.Warn("SASL authentication failed") + conn.Cap(CAP_END) +} + // Handler to trigger a CONNECTED event on receipt of numeric 001 // : 001 :Welcome message !@ func (conn *Conn) h_001(line *Line) { diff --git a/client/sasl.go b/client/sasl.go new file mode 100644 index 0000000..3bd7d9d --- /dev/null +++ b/client/sasl.go @@ -0,0 +1,43 @@ +package client + +import ( + "encoding/base64" +) + +// saslMechanism is the name of the SASL authentication mechanism used. +type saslMechanism string + +const ( + // saslPlain is the username and password based PLAIN + // authentication mechanism. + saslPlain saslMechanism = "PLAIN" +) + +// saslCap is the IRCv3 capability used for SASL authentication. +const saslCap = "sasl" + +// SaslAuthenticator authenticates the connection using SASL in the +// connection phase. +type SaslAuthenticator struct { + mechanism saslMechanism + authenticationRequest func() string +} + +func encodePlainUsernamePassword(username, password string) string { + requestBytes := []byte(username) + requestBytes = append(requestBytes, byte(0)) + requestBytes = append(requestBytes, []byte(username)...) + requestBytes = append(requestBytes, byte(0)) + requestBytes = append(requestBytes, []byte(password)...) + + return base64.StdEncoding.EncodeToString(requestBytes) +} + +func SaslPlain(username, password string) *SaslAuthenticator { + return &SaslAuthenticator{ + mechanism: saslPlain, + authenticationRequest: func() string { + return encodePlainUsernamePassword(username, password) + }, + } +} diff --git a/client/sasl_test.go b/client/sasl_test.go new file mode 100644 index 0000000..415c3bf --- /dev/null +++ b/client/sasl_test.go @@ -0,0 +1,26 @@ +package client + +import ( + "testing" +) + +func TestSaslPlainWorkflow(t *testing.T) { + c, s := setUp(t) + defer s.tearDown() + + c.Config().Sasl = SaslPlain("test", "password") + c.Config().EnableCapabilityNegotiation = true + + c.h_REGISTER(&Line{Cmd: REGISTER}) + s.nc.Expect("CAP LS") + s.nc.Expect("NICK test") + s.nc.Expect("USER test 12 * :Testing IRC") + s.nc.Send("CAP * LS :sasl foobar") + s.nc.Expect("CAP REQ :sasl") + s.nc.Send("CAP * ACK :sasl") + s.nc.Expect("AUTHENTICATE PLAIN") + s.nc.Send("AUTHENTICATE +") + s.nc.Expect("AUTHENTICATE dGVzdAB0ZXN0AHBhc3N3b3Jk") + s.nc.Send("904 test :SASL authentication successful") + s.nc.Expect("CAP END") +}