From ae345732fa8f1a21570a8dbe4b2d99b40415bc69 Mon Sep 17 00:00:00 2001 From: ap4y Date: Mon, 11 Jan 2016 21:01:29 +1300 Subject: [PATCH] Refactor handleRequest using new Request struct --- request.go | 95 ++++++++++++++++++++++++++----------------------- request_test.go | 30 ++++++++++------ 2 files changed, 70 insertions(+), 55 deletions(-) diff --git a/request.go b/request.go index acc348e..46b4e55 100644 --- a/request.go +++ b/request.go @@ -10,9 +10,9 @@ import ( ) const ( - connectCommand = uint8(1) - bindCommand = uint8(2) - associateCommand = uint8(3) + ConnectCommand = uint8(1) + BindCommand = uint8(2) + AssociateCommand = uint8(3) ipv4Address = uint8(1) fqdnAddress = uint8(3) ipv6Address = uint8(4) @@ -47,6 +47,13 @@ type AddrSpec struct { Port int } +func (a *AddrSpec) String() string { + if a.FQDN != "" { + return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port) + } + return fmt.Sprintf("%s:%d", a.IP, a.Port) +} + // A Request represents request received by a server type Request struct { // Protocol version @@ -69,38 +76,39 @@ type conn interface { RemoteAddr() net.Addr } -func (a *AddrSpec) String() string { - if a.FQDN != "" { - return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port) - } - return fmt.Sprintf("%s:%d", a.IP, a.Port) -} - -// handleRequest is used for request processing after authentication -func (s *Server) handleRequest(conn conn, bufConn io.Reader) error { +// NewRequest creates a new Request from the tcp connection +func NewRequest(bufConn io.Reader) (*Request, error) { // Read the version byte header := []byte{0, 0, 0} if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil { - return fmt.Errorf("Failed to get command version: %v", err) + return nil, fmt.Errorf("Failed to get command version: %v", err) } // Ensure we are compatible if header[0] != socks5Version { - return fmt.Errorf("Unsupported command version: %v", header[0]) + return nil, fmt.Errorf("Unsupported command version: %v", header[0]) } // Read in the destination address dest, err := readAddrSpec(bufConn) if err != nil { - if err == unrecognizedAddrType { - if err := sendReply(conn, addrTypeNotSupported, nil); err != nil { - return fmt.Errorf("Failed to send reply: %v", err) - } - } - return fmt.Errorf("Failed to read destination address: %v", err) + return nil, err } + request := &Request{ + Version: socks5Version, + Command: header[1], + DestAddr: dest, + bufConn: bufConn, + } + + return request, nil +} + +// handleRequest is used for request processing after authentication +func (s *Server) handleRequest(req *Request, conn conn) error { // Resolve the address if we have a FQDN + dest := req.DestAddr if dest.FQDN != "" { addr, err := s.config.Resolver.Resolve(dest.FQDN) if err != nil { @@ -113,40 +121,39 @@ func (s *Server) handleRequest(conn conn, bufConn io.Reader) error { } // Apply any address rewrites - realDest := dest + req.realDestAddr = req.DestAddr if s.config.Rewriter != nil { - realDest = s.config.Rewriter.Rewrite(dest) + req.realDestAddr = s.config.Rewriter.Rewrite(req) } // Switch on the command - switch header[1] { - case connectCommand: - return s.handleConnect(conn, bufConn, dest, realDest) - case bindCommand: - return s.handleBind(conn, bufConn, dest, realDest) - case associateCommand: - return s.handleAssociate(conn, bufConn, dest, realDest) + switch req.Command { + case ConnectCommand: + return s.handleConnect(conn, req) + case BindCommand: + return s.handleBind(conn, req) + case AssociateCommand: + return s.handleAssociate(conn, req) default: if err := sendReply(conn, commandNotSupported, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } - return fmt.Errorf("Unsupported command: %v", header[1]) + return fmt.Errorf("Unsupported command: %v", req.Command) } } // handleConnect is used to handle a connect command -func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest, realDest *AddrSpec) error { +func (s *Server) handleConnect(conn conn, req *Request) error { // Check if this is allowed - client := conn.RemoteAddr().(*net.TCPAddr) - if !s.config.Rules.AllowConnect(realDest.IP, realDest.Port, client.IP, client.Port) { + if !s.config.Rules.Allow(req) { if err := sendReply(conn, ruleFailure, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } - return fmt.Errorf("Connect to %v blocked by rules", dest) + return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr) } // Attempt to connect - addr := net.TCPAddr{IP: realDest.IP, Port: realDest.Port} + addr := net.TCPAddr{IP: req.realDestAddr.IP, Port: req.realDestAddr.Port} target, err := net.DialTCP("tcp", nil, &addr) if err != nil { msg := err.Error() @@ -159,7 +166,7 @@ func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest, realDest *Add if err := sendReply(conn, resp, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } - return fmt.Errorf("Connect to %v failed: %v", dest, err) + return fmt.Errorf("Connect to %v failed: %v", req.DestAddr, err) } defer target.Close() @@ -172,7 +179,7 @@ func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest, realDest *Add // Start proxying errCh := make(chan error, 2) - go proxy("target", target, bufConn, errCh, s.config.Logger) + go proxy("target", target, req.bufConn, errCh, s.config.Logger) go proxy("client", conn, target, errCh, s.config.Logger) // Wait @@ -183,14 +190,13 @@ func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest, realDest *Add } // handleBind is used to handle a connect command -func (s *Server) handleBind(conn conn, bufConn io.Reader, dest, realDest *AddrSpec) error { +func (s *Server) handleBind(conn conn, req *Request) error { // Check if this is allowed - client := conn.RemoteAddr().(*net.TCPAddr) - if !s.config.Rules.AllowBind(realDest.IP, realDest.Port, client.IP, client.Port) { + if !s.config.Rules.Allow(req) { if err := sendReply(conn, ruleFailure, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } - return fmt.Errorf("Bind to %v blocked by rules", dest) + return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr) } // TODO: Support bind @@ -201,14 +207,13 @@ func (s *Server) handleBind(conn conn, bufConn io.Reader, dest, realDest *AddrSp } // handleAssociate is used to handle a connect command -func (s *Server) handleAssociate(conn conn, bufConn io.Reader, dest, realDest *AddrSpec) error { +func (s *Server) handleAssociate(conn conn, req *Request) error { // Check if this is allowed - client := conn.RemoteAddr().(*net.TCPAddr) - if !s.config.Rules.AllowAssociate(realDest.IP, realDest.Port, client.IP, client.Port) { + if !s.config.Rules.Allow(req) { if err := sendReply(conn, ruleFailure, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } - return fmt.Errorf("Associate to %v blocked by rules", dest) + return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr) } // TODO: Support associate diff --git a/request_test.go b/request_test.go index 880a3cd..5465113 100644 --- a/request_test.go +++ b/request_test.go @@ -56,19 +56,24 @@ func TestRequest_Connect(t *testing.T) { }} // Create the connect request - req := bytes.NewBuffer(nil) - req.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) + buf := bytes.NewBuffer(nil) + buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) port := []byte{0, 0} binary.BigEndian.PutUint16(port, uint16(lAddr.Port)) - req.Write(port) + buf.Write(port) // Send a ping - req.Write([]byte("ping")) + buf.Write([]byte("ping")) // Handle the request resp := &MockConn{} - if err := s.handleRequest(resp, req); err != nil { + req, err := NewRequest(buf) + if err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.handleRequest(req, resp); err != nil { t.Fatalf("err: %v", err) } @@ -126,19 +131,24 @@ func TestRequest_Connect_RuleFail(t *testing.T) { }} // Create the connect request - req := bytes.NewBuffer(nil) - req.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) + buf := bytes.NewBuffer(nil) + buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) port := []byte{0, 0} binary.BigEndian.PutUint16(port, uint16(lAddr.Port)) - req.Write(port) + buf.Write(port) // Send a ping - req.Write([]byte("ping")) + buf.Write([]byte("ping")) // Handle the request resp := &MockConn{} - if err := s.handleRequest(resp, req); !strings.Contains(err.Error(), "blocked by rules") { + req, err := NewRequest(buf) + if err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.handleRequest(req, resp); !strings.Contains(err.Error(), "blocked by rules") { t.Fatalf("err: %v", err) }