1
0
Fork 0

Refactor handleRequest using new Request struct

logger
ap4y 2016-01-11 21:01:29 +13:00
parent 353e906a7b
commit ae345732fa
2 changed files with 70 additions and 55 deletions

View File

@ -10,9 +10,9 @@ import (
) )
const ( const (
connectCommand = uint8(1) ConnectCommand = uint8(1)
bindCommand = uint8(2) BindCommand = uint8(2)
associateCommand = uint8(3) AssociateCommand = uint8(3)
ipv4Address = uint8(1) ipv4Address = uint8(1)
fqdnAddress = uint8(3) fqdnAddress = uint8(3)
ipv6Address = uint8(4) ipv6Address = uint8(4)
@ -47,6 +47,13 @@ type AddrSpec struct {
Port int Port int
} }
func (a *AddrSpec) String() string {
if a.FQDN != "" {
return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port)
}
return fmt.Sprintf("%s:%d", a.IP, a.Port)
}
// A Request represents request received by a server // A Request represents request received by a server
type Request struct { type Request struct {
// Protocol version // Protocol version
@ -69,38 +76,39 @@ type conn interface {
RemoteAddr() net.Addr RemoteAddr() net.Addr
} }
func (a *AddrSpec) String() string { // NewRequest creates a new Request from the tcp connection
if a.FQDN != "" { func NewRequest(bufConn io.Reader) (*Request, error) {
return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port)
}
return fmt.Sprintf("%s:%d", a.IP, a.Port)
}
// handleRequest is used for request processing after authentication
func (s *Server) handleRequest(conn conn, bufConn io.Reader) error {
// Read the version byte // Read the version byte
header := []byte{0, 0, 0} header := []byte{0, 0, 0}
if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil { if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil {
return fmt.Errorf("Failed to get command version: %v", err) return nil, fmt.Errorf("Failed to get command version: %v", err)
} }
// Ensure we are compatible // Ensure we are compatible
if header[0] != socks5Version { if header[0] != socks5Version {
return fmt.Errorf("Unsupported command version: %v", header[0]) return nil, fmt.Errorf("Unsupported command version: %v", header[0])
} }
// Read in the destination address // Read in the destination address
dest, err := readAddrSpec(bufConn) dest, err := readAddrSpec(bufConn)
if err != nil { if err != nil {
if err == unrecognizedAddrType { return nil, err
if err := sendReply(conn, addrTypeNotSupported, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
}
return fmt.Errorf("Failed to read destination address: %v", err)
} }
request := &Request{
Version: socks5Version,
Command: header[1],
DestAddr: dest,
bufConn: bufConn,
}
return request, nil
}
// handleRequest is used for request processing after authentication
func (s *Server) handleRequest(req *Request, conn conn) error {
// Resolve the address if we have a FQDN // Resolve the address if we have a FQDN
dest := req.DestAddr
if dest.FQDN != "" { if dest.FQDN != "" {
addr, err := s.config.Resolver.Resolve(dest.FQDN) addr, err := s.config.Resolver.Resolve(dest.FQDN)
if err != nil { if err != nil {
@ -113,40 +121,39 @@ func (s *Server) handleRequest(conn conn, bufConn io.Reader) error {
} }
// Apply any address rewrites // Apply any address rewrites
realDest := dest req.realDestAddr = req.DestAddr
if s.config.Rewriter != nil { if s.config.Rewriter != nil {
realDest = s.config.Rewriter.Rewrite(dest) req.realDestAddr = s.config.Rewriter.Rewrite(req)
} }
// Switch on the command // Switch on the command
switch header[1] { switch req.Command {
case connectCommand: case ConnectCommand:
return s.handleConnect(conn, bufConn, dest, realDest) return s.handleConnect(conn, req)
case bindCommand: case BindCommand:
return s.handleBind(conn, bufConn, dest, realDest) return s.handleBind(conn, req)
case associateCommand: case AssociateCommand:
return s.handleAssociate(conn, bufConn, dest, realDest) return s.handleAssociate(conn, req)
default: default:
if err := sendReply(conn, commandNotSupported, nil); err != nil { if err := sendReply(conn, commandNotSupported, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("Failed to send reply: %v", err)
} }
return fmt.Errorf("Unsupported command: %v", header[1]) return fmt.Errorf("Unsupported command: %v", req.Command)
} }
} }
// handleConnect is used to handle a connect command // handleConnect is used to handle a connect command
func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest, realDest *AddrSpec) error { func (s *Server) handleConnect(conn conn, req *Request) error {
// Check if this is allowed // Check if this is allowed
client := conn.RemoteAddr().(*net.TCPAddr) if !s.config.Rules.Allow(req) {
if !s.config.Rules.AllowConnect(realDest.IP, realDest.Port, client.IP, client.Port) {
if err := sendReply(conn, ruleFailure, nil); err != nil { if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("Failed to send reply: %v", err)
} }
return fmt.Errorf("Connect to %v blocked by rules", dest) return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr)
} }
// Attempt to connect // Attempt to connect
addr := net.TCPAddr{IP: realDest.IP, Port: realDest.Port} addr := net.TCPAddr{IP: req.realDestAddr.IP, Port: req.realDestAddr.Port}
target, err := net.DialTCP("tcp", nil, &addr) target, err := net.DialTCP("tcp", nil, &addr)
if err != nil { if err != nil {
msg := err.Error() msg := err.Error()
@ -159,7 +166,7 @@ func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest, realDest *Add
if err := sendReply(conn, resp, nil); err != nil { if err := sendReply(conn, resp, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("Failed to send reply: %v", err)
} }
return fmt.Errorf("Connect to %v failed: %v", dest, err) return fmt.Errorf("Connect to %v failed: %v", req.DestAddr, err)
} }
defer target.Close() defer target.Close()
@ -172,7 +179,7 @@ func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest, realDest *Add
// Start proxying // Start proxying
errCh := make(chan error, 2) errCh := make(chan error, 2)
go proxy("target", target, bufConn, errCh, s.config.Logger) go proxy("target", target, req.bufConn, errCh, s.config.Logger)
go proxy("client", conn, target, errCh, s.config.Logger) go proxy("client", conn, target, errCh, s.config.Logger)
// Wait // Wait
@ -183,14 +190,13 @@ func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest, realDest *Add
} }
// handleBind is used to handle a connect command // handleBind is used to handle a connect command
func (s *Server) handleBind(conn conn, bufConn io.Reader, dest, realDest *AddrSpec) error { func (s *Server) handleBind(conn conn, req *Request) error {
// Check if this is allowed // Check if this is allowed
client := conn.RemoteAddr().(*net.TCPAddr) if !s.config.Rules.Allow(req) {
if !s.config.Rules.AllowBind(realDest.IP, realDest.Port, client.IP, client.Port) {
if err := sendReply(conn, ruleFailure, nil); err != nil { if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("Failed to send reply: %v", err)
} }
return fmt.Errorf("Bind to %v blocked by rules", dest) return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr)
} }
// TODO: Support bind // TODO: Support bind
@ -201,14 +207,13 @@ func (s *Server) handleBind(conn conn, bufConn io.Reader, dest, realDest *AddrSp
} }
// handleAssociate is used to handle a connect command // handleAssociate is used to handle a connect command
func (s *Server) handleAssociate(conn conn, bufConn io.Reader, dest, realDest *AddrSpec) error { func (s *Server) handleAssociate(conn conn, req *Request) error {
// Check if this is allowed // Check if this is allowed
client := conn.RemoteAddr().(*net.TCPAddr) if !s.config.Rules.Allow(req) {
if !s.config.Rules.AllowAssociate(realDest.IP, realDest.Port, client.IP, client.Port) {
if err := sendReply(conn, ruleFailure, nil); err != nil { if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("Failed to send reply: %v", err)
} }
return fmt.Errorf("Associate to %v blocked by rules", dest) return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr)
} }
// TODO: Support associate // TODO: Support associate

View File

@ -56,19 +56,24 @@ func TestRequest_Connect(t *testing.T) {
}} }}
// Create the connect request // Create the connect request
req := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
req.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
port := []byte{0, 0} port := []byte{0, 0}
binary.BigEndian.PutUint16(port, uint16(lAddr.Port)) binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
req.Write(port) buf.Write(port)
// Send a ping // Send a ping
req.Write([]byte("ping")) buf.Write([]byte("ping"))
// Handle the request // Handle the request
resp := &MockConn{} resp := &MockConn{}
if err := s.handleRequest(resp, req); err != nil { req, err := NewRequest(buf)
if err != nil {
t.Fatalf("err: %v", err)
}
if err := s.handleRequest(req, resp); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -126,19 +131,24 @@ func TestRequest_Connect_RuleFail(t *testing.T) {
}} }}
// Create the connect request // Create the connect request
req := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
req.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
port := []byte{0, 0} port := []byte{0, 0}
binary.BigEndian.PutUint16(port, uint16(lAddr.Port)) binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
req.Write(port) buf.Write(port)
// Send a ping // Send a ping
req.Write([]byte("ping")) buf.Write([]byte("ping"))
// Handle the request // Handle the request
resp := &MockConn{} resp := &MockConn{}
if err := s.handleRequest(resp, req); !strings.Contains(err.Error(), "blocked by rules") { req, err := NewRequest(buf)
if err != nil {
t.Fatalf("err: %v", err)
}
if err := s.handleRequest(req, resp); !strings.Contains(err.Error(), "blocked by rules") {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }