1
0
Fork 0

Introduce AuthContext struct and refactor auth api using it

logger
ap4y 2016-01-11 20:52:06 +13:00
parent aad92b1d46
commit 353e906a7b
2 changed files with 81 additions and 44 deletions

64
auth.go
View File

@ -6,9 +6,9 @@ import (
) )
const ( const (
noAuth = uint8(0) NoAuth = uint8(0)
noAcceptable = uint8(255) noAcceptable = uint8(255)
userPassAuth = uint8(2) UserPassAuth = uint8(2)
userAuthVersion = uint8(1) userAuthVersion = uint8(1)
authSuccess = uint8(0) authSuccess = uint8(0)
authFailure = uint8(1) authFailure = uint8(1)
@ -19,21 +19,32 @@ var (
NoSupportedAuth = fmt.Errorf("No supported authentication mechanism") NoSupportedAuth = fmt.Errorf("No supported authentication mechanism")
) )
// A Request encapsulates authentication state provided
// during negotiation
type AuthContext struct {
// Provided auth method
Method uint8
// Payload provided during negotiation.
// Keys depend on the used auth method.
// For UserPassauth contains Username
Payload map[string]string
}
type Authenticator interface { type Authenticator interface {
Authenticate(reader io.Reader, writer io.Writer) error Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error)
GetCode() uint8 GetCode() uint8
} }
// NoAuthAuthenticator is used to handle the "No Authentication" mode // NoAuthAuthenticator is used to handle the "No Authentication" mode
type NoAuthAuthenticator struct {} type NoAuthAuthenticator struct{}
func (a NoAuthAuthenticator) GetCode() uint8 { func (a NoAuthAuthenticator) GetCode() uint8 {
return noAuth return NoAuth
} }
func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer) error { func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
_, err := writer.Write([]byte{socks5Version, noAuth}) _, err := writer.Write([]byte{socks5Version, NoAuth})
return err return &AuthContext{NoAuth, nil}, err
} }
// UserPassAuthenticator is used to handle username/password based // UserPassAuthenticator is used to handle username/password based
@ -43,70 +54,67 @@ type UserPassAuthenticator struct {
} }
func (a UserPassAuthenticator) GetCode() uint8 { func (a UserPassAuthenticator) GetCode() uint8 {
return userPassAuth return UserPassAuth
} }
func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) error { func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
// Tell the client to use user/pass auth // Tell the client to use user/pass auth
if _, err := writer.Write([]byte{socks5Version, userPassAuth}); err != nil { if _, err := writer.Write([]byte{socks5Version, UserPassAuth}); err != nil {
return err return nil, err
} }
// Get the version and username length // Get the version and username length
header := []byte{0, 0} header := []byte{0, 0}
if _, err := io.ReadAtLeast(reader, header, 2); err != nil { if _, err := io.ReadAtLeast(reader, header, 2); err != nil {
return err return nil, err
} }
// Ensure we are compatible // Ensure we are compatible
if header[0] != userAuthVersion { if header[0] != userAuthVersion {
return fmt.Errorf("Unsupported auth version: %v", header[0]) return nil, fmt.Errorf("Unsupported auth version: %v", header[0])
} }
// Get the user name // Get the user name
userLen := int(header[1]) userLen := int(header[1])
user := make([]byte, userLen) user := make([]byte, userLen)
if _, err := io.ReadAtLeast(reader, user, userLen); err != nil { if _, err := io.ReadAtLeast(reader, user, userLen); err != nil {
return err return nil, err
} }
// Get the password length // Get the password length
if _, err := reader.Read(header[:1]); err != nil { if _, err := reader.Read(header[:1]); err != nil {
return err return nil, err
} }
// Get the password // Get the password
passLen := int(header[0]) passLen := int(header[0])
pass := make([]byte, passLen) pass := make([]byte, passLen)
if _, err := io.ReadAtLeast(reader, pass, passLen); err != nil { if _, err := io.ReadAtLeast(reader, pass, passLen); err != nil {
return err return nil, err
} }
// Verify the password // Verify the password
if a.Credentials.Valid(string(user), string(pass)) { if a.Credentials.Valid(string(user), string(pass)) {
if _, err := writer.Write([]byte{userAuthVersion, authSuccess}); err != nil { if _, err := writer.Write([]byte{userAuthVersion, authSuccess}); err != nil {
return err return nil, err
} }
} else { } else {
if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil { if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil {
return err return nil, err
} }
return UserAuthFailed return nil, UserAuthFailed
} }
// Done // Done
return nil return &AuthContext{UserPassAuth, map[string]string{"Username": string(user)}}, nil
} }
// authenticate is used to handle connection authentication // authenticate is used to handle connection authentication
func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) error { func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext, error) {
// Get the methods // Get the methods
methods, err := readMethods(bufConn) methods, err := readMethods(bufConn)
if err != nil { if err != nil {
return fmt.Errorf("Failed to get auth methods: %v", err) return nil, fmt.Errorf("Failed to get auth methods: %v", err)
} }
// Select a usable method // Select a usable method
@ -118,11 +126,9 @@ func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) error {
} }
// No usable method found // No usable method found
return noAcceptableAuth(conn) return nil, noAcceptableAuth(conn)
} }
// noAcceptableAuth is used to handle when we have no eligible // noAcceptableAuth is used to handle when we have no eligible
// authentication mechanism // authentication mechanism
func noAcceptableAuth(conn io.Writer) error { func noAcceptableAuth(conn io.Writer) error {

View File

@ -7,47 +7,66 @@ import (
func TestNoAuth(t *testing.T) { func TestNoAuth(t *testing.T) {
req := bytes.NewBuffer(nil) req := bytes.NewBuffer(nil)
req.Write([]byte{1, noAuth}) req.Write([]byte{1, NoAuth})
var resp bytes.Buffer var resp bytes.Buffer
s, _ := New(&Config{}) s, _ := New(&Config{})
if err := s.authenticate(&resp, req); err != nil { ctx, err := s.authenticate(&resp, req)
if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if ctx.Method != NoAuth {
t.Fatal("Invalid Context Method")
}
out := resp.Bytes() out := resp.Bytes()
if !bytes.Equal(out, []byte{socks5Version, noAuth}) { if !bytes.Equal(out, []byte{socks5Version, NoAuth}) {
t.Fatalf("bad: %v", out) t.Fatalf("bad: %v", out)
} }
} }
func TestPasswordAuth_Valid(t *testing.T) { func TestPasswordAuth_Valid(t *testing.T) {
req := bytes.NewBuffer(nil) req := bytes.NewBuffer(nil)
req.Write([]byte{2, noAuth, userPassAuth}) req.Write([]byte{2, NoAuth, UserPassAuth})
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'}) req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
var resp bytes.Buffer var resp bytes.Buffer
cred := StaticCredentials{ cred := StaticCredentials{
"foo": "bar", "foo": "bar",
} }
cator := UserPassAuthenticator{Credentials: cred} cator := UserPassAuthenticator{Credentials: cred}
s, _ := New(&Config{AuthMethods:[]Authenticator{cator}}) s, _ := New(&Config{AuthMethods: []Authenticator{cator}})
if err := s.authenticate(&resp, req); err != nil { ctx, err := s.authenticate(&resp, req)
if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if ctx.Method != UserPassAuth {
t.Fatal("Invalid Context Method")
}
val, ok := ctx.Payload["Username"]
if !ok {
t.Fatal("Missing key Username in auth context's payload")
}
if val != "foo" {
t.Fatal("Invalid Username in auth context's payload")
}
out := resp.Bytes() out := resp.Bytes()
if !bytes.Equal(out, []byte{socks5Version, userPassAuth, 1, authSuccess}) { if !bytes.Equal(out, []byte{socks5Version, UserPassAuth, 1, authSuccess}) {
t.Fatalf("bad: %v", out) t.Fatalf("bad: %v", out)
} }
} }
func TestPasswordAuth_Invalid(t *testing.T) { func TestPasswordAuth_Invalid(t *testing.T) {
req := bytes.NewBuffer(nil) req := bytes.NewBuffer(nil)
req.Write([]byte{2, noAuth, userPassAuth}) req.Write([]byte{2, NoAuth, UserPassAuth})
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'z'}) req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'z'})
var resp bytes.Buffer var resp bytes.Buffer
@ -55,20 +74,26 @@ func TestPasswordAuth_Invalid(t *testing.T) {
"foo": "bar", "foo": "bar",
} }
cator := UserPassAuthenticator{Credentials: cred} cator := UserPassAuthenticator{Credentials: cred}
s, _ := New(&Config{AuthMethods:[]Authenticator{cator}}) s, _ := New(&Config{AuthMethods: []Authenticator{cator}})
if err := s.authenticate(&resp, req); err != UserAuthFailed {
ctx, err := s.authenticate(&resp, req)
if err != UserAuthFailed {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if ctx != nil {
t.Fatal("Invalid Context Method")
}
out := resp.Bytes() out := resp.Bytes()
if !bytes.Equal(out, []byte{socks5Version, userPassAuth, 1, authFailure}) { if !bytes.Equal(out, []byte{socks5Version, UserPassAuth, 1, authFailure}) {
t.Fatalf("bad: %v", out) t.Fatalf("bad: %v", out)
} }
} }
func TestNoSupportedAuth(t *testing.T) { func TestNoSupportedAuth(t *testing.T) {
req := bytes.NewBuffer(nil) req := bytes.NewBuffer(nil)
req.Write([]byte{1, noAuth}) req.Write([]byte{1, NoAuth})
var resp bytes.Buffer var resp bytes.Buffer
cred := StaticCredentials{ cred := StaticCredentials{
@ -76,11 +101,17 @@ func TestNoSupportedAuth(t *testing.T) {
} }
cator := UserPassAuthenticator{Credentials: cred} cator := UserPassAuthenticator{Credentials: cred}
s, _ := New(&Config{AuthMethods:[]Authenticator{cator}}) s, _ := New(&Config{AuthMethods: []Authenticator{cator}})
if err := s.authenticate(&resp, req); err != NoSupportedAuth {
ctx, err := s.authenticate(&resp, req)
if err != NoSupportedAuth {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if ctx != nil {
t.Fatal("Invalid Context Method")
}
out := resp.Bytes() out := resp.Bytes()
if !bytes.Equal(out, []byte{socks5Version, noAcceptable}) { if !bytes.Equal(out, []byte{socks5Version, noAcceptable}) {
t.Fatalf("bad: %v", out) t.Fatalf("bad: %v", out)