Introduce AuthContext struct and refactor auth api using it
parent
aad92b1d46
commit
353e906a7b
64
auth.go
64
auth.go
|
@ -6,9 +6,9 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
noAuth = uint8(0)
|
NoAuth = uint8(0)
|
||||||
noAcceptable = uint8(255)
|
noAcceptable = uint8(255)
|
||||||
userPassAuth = uint8(2)
|
UserPassAuth = uint8(2)
|
||||||
userAuthVersion = uint8(1)
|
userAuthVersion = uint8(1)
|
||||||
authSuccess = uint8(0)
|
authSuccess = uint8(0)
|
||||||
authFailure = uint8(1)
|
authFailure = uint8(1)
|
||||||
|
@ -19,21 +19,32 @@ var (
|
||||||
NoSupportedAuth = fmt.Errorf("No supported authentication mechanism")
|
NoSupportedAuth = fmt.Errorf("No supported authentication mechanism")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// A Request encapsulates authentication state provided
|
||||||
|
// during negotiation
|
||||||
|
type AuthContext struct {
|
||||||
|
// Provided auth method
|
||||||
|
Method uint8
|
||||||
|
// Payload provided during negotiation.
|
||||||
|
// Keys depend on the used auth method.
|
||||||
|
// For UserPassauth contains Username
|
||||||
|
Payload map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
type Authenticator interface {
|
type Authenticator interface {
|
||||||
Authenticate(reader io.Reader, writer io.Writer) error
|
Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error)
|
||||||
GetCode() uint8
|
GetCode() uint8
|
||||||
}
|
}
|
||||||
|
|
||||||
// NoAuthAuthenticator is used to handle the "No Authentication" mode
|
// NoAuthAuthenticator is used to handle the "No Authentication" mode
|
||||||
type NoAuthAuthenticator struct {}
|
type NoAuthAuthenticator struct{}
|
||||||
|
|
||||||
func (a NoAuthAuthenticator) GetCode() uint8 {
|
func (a NoAuthAuthenticator) GetCode() uint8 {
|
||||||
return noAuth
|
return NoAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer) error {
|
func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
|
||||||
_, err := writer.Write([]byte{socks5Version, noAuth})
|
_, err := writer.Write([]byte{socks5Version, NoAuth})
|
||||||
return err
|
return &AuthContext{NoAuth, nil}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserPassAuthenticator is used to handle username/password based
|
// UserPassAuthenticator is used to handle username/password based
|
||||||
|
@ -43,70 +54,67 @@ type UserPassAuthenticator struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a UserPassAuthenticator) GetCode() uint8 {
|
func (a UserPassAuthenticator) GetCode() uint8 {
|
||||||
return userPassAuth
|
return UserPassAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) error {
|
func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
|
||||||
// Tell the client to use user/pass auth
|
// Tell the client to use user/pass auth
|
||||||
if _, err := writer.Write([]byte{socks5Version, userPassAuth}); err != nil {
|
if _, err := writer.Write([]byte{socks5Version, UserPassAuth}); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the version and username length
|
// Get the version and username length
|
||||||
header := []byte{0, 0}
|
header := []byte{0, 0}
|
||||||
if _, err := io.ReadAtLeast(reader, header, 2); err != nil {
|
if _, err := io.ReadAtLeast(reader, header, 2); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure we are compatible
|
// Ensure we are compatible
|
||||||
if header[0] != userAuthVersion {
|
if header[0] != userAuthVersion {
|
||||||
return fmt.Errorf("Unsupported auth version: %v", header[0])
|
return nil, fmt.Errorf("Unsupported auth version: %v", header[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the user name
|
// Get the user name
|
||||||
userLen := int(header[1])
|
userLen := int(header[1])
|
||||||
user := make([]byte, userLen)
|
user := make([]byte, userLen)
|
||||||
if _, err := io.ReadAtLeast(reader, user, userLen); err != nil {
|
if _, err := io.ReadAtLeast(reader, user, userLen); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the password length
|
// Get the password length
|
||||||
if _, err := reader.Read(header[:1]); err != nil {
|
if _, err := reader.Read(header[:1]); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the password
|
// Get the password
|
||||||
passLen := int(header[0])
|
passLen := int(header[0])
|
||||||
pass := make([]byte, passLen)
|
pass := make([]byte, passLen)
|
||||||
if _, err := io.ReadAtLeast(reader, pass, passLen); err != nil {
|
if _, err := io.ReadAtLeast(reader, pass, passLen); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify the password
|
// Verify the password
|
||||||
if a.Credentials.Valid(string(user), string(pass)) {
|
if a.Credentials.Valid(string(user), string(pass)) {
|
||||||
if _, err := writer.Write([]byte{userAuthVersion, authSuccess}); err != nil {
|
if _, err := writer.Write([]byte{userAuthVersion, authSuccess}); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil {
|
if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
return UserAuthFailed
|
return nil, UserAuthFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
// Done
|
// Done
|
||||||
return nil
|
return &AuthContext{UserPassAuth, map[string]string{"Username": string(user)}}, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// authenticate is used to handle connection authentication
|
// authenticate is used to handle connection authentication
|
||||||
func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) error {
|
func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext, error) {
|
||||||
// Get the methods
|
// Get the methods
|
||||||
methods, err := readMethods(bufConn)
|
methods, err := readMethods(bufConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Failed to get auth methods: %v", err)
|
return nil, fmt.Errorf("Failed to get auth methods: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Select a usable method
|
// Select a usable method
|
||||||
|
@ -118,11 +126,9 @@ func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// No usable method found
|
// No usable method found
|
||||||
return noAcceptableAuth(conn)
|
return nil, noAcceptableAuth(conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// noAcceptableAuth is used to handle when we have no eligible
|
// noAcceptableAuth is used to handle when we have no eligible
|
||||||
// authentication mechanism
|
// authentication mechanism
|
||||||
func noAcceptableAuth(conn io.Writer) error {
|
func noAcceptableAuth(conn io.Writer) error {
|
||||||
|
|
61
auth_test.go
61
auth_test.go
|
@ -7,47 +7,66 @@ import (
|
||||||
|
|
||||||
func TestNoAuth(t *testing.T) {
|
func TestNoAuth(t *testing.T) {
|
||||||
req := bytes.NewBuffer(nil)
|
req := bytes.NewBuffer(nil)
|
||||||
req.Write([]byte{1, noAuth})
|
req.Write([]byte{1, NoAuth})
|
||||||
var resp bytes.Buffer
|
var resp bytes.Buffer
|
||||||
|
|
||||||
s, _ := New(&Config{})
|
s, _ := New(&Config{})
|
||||||
if err := s.authenticate(&resp, req); err != nil {
|
ctx, err := s.authenticate(&resp, req)
|
||||||
|
if err != nil {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ctx.Method != NoAuth {
|
||||||
|
t.Fatal("Invalid Context Method")
|
||||||
|
}
|
||||||
|
|
||||||
out := resp.Bytes()
|
out := resp.Bytes()
|
||||||
if !bytes.Equal(out, []byte{socks5Version, noAuth}) {
|
if !bytes.Equal(out, []byte{socks5Version, NoAuth}) {
|
||||||
t.Fatalf("bad: %v", out)
|
t.Fatalf("bad: %v", out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPasswordAuth_Valid(t *testing.T) {
|
func TestPasswordAuth_Valid(t *testing.T) {
|
||||||
req := bytes.NewBuffer(nil)
|
req := bytes.NewBuffer(nil)
|
||||||
req.Write([]byte{2, noAuth, userPassAuth})
|
req.Write([]byte{2, NoAuth, UserPassAuth})
|
||||||
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
|
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
|
||||||
var resp bytes.Buffer
|
var resp bytes.Buffer
|
||||||
|
|
||||||
cred := StaticCredentials{
|
cred := StaticCredentials{
|
||||||
"foo": "bar",
|
"foo": "bar",
|
||||||
}
|
}
|
||||||
|
|
||||||
cator := UserPassAuthenticator{Credentials: cred}
|
cator := UserPassAuthenticator{Credentials: cred}
|
||||||
|
|
||||||
s, _ := New(&Config{AuthMethods:[]Authenticator{cator}})
|
s, _ := New(&Config{AuthMethods: []Authenticator{cator}})
|
||||||
|
|
||||||
if err := s.authenticate(&resp, req); err != nil {
|
ctx, err := s.authenticate(&resp, req)
|
||||||
|
if err != nil {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ctx.Method != UserPassAuth {
|
||||||
|
t.Fatal("Invalid Context Method")
|
||||||
|
}
|
||||||
|
|
||||||
|
val, ok := ctx.Payload["Username"]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Missing key Username in auth context's payload")
|
||||||
|
}
|
||||||
|
|
||||||
|
if val != "foo" {
|
||||||
|
t.Fatal("Invalid Username in auth context's payload")
|
||||||
|
}
|
||||||
|
|
||||||
out := resp.Bytes()
|
out := resp.Bytes()
|
||||||
if !bytes.Equal(out, []byte{socks5Version, userPassAuth, 1, authSuccess}) {
|
if !bytes.Equal(out, []byte{socks5Version, UserPassAuth, 1, authSuccess}) {
|
||||||
t.Fatalf("bad: %v", out)
|
t.Fatalf("bad: %v", out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPasswordAuth_Invalid(t *testing.T) {
|
func TestPasswordAuth_Invalid(t *testing.T) {
|
||||||
req := bytes.NewBuffer(nil)
|
req := bytes.NewBuffer(nil)
|
||||||
req.Write([]byte{2, noAuth, userPassAuth})
|
req.Write([]byte{2, NoAuth, UserPassAuth})
|
||||||
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'z'})
|
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'z'})
|
||||||
var resp bytes.Buffer
|
var resp bytes.Buffer
|
||||||
|
|
||||||
|
@ -55,20 +74,26 @@ func TestPasswordAuth_Invalid(t *testing.T) {
|
||||||
"foo": "bar",
|
"foo": "bar",
|
||||||
}
|
}
|
||||||
cator := UserPassAuthenticator{Credentials: cred}
|
cator := UserPassAuthenticator{Credentials: cred}
|
||||||
s, _ := New(&Config{AuthMethods:[]Authenticator{cator}})
|
s, _ := New(&Config{AuthMethods: []Authenticator{cator}})
|
||||||
if err := s.authenticate(&resp, req); err != UserAuthFailed {
|
|
||||||
|
ctx, err := s.authenticate(&resp, req)
|
||||||
|
if err != UserAuthFailed {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ctx != nil {
|
||||||
|
t.Fatal("Invalid Context Method")
|
||||||
|
}
|
||||||
|
|
||||||
out := resp.Bytes()
|
out := resp.Bytes()
|
||||||
if !bytes.Equal(out, []byte{socks5Version, userPassAuth, 1, authFailure}) {
|
if !bytes.Equal(out, []byte{socks5Version, UserPassAuth, 1, authFailure}) {
|
||||||
t.Fatalf("bad: %v", out)
|
t.Fatalf("bad: %v", out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNoSupportedAuth(t *testing.T) {
|
func TestNoSupportedAuth(t *testing.T) {
|
||||||
req := bytes.NewBuffer(nil)
|
req := bytes.NewBuffer(nil)
|
||||||
req.Write([]byte{1, noAuth})
|
req.Write([]byte{1, NoAuth})
|
||||||
var resp bytes.Buffer
|
var resp bytes.Buffer
|
||||||
|
|
||||||
cred := StaticCredentials{
|
cred := StaticCredentials{
|
||||||
|
@ -76,11 +101,17 @@ func TestNoSupportedAuth(t *testing.T) {
|
||||||
}
|
}
|
||||||
cator := UserPassAuthenticator{Credentials: cred}
|
cator := UserPassAuthenticator{Credentials: cred}
|
||||||
|
|
||||||
s, _ := New(&Config{AuthMethods:[]Authenticator{cator}})
|
s, _ := New(&Config{AuthMethods: []Authenticator{cator}})
|
||||||
if err := s.authenticate(&resp, req); err != NoSupportedAuth {
|
|
||||||
|
ctx, err := s.authenticate(&resp, req)
|
||||||
|
if err != NoSupportedAuth {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ctx != nil {
|
||||||
|
t.Fatal("Invalid Context Method")
|
||||||
|
}
|
||||||
|
|
||||||
out := resp.Bytes()
|
out := resp.Bytes()
|
||||||
if !bytes.Equal(out, []byte{socks5Version, noAcceptable}) {
|
if !bytes.Equal(out, []byte{socks5Version, noAcceptable}) {
|
||||||
t.Fatalf("bad: %v", out)
|
t.Fatalf("bad: %v", out)
|
||||||
|
|
Loading…
Reference in New Issue