From 353e906a7bede80eb956912a40ef0756345f1d31 Mon Sep 17 00:00:00 2001 From: ap4y Date: Mon, 11 Jan 2016 20:52:06 +1300 Subject: [PATCH] Introduce AuthContext struct and refactor auth api using it --- auth.go | 64 ++++++++++++++++++++++++++++------------------------ auth_test.go | 61 +++++++++++++++++++++++++++++++++++++------------ 2 files changed, 81 insertions(+), 44 deletions(-) diff --git a/auth.go b/auth.go index 2f194c0..7811e2a 100644 --- a/auth.go +++ b/auth.go @@ -6,9 +6,9 @@ import ( ) const ( - noAuth = uint8(0) + NoAuth = uint8(0) noAcceptable = uint8(255) - userPassAuth = uint8(2) + UserPassAuth = uint8(2) userAuthVersion = uint8(1) authSuccess = uint8(0) authFailure = uint8(1) @@ -19,21 +19,32 @@ var ( NoSupportedAuth = fmt.Errorf("No supported authentication mechanism") ) +// A Request encapsulates authentication state provided +// during negotiation +type AuthContext struct { + // Provided auth method + Method uint8 + // Payload provided during negotiation. + // Keys depend on the used auth method. + // For UserPassauth contains Username + Payload map[string]string +} + type Authenticator interface { - Authenticate(reader io.Reader, writer io.Writer) error + Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) GetCode() uint8 } // NoAuthAuthenticator is used to handle the "No Authentication" mode -type NoAuthAuthenticator struct {} +type NoAuthAuthenticator struct{} func (a NoAuthAuthenticator) GetCode() uint8 { - return noAuth + return NoAuth } -func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer) error { - _, err := writer.Write([]byte{socks5Version, noAuth}) - return err +func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) { + _, err := writer.Write([]byte{socks5Version, NoAuth}) + return &AuthContext{NoAuth, nil}, err } // UserPassAuthenticator is used to handle username/password based @@ -43,70 +54,67 @@ type UserPassAuthenticator struct { } func (a UserPassAuthenticator) GetCode() uint8 { - return userPassAuth + return UserPassAuth } -func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) error { +func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) { // Tell the client to use user/pass auth - if _, err := writer.Write([]byte{socks5Version, userPassAuth}); err != nil { - return err + if _, err := writer.Write([]byte{socks5Version, UserPassAuth}); err != nil { + return nil, err } // Get the version and username length header := []byte{0, 0} if _, err := io.ReadAtLeast(reader, header, 2); err != nil { - return err + return nil, err } // Ensure we are compatible if header[0] != userAuthVersion { - return fmt.Errorf("Unsupported auth version: %v", header[0]) + return nil, fmt.Errorf("Unsupported auth version: %v", header[0]) } // Get the user name userLen := int(header[1]) user := make([]byte, userLen) if _, err := io.ReadAtLeast(reader, user, userLen); err != nil { - return err + return nil, err } // Get the password length if _, err := reader.Read(header[:1]); err != nil { - return err + return nil, err } // Get the password passLen := int(header[0]) pass := make([]byte, passLen) if _, err := io.ReadAtLeast(reader, pass, passLen); err != nil { - return err + return nil, err } // Verify the password if a.Credentials.Valid(string(user), string(pass)) { if _, err := writer.Write([]byte{userAuthVersion, authSuccess}); err != nil { - return err + return nil, err } } else { if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil { - return err + return nil, err } - return UserAuthFailed + return nil, UserAuthFailed } // Done - return nil - + return &AuthContext{UserPassAuth, map[string]string{"Username": string(user)}}, nil } - - // authenticate is used to handle connection authentication -func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) error { +func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext, error) { // Get the methods methods, err := readMethods(bufConn) if err != nil { - return fmt.Errorf("Failed to get auth methods: %v", err) + return nil, fmt.Errorf("Failed to get auth methods: %v", err) } // Select a usable method @@ -118,11 +126,9 @@ func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) error { } // No usable method found - return noAcceptableAuth(conn) + return nil, noAcceptableAuth(conn) } - - // noAcceptableAuth is used to handle when we have no eligible // authentication mechanism func noAcceptableAuth(conn io.Writer) error { diff --git a/auth_test.go b/auth_test.go index f0ea2a5..f782f4a 100644 --- a/auth_test.go +++ b/auth_test.go @@ -7,47 +7,66 @@ import ( func TestNoAuth(t *testing.T) { req := bytes.NewBuffer(nil) - req.Write([]byte{1, noAuth}) + req.Write([]byte{1, NoAuth}) var resp bytes.Buffer s, _ := New(&Config{}) - if err := s.authenticate(&resp, req); err != nil { + ctx, err := s.authenticate(&resp, req) + if err != nil { t.Fatalf("err: %v", err) } + if ctx.Method != NoAuth { + t.Fatal("Invalid Context Method") + } + out := resp.Bytes() - if !bytes.Equal(out, []byte{socks5Version, noAuth}) { + if !bytes.Equal(out, []byte{socks5Version, NoAuth}) { t.Fatalf("bad: %v", out) } } func TestPasswordAuth_Valid(t *testing.T) { req := bytes.NewBuffer(nil) - req.Write([]byte{2, noAuth, userPassAuth}) + req.Write([]byte{2, NoAuth, UserPassAuth}) req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'}) var resp bytes.Buffer cred := StaticCredentials{ "foo": "bar", } - + cator := UserPassAuthenticator{Credentials: cred} - s, _ := New(&Config{AuthMethods:[]Authenticator{cator}}) + s, _ := New(&Config{AuthMethods: []Authenticator{cator}}) - if err := s.authenticate(&resp, req); err != nil { + ctx, err := s.authenticate(&resp, req) + if err != nil { t.Fatalf("err: %v", err) } + if ctx.Method != UserPassAuth { + t.Fatal("Invalid Context Method") + } + + val, ok := ctx.Payload["Username"] + if !ok { + t.Fatal("Missing key Username in auth context's payload") + } + + if val != "foo" { + t.Fatal("Invalid Username in auth context's payload") + } + out := resp.Bytes() - if !bytes.Equal(out, []byte{socks5Version, userPassAuth, 1, authSuccess}) { + if !bytes.Equal(out, []byte{socks5Version, UserPassAuth, 1, authSuccess}) { t.Fatalf("bad: %v", out) } } func TestPasswordAuth_Invalid(t *testing.T) { req := bytes.NewBuffer(nil) - req.Write([]byte{2, noAuth, userPassAuth}) + req.Write([]byte{2, NoAuth, UserPassAuth}) req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'z'}) var resp bytes.Buffer @@ -55,20 +74,26 @@ func TestPasswordAuth_Invalid(t *testing.T) { "foo": "bar", } cator := UserPassAuthenticator{Credentials: cred} - s, _ := New(&Config{AuthMethods:[]Authenticator{cator}}) - if err := s.authenticate(&resp, req); err != UserAuthFailed { + s, _ := New(&Config{AuthMethods: []Authenticator{cator}}) + + ctx, err := s.authenticate(&resp, req) + if err != UserAuthFailed { t.Fatalf("err: %v", err) } + if ctx != nil { + t.Fatal("Invalid Context Method") + } + out := resp.Bytes() - if !bytes.Equal(out, []byte{socks5Version, userPassAuth, 1, authFailure}) { + if !bytes.Equal(out, []byte{socks5Version, UserPassAuth, 1, authFailure}) { t.Fatalf("bad: %v", out) } } func TestNoSupportedAuth(t *testing.T) { req := bytes.NewBuffer(nil) - req.Write([]byte{1, noAuth}) + req.Write([]byte{1, NoAuth}) var resp bytes.Buffer cred := StaticCredentials{ @@ -76,11 +101,17 @@ func TestNoSupportedAuth(t *testing.T) { } cator := UserPassAuthenticator{Credentials: cred} - s, _ := New(&Config{AuthMethods:[]Authenticator{cator}}) - if err := s.authenticate(&resp, req); err != NoSupportedAuth { + s, _ := New(&Config{AuthMethods: []Authenticator{cator}}) + + ctx, err := s.authenticate(&resp, req) + if err != NoSupportedAuth { t.Fatalf("err: %v", err) } + if ctx != nil { + t.Fatal("Invalid Context Method") + } + out := resp.Bytes() if !bytes.Equal(out, []byte{socks5Version, noAcceptable}) { t.Fatalf("bad: %v", out)