1
0
Fork 0

Add golang.org/x/net/context support.

Read http://blog.golang.org/context for its benefits.

This PR has not utilized contexts yet; just passing them
to every customization points to help them add/retrieve
request context values.
logger
ymmt2005 2016-03-08 23:29:11 +09:00
parent 3a873e99f5
commit 385bbe4759
6 changed files with 59 additions and 30 deletions

View File

@ -5,6 +5,8 @@ import (
"io"
"net"
"strings"
"golang.org/x/net/context"
)
const (
@ -34,7 +36,7 @@ var (
// AddressRewriter is used to rewrite a destination transparently
type AddressRewriter interface {
Rewrite(request *Request) *AddrSpec
Rewrite(ctx context.Context, request *Request) (context.Context, *AddrSpec)
}
// AddrSpec is used to return the target AddrSpec
@ -105,33 +107,36 @@ func NewRequest(bufConn io.Reader) (*Request, error) {
// handleRequest is used for request processing after authentication
func (s *Server) handleRequest(req *Request, conn conn) error {
ctx := context.Background()
// Resolve the address if we have a FQDN
dest := req.DestAddr
if dest.FQDN != "" {
addr, err := s.config.Resolver.Resolve(dest.FQDN)
ctx_, addr, err := s.config.Resolver.Resolve(ctx, dest.FQDN)
if err != nil {
if err := sendReply(conn, hostUnreachable, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return fmt.Errorf("Failed to resolve destination '%v': %v", dest.FQDN, err)
}
ctx = ctx_
dest.IP = addr
}
// Apply any address rewrites
req.realDestAddr = req.DestAddr
if s.config.Rewriter != nil {
req.realDestAddr = s.config.Rewriter.Rewrite(req)
ctx, req.realDestAddr = s.config.Rewriter.Rewrite(ctx, req)
}
// Switch on the command
switch req.Command {
case ConnectCommand:
return s.handleConnect(conn, req)
return s.handleConnect(ctx, conn, req)
case BindCommand:
return s.handleBind(conn, req)
return s.handleBind(ctx, conn, req)
case AssociateCommand:
return s.handleAssociate(conn, req)
return s.handleAssociate(ctx, conn, req)
default:
if err := sendReply(conn, commandNotSupported, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
@ -141,22 +146,26 @@ func (s *Server) handleRequest(req *Request, conn conn) error {
}
// handleConnect is used to handle a connect command
func (s *Server) handleConnect(conn conn, req *Request) error {
func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) error {
// Check if this is allowed
if !s.config.Rules.Allow(req) {
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr)
} else {
ctx = ctx_
}
// Attempt to connect
addr := (&net.TCPAddr{IP: req.realDestAddr.IP, Port: req.realDestAddr.Port}).String()
dial := s.config.Dial
if dial == nil {
dial = net.Dial
dial = func(ctx context.Context, net_, addr string) (net.Conn, error) {
return net.Dial(net_, addr)
}
}
target, err := dial("tcp", addr)
target, err := dial(ctx, "tcp", addr)
if err != nil {
msg := err.Error()
resp := hostUnreachable
@ -196,13 +205,15 @@ func (s *Server) handleConnect(conn conn, req *Request) error {
}
// handleBind is used to handle a connect command
func (s *Server) handleBind(conn conn, req *Request) error {
func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error {
// Check if this is allowed
if !s.config.Rules.Allow(req) {
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr)
} else {
ctx = ctx_
}
// TODO: Support bind
@ -213,13 +224,15 @@ func (s *Server) handleBind(conn conn, req *Request) error {
}
// handleAssociate is used to handle a connect command
func (s *Server) handleAssociate(conn conn, req *Request) error {
func (s *Server) handleAssociate(ctx context.Context, conn conn, req *Request) error {
// Check if this is allowed
if !s.config.Rules.Allow(req) {
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr)
} else {
ctx = ctx_
}
// TODO: Support associate

View File

@ -2,20 +2,22 @@ package socks5
import (
"net"
"golang.org/x/net/context"
)
// NameResolver is used to implement custom name resolution
type NameResolver interface {
Resolve(name string) (net.IP, error)
Resolve(ctx context.Context, name string) (context.Context, net.IP, error)
}
// DNSResolver uses the system DNS to resolve host names
type DNSResolver struct{}
func (d DNSResolver) Resolve(name string) (net.IP, error) {
func (d DNSResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) {
addr, err := net.ResolveIPAddr("ip", name)
if err != nil {
return nil, err
return ctx, nil, err
}
return addr.IP, err
return ctx, addr.IP, err
}

View File

@ -2,12 +2,15 @@ package socks5
import (
"testing"
"golang.org/x/net/context"
)
func TestDNSResolver(t *testing.T) {
d := DNSResolver{}
ctx := context.Background()
addr, err := d.Resolve("localhost")
_, addr, err := d.Resolve(ctx, "localhost")
if err != nil {
t.Fatalf("err: %v", err)
}

View File

@ -1,8 +1,12 @@
package socks5
import (
"golang.org/x/net/context"
)
// RuleSet is used to provide custom rules to allow or prohibit actions
type RuleSet interface {
Allow(req *Request) bool
Allow(ctx context.Context, req *Request) (context.Context, bool)
}
// PermitAll returns a RuleSet which allows all types of connections
@ -23,15 +27,15 @@ type PermitCommand struct {
EnableAssociate bool
}
func (p *PermitCommand) Allow(req *Request) bool {
func (p *PermitCommand) Allow(ctx context.Context, req *Request) (context.Context, bool) {
switch req.Command {
case ConnectCommand:
return p.EnableConnect
return ctx, p.EnableConnect
case BindCommand:
return p.EnableBind
return ctx, p.EnableBind
case AssociateCommand:
return p.EnableAssociate
return ctx, p.EnableAssociate
}
return false
return ctx, false
}

View File

@ -1,19 +1,24 @@
package socks5
import "testing"
import (
"testing"
"golang.org/x/net/context"
)
func TestPermitCommand(t *testing.T) {
ctx := context.Background()
r := &PermitCommand{true, false, false}
if !r.Allow(&Request{Command: ConnectCommand}) {
if _, ok := r.Allow(ctx, &Request{Command: ConnectCommand}); !ok {
t.Fatalf("expect connect")
}
if r.Allow(&Request{Command: BindCommand}) {
if _, ok := r.Allow(ctx, &Request{Command: BindCommand}); ok {
t.Fatalf("do not expect bind")
}
if r.Allow(&Request{Command: AssociateCommand}) {
if _, ok := r.Allow(ctx, &Request{Command: AssociateCommand}); ok {
t.Fatalf("do not expect associate")
}
}

View File

@ -6,6 +6,8 @@ import (
"log"
"net"
"os"
"golang.org/x/net/context"
)
const (
@ -45,7 +47,7 @@ type Config struct {
Logger *log.Logger
// Optional function for dialing out
Dial func(network, addr string) (net.Conn, error)
Dial func(ctx context.Context, network, addr string) (net.Conn, error)
}
// Server is reponsible for accepting connections and handling