diff --git a/request.go b/request.go index 7e713f5..e3eebff 100644 --- a/request.go +++ b/request.go @@ -5,6 +5,8 @@ import ( "io" "net" "strings" + + "golang.org/x/net/context" ) const ( @@ -34,7 +36,7 @@ var ( // AddressRewriter is used to rewrite a destination transparently type AddressRewriter interface { - Rewrite(request *Request) *AddrSpec + Rewrite(ctx context.Context, request *Request) (context.Context, *AddrSpec) } // AddrSpec is used to return the target AddrSpec @@ -105,33 +107,36 @@ func NewRequest(bufConn io.Reader) (*Request, error) { // handleRequest is used for request processing after authentication func (s *Server) handleRequest(req *Request, conn conn) error { + ctx := context.Background() + // Resolve the address if we have a FQDN dest := req.DestAddr if dest.FQDN != "" { - addr, err := s.config.Resolver.Resolve(dest.FQDN) + ctx_, addr, err := s.config.Resolver.Resolve(ctx, dest.FQDN) if err != nil { if err := sendReply(conn, hostUnreachable, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return fmt.Errorf("Failed to resolve destination '%v': %v", dest.FQDN, err) } + ctx = ctx_ dest.IP = addr } // Apply any address rewrites req.realDestAddr = req.DestAddr if s.config.Rewriter != nil { - req.realDestAddr = s.config.Rewriter.Rewrite(req) + ctx, req.realDestAddr = s.config.Rewriter.Rewrite(ctx, req) } // Switch on the command switch req.Command { case ConnectCommand: - return s.handleConnect(conn, req) + return s.handleConnect(ctx, conn, req) case BindCommand: - return s.handleBind(conn, req) + return s.handleBind(ctx, conn, req) case AssociateCommand: - return s.handleAssociate(conn, req) + return s.handleAssociate(ctx, conn, req) default: if err := sendReply(conn, commandNotSupported, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) @@ -141,22 +146,26 @@ func (s *Server) handleRequest(req *Request, conn conn) error { } // handleConnect is used to handle a connect command -func (s *Server) handleConnect(conn conn, req *Request) error { +func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) error { // Check if this is allowed - if !s.config.Rules.Allow(req) { + if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { if err := sendReply(conn, ruleFailure, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr) + } else { + ctx = ctx_ } // Attempt to connect addr := (&net.TCPAddr{IP: req.realDestAddr.IP, Port: req.realDestAddr.Port}).String() dial := s.config.Dial if dial == nil { - dial = net.Dial + dial = func(ctx context.Context, net_, addr string) (net.Conn, error) { + return net.Dial(net_, addr) + } } - target, err := dial("tcp", addr) + target, err := dial(ctx, "tcp", addr) if err != nil { msg := err.Error() resp := hostUnreachable @@ -196,13 +205,15 @@ func (s *Server) handleConnect(conn conn, req *Request) error { } // handleBind is used to handle a connect command -func (s *Server) handleBind(conn conn, req *Request) error { +func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error { // Check if this is allowed - if !s.config.Rules.Allow(req) { + if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { if err := sendReply(conn, ruleFailure, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr) + } else { + ctx = ctx_ } // TODO: Support bind @@ -213,13 +224,15 @@ func (s *Server) handleBind(conn conn, req *Request) error { } // handleAssociate is used to handle a connect command -func (s *Server) handleAssociate(conn conn, req *Request) error { +func (s *Server) handleAssociate(ctx context.Context, conn conn, req *Request) error { // Check if this is allowed - if !s.config.Rules.Allow(req) { + if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { if err := sendReply(conn, ruleFailure, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr) + } else { + ctx = ctx_ } // TODO: Support associate diff --git a/resolver.go b/resolver.go index 2aadd05..b75a5c4 100644 --- a/resolver.go +++ b/resolver.go @@ -2,20 +2,22 @@ package socks5 import ( "net" + + "golang.org/x/net/context" ) // NameResolver is used to implement custom name resolution type NameResolver interface { - Resolve(name string) (net.IP, error) + Resolve(ctx context.Context, name string) (context.Context, net.IP, error) } // DNSResolver uses the system DNS to resolve host names type DNSResolver struct{} -func (d DNSResolver) Resolve(name string) (net.IP, error) { +func (d DNSResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) { addr, err := net.ResolveIPAddr("ip", name) if err != nil { - return nil, err + return ctx, nil, err } - return addr.IP, err + return ctx, addr.IP, err } diff --git a/resolver_test.go b/resolver_test.go index 4ac3c0e..16d56ee 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -2,12 +2,15 @@ package socks5 import ( "testing" + + "golang.org/x/net/context" ) func TestDNSResolver(t *testing.T) { d := DNSResolver{} + ctx := context.Background() - addr, err := d.Resolve("localhost") + _, addr, err := d.Resolve(ctx, "localhost") if err != nil { t.Fatalf("err: %v", err) } diff --git a/ruleset.go b/ruleset.go index 1f67878..ba0e353 100644 --- a/ruleset.go +++ b/ruleset.go @@ -1,8 +1,12 @@ package socks5 +import ( + "golang.org/x/net/context" +) + // RuleSet is used to provide custom rules to allow or prohibit actions type RuleSet interface { - Allow(req *Request) bool + Allow(ctx context.Context, req *Request) (context.Context, bool) } // PermitAll returns a RuleSet which allows all types of connections @@ -23,15 +27,15 @@ type PermitCommand struct { EnableAssociate bool } -func (p *PermitCommand) Allow(req *Request) bool { +func (p *PermitCommand) Allow(ctx context.Context, req *Request) (context.Context, bool) { switch req.Command { case ConnectCommand: - return p.EnableConnect + return ctx, p.EnableConnect case BindCommand: - return p.EnableBind + return ctx, p.EnableBind case AssociateCommand: - return p.EnableAssociate + return ctx, p.EnableAssociate } - return false + return ctx, false } diff --git a/ruleset_test.go b/ruleset_test.go index abdaed4..b93f4a8 100644 --- a/ruleset_test.go +++ b/ruleset_test.go @@ -1,19 +1,24 @@ package socks5 -import "testing" +import ( + "testing" + + "golang.org/x/net/context" +) func TestPermitCommand(t *testing.T) { + ctx := context.Background() r := &PermitCommand{true, false, false} - if !r.Allow(&Request{Command: ConnectCommand}) { + if _, ok := r.Allow(ctx, &Request{Command: ConnectCommand}); !ok { t.Fatalf("expect connect") } - if r.Allow(&Request{Command: BindCommand}) { + if _, ok := r.Allow(ctx, &Request{Command: BindCommand}); ok { t.Fatalf("do not expect bind") } - if r.Allow(&Request{Command: AssociateCommand}) { + if _, ok := r.Allow(ctx, &Request{Command: AssociateCommand}); ok { t.Fatalf("do not expect associate") } } diff --git a/socks5.go b/socks5.go index e6bca39..a17be68 100644 --- a/socks5.go +++ b/socks5.go @@ -6,6 +6,8 @@ import ( "log" "net" "os" + + "golang.org/x/net/context" ) const ( @@ -45,7 +47,7 @@ type Config struct { Logger *log.Logger // Optional function for dialing out - Dial func(network, addr string) (net.Conn, error) + Dial func(ctx context.Context, network, addr string) (net.Conn, error) } // Server is reponsible for accepting connections and handling