commit
b7e3cc6811
64
auth.go
64
auth.go
|
@ -6,9 +6,9 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
noAuth = uint8(0)
|
NoAuth = uint8(0)
|
||||||
noAcceptable = uint8(255)
|
noAcceptable = uint8(255)
|
||||||
userPassAuth = uint8(2)
|
UserPassAuth = uint8(2)
|
||||||
userAuthVersion = uint8(1)
|
userAuthVersion = uint8(1)
|
||||||
authSuccess = uint8(0)
|
authSuccess = uint8(0)
|
||||||
authFailure = uint8(1)
|
authFailure = uint8(1)
|
||||||
|
@ -19,21 +19,32 @@ var (
|
||||||
NoSupportedAuth = fmt.Errorf("No supported authentication mechanism")
|
NoSupportedAuth = fmt.Errorf("No supported authentication mechanism")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// A Request encapsulates authentication state provided
|
||||||
|
// during negotiation
|
||||||
|
type AuthContext struct {
|
||||||
|
// Provided auth method
|
||||||
|
Method uint8
|
||||||
|
// Payload provided during negotiation.
|
||||||
|
// Keys depend on the used auth method.
|
||||||
|
// For UserPassauth contains Username
|
||||||
|
Payload map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
type Authenticator interface {
|
type Authenticator interface {
|
||||||
Authenticate(reader io.Reader, writer io.Writer) error
|
Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error)
|
||||||
GetCode() uint8
|
GetCode() uint8
|
||||||
}
|
}
|
||||||
|
|
||||||
// NoAuthAuthenticator is used to handle the "No Authentication" mode
|
// NoAuthAuthenticator is used to handle the "No Authentication" mode
|
||||||
type NoAuthAuthenticator struct {}
|
type NoAuthAuthenticator struct{}
|
||||||
|
|
||||||
func (a NoAuthAuthenticator) GetCode() uint8 {
|
func (a NoAuthAuthenticator) GetCode() uint8 {
|
||||||
return noAuth
|
return NoAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer) error {
|
func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
|
||||||
_, err := writer.Write([]byte{socks5Version, noAuth})
|
_, err := writer.Write([]byte{socks5Version, NoAuth})
|
||||||
return err
|
return &AuthContext{NoAuth, nil}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserPassAuthenticator is used to handle username/password based
|
// UserPassAuthenticator is used to handle username/password based
|
||||||
|
@ -43,70 +54,67 @@ type UserPassAuthenticator struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a UserPassAuthenticator) GetCode() uint8 {
|
func (a UserPassAuthenticator) GetCode() uint8 {
|
||||||
return userPassAuth
|
return UserPassAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) error {
|
func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
|
||||||
// Tell the client to use user/pass auth
|
// Tell the client to use user/pass auth
|
||||||
if _, err := writer.Write([]byte{socks5Version, userPassAuth}); err != nil {
|
if _, err := writer.Write([]byte{socks5Version, UserPassAuth}); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the version and username length
|
// Get the version and username length
|
||||||
header := []byte{0, 0}
|
header := []byte{0, 0}
|
||||||
if _, err := io.ReadAtLeast(reader, header, 2); err != nil {
|
if _, err := io.ReadAtLeast(reader, header, 2); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure we are compatible
|
// Ensure we are compatible
|
||||||
if header[0] != userAuthVersion {
|
if header[0] != userAuthVersion {
|
||||||
return fmt.Errorf("Unsupported auth version: %v", header[0])
|
return nil, fmt.Errorf("Unsupported auth version: %v", header[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the user name
|
// Get the user name
|
||||||
userLen := int(header[1])
|
userLen := int(header[1])
|
||||||
user := make([]byte, userLen)
|
user := make([]byte, userLen)
|
||||||
if _, err := io.ReadAtLeast(reader, user, userLen); err != nil {
|
if _, err := io.ReadAtLeast(reader, user, userLen); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the password length
|
// Get the password length
|
||||||
if _, err := reader.Read(header[:1]); err != nil {
|
if _, err := reader.Read(header[:1]); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the password
|
// Get the password
|
||||||
passLen := int(header[0])
|
passLen := int(header[0])
|
||||||
pass := make([]byte, passLen)
|
pass := make([]byte, passLen)
|
||||||
if _, err := io.ReadAtLeast(reader, pass, passLen); err != nil {
|
if _, err := io.ReadAtLeast(reader, pass, passLen); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify the password
|
// Verify the password
|
||||||
if a.Credentials.Valid(string(user), string(pass)) {
|
if a.Credentials.Valid(string(user), string(pass)) {
|
||||||
if _, err := writer.Write([]byte{userAuthVersion, authSuccess}); err != nil {
|
if _, err := writer.Write([]byte{userAuthVersion, authSuccess}); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil {
|
if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
return UserAuthFailed
|
return nil, UserAuthFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
// Done
|
// Done
|
||||||
return nil
|
return &AuthContext{UserPassAuth, map[string]string{"Username": string(user)}}, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// authenticate is used to handle connection authentication
|
// authenticate is used to handle connection authentication
|
||||||
func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) error {
|
func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext, error) {
|
||||||
// Get the methods
|
// Get the methods
|
||||||
methods, err := readMethods(bufConn)
|
methods, err := readMethods(bufConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Failed to get auth methods: %v", err)
|
return nil, fmt.Errorf("Failed to get auth methods: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Select a usable method
|
// Select a usable method
|
||||||
|
@ -118,11 +126,9 @@ func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// No usable method found
|
// No usable method found
|
||||||
return noAcceptableAuth(conn)
|
return nil, noAcceptableAuth(conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// noAcceptableAuth is used to handle when we have no eligible
|
// noAcceptableAuth is used to handle when we have no eligible
|
||||||
// authentication mechanism
|
// authentication mechanism
|
||||||
func noAcceptableAuth(conn io.Writer) error {
|
func noAcceptableAuth(conn io.Writer) error {
|
||||||
|
|
61
auth_test.go
61
auth_test.go
|
@ -7,47 +7,66 @@ import (
|
||||||
|
|
||||||
func TestNoAuth(t *testing.T) {
|
func TestNoAuth(t *testing.T) {
|
||||||
req := bytes.NewBuffer(nil)
|
req := bytes.NewBuffer(nil)
|
||||||
req.Write([]byte{1, noAuth})
|
req.Write([]byte{1, NoAuth})
|
||||||
var resp bytes.Buffer
|
var resp bytes.Buffer
|
||||||
|
|
||||||
s, _ := New(&Config{})
|
s, _ := New(&Config{})
|
||||||
if err := s.authenticate(&resp, req); err != nil {
|
ctx, err := s.authenticate(&resp, req)
|
||||||
|
if err != nil {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ctx.Method != NoAuth {
|
||||||
|
t.Fatal("Invalid Context Method")
|
||||||
|
}
|
||||||
|
|
||||||
out := resp.Bytes()
|
out := resp.Bytes()
|
||||||
if !bytes.Equal(out, []byte{socks5Version, noAuth}) {
|
if !bytes.Equal(out, []byte{socks5Version, NoAuth}) {
|
||||||
t.Fatalf("bad: %v", out)
|
t.Fatalf("bad: %v", out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPasswordAuth_Valid(t *testing.T) {
|
func TestPasswordAuth_Valid(t *testing.T) {
|
||||||
req := bytes.NewBuffer(nil)
|
req := bytes.NewBuffer(nil)
|
||||||
req.Write([]byte{2, noAuth, userPassAuth})
|
req.Write([]byte{2, NoAuth, UserPassAuth})
|
||||||
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
|
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
|
||||||
var resp bytes.Buffer
|
var resp bytes.Buffer
|
||||||
|
|
||||||
cred := StaticCredentials{
|
cred := StaticCredentials{
|
||||||
"foo": "bar",
|
"foo": "bar",
|
||||||
}
|
}
|
||||||
|
|
||||||
cator := UserPassAuthenticator{Credentials: cred}
|
cator := UserPassAuthenticator{Credentials: cred}
|
||||||
|
|
||||||
s, _ := New(&Config{AuthMethods:[]Authenticator{cator}})
|
s, _ := New(&Config{AuthMethods: []Authenticator{cator}})
|
||||||
|
|
||||||
if err := s.authenticate(&resp, req); err != nil {
|
ctx, err := s.authenticate(&resp, req)
|
||||||
|
if err != nil {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ctx.Method != UserPassAuth {
|
||||||
|
t.Fatal("Invalid Context Method")
|
||||||
|
}
|
||||||
|
|
||||||
|
val, ok := ctx.Payload["Username"]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Missing key Username in auth context's payload")
|
||||||
|
}
|
||||||
|
|
||||||
|
if val != "foo" {
|
||||||
|
t.Fatal("Invalid Username in auth context's payload")
|
||||||
|
}
|
||||||
|
|
||||||
out := resp.Bytes()
|
out := resp.Bytes()
|
||||||
if !bytes.Equal(out, []byte{socks5Version, userPassAuth, 1, authSuccess}) {
|
if !bytes.Equal(out, []byte{socks5Version, UserPassAuth, 1, authSuccess}) {
|
||||||
t.Fatalf("bad: %v", out)
|
t.Fatalf("bad: %v", out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPasswordAuth_Invalid(t *testing.T) {
|
func TestPasswordAuth_Invalid(t *testing.T) {
|
||||||
req := bytes.NewBuffer(nil)
|
req := bytes.NewBuffer(nil)
|
||||||
req.Write([]byte{2, noAuth, userPassAuth})
|
req.Write([]byte{2, NoAuth, UserPassAuth})
|
||||||
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'z'})
|
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'z'})
|
||||||
var resp bytes.Buffer
|
var resp bytes.Buffer
|
||||||
|
|
||||||
|
@ -55,20 +74,26 @@ func TestPasswordAuth_Invalid(t *testing.T) {
|
||||||
"foo": "bar",
|
"foo": "bar",
|
||||||
}
|
}
|
||||||
cator := UserPassAuthenticator{Credentials: cred}
|
cator := UserPassAuthenticator{Credentials: cred}
|
||||||
s, _ := New(&Config{AuthMethods:[]Authenticator{cator}})
|
s, _ := New(&Config{AuthMethods: []Authenticator{cator}})
|
||||||
if err := s.authenticate(&resp, req); err != UserAuthFailed {
|
|
||||||
|
ctx, err := s.authenticate(&resp, req)
|
||||||
|
if err != UserAuthFailed {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ctx != nil {
|
||||||
|
t.Fatal("Invalid Context Method")
|
||||||
|
}
|
||||||
|
|
||||||
out := resp.Bytes()
|
out := resp.Bytes()
|
||||||
if !bytes.Equal(out, []byte{socks5Version, userPassAuth, 1, authFailure}) {
|
if !bytes.Equal(out, []byte{socks5Version, UserPassAuth, 1, authFailure}) {
|
||||||
t.Fatalf("bad: %v", out)
|
t.Fatalf("bad: %v", out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNoSupportedAuth(t *testing.T) {
|
func TestNoSupportedAuth(t *testing.T) {
|
||||||
req := bytes.NewBuffer(nil)
|
req := bytes.NewBuffer(nil)
|
||||||
req.Write([]byte{1, noAuth})
|
req.Write([]byte{1, NoAuth})
|
||||||
var resp bytes.Buffer
|
var resp bytes.Buffer
|
||||||
|
|
||||||
cred := StaticCredentials{
|
cred := StaticCredentials{
|
||||||
|
@ -76,11 +101,17 @@ func TestNoSupportedAuth(t *testing.T) {
|
||||||
}
|
}
|
||||||
cator := UserPassAuthenticator{Credentials: cred}
|
cator := UserPassAuthenticator{Credentials: cred}
|
||||||
|
|
||||||
s, _ := New(&Config{AuthMethods:[]Authenticator{cator}})
|
s, _ := New(&Config{AuthMethods: []Authenticator{cator}})
|
||||||
if err := s.authenticate(&resp, req); err != NoSupportedAuth {
|
|
||||||
|
ctx, err := s.authenticate(&resp, req)
|
||||||
|
if err != NoSupportedAuth {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ctx != nil {
|
||||||
|
t.Fatal("Invalid Context Method")
|
||||||
|
}
|
||||||
|
|
||||||
out := resp.Bytes()
|
out := resp.Bytes()
|
||||||
if !bytes.Equal(out, []byte{socks5Version, noAcceptable}) {
|
if !bytes.Equal(out, []byte{socks5Version, noAcceptable}) {
|
||||||
t.Fatalf("bad: %v", out)
|
t.Fatalf("bad: %v", out)
|
||||||
|
|
110
request.go
110
request.go
|
@ -10,9 +10,9 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
connectCommand = uint8(1)
|
ConnectCommand = uint8(1)
|
||||||
bindCommand = uint8(2)
|
BindCommand = uint8(2)
|
||||||
associateCommand = uint8(3)
|
AssociateCommand = uint8(3)
|
||||||
ipv4Address = uint8(1)
|
ipv4Address = uint8(1)
|
||||||
fqdnAddress = uint8(3)
|
fqdnAddress = uint8(3)
|
||||||
ipv6Address = uint8(4)
|
ipv6Address = uint8(4)
|
||||||
|
@ -36,7 +36,7 @@ var (
|
||||||
|
|
||||||
// AddressRewriter is used to rewrite a destination transparently
|
// AddressRewriter is used to rewrite a destination transparently
|
||||||
type AddressRewriter interface {
|
type AddressRewriter interface {
|
||||||
Rewrite(addr *AddrSpec) *AddrSpec
|
Rewrite(request *Request) *AddrSpec
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddrSpec is used to return the target AddrSpec
|
// AddrSpec is used to return the target AddrSpec
|
||||||
|
@ -47,11 +47,6 @@ type AddrSpec struct {
|
||||||
Port int
|
Port int
|
||||||
}
|
}
|
||||||
|
|
||||||
type conn interface {
|
|
||||||
Write([]byte) (int, error)
|
|
||||||
RemoteAddr() net.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AddrSpec) String() string {
|
func (a *AddrSpec) String() string {
|
||||||
if a.FQDN != "" {
|
if a.FQDN != "" {
|
||||||
return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port)
|
return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port)
|
||||||
|
@ -59,31 +54,61 @@ func (a *AddrSpec) String() string {
|
||||||
return fmt.Sprintf("%s:%d", a.IP, a.Port)
|
return fmt.Sprintf("%s:%d", a.IP, a.Port)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleRequest is used for request processing after authentication
|
// A Request represents request received by a server
|
||||||
func (s *Server) handleRequest(conn conn, bufConn io.Reader) error {
|
type Request struct {
|
||||||
|
// Protocol version
|
||||||
|
Version uint8
|
||||||
|
// Requested command
|
||||||
|
Command uint8
|
||||||
|
// AuthContext provided during negotiation
|
||||||
|
AuthContext *AuthContext
|
||||||
|
// AddrSpec of the the network that sent the request
|
||||||
|
RemoteAddr *AddrSpec
|
||||||
|
// AddrSpec of the desired destination
|
||||||
|
DestAddr *AddrSpec
|
||||||
|
// AddrSpec of the actual destination (might be affected by rewrite)
|
||||||
|
realDestAddr *AddrSpec
|
||||||
|
bufConn io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
type conn interface {
|
||||||
|
Write([]byte) (int, error)
|
||||||
|
RemoteAddr() net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRequest creates a new Request from the tcp connection
|
||||||
|
func NewRequest(bufConn io.Reader) (*Request, error) {
|
||||||
// Read the version byte
|
// Read the version byte
|
||||||
header := []byte{0, 0, 0}
|
header := []byte{0, 0, 0}
|
||||||
if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil {
|
if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil {
|
||||||
return fmt.Errorf("Failed to get command version: %v", err)
|
return nil, fmt.Errorf("Failed to get command version: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure we are compatible
|
// Ensure we are compatible
|
||||||
if header[0] != socks5Version {
|
if header[0] != socks5Version {
|
||||||
return fmt.Errorf("Unsupported command version: %v", header[0])
|
return nil, fmt.Errorf("Unsupported command version: %v", header[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read in the destination address
|
// Read in the destination address
|
||||||
dest, err := readAddrSpec(bufConn)
|
dest, err := readAddrSpec(bufConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == unrecognizedAddrType {
|
return nil, err
|
||||||
if err := sendReply(conn, addrTypeNotSupported, nil); err != nil {
|
|
||||||
return fmt.Errorf("Failed to send reply: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fmt.Errorf("Failed to read destination address: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
request := &Request{
|
||||||
|
Version: socks5Version,
|
||||||
|
Command: header[1],
|
||||||
|
DestAddr: dest,
|
||||||
|
bufConn: bufConn,
|
||||||
|
}
|
||||||
|
|
||||||
|
return request, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleRequest is used for request processing after authentication
|
||||||
|
func (s *Server) handleRequest(req *Request, conn conn) error {
|
||||||
// Resolve the address if we have a FQDN
|
// Resolve the address if we have a FQDN
|
||||||
|
dest := req.DestAddr
|
||||||
if dest.FQDN != "" {
|
if dest.FQDN != "" {
|
||||||
addr, err := s.config.Resolver.Resolve(dest.FQDN)
|
addr, err := s.config.Resolver.Resolve(dest.FQDN)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -96,40 +121,39 @@ func (s *Server) handleRequest(conn conn, bufConn io.Reader) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply any address rewrites
|
// Apply any address rewrites
|
||||||
realDest := dest
|
req.realDestAddr = req.DestAddr
|
||||||
if s.config.Rewriter != nil {
|
if s.config.Rewriter != nil {
|
||||||
realDest = s.config.Rewriter.Rewrite(dest)
|
req.realDestAddr = s.config.Rewriter.Rewrite(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Switch on the command
|
// Switch on the command
|
||||||
switch header[1] {
|
switch req.Command {
|
||||||
case connectCommand:
|
case ConnectCommand:
|
||||||
return s.handleConnect(conn, bufConn, dest, realDest)
|
return s.handleConnect(conn, req)
|
||||||
case bindCommand:
|
case BindCommand:
|
||||||
return s.handleBind(conn, bufConn, dest, realDest)
|
return s.handleBind(conn, req)
|
||||||
case associateCommand:
|
case AssociateCommand:
|
||||||
return s.handleAssociate(conn, bufConn, dest, realDest)
|
return s.handleAssociate(conn, req)
|
||||||
default:
|
default:
|
||||||
if err := sendReply(conn, commandNotSupported, nil); err != nil {
|
if err := sendReply(conn, commandNotSupported, nil); err != nil {
|
||||||
return fmt.Errorf("Failed to send reply: %v", err)
|
return fmt.Errorf("Failed to send reply: %v", err)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("Unsupported command: %v", header[1])
|
return fmt.Errorf("Unsupported command: %v", req.Command)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleConnect is used to handle a connect command
|
// handleConnect is used to handle a connect command
|
||||||
func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest, realDest *AddrSpec) error {
|
func (s *Server) handleConnect(conn conn, req *Request) error {
|
||||||
// Check if this is allowed
|
// Check if this is allowed
|
||||||
client := conn.RemoteAddr().(*net.TCPAddr)
|
if !s.config.Rules.Allow(req) {
|
||||||
if !s.config.Rules.AllowConnect(realDest.IP, realDest.Port, client.IP, client.Port) {
|
|
||||||
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
||||||
return fmt.Errorf("Failed to send reply: %v", err)
|
return fmt.Errorf("Failed to send reply: %v", err)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("Connect to %v blocked by rules", dest)
|
return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attempt to connect
|
// Attempt to connect
|
||||||
addr := net.TCPAddr{IP: realDest.IP, Port: realDest.Port}
|
addr := net.TCPAddr{IP: req.realDestAddr.IP, Port: req.realDestAddr.Port}
|
||||||
target, err := net.DialTCP("tcp", nil, &addr)
|
target, err := net.DialTCP("tcp", nil, &addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
msg := err.Error()
|
msg := err.Error()
|
||||||
|
@ -142,7 +166,7 @@ func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest, realDest *Add
|
||||||
if err := sendReply(conn, resp, nil); err != nil {
|
if err := sendReply(conn, resp, nil); err != nil {
|
||||||
return fmt.Errorf("Failed to send reply: %v", err)
|
return fmt.Errorf("Failed to send reply: %v", err)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("Connect to %v failed: %v", dest, err)
|
return fmt.Errorf("Connect to %v failed: %v", req.DestAddr, err)
|
||||||
}
|
}
|
||||||
defer target.Close()
|
defer target.Close()
|
||||||
|
|
||||||
|
@ -155,7 +179,7 @@ func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest, realDest *Add
|
||||||
|
|
||||||
// Start proxying
|
// Start proxying
|
||||||
errCh := make(chan error, 2)
|
errCh := make(chan error, 2)
|
||||||
go proxy("target", target, bufConn, errCh, s.config.Logger)
|
go proxy("target", target, req.bufConn, errCh, s.config.Logger)
|
||||||
go proxy("client", conn, target, errCh, s.config.Logger)
|
go proxy("client", conn, target, errCh, s.config.Logger)
|
||||||
|
|
||||||
// Wait
|
// Wait
|
||||||
|
@ -166,14 +190,13 @@ func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest, realDest *Add
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleBind is used to handle a connect command
|
// handleBind is used to handle a connect command
|
||||||
func (s *Server) handleBind(conn conn, bufConn io.Reader, dest, realDest *AddrSpec) error {
|
func (s *Server) handleBind(conn conn, req *Request) error {
|
||||||
// Check if this is allowed
|
// Check if this is allowed
|
||||||
client := conn.RemoteAddr().(*net.TCPAddr)
|
if !s.config.Rules.Allow(req) {
|
||||||
if !s.config.Rules.AllowBind(realDest.IP, realDest.Port, client.IP, client.Port) {
|
|
||||||
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
||||||
return fmt.Errorf("Failed to send reply: %v", err)
|
return fmt.Errorf("Failed to send reply: %v", err)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("Bind to %v blocked by rules", dest)
|
return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Support bind
|
// TODO: Support bind
|
||||||
|
@ -184,14 +207,13 @@ func (s *Server) handleBind(conn conn, bufConn io.Reader, dest, realDest *AddrSp
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleAssociate is used to handle a connect command
|
// handleAssociate is used to handle a connect command
|
||||||
func (s *Server) handleAssociate(conn conn, bufConn io.Reader, dest, realDest *AddrSpec) error {
|
func (s *Server) handleAssociate(conn conn, req *Request) error {
|
||||||
// Check if this is allowed
|
// Check if this is allowed
|
||||||
client := conn.RemoteAddr().(*net.TCPAddr)
|
if !s.config.Rules.Allow(req) {
|
||||||
if !s.config.Rules.AllowAssociate(realDest.IP, realDest.Port, client.IP, client.Port) {
|
|
||||||
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
||||||
return fmt.Errorf("Failed to send reply: %v", err)
|
return fmt.Errorf("Failed to send reply: %v", err)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("Associate to %v blocked by rules", dest)
|
return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Support associate
|
// TODO: Support associate
|
||||||
|
|
|
@ -56,19 +56,24 @@ func TestRequest_Connect(t *testing.T) {
|
||||||
}}
|
}}
|
||||||
|
|
||||||
// Create the connect request
|
// Create the connect request
|
||||||
req := bytes.NewBuffer(nil)
|
buf := bytes.NewBuffer(nil)
|
||||||
req.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
|
buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
|
||||||
|
|
||||||
port := []byte{0, 0}
|
port := []byte{0, 0}
|
||||||
binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
|
binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
|
||||||
req.Write(port)
|
buf.Write(port)
|
||||||
|
|
||||||
// Send a ping
|
// Send a ping
|
||||||
req.Write([]byte("ping"))
|
buf.Write([]byte("ping"))
|
||||||
|
|
||||||
// Handle the request
|
// Handle the request
|
||||||
resp := &MockConn{}
|
resp := &MockConn{}
|
||||||
if err := s.handleRequest(resp, req); err != nil {
|
req, err := NewRequest(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.handleRequest(req, resp); err != nil {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -126,19 +131,24 @@ func TestRequest_Connect_RuleFail(t *testing.T) {
|
||||||
}}
|
}}
|
||||||
|
|
||||||
// Create the connect request
|
// Create the connect request
|
||||||
req := bytes.NewBuffer(nil)
|
buf := bytes.NewBuffer(nil)
|
||||||
req.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
|
buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
|
||||||
|
|
||||||
port := []byte{0, 0}
|
port := []byte{0, 0}
|
||||||
binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
|
binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
|
||||||
req.Write(port)
|
buf.Write(port)
|
||||||
|
|
||||||
// Send a ping
|
// Send a ping
|
||||||
req.Write([]byte("ping"))
|
buf.Write([]byte("ping"))
|
||||||
|
|
||||||
// Handle the request
|
// Handle the request
|
||||||
resp := &MockConn{}
|
resp := &MockConn{}
|
||||||
if err := s.handleRequest(resp, req); !strings.Contains(err.Error(), "blocked by rules") {
|
req, err := NewRequest(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.handleRequest(req, resp); !strings.Contains(err.Error(), "blocked by rules") {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
32
ruleset.go
32
ruleset.go
|
@ -1,19 +1,8 @@
|
||||||
package socks5
|
package socks5
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
// RuleSet is used to provide custom rules to allow or prohibit actions
|
// RuleSet is used to provide custom rules to allow or prohibit actions
|
||||||
type RuleSet interface {
|
type RuleSet interface {
|
||||||
// AllowConnect is used to filter connect requests
|
Allow(req *Request) bool
|
||||||
AllowConnect(dstIP net.IP, dstPort int, srcIP net.IP, srcPort int) bool
|
|
||||||
|
|
||||||
// AllowBind is used to filter bind requests
|
|
||||||
AllowBind(dstIP net.IP, dstPort int, srcIP net.IP, srcPort int) bool
|
|
||||||
|
|
||||||
// AllowAssociate is used to filter associate requests
|
|
||||||
AllowAssociate(dstIP net.IP, dstPort int, srcIP net.IP, srcPort int) bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// PermitAll returns a RuleSet which allows all types of connections
|
// PermitAll returns a RuleSet which allows all types of connections
|
||||||
|
@ -34,14 +23,15 @@ type PermitCommand struct {
|
||||||
EnableAssociate bool
|
EnableAssociate bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PermitCommand) AllowConnect(net.IP, int, net.IP, int) bool {
|
func (p *PermitCommand) Allow(req *Request) bool {
|
||||||
return p.EnableConnect
|
switch req.Command {
|
||||||
}
|
case ConnectCommand:
|
||||||
|
return p.EnableConnect
|
||||||
|
case BindCommand:
|
||||||
|
return p.EnableBind
|
||||||
|
case AssociateCommand:
|
||||||
|
return p.EnableAssociate
|
||||||
|
}
|
||||||
|
|
||||||
func (p *PermitCommand) AllowBind(net.IP, int, net.IP, int) bool {
|
return false
|
||||||
return p.EnableBind
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PermitCommand) AllowAssociate(net.IP, int, net.IP, int) bool {
|
|
||||||
return p.EnableAssociate
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,21 +1,19 @@
|
||||||
package socks5
|
package socks5
|
||||||
|
|
||||||
import (
|
import "testing"
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestPermitCommand(t *testing.T) {
|
func TestPermitCommand(t *testing.T) {
|
||||||
r := &PermitCommand{true, false, false}
|
r := &PermitCommand{true, false, false}
|
||||||
|
|
||||||
if !r.AllowConnect(nil, 500, nil, 1000) {
|
if !r.Allow(&Request{Command: ConnectCommand}) {
|
||||||
t.Fatalf("expect connect")
|
t.Fatalf("expect connect")
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.AllowBind(nil, 500, nil, 1000) {
|
if r.Allow(&Request{Command: BindCommand}) {
|
||||||
t.Fatalf("do not expect bind")
|
t.Fatalf("do not expect bind")
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.AllowAssociate(nil, 500, nil, 1000) {
|
if r.Allow(&Request{Command: AssociateCommand}) {
|
||||||
t.Fatalf("do not expect associate")
|
t.Fatalf("do not expect associate")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
19
socks5.go
19
socks5.go
|
@ -132,14 +132,29 @@ func (s *Server) ServeConn(conn net.Conn) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authenticate the connection
|
// Authenticate the connection
|
||||||
if err := s.authenticate(conn, bufConn); err != nil {
|
authContext, err := s.authenticate(conn, bufConn)
|
||||||
|
if err != nil {
|
||||||
err = fmt.Errorf("Failed to authenticate: %v", err)
|
err = fmt.Errorf("Failed to authenticate: %v", err)
|
||||||
s.config.Logger.Printf("[ERR] socks: %v", err)
|
s.config.Logger.Printf("[ERR] socks: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
request, err := NewRequest(bufConn)
|
||||||
|
if err != nil {
|
||||||
|
if err == unrecognizedAddrType {
|
||||||
|
if err := sendReply(conn, addrTypeNotSupported, nil); err != nil {
|
||||||
|
return fmt.Errorf("Failed to send reply: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Errorf("Failed to read destination address: %v", err)
|
||||||
|
}
|
||||||
|
request.AuthContext = authContext
|
||||||
|
if client, ok := conn.RemoteAddr().(*net.TCPAddr); ok {
|
||||||
|
request.RemoteAddr = &AddrSpec{IP: client.IP, Port: client.Port}
|
||||||
|
}
|
||||||
|
|
||||||
// Process the client request
|
// Process the client request
|
||||||
if err := s.handleRequest(conn, bufConn); err != nil {
|
if err := s.handleRequest(request, conn); err != nil {
|
||||||
err = fmt.Errorf("Failed to handle request: %v", err)
|
err = fmt.Errorf("Failed to handle request: %v", err)
|
||||||
s.config.Logger.Printf("[ERR] socks: %v", err)
|
s.config.Logger.Printf("[ERR] socks: %v", err)
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -67,7 +67,7 @@ func TestSOCKS5_Connect(t *testing.T) {
|
||||||
// Connect, auth and connec to local
|
// Connect, auth and connec to local
|
||||||
req := bytes.NewBuffer(nil)
|
req := bytes.NewBuffer(nil)
|
||||||
req.Write([]byte{5})
|
req.Write([]byte{5})
|
||||||
req.Write([]byte{2, noAuth, userPassAuth})
|
req.Write([]byte{2, NoAuth, UserPassAuth})
|
||||||
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
|
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
|
||||||
req.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
|
req.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ func TestSOCKS5_Connect(t *testing.T) {
|
||||||
|
|
||||||
// Verify response
|
// Verify response
|
||||||
expected := []byte{
|
expected := []byte{
|
||||||
socks5Version, userPassAuth,
|
socks5Version, UserPassAuth,
|
||||||
1, authSuccess,
|
1, authSuccess,
|
||||||
5,
|
5,
|
||||||
0,
|
0,
|
||||||
|
|
Loading…
Reference in New Issue