1
0
Fork 0

Add golang.org/x/net/context support.

Read http://blog.golang.org/context for its benefits.

This PR has not utilized contexts yet; just passing them
to every customization points to help them add/retrieve
request context values.
logger
ymmt2005 2016-03-08 23:29:11 +09:00
parent 3a873e99f5
commit 385bbe4759
6 changed files with 59 additions and 30 deletions

View File

@ -5,6 +5,8 @@ import (
"io" "io"
"net" "net"
"strings" "strings"
"golang.org/x/net/context"
) )
const ( const (
@ -34,7 +36,7 @@ var (
// AddressRewriter is used to rewrite a destination transparently // AddressRewriter is used to rewrite a destination transparently
type AddressRewriter interface { type AddressRewriter interface {
Rewrite(request *Request) *AddrSpec Rewrite(ctx context.Context, request *Request) (context.Context, *AddrSpec)
} }
// AddrSpec is used to return the target AddrSpec // AddrSpec is used to return the target AddrSpec
@ -105,33 +107,36 @@ func NewRequest(bufConn io.Reader) (*Request, error) {
// handleRequest is used for request processing after authentication // handleRequest is used for request processing after authentication
func (s *Server) handleRequest(req *Request, conn conn) error { func (s *Server) handleRequest(req *Request, conn conn) error {
ctx := context.Background()
// Resolve the address if we have a FQDN // Resolve the address if we have a FQDN
dest := req.DestAddr dest := req.DestAddr
if dest.FQDN != "" { if dest.FQDN != "" {
addr, err := s.config.Resolver.Resolve(dest.FQDN) ctx_, addr, err := s.config.Resolver.Resolve(ctx, dest.FQDN)
if err != nil { if err != nil {
if err := sendReply(conn, hostUnreachable, nil); err != nil { if err := sendReply(conn, hostUnreachable, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("Failed to send reply: %v", err)
} }
return fmt.Errorf("Failed to resolve destination '%v': %v", dest.FQDN, err) return fmt.Errorf("Failed to resolve destination '%v': %v", dest.FQDN, err)
} }
ctx = ctx_
dest.IP = addr dest.IP = addr
} }
// Apply any address rewrites // Apply any address rewrites
req.realDestAddr = req.DestAddr req.realDestAddr = req.DestAddr
if s.config.Rewriter != nil { if s.config.Rewriter != nil {
req.realDestAddr = s.config.Rewriter.Rewrite(req) ctx, req.realDestAddr = s.config.Rewriter.Rewrite(ctx, req)
} }
// Switch on the command // Switch on the command
switch req.Command { switch req.Command {
case ConnectCommand: case ConnectCommand:
return s.handleConnect(conn, req) return s.handleConnect(ctx, conn, req)
case BindCommand: case BindCommand:
return s.handleBind(conn, req) return s.handleBind(ctx, conn, req)
case AssociateCommand: case AssociateCommand:
return s.handleAssociate(conn, req) return s.handleAssociate(ctx, conn, req)
default: default:
if err := sendReply(conn, commandNotSupported, nil); err != nil { if err := sendReply(conn, commandNotSupported, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("Failed to send reply: %v", err)
@ -141,22 +146,26 @@ func (s *Server) handleRequest(req *Request, conn conn) error {
} }
// handleConnect is used to handle a connect command // handleConnect is used to handle a connect command
func (s *Server) handleConnect(conn conn, req *Request) error { func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) error {
// Check if this is allowed // Check if this is allowed
if !s.config.Rules.Allow(req) { if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil { if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("Failed to send reply: %v", err)
} }
return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr) return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr)
} else {
ctx = ctx_
} }
// Attempt to connect // Attempt to connect
addr := (&net.TCPAddr{IP: req.realDestAddr.IP, Port: req.realDestAddr.Port}).String() addr := (&net.TCPAddr{IP: req.realDestAddr.IP, Port: req.realDestAddr.Port}).String()
dial := s.config.Dial dial := s.config.Dial
if dial == nil { if dial == nil {
dial = net.Dial dial = func(ctx context.Context, net_, addr string) (net.Conn, error) {
return net.Dial(net_, addr)
}
} }
target, err := dial("tcp", addr) target, err := dial(ctx, "tcp", addr)
if err != nil { if err != nil {
msg := err.Error() msg := err.Error()
resp := hostUnreachable resp := hostUnreachable
@ -196,13 +205,15 @@ func (s *Server) handleConnect(conn conn, req *Request) error {
} }
// handleBind is used to handle a connect command // handleBind is used to handle a connect command
func (s *Server) handleBind(conn conn, req *Request) error { func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error {
// Check if this is allowed // Check if this is allowed
if !s.config.Rules.Allow(req) { if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil { if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("Failed to send reply: %v", err)
} }
return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr) return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr)
} else {
ctx = ctx_
} }
// TODO: Support bind // TODO: Support bind
@ -213,13 +224,15 @@ func (s *Server) handleBind(conn conn, req *Request) error {
} }
// handleAssociate is used to handle a connect command // handleAssociate is used to handle a connect command
func (s *Server) handleAssociate(conn conn, req *Request) error { func (s *Server) handleAssociate(ctx context.Context, conn conn, req *Request) error {
// Check if this is allowed // Check if this is allowed
if !s.config.Rules.Allow(req) { if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil { if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("Failed to send reply: %v", err)
} }
return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr) return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr)
} else {
ctx = ctx_
} }
// TODO: Support associate // TODO: Support associate

View File

@ -2,20 +2,22 @@ package socks5
import ( import (
"net" "net"
"golang.org/x/net/context"
) )
// NameResolver is used to implement custom name resolution // NameResolver is used to implement custom name resolution
type NameResolver interface { type NameResolver interface {
Resolve(name string) (net.IP, error) Resolve(ctx context.Context, name string) (context.Context, net.IP, error)
} }
// DNSResolver uses the system DNS to resolve host names // DNSResolver uses the system DNS to resolve host names
type DNSResolver struct{} type DNSResolver struct{}
func (d DNSResolver) Resolve(name string) (net.IP, error) { func (d DNSResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) {
addr, err := net.ResolveIPAddr("ip", name) addr, err := net.ResolveIPAddr("ip", name)
if err != nil { if err != nil {
return nil, err return ctx, nil, err
} }
return addr.IP, err return ctx, addr.IP, err
} }

View File

@ -2,12 +2,15 @@ package socks5
import ( import (
"testing" "testing"
"golang.org/x/net/context"
) )
func TestDNSResolver(t *testing.T) { func TestDNSResolver(t *testing.T) {
d := DNSResolver{} d := DNSResolver{}
ctx := context.Background()
addr, err := d.Resolve("localhost") _, addr, err := d.Resolve(ctx, "localhost")
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }

View File

@ -1,8 +1,12 @@
package socks5 package socks5
import (
"golang.org/x/net/context"
)
// RuleSet is used to provide custom rules to allow or prohibit actions // RuleSet is used to provide custom rules to allow or prohibit actions
type RuleSet interface { type RuleSet interface {
Allow(req *Request) bool Allow(ctx context.Context, req *Request) (context.Context, bool)
} }
// PermitAll returns a RuleSet which allows all types of connections // PermitAll returns a RuleSet which allows all types of connections
@ -23,15 +27,15 @@ type PermitCommand struct {
EnableAssociate bool EnableAssociate bool
} }
func (p *PermitCommand) Allow(req *Request) bool { func (p *PermitCommand) Allow(ctx context.Context, req *Request) (context.Context, bool) {
switch req.Command { switch req.Command {
case ConnectCommand: case ConnectCommand:
return p.EnableConnect return ctx, p.EnableConnect
case BindCommand: case BindCommand:
return p.EnableBind return ctx, p.EnableBind
case AssociateCommand: case AssociateCommand:
return p.EnableAssociate return ctx, p.EnableAssociate
} }
return false return ctx, false
} }

View File

@ -1,19 +1,24 @@
package socks5 package socks5
import "testing" import (
"testing"
"golang.org/x/net/context"
)
func TestPermitCommand(t *testing.T) { func TestPermitCommand(t *testing.T) {
ctx := context.Background()
r := &PermitCommand{true, false, false} r := &PermitCommand{true, false, false}
if !r.Allow(&Request{Command: ConnectCommand}) { if _, ok := r.Allow(ctx, &Request{Command: ConnectCommand}); !ok {
t.Fatalf("expect connect") t.Fatalf("expect connect")
} }
if r.Allow(&Request{Command: BindCommand}) { if _, ok := r.Allow(ctx, &Request{Command: BindCommand}); ok {
t.Fatalf("do not expect bind") t.Fatalf("do not expect bind")
} }
if r.Allow(&Request{Command: AssociateCommand}) { if _, ok := r.Allow(ctx, &Request{Command: AssociateCommand}); ok {
t.Fatalf("do not expect associate") t.Fatalf("do not expect associate")
} }
} }

View File

@ -6,6 +6,8 @@ import (
"log" "log"
"net" "net"
"os" "os"
"golang.org/x/net/context"
) )
const ( const (
@ -45,7 +47,7 @@ type Config struct {
Logger *log.Logger Logger *log.Logger
// Optional function for dialing out // Optional function for dialing out
Dial func(network, addr string) (net.Conn, error) Dial func(ctx context.Context, network, addr string) (net.Conn, error)
} }
// Server is reponsible for accepting connections and handling // Server is reponsible for accepting connections and handling