Add golang.org/x/net/context support.
Read http://blog.golang.org/context for its benefits. This PR has not utilized contexts yet; just passing them to every customization points to help them add/retrieve request context values.logger
parent
3a873e99f5
commit
385bbe4759
41
request.go
41
request.go
|
@ -5,6 +5,8 @@ import (
|
|||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -34,7 +36,7 @@ var (
|
|||
|
||||
// AddressRewriter is used to rewrite a destination transparently
|
||||
type AddressRewriter interface {
|
||||
Rewrite(request *Request) *AddrSpec
|
||||
Rewrite(ctx context.Context, request *Request) (context.Context, *AddrSpec)
|
||||
}
|
||||
|
||||
// AddrSpec is used to return the target AddrSpec
|
||||
|
@ -105,33 +107,36 @@ func NewRequest(bufConn io.Reader) (*Request, error) {
|
|||
|
||||
// handleRequest is used for request processing after authentication
|
||||
func (s *Server) handleRequest(req *Request, conn conn) error {
|
||||
ctx := context.Background()
|
||||
|
||||
// Resolve the address if we have a FQDN
|
||||
dest := req.DestAddr
|
||||
if dest.FQDN != "" {
|
||||
addr, err := s.config.Resolver.Resolve(dest.FQDN)
|
||||
ctx_, addr, err := s.config.Resolver.Resolve(ctx, dest.FQDN)
|
||||
if err != nil {
|
||||
if err := sendReply(conn, hostUnreachable, nil); err != nil {
|
||||
return fmt.Errorf("Failed to send reply: %v", err)
|
||||
}
|
||||
return fmt.Errorf("Failed to resolve destination '%v': %v", dest.FQDN, err)
|
||||
}
|
||||
ctx = ctx_
|
||||
dest.IP = addr
|
||||
}
|
||||
|
||||
// Apply any address rewrites
|
||||
req.realDestAddr = req.DestAddr
|
||||
if s.config.Rewriter != nil {
|
||||
req.realDestAddr = s.config.Rewriter.Rewrite(req)
|
||||
ctx, req.realDestAddr = s.config.Rewriter.Rewrite(ctx, req)
|
||||
}
|
||||
|
||||
// Switch on the command
|
||||
switch req.Command {
|
||||
case ConnectCommand:
|
||||
return s.handleConnect(conn, req)
|
||||
return s.handleConnect(ctx, conn, req)
|
||||
case BindCommand:
|
||||
return s.handleBind(conn, req)
|
||||
return s.handleBind(ctx, conn, req)
|
||||
case AssociateCommand:
|
||||
return s.handleAssociate(conn, req)
|
||||
return s.handleAssociate(ctx, conn, req)
|
||||
default:
|
||||
if err := sendReply(conn, commandNotSupported, nil); err != nil {
|
||||
return fmt.Errorf("Failed to send reply: %v", err)
|
||||
|
@ -141,22 +146,26 @@ func (s *Server) handleRequest(req *Request, conn conn) error {
|
|||
}
|
||||
|
||||
// handleConnect is used to handle a connect command
|
||||
func (s *Server) handleConnect(conn conn, req *Request) error {
|
||||
func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) error {
|
||||
// Check if this is allowed
|
||||
if !s.config.Rules.Allow(req) {
|
||||
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
|
||||
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
||||
return fmt.Errorf("Failed to send reply: %v", err)
|
||||
}
|
||||
return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr)
|
||||
} else {
|
||||
ctx = ctx_
|
||||
}
|
||||
|
||||
// Attempt to connect
|
||||
addr := (&net.TCPAddr{IP: req.realDestAddr.IP, Port: req.realDestAddr.Port}).String()
|
||||
dial := s.config.Dial
|
||||
if dial == nil {
|
||||
dial = net.Dial
|
||||
dial = func(ctx context.Context, net_, addr string) (net.Conn, error) {
|
||||
return net.Dial(net_, addr)
|
||||
}
|
||||
}
|
||||
target, err := dial("tcp", addr)
|
||||
target, err := dial(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
msg := err.Error()
|
||||
resp := hostUnreachable
|
||||
|
@ -196,13 +205,15 @@ func (s *Server) handleConnect(conn conn, req *Request) error {
|
|||
}
|
||||
|
||||
// handleBind is used to handle a connect command
|
||||
func (s *Server) handleBind(conn conn, req *Request) error {
|
||||
func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error {
|
||||
// Check if this is allowed
|
||||
if !s.config.Rules.Allow(req) {
|
||||
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
|
||||
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
||||
return fmt.Errorf("Failed to send reply: %v", err)
|
||||
}
|
||||
return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr)
|
||||
} else {
|
||||
ctx = ctx_
|
||||
}
|
||||
|
||||
// TODO: Support bind
|
||||
|
@ -213,13 +224,15 @@ func (s *Server) handleBind(conn conn, req *Request) error {
|
|||
}
|
||||
|
||||
// handleAssociate is used to handle a connect command
|
||||
func (s *Server) handleAssociate(conn conn, req *Request) error {
|
||||
func (s *Server) handleAssociate(ctx context.Context, conn conn, req *Request) error {
|
||||
// Check if this is allowed
|
||||
if !s.config.Rules.Allow(req) {
|
||||
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
|
||||
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
||||
return fmt.Errorf("Failed to send reply: %v", err)
|
||||
}
|
||||
return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr)
|
||||
} else {
|
||||
ctx = ctx_
|
||||
}
|
||||
|
||||
// TODO: Support associate
|
||||
|
|
10
resolver.go
10
resolver.go
|
@ -2,20 +2,22 @@ package socks5
|
|||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// NameResolver is used to implement custom name resolution
|
||||
type NameResolver interface {
|
||||
Resolve(name string) (net.IP, error)
|
||||
Resolve(ctx context.Context, name string) (context.Context, net.IP, error)
|
||||
}
|
||||
|
||||
// DNSResolver uses the system DNS to resolve host names
|
||||
type DNSResolver struct{}
|
||||
|
||||
func (d DNSResolver) Resolve(name string) (net.IP, error) {
|
||||
func (d DNSResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) {
|
||||
addr, err := net.ResolveIPAddr("ip", name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return ctx, nil, err
|
||||
}
|
||||
return addr.IP, err
|
||||
return ctx, addr.IP, err
|
||||
}
|
||||
|
|
|
@ -2,12 +2,15 @@ package socks5
|
|||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
func TestDNSResolver(t *testing.T) {
|
||||
d := DNSResolver{}
|
||||
ctx := context.Background()
|
||||
|
||||
addr, err := d.Resolve("localhost")
|
||||
_, addr, err := d.Resolve(ctx, "localhost")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
|
16
ruleset.go
16
ruleset.go
|
@ -1,8 +1,12 @@
|
|||
package socks5
|
||||
|
||||
import (
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// RuleSet is used to provide custom rules to allow or prohibit actions
|
||||
type RuleSet interface {
|
||||
Allow(req *Request) bool
|
||||
Allow(ctx context.Context, req *Request) (context.Context, bool)
|
||||
}
|
||||
|
||||
// PermitAll returns a RuleSet which allows all types of connections
|
||||
|
@ -23,15 +27,15 @@ type PermitCommand struct {
|
|||
EnableAssociate bool
|
||||
}
|
||||
|
||||
func (p *PermitCommand) Allow(req *Request) bool {
|
||||
func (p *PermitCommand) Allow(ctx context.Context, req *Request) (context.Context, bool) {
|
||||
switch req.Command {
|
||||
case ConnectCommand:
|
||||
return p.EnableConnect
|
||||
return ctx, p.EnableConnect
|
||||
case BindCommand:
|
||||
return p.EnableBind
|
||||
return ctx, p.EnableBind
|
||||
case AssociateCommand:
|
||||
return p.EnableAssociate
|
||||
return ctx, p.EnableAssociate
|
||||
}
|
||||
|
||||
return false
|
||||
return ctx, false
|
||||
}
|
||||
|
|
|
@ -1,19 +1,24 @@
|
|||
package socks5
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
func TestPermitCommand(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
r := &PermitCommand{true, false, false}
|
||||
|
||||
if !r.Allow(&Request{Command: ConnectCommand}) {
|
||||
if _, ok := r.Allow(ctx, &Request{Command: ConnectCommand}); !ok {
|
||||
t.Fatalf("expect connect")
|
||||
}
|
||||
|
||||
if r.Allow(&Request{Command: BindCommand}) {
|
||||
if _, ok := r.Allow(ctx, &Request{Command: BindCommand}); ok {
|
||||
t.Fatalf("do not expect bind")
|
||||
}
|
||||
|
||||
if r.Allow(&Request{Command: AssociateCommand}) {
|
||||
if _, ok := r.Allow(ctx, &Request{Command: AssociateCommand}); ok {
|
||||
t.Fatalf("do not expect associate")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,8 @@ import (
|
|||
"log"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -45,7 +47,7 @@ type Config struct {
|
|||
Logger *log.Logger
|
||||
|
||||
// Optional function for dialing out
|
||||
Dial func(network, addr string) (net.Conn, error)
|
||||
Dial func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// Server is reponsible for accepting connections and handling
|
||||
|
|
Loading…
Reference in New Issue