diff --git a/auth.go b/auth.go index 3494824..2f194c0 100644 --- a/auth.go +++ b/auth.go @@ -19,42 +19,42 @@ var ( NoSupportedAuth = fmt.Errorf("No supported authentication mechanism") ) -// authenticate is used to handle connection authentication -func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) error { - // Get the methods - methods, err := readMethods(bufConn) - if err != nil { - return fmt.Errorf("Failed to get auth methods: %v", err) - } - - // Determine what is supported - supportUserPass := s.config.Credentials != nil - - // Select a usable method - for _, method := range methods { - if method == noAuth && !supportUserPass { - return noAuthMode(conn) - } - if method == userPassAuth && supportUserPass { - return s.userPassAuth(conn, bufConn) - } - } - - // No usable method found - return noAcceptableAuth(conn) +type Authenticator interface { + Authenticate(reader io.Reader, writer io.Writer) error + GetCode() uint8 } -// userPassAuth is used to handle username/password based +// NoAuthAuthenticator is used to handle the "No Authentication" mode +type NoAuthAuthenticator struct {} + +func (a NoAuthAuthenticator) GetCode() uint8 { + return noAuth +} + +func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer) error { + _, err := writer.Write([]byte{socks5Version, noAuth}) + return err +} + +// UserPassAuthenticator is used to handle username/password based // authentication -func (s *Server) userPassAuth(conn io.Writer, bufConn io.Reader) error { +type UserPassAuthenticator struct { + Credentials CredentialStore +} + +func (a UserPassAuthenticator) GetCode() uint8 { + return userPassAuth +} + +func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) error { // Tell the client to use user/pass auth - if _, err := conn.Write([]byte{socks5Version, userPassAuth}); err != nil { + if _, err := writer.Write([]byte{socks5Version, userPassAuth}); err != nil { return err } // Get the version and username length header := []byte{0, 0} - if _, err := io.ReadAtLeast(bufConn, header, 2); err != nil { + if _, err := io.ReadAtLeast(reader, header, 2); err != nil { return err } @@ -66,29 +66,29 @@ func (s *Server) userPassAuth(conn io.Writer, bufConn io.Reader) error { // Get the user name userLen := int(header[1]) user := make([]byte, userLen) - if _, err := io.ReadAtLeast(bufConn, user, userLen); err != nil { + if _, err := io.ReadAtLeast(reader, user, userLen); err != nil { return err } // Get the password length - if _, err := bufConn.Read(header[:1]); err != nil { + if _, err := reader.Read(header[:1]); err != nil { return err } // Get the password passLen := int(header[0]) pass := make([]byte, passLen) - if _, err := io.ReadAtLeast(bufConn, pass, passLen); err != nil { + if _, err := io.ReadAtLeast(reader, pass, passLen); err != nil { return err } // Verify the password - if s.config.Credentials.Valid(string(user), string(pass)) { - if _, err := conn.Write([]byte{userAuthVersion, authSuccess}); err != nil { + if a.Credentials.Valid(string(user), string(pass)) { + if _, err := writer.Write([]byte{userAuthVersion, authSuccess}); err != nil { return err } } else { - if _, err := conn.Write([]byte{userAuthVersion, authFailure}); err != nil { + if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil { return err } return UserAuthFailed @@ -96,14 +96,33 @@ func (s *Server) userPassAuth(conn io.Writer, bufConn io.Reader) error { // Done return nil + } -// noAuth is used to handle the "No Authentication" mode -func noAuthMode(conn io.Writer) error { - _, err := conn.Write([]byte{socks5Version, noAuth}) - return err + + +// authenticate is used to handle connection authentication +func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) error { + // Get the methods + methods, err := readMethods(bufConn) + if err != nil { + return fmt.Errorf("Failed to get auth methods: %v", err) + } + + // Select a usable method + for _, method := range methods { + cator, found := s.authMethods[method] + if found { + return cator.Authenticate(bufConn, conn) + } + } + + // No usable method found + return noAcceptableAuth(conn) } + + // noAcceptableAuth is used to handle when we have no eligible // authentication mechanism func noAcceptableAuth(conn io.Writer) error { diff --git a/auth_test.go b/auth_test.go index f45697b..f0ea2a5 100644 --- a/auth_test.go +++ b/auth_test.go @@ -10,7 +10,7 @@ func TestNoAuth(t *testing.T) { req.Write([]byte{1, noAuth}) var resp bytes.Buffer - s := &Server{config: &Config{}} + s, _ := New(&Config{}) if err := s.authenticate(&resp, req); err != nil { t.Fatalf("err: %v", err) } @@ -30,7 +30,11 @@ func TestPasswordAuth_Valid(t *testing.T) { cred := StaticCredentials{ "foo": "bar", } - s := &Server{config: &Config{Credentials: cred}} + + cator := UserPassAuthenticator{Credentials: cred} + + s, _ := New(&Config{AuthMethods:[]Authenticator{cator}}) + if err := s.authenticate(&resp, req); err != nil { t.Fatalf("err: %v", err) } @@ -50,7 +54,8 @@ func TestPasswordAuth_Invalid(t *testing.T) { cred := StaticCredentials{ "foo": "bar", } - s := &Server{config: &Config{Credentials: cred}} + cator := UserPassAuthenticator{Credentials: cred} + s, _ := New(&Config{AuthMethods:[]Authenticator{cator}}) if err := s.authenticate(&resp, req); err != UserAuthFailed { t.Fatalf("err: %v", err) } @@ -69,7 +74,9 @@ func TestNoSupportedAuth(t *testing.T) { cred := StaticCredentials{ "foo": "bar", } - s := &Server{config: &Config{Credentials: cred}} + cator := UserPassAuthenticator{Credentials: cred} + + s, _ := New(&Config{AuthMethods:[]Authenticator{cator}}) if err := s.authenticate(&resp, req); err != NoSupportedAuth { t.Fatalf("err: %v", err) } diff --git a/socks5.go b/socks5.go index 0df5a90..a1334b8 100644 --- a/socks5.go +++ b/socks5.go @@ -13,9 +13,9 @@ const ( // Config is used to setup and configure a Server type Config struct { - // If provided, username/password authentication is enabled - // otherwise, non-authenticated mode is allowed - Credentials CredentialStore + // AuthMethods can be provided to implement custom authentication + // By default, "auth-less" mode is enabled. For password-based auth use UserPassAuthenticator. + AuthMethods []Authenticator // Resolver can be provided to do custom name resolution. // Defaults to DNSResolver if not provided. @@ -38,10 +38,16 @@ type Config struct { // the details of the SOCKS5 protocol type Server struct { config *Config + authMethods map[uint8]Authenticator } // New creates a new Server and potentially returns an error func New(conf *Config) (*Server, error) { + // Ensure we have at least one authentication method enabled + if conf.AuthMethods == nil || len(conf.AuthMethods) == 0 { + conf.AuthMethods = []Authenticator{&NoAuthAuthenticator{}} + } + // Ensure we have a DNS resolver if conf.Resolver == nil { conf.Resolver = DNSResolver{} @@ -55,6 +61,13 @@ func New(conf *Config) (*Server, error) { server := &Server{ config: conf, } + + server.authMethods = make(map[uint8]Authenticator) + + for _, a := range conf.AuthMethods { + server.authMethods[a.GetCode()] = a + } + return server, nil } diff --git a/socks5_test.go b/socks5_test.go index 4fa724f..f78fc88 100644 --- a/socks5_test.go +++ b/socks5_test.go @@ -38,8 +38,9 @@ func TestSOCKS5_Connect(t *testing.T) { creds := StaticCredentials{ "foo": "bar", } + cator := UserPassAuthenticator{Credentials : creds} conf := &Config{ - Credentials: creds, + AuthMethods : []Authenticator{cator}, } serv, err := New(conf) if err != nil {