Refactor out proxy dialing code; store contextless proxy.Dialer.

This commit is contained in:
Alex Bramley 2023-11-23 20:44:32 +00:00
parent d655f8950c
commit b1a6e3a286
1 changed files with 28 additions and 28 deletions

View File

@ -4,7 +4,6 @@ import (
"bufio" "bufio"
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -13,7 +12,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/emersion/go-sasl" sasl "github.com/emersion/go-sasl"
"github.com/fluffle/goirc/logging" "github.com/fluffle/goirc/logging"
"github.com/fluffle/goirc/state" "github.com/fluffle/goirc/state"
"golang.org/x/net/proxy" "golang.org/x/net/proxy"
@ -39,7 +38,7 @@ type Conn struct {
// I/O stuff to server // I/O stuff to server
dialer *net.Dialer dialer *net.Dialer
proxyDialer proxy.ContextDialer proxyDialer proxy.Dialer
sock net.Conn sock net.Conn
io *bufio.ReadWriter io *bufio.ReadWriter
in chan *Line in chan *Line
@ -404,34 +403,13 @@ func (conn *Conn) internalConnect(ctx context.Context) error {
} }
if conn.cfg.Proxy != "" { if conn.cfg.Proxy != "" {
proxyURL, err := url.Parse(conn.cfg.Proxy) s, err := conn.dialProxy(ctx)
if err != nil { if err != nil {
logging.Info("irc.Connect(): Connecting via proxy %q: %v",
conn.cfg.Proxy, err)
return err return err
} }
proxyDialer, err := proxy.FromURL(proxyURL, conn.dialer) conn.sock = s
if err != nil {
return err
}
contextProxyDialer, ok := proxyDialer.(proxy.ContextDialer)
if ok {
conn.proxyDialer = contextProxyDialer
logging.Info("irc.Connect(): Connecting to %s.", conn.cfg.Server)
if s, err := conn.proxyDialer.DialContext(ctx, "tcp", conn.cfg.Server); err == nil {
conn.sock = s
} else {
return err
}
} else {
logging.Warn("Dialer for proxy does not support context, please implement DialContext")
conn.proxyDialer = proxyDialer
logging.Info("irc.Connect(): Connecting to %s.", conn.cfg.Server)
if s, err := conn.proxyDialer.Dial("tcp", conn.cfg.Server); err == nil {
conn.sock = s
} else {
return err
}
}
} else { } else {
logging.Info("irc.Connect(): Connecting to %s.", conn.cfg.Server) logging.Info("irc.Connect(): Connecting to %s.", conn.cfg.Server)
if s, err := conn.dialer.DialContext(ctx, "tcp", conn.cfg.Server); err == nil { if s, err := conn.dialer.DialContext(ctx, "tcp", conn.cfg.Server); err == nil {
@ -455,6 +433,28 @@ func (conn *Conn) internalConnect(ctx context.Context) error {
return nil return nil
} }
// dialProxy handles dialling via a proxy
func (conn *Conn) dialProxy(ctx context.Context) (net.Conn, error) {
proxyURL, err := url.Parse(conn.cfg.Proxy)
if err != nil {
return nil, fmt.Errorf("parsing url: %v", err)
}
proxyDialer, err := proxy.FromURL(proxyURL, conn.dialer)
if err != nil {
return nil, fmt.Errorf("creating dialer: %v", err)
}
conn.proxyDialer = proxyDialer
contextProxyDialer, ok := proxyDialer.(proxy.ContextDialer)
if ok {
logging.Info("irc.Connect(): Connecting to %s.", conn.cfg.Server)
return contextProxyDialer.DialContext(ctx, "tcp", conn.cfg.Server)
} else {
logging.Warn("Dialer for proxy does not support context, please implement DialContext")
logging.Info("irc.Connect(): Connecting to %s.", conn.cfg.Server)
return conn.proxyDialer.Dial("tcp", conn.cfg.Server)
}
}
// postConnect performs post-connection setup, for ease of testing. // postConnect performs post-connection setup, for ease of testing.
func (conn *Conn) postConnect(ctx context.Context, start bool) { func (conn *Conn) postConnect(ctx context.Context, start bool) {
conn.io = bufio.NewReadWriter( conn.io = bufio.NewReadWriter(