1
0
Fork 0

Merge pull request #6 from ap4y/request_struct

Auth based rewrites and rulesets
logger
Armon Dadgar 2016-01-18 14:18:52 -08:00
commit b7e3cc6811
8 changed files with 201 additions and 129 deletions

64
auth.go
View File

@ -6,9 +6,9 @@ import (
) )
const ( const (
noAuth = uint8(0) NoAuth = uint8(0)
noAcceptable = uint8(255) noAcceptable = uint8(255)
userPassAuth = uint8(2) UserPassAuth = uint8(2)
userAuthVersion = uint8(1) userAuthVersion = uint8(1)
authSuccess = uint8(0) authSuccess = uint8(0)
authFailure = uint8(1) authFailure = uint8(1)
@ -19,21 +19,32 @@ var (
NoSupportedAuth = fmt.Errorf("No supported authentication mechanism") NoSupportedAuth = fmt.Errorf("No supported authentication mechanism")
) )
// A Request encapsulates authentication state provided
// during negotiation
type AuthContext struct {
// Provided auth method
Method uint8
// Payload provided during negotiation.
// Keys depend on the used auth method.
// For UserPassauth contains Username
Payload map[string]string
}
type Authenticator interface { type Authenticator interface {
Authenticate(reader io.Reader, writer io.Writer) error Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error)
GetCode() uint8 GetCode() uint8
} }
// NoAuthAuthenticator is used to handle the "No Authentication" mode // NoAuthAuthenticator is used to handle the "No Authentication" mode
type NoAuthAuthenticator struct {} type NoAuthAuthenticator struct{}
func (a NoAuthAuthenticator) GetCode() uint8 { func (a NoAuthAuthenticator) GetCode() uint8 {
return noAuth return NoAuth
} }
func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer) error { func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
_, err := writer.Write([]byte{socks5Version, noAuth}) _, err := writer.Write([]byte{socks5Version, NoAuth})
return err return &AuthContext{NoAuth, nil}, err
} }
// UserPassAuthenticator is used to handle username/password based // UserPassAuthenticator is used to handle username/password based
@ -43,70 +54,67 @@ type UserPassAuthenticator struct {
} }
func (a UserPassAuthenticator) GetCode() uint8 { func (a UserPassAuthenticator) GetCode() uint8 {
return userPassAuth return UserPassAuth
} }
func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) error { func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
// Tell the client to use user/pass auth // Tell the client to use user/pass auth
if _, err := writer.Write([]byte{socks5Version, userPassAuth}); err != nil { if _, err := writer.Write([]byte{socks5Version, UserPassAuth}); err != nil {
return err return nil, err
} }
// Get the version and username length // Get the version and username length
header := []byte{0, 0} header := []byte{0, 0}
if _, err := io.ReadAtLeast(reader, header, 2); err != nil { if _, err := io.ReadAtLeast(reader, header, 2); err != nil {
return err return nil, err
} }
// Ensure we are compatible // Ensure we are compatible
if header[0] != userAuthVersion { if header[0] != userAuthVersion {
return fmt.Errorf("Unsupported auth version: %v", header[0]) return nil, fmt.Errorf("Unsupported auth version: %v", header[0])
} }
// Get the user name // Get the user name
userLen := int(header[1]) userLen := int(header[1])
user := make([]byte, userLen) user := make([]byte, userLen)
if _, err := io.ReadAtLeast(reader, user, userLen); err != nil { if _, err := io.ReadAtLeast(reader, user, userLen); err != nil {
return err return nil, err
} }
// Get the password length // Get the password length
if _, err := reader.Read(header[:1]); err != nil { if _, err := reader.Read(header[:1]); err != nil {
return err return nil, err
} }
// Get the password // Get the password
passLen := int(header[0]) passLen := int(header[0])
pass := make([]byte, passLen) pass := make([]byte, passLen)
if _, err := io.ReadAtLeast(reader, pass, passLen); err != nil { if _, err := io.ReadAtLeast(reader, pass, passLen); err != nil {
return err return nil, err
} }
// Verify the password // Verify the password
if a.Credentials.Valid(string(user), string(pass)) { if a.Credentials.Valid(string(user), string(pass)) {
if _, err := writer.Write([]byte{userAuthVersion, authSuccess}); err != nil { if _, err := writer.Write([]byte{userAuthVersion, authSuccess}); err != nil {
return err return nil, err
} }
} else { } else {
if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil { if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil {
return err return nil, err
} }
return UserAuthFailed return nil, UserAuthFailed
} }
// Done // Done
return nil return &AuthContext{UserPassAuth, map[string]string{"Username": string(user)}}, nil
} }
// authenticate is used to handle connection authentication // authenticate is used to handle connection authentication
func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) error { func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext, error) {
// Get the methods // Get the methods
methods, err := readMethods(bufConn) methods, err := readMethods(bufConn)
if err != nil { if err != nil {
return fmt.Errorf("Failed to get auth methods: %v", err) return nil, fmt.Errorf("Failed to get auth methods: %v", err)
} }
// Select a usable method // Select a usable method
@ -118,11 +126,9 @@ func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) error {
} }
// No usable method found // No usable method found
return noAcceptableAuth(conn) return nil, noAcceptableAuth(conn)
} }
// noAcceptableAuth is used to handle when we have no eligible // noAcceptableAuth is used to handle when we have no eligible
// authentication mechanism // authentication mechanism
func noAcceptableAuth(conn io.Writer) error { func noAcceptableAuth(conn io.Writer) error {

View File

@ -7,47 +7,66 @@ import (
func TestNoAuth(t *testing.T) { func TestNoAuth(t *testing.T) {
req := bytes.NewBuffer(nil) req := bytes.NewBuffer(nil)
req.Write([]byte{1, noAuth}) req.Write([]byte{1, NoAuth})
var resp bytes.Buffer var resp bytes.Buffer
s, _ := New(&Config{}) s, _ := New(&Config{})
if err := s.authenticate(&resp, req); err != nil { ctx, err := s.authenticate(&resp, req)
if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if ctx.Method != NoAuth {
t.Fatal("Invalid Context Method")
}
out := resp.Bytes() out := resp.Bytes()
if !bytes.Equal(out, []byte{socks5Version, noAuth}) { if !bytes.Equal(out, []byte{socks5Version, NoAuth}) {
t.Fatalf("bad: %v", out) t.Fatalf("bad: %v", out)
} }
} }
func TestPasswordAuth_Valid(t *testing.T) { func TestPasswordAuth_Valid(t *testing.T) {
req := bytes.NewBuffer(nil) req := bytes.NewBuffer(nil)
req.Write([]byte{2, noAuth, userPassAuth}) req.Write([]byte{2, NoAuth, UserPassAuth})
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'}) req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
var resp bytes.Buffer var resp bytes.Buffer
cred := StaticCredentials{ cred := StaticCredentials{
"foo": "bar", "foo": "bar",
} }
cator := UserPassAuthenticator{Credentials: cred} cator := UserPassAuthenticator{Credentials: cred}
s, _ := New(&Config{AuthMethods:[]Authenticator{cator}}) s, _ := New(&Config{AuthMethods: []Authenticator{cator}})
if err := s.authenticate(&resp, req); err != nil { ctx, err := s.authenticate(&resp, req)
if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if ctx.Method != UserPassAuth {
t.Fatal("Invalid Context Method")
}
val, ok := ctx.Payload["Username"]
if !ok {
t.Fatal("Missing key Username in auth context's payload")
}
if val != "foo" {
t.Fatal("Invalid Username in auth context's payload")
}
out := resp.Bytes() out := resp.Bytes()
if !bytes.Equal(out, []byte{socks5Version, userPassAuth, 1, authSuccess}) { if !bytes.Equal(out, []byte{socks5Version, UserPassAuth, 1, authSuccess}) {
t.Fatalf("bad: %v", out) t.Fatalf("bad: %v", out)
} }
} }
func TestPasswordAuth_Invalid(t *testing.T) { func TestPasswordAuth_Invalid(t *testing.T) {
req := bytes.NewBuffer(nil) req := bytes.NewBuffer(nil)
req.Write([]byte{2, noAuth, userPassAuth}) req.Write([]byte{2, NoAuth, UserPassAuth})
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'z'}) req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'z'})
var resp bytes.Buffer var resp bytes.Buffer
@ -55,20 +74,26 @@ func TestPasswordAuth_Invalid(t *testing.T) {
"foo": "bar", "foo": "bar",
} }
cator := UserPassAuthenticator{Credentials: cred} cator := UserPassAuthenticator{Credentials: cred}
s, _ := New(&Config{AuthMethods:[]Authenticator{cator}}) s, _ := New(&Config{AuthMethods: []Authenticator{cator}})
if err := s.authenticate(&resp, req); err != UserAuthFailed {
ctx, err := s.authenticate(&resp, req)
if err != UserAuthFailed {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if ctx != nil {
t.Fatal("Invalid Context Method")
}
out := resp.Bytes() out := resp.Bytes()
if !bytes.Equal(out, []byte{socks5Version, userPassAuth, 1, authFailure}) { if !bytes.Equal(out, []byte{socks5Version, UserPassAuth, 1, authFailure}) {
t.Fatalf("bad: %v", out) t.Fatalf("bad: %v", out)
} }
} }
func TestNoSupportedAuth(t *testing.T) { func TestNoSupportedAuth(t *testing.T) {
req := bytes.NewBuffer(nil) req := bytes.NewBuffer(nil)
req.Write([]byte{1, noAuth}) req.Write([]byte{1, NoAuth})
var resp bytes.Buffer var resp bytes.Buffer
cred := StaticCredentials{ cred := StaticCredentials{
@ -76,11 +101,17 @@ func TestNoSupportedAuth(t *testing.T) {
} }
cator := UserPassAuthenticator{Credentials: cred} cator := UserPassAuthenticator{Credentials: cred}
s, _ := New(&Config{AuthMethods:[]Authenticator{cator}}) s, _ := New(&Config{AuthMethods: []Authenticator{cator}})
if err := s.authenticate(&resp, req); err != NoSupportedAuth {
ctx, err := s.authenticate(&resp, req)
if err != NoSupportedAuth {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if ctx != nil {
t.Fatal("Invalid Context Method")
}
out := resp.Bytes() out := resp.Bytes()
if !bytes.Equal(out, []byte{socks5Version, noAcceptable}) { if !bytes.Equal(out, []byte{socks5Version, noAcceptable}) {
t.Fatalf("bad: %v", out) t.Fatalf("bad: %v", out)

View File

@ -10,9 +10,9 @@ import (
) )
const ( const (
connectCommand = uint8(1) ConnectCommand = uint8(1)
bindCommand = uint8(2) BindCommand = uint8(2)
associateCommand = uint8(3) AssociateCommand = uint8(3)
ipv4Address = uint8(1) ipv4Address = uint8(1)
fqdnAddress = uint8(3) fqdnAddress = uint8(3)
ipv6Address = uint8(4) ipv6Address = uint8(4)
@ -36,7 +36,7 @@ var (
// AddressRewriter is used to rewrite a destination transparently // AddressRewriter is used to rewrite a destination transparently
type AddressRewriter interface { type AddressRewriter interface {
Rewrite(addr *AddrSpec) *AddrSpec Rewrite(request *Request) *AddrSpec
} }
// AddrSpec is used to return the target AddrSpec // AddrSpec is used to return the target AddrSpec
@ -47,11 +47,6 @@ type AddrSpec struct {
Port int Port int
} }
type conn interface {
Write([]byte) (int, error)
RemoteAddr() net.Addr
}
func (a *AddrSpec) String() string { func (a *AddrSpec) String() string {
if a.FQDN != "" { if a.FQDN != "" {
return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port) return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port)
@ -59,31 +54,61 @@ func (a *AddrSpec) String() string {
return fmt.Sprintf("%s:%d", a.IP, a.Port) return fmt.Sprintf("%s:%d", a.IP, a.Port)
} }
// handleRequest is used for request processing after authentication // A Request represents request received by a server
func (s *Server) handleRequest(conn conn, bufConn io.Reader) error { type Request struct {
// Protocol version
Version uint8
// Requested command
Command uint8
// AuthContext provided during negotiation
AuthContext *AuthContext
// AddrSpec of the the network that sent the request
RemoteAddr *AddrSpec
// AddrSpec of the desired destination
DestAddr *AddrSpec
// AddrSpec of the actual destination (might be affected by rewrite)
realDestAddr *AddrSpec
bufConn io.Reader
}
type conn interface {
Write([]byte) (int, error)
RemoteAddr() net.Addr
}
// NewRequest creates a new Request from the tcp connection
func NewRequest(bufConn io.Reader) (*Request, error) {
// Read the version byte // Read the version byte
header := []byte{0, 0, 0} header := []byte{0, 0, 0}
if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil { if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil {
return fmt.Errorf("Failed to get command version: %v", err) return nil, fmt.Errorf("Failed to get command version: %v", err)
} }
// Ensure we are compatible // Ensure we are compatible
if header[0] != socks5Version { if header[0] != socks5Version {
return fmt.Errorf("Unsupported command version: %v", header[0]) return nil, fmt.Errorf("Unsupported command version: %v", header[0])
} }
// Read in the destination address // Read in the destination address
dest, err := readAddrSpec(bufConn) dest, err := readAddrSpec(bufConn)
if err != nil { if err != nil {
if err == unrecognizedAddrType { return nil, err
if err := sendReply(conn, addrTypeNotSupported, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
}
return fmt.Errorf("Failed to read destination address: %v", err)
} }
request := &Request{
Version: socks5Version,
Command: header[1],
DestAddr: dest,
bufConn: bufConn,
}
return request, nil
}
// handleRequest is used for request processing after authentication
func (s *Server) handleRequest(req *Request, conn conn) error {
// Resolve the address if we have a FQDN // Resolve the address if we have a FQDN
dest := req.DestAddr
if dest.FQDN != "" { if dest.FQDN != "" {
addr, err := s.config.Resolver.Resolve(dest.FQDN) addr, err := s.config.Resolver.Resolve(dest.FQDN)
if err != nil { if err != nil {
@ -96,40 +121,39 @@ func (s *Server) handleRequest(conn conn, bufConn io.Reader) error {
} }
// Apply any address rewrites // Apply any address rewrites
realDest := dest req.realDestAddr = req.DestAddr
if s.config.Rewriter != nil { if s.config.Rewriter != nil {
realDest = s.config.Rewriter.Rewrite(dest) req.realDestAddr = s.config.Rewriter.Rewrite(req)
} }
// Switch on the command // Switch on the command
switch header[1] { switch req.Command {
case connectCommand: case ConnectCommand:
return s.handleConnect(conn, bufConn, dest, realDest) return s.handleConnect(conn, req)
case bindCommand: case BindCommand:
return s.handleBind(conn, bufConn, dest, realDest) return s.handleBind(conn, req)
case associateCommand: case AssociateCommand:
return s.handleAssociate(conn, bufConn, dest, realDest) return s.handleAssociate(conn, req)
default: default:
if err := sendReply(conn, commandNotSupported, nil); err != nil { if err := sendReply(conn, commandNotSupported, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("Failed to send reply: %v", err)
} }
return fmt.Errorf("Unsupported command: %v", header[1]) return fmt.Errorf("Unsupported command: %v", req.Command)
} }
} }
// handleConnect is used to handle a connect command // handleConnect is used to handle a connect command
func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest, realDest *AddrSpec) error { func (s *Server) handleConnect(conn conn, req *Request) error {
// Check if this is allowed // Check if this is allowed
client := conn.RemoteAddr().(*net.TCPAddr) if !s.config.Rules.Allow(req) {
if !s.config.Rules.AllowConnect(realDest.IP, realDest.Port, client.IP, client.Port) {
if err := sendReply(conn, ruleFailure, nil); err != nil { if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("Failed to send reply: %v", err)
} }
return fmt.Errorf("Connect to %v blocked by rules", dest) return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr)
} }
// Attempt to connect // Attempt to connect
addr := net.TCPAddr{IP: realDest.IP, Port: realDest.Port} addr := net.TCPAddr{IP: req.realDestAddr.IP, Port: req.realDestAddr.Port}
target, err := net.DialTCP("tcp", nil, &addr) target, err := net.DialTCP("tcp", nil, &addr)
if err != nil { if err != nil {
msg := err.Error() msg := err.Error()
@ -142,7 +166,7 @@ func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest, realDest *Add
if err := sendReply(conn, resp, nil); err != nil { if err := sendReply(conn, resp, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("Failed to send reply: %v", err)
} }
return fmt.Errorf("Connect to %v failed: %v", dest, err) return fmt.Errorf("Connect to %v failed: %v", req.DestAddr, err)
} }
defer target.Close() defer target.Close()
@ -155,7 +179,7 @@ func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest, realDest *Add
// Start proxying // Start proxying
errCh := make(chan error, 2) errCh := make(chan error, 2)
go proxy("target", target, bufConn, errCh, s.config.Logger) go proxy("target", target, req.bufConn, errCh, s.config.Logger)
go proxy("client", conn, target, errCh, s.config.Logger) go proxy("client", conn, target, errCh, s.config.Logger)
// Wait // Wait
@ -166,14 +190,13 @@ func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest, realDest *Add
} }
// handleBind is used to handle a connect command // handleBind is used to handle a connect command
func (s *Server) handleBind(conn conn, bufConn io.Reader, dest, realDest *AddrSpec) error { func (s *Server) handleBind(conn conn, req *Request) error {
// Check if this is allowed // Check if this is allowed
client := conn.RemoteAddr().(*net.TCPAddr) if !s.config.Rules.Allow(req) {
if !s.config.Rules.AllowBind(realDest.IP, realDest.Port, client.IP, client.Port) {
if err := sendReply(conn, ruleFailure, nil); err != nil { if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("Failed to send reply: %v", err)
} }
return fmt.Errorf("Bind to %v blocked by rules", dest) return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr)
} }
// TODO: Support bind // TODO: Support bind
@ -184,14 +207,13 @@ func (s *Server) handleBind(conn conn, bufConn io.Reader, dest, realDest *AddrSp
} }
// handleAssociate is used to handle a connect command // handleAssociate is used to handle a connect command
func (s *Server) handleAssociate(conn conn, bufConn io.Reader, dest, realDest *AddrSpec) error { func (s *Server) handleAssociate(conn conn, req *Request) error {
// Check if this is allowed // Check if this is allowed
client := conn.RemoteAddr().(*net.TCPAddr) if !s.config.Rules.Allow(req) {
if !s.config.Rules.AllowAssociate(realDest.IP, realDest.Port, client.IP, client.Port) {
if err := sendReply(conn, ruleFailure, nil); err != nil { if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("Failed to send reply: %v", err)
} }
return fmt.Errorf("Associate to %v blocked by rules", dest) return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr)
} }
// TODO: Support associate // TODO: Support associate

View File

@ -56,19 +56,24 @@ func TestRequest_Connect(t *testing.T) {
}} }}
// Create the connect request // Create the connect request
req := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
req.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
port := []byte{0, 0} port := []byte{0, 0}
binary.BigEndian.PutUint16(port, uint16(lAddr.Port)) binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
req.Write(port) buf.Write(port)
// Send a ping // Send a ping
req.Write([]byte("ping")) buf.Write([]byte("ping"))
// Handle the request // Handle the request
resp := &MockConn{} resp := &MockConn{}
if err := s.handleRequest(resp, req); err != nil { req, err := NewRequest(buf)
if err != nil {
t.Fatalf("err: %v", err)
}
if err := s.handleRequest(req, resp); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -126,19 +131,24 @@ func TestRequest_Connect_RuleFail(t *testing.T) {
}} }}
// Create the connect request // Create the connect request
req := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
req.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
port := []byte{0, 0} port := []byte{0, 0}
binary.BigEndian.PutUint16(port, uint16(lAddr.Port)) binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
req.Write(port) buf.Write(port)
// Send a ping // Send a ping
req.Write([]byte("ping")) buf.Write([]byte("ping"))
// Handle the request // Handle the request
resp := &MockConn{} resp := &MockConn{}
if err := s.handleRequest(resp, req); !strings.Contains(err.Error(), "blocked by rules") { req, err := NewRequest(buf)
if err != nil {
t.Fatalf("err: %v", err)
}
if err := s.handleRequest(req, resp); !strings.Contains(err.Error(), "blocked by rules") {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }

View File

@ -1,19 +1,8 @@
package socks5 package socks5
import (
"net"
)
// RuleSet is used to provide custom rules to allow or prohibit actions // RuleSet is used to provide custom rules to allow or prohibit actions
type RuleSet interface { type RuleSet interface {
// AllowConnect is used to filter connect requests Allow(req *Request) bool
AllowConnect(dstIP net.IP, dstPort int, srcIP net.IP, srcPort int) bool
// AllowBind is used to filter bind requests
AllowBind(dstIP net.IP, dstPort int, srcIP net.IP, srcPort int) bool
// AllowAssociate is used to filter associate requests
AllowAssociate(dstIP net.IP, dstPort int, srcIP net.IP, srcPort int) bool
} }
// PermitAll returns a RuleSet which allows all types of connections // PermitAll returns a RuleSet which allows all types of connections
@ -34,14 +23,15 @@ type PermitCommand struct {
EnableAssociate bool EnableAssociate bool
} }
func (p *PermitCommand) AllowConnect(net.IP, int, net.IP, int) bool { func (p *PermitCommand) Allow(req *Request) bool {
return p.EnableConnect switch req.Command {
} case ConnectCommand:
return p.EnableConnect
case BindCommand:
return p.EnableBind
case AssociateCommand:
return p.EnableAssociate
}
func (p *PermitCommand) AllowBind(net.IP, int, net.IP, int) bool { return false
return p.EnableBind
}
func (p *PermitCommand) AllowAssociate(net.IP, int, net.IP, int) bool {
return p.EnableAssociate
} }

View File

@ -1,21 +1,19 @@
package socks5 package socks5
import ( import "testing"
"testing"
)
func TestPermitCommand(t *testing.T) { func TestPermitCommand(t *testing.T) {
r := &PermitCommand{true, false, false} r := &PermitCommand{true, false, false}
if !r.AllowConnect(nil, 500, nil, 1000) { if !r.Allow(&Request{Command: ConnectCommand}) {
t.Fatalf("expect connect") t.Fatalf("expect connect")
} }
if r.AllowBind(nil, 500, nil, 1000) { if r.Allow(&Request{Command: BindCommand}) {
t.Fatalf("do not expect bind") t.Fatalf("do not expect bind")
} }
if r.AllowAssociate(nil, 500, nil, 1000) { if r.Allow(&Request{Command: AssociateCommand}) {
t.Fatalf("do not expect associate") t.Fatalf("do not expect associate")
} }
} }

View File

@ -132,14 +132,29 @@ func (s *Server) ServeConn(conn net.Conn) error {
} }
// Authenticate the connection // Authenticate the connection
if err := s.authenticate(conn, bufConn); err != nil { authContext, err := s.authenticate(conn, bufConn)
if err != nil {
err = fmt.Errorf("Failed to authenticate: %v", err) err = fmt.Errorf("Failed to authenticate: %v", err)
s.config.Logger.Printf("[ERR] socks: %v", err) s.config.Logger.Printf("[ERR] socks: %v", err)
return err return err
} }
request, err := NewRequest(bufConn)
if err != nil {
if err == unrecognizedAddrType {
if err := sendReply(conn, addrTypeNotSupported, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
}
return fmt.Errorf("Failed to read destination address: %v", err)
}
request.AuthContext = authContext
if client, ok := conn.RemoteAddr().(*net.TCPAddr); ok {
request.RemoteAddr = &AddrSpec{IP: client.IP, Port: client.Port}
}
// Process the client request // Process the client request
if err := s.handleRequest(conn, bufConn); err != nil { if err := s.handleRequest(request, conn); err != nil {
err = fmt.Errorf("Failed to handle request: %v", err) err = fmt.Errorf("Failed to handle request: %v", err)
s.config.Logger.Printf("[ERR] socks: %v", err) s.config.Logger.Printf("[ERR] socks: %v", err)
return err return err

View File

@ -67,7 +67,7 @@ func TestSOCKS5_Connect(t *testing.T) {
// Connect, auth and connec to local // Connect, auth and connec to local
req := bytes.NewBuffer(nil) req := bytes.NewBuffer(nil)
req.Write([]byte{5}) req.Write([]byte{5})
req.Write([]byte{2, noAuth, userPassAuth}) req.Write([]byte{2, NoAuth, UserPassAuth})
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'}) req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
req.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) req.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
@ -83,7 +83,7 @@ func TestSOCKS5_Connect(t *testing.T) {
// Verify response // Verify response
expected := []byte{ expected := []byte{
socks5Version, userPassAuth, socks5Version, UserPassAuth,
1, authSuccess, 1, authSuccess,
5, 5,
0, 0,