diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..3494824 --- /dev/null +++ b/auth.go @@ -0,0 +1,126 @@ +package socks5 + +import ( + "fmt" + "io" +) + +const ( + noAuth = uint8(0) + noAcceptable = uint8(255) + userPassAuth = uint8(2) + userAuthVersion = uint8(1) + authSuccess = uint8(0) + authFailure = uint8(1) +) + +var ( + UserAuthFailed = fmt.Errorf("User authentication failed") + NoSupportedAuth = fmt.Errorf("No supported authentication mechanism") +) + +// authenticate is used to handle connection authentication +func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) error { + // Get the methods + methods, err := readMethods(bufConn) + if err != nil { + return fmt.Errorf("Failed to get auth methods: %v", err) + } + + // Determine what is supported + supportUserPass := s.config.Credentials != nil + + // Select a usable method + for _, method := range methods { + if method == noAuth && !supportUserPass { + return noAuthMode(conn) + } + if method == userPassAuth && supportUserPass { + return s.userPassAuth(conn, bufConn) + } + } + + // No usable method found + return noAcceptableAuth(conn) +} + +// userPassAuth is used to handle username/password based +// authentication +func (s *Server) userPassAuth(conn io.Writer, bufConn io.Reader) error { + // Tell the client to use user/pass auth + if _, err := conn.Write([]byte{socks5Version, userPassAuth}); err != nil { + return err + } + + // Get the version and username length + header := []byte{0, 0} + if _, err := io.ReadAtLeast(bufConn, header, 2); err != nil { + return err + } + + // Ensure we are compatible + if header[0] != userAuthVersion { + return fmt.Errorf("Unsupported auth version: %v", header[0]) + } + + // Get the user name + userLen := int(header[1]) + user := make([]byte, userLen) + if _, err := io.ReadAtLeast(bufConn, user, userLen); err != nil { + return err + } + + // Get the password length + if _, err := bufConn.Read(header[:1]); err != nil { + return err + } + + // Get the password + passLen := int(header[0]) + pass := make([]byte, passLen) + if _, err := io.ReadAtLeast(bufConn, pass, passLen); err != nil { + return err + } + + // Verify the password + if s.config.Credentials.Valid(string(user), string(pass)) { + if _, err := conn.Write([]byte{userAuthVersion, authSuccess}); err != nil { + return err + } + } else { + if _, err := conn.Write([]byte{userAuthVersion, authFailure}); err != nil { + return err + } + return UserAuthFailed + } + + // Done + return nil +} + +// noAuth is used to handle the "No Authentication" mode +func noAuthMode(conn io.Writer) error { + _, err := conn.Write([]byte{socks5Version, noAuth}) + return err +} + +// noAcceptableAuth is used to handle when we have no eligible +// authentication mechanism +func noAcceptableAuth(conn io.Writer) error { + conn.Write([]byte{socks5Version, noAcceptable}) + return NoSupportedAuth +} + +// readMethods is used to read the number of methods +// and proceeding auth methods +func readMethods(r io.Reader) ([]byte, error) { + header := []byte{0} + if _, err := r.Read(header); err != nil { + return nil, err + } + + numMethods := int(header[0]) + methods := make([]byte, numMethods) + _, err := io.ReadAtLeast(r, methods, numMethods) + return methods, err +} diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 0000000..f45697b --- /dev/null +++ b/auth_test.go @@ -0,0 +1,81 @@ +package socks5 + +import ( + "bytes" + "testing" +) + +func TestNoAuth(t *testing.T) { + req := bytes.NewBuffer(nil) + req.Write([]byte{1, noAuth}) + var resp bytes.Buffer + + s := &Server{config: &Config{}} + if err := s.authenticate(&resp, req); err != nil { + t.Fatalf("err: %v", err) + } + + out := resp.Bytes() + if !bytes.Equal(out, []byte{socks5Version, noAuth}) { + t.Fatalf("bad: %v", out) + } +} + +func TestPasswordAuth_Valid(t *testing.T) { + req := bytes.NewBuffer(nil) + req.Write([]byte{2, noAuth, userPassAuth}) + req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'}) + var resp bytes.Buffer + + cred := StaticCredentials{ + "foo": "bar", + } + s := &Server{config: &Config{Credentials: cred}} + if err := s.authenticate(&resp, req); err != nil { + t.Fatalf("err: %v", err) + } + + out := resp.Bytes() + if !bytes.Equal(out, []byte{socks5Version, userPassAuth, 1, authSuccess}) { + t.Fatalf("bad: %v", out) + } +} + +func TestPasswordAuth_Invalid(t *testing.T) { + req := bytes.NewBuffer(nil) + req.Write([]byte{2, noAuth, userPassAuth}) + req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'z'}) + var resp bytes.Buffer + + cred := StaticCredentials{ + "foo": "bar", + } + s := &Server{config: &Config{Credentials: cred}} + if err := s.authenticate(&resp, req); err != UserAuthFailed { + t.Fatalf("err: %v", err) + } + + out := resp.Bytes() + if !bytes.Equal(out, []byte{socks5Version, userPassAuth, 1, authFailure}) { + t.Fatalf("bad: %v", out) + } +} + +func TestNoSupportedAuth(t *testing.T) { + req := bytes.NewBuffer(nil) + req.Write([]byte{1, noAuth}) + var resp bytes.Buffer + + cred := StaticCredentials{ + "foo": "bar", + } + s := &Server{config: &Config{Credentials: cred}} + if err := s.authenticate(&resp, req); err != NoSupportedAuth { + t.Fatalf("err: %v", err) + } + + out := resp.Bytes() + if !bytes.Equal(out, []byte{socks5Version, noAcceptable}) { + t.Fatalf("bad: %v", out) + } +}