diff --git a/request.go b/request.go index 50e65c6..3e92f19 100644 --- a/request.go +++ b/request.go @@ -35,6 +35,11 @@ var ( unrecognizedAddrType = fmt.Errorf("Unrecognized address type") ) +// AddressRewriter is used to rewrite a destination transparently +type AddressRewriter interface { + Rewrite(addr *AddrSpec) *AddrSpec +} + // AddrSpec is used to return the target AddrSpec // which may be specified as IPv4, IPv6, or a FQDN type AddrSpec struct { @@ -91,18 +96,22 @@ func (s *Server) handleRequest(conn conn, bufConn io.Reader) error { dest.IP = addr } - // Check if we have a rewriter + // Apply any address rewrites + realDest := dest + if s.config.Rewriter != nil { + realDest = s.config.Rewriter.Rewrite(dest) + } // Switch on the command switch header[1] { case connectCommand: - return s.handleConnect(conn, bufConn, dest) + return s.handleConnect(conn, bufConn, dest, realDest) case bindCommand: - return s.handleBind(conn, bufConn, dest) + return s.handleBind(conn, bufConn, dest, realDest) case associateCommand: - return s.handleAssociate(conn, bufConn, dest) + return s.handleAssociate(conn, bufConn, dest, realDest) default: - if err := sendReply(conn, commandNotSupported, nil); err != nil { + if err := sendReply(conn, commandNotSupported, dest); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return fmt.Errorf("Unsupported command: %v", header[1]) @@ -110,10 +119,10 @@ func (s *Server) handleRequest(conn conn, bufConn io.Reader) error { } // handleConnect is used to handle a connect command -func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest *AddrSpec) error { +func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest, realDest *AddrSpec) error { // Check if this is allowed client := conn.RemoteAddr().(*net.TCPAddr) - if !s.config.Rules.AllowConnect(dest.IP, dest.Port, client.IP, client.Port) { + if !s.config.Rules.AllowConnect(realDest.IP, realDest.Port, client.IP, client.Port) { if err := sendReply(conn, ruleFailure, dest); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } @@ -121,7 +130,7 @@ func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest *AddrSpec) err } // Attempt to connect - addr := net.TCPAddr{IP: dest.IP, Port: dest.Port} + addr := net.TCPAddr{IP: realDest.IP, Port: realDest.Port} target, err := net.DialTCP("tcp", nil, &addr) if err != nil { msg := err.Error() @@ -156,10 +165,10 @@ func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest *AddrSpec) err } // handleBind is used to handle a connect command -func (s *Server) handleBind(conn conn, bufConn io.Reader, dest *AddrSpec) error { +func (s *Server) handleBind(conn conn, bufConn io.Reader, dest, realDest *AddrSpec) error { // Check if this is allowed client := conn.RemoteAddr().(*net.TCPAddr) - if !s.config.Rules.AllowBind(dest.IP, dest.Port, client.IP, client.Port) { + if !s.config.Rules.AllowBind(realDest.IP, realDest.Port, client.IP, client.Port) { if err := sendReply(conn, ruleFailure, dest); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } @@ -167,17 +176,17 @@ func (s *Server) handleBind(conn conn, bufConn io.Reader, dest *AddrSpec) error } // TODO: Support bind - if err := sendReply(conn, commandNotSupported, nil); err != nil { + if err := sendReply(conn, commandNotSupported, dest); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return nil } // handleAssociate is used to handle a connect command -func (s *Server) handleAssociate(conn conn, bufConn io.Reader, dest *AddrSpec) error { +func (s *Server) handleAssociate(conn conn, bufConn io.Reader, dest, realDest *AddrSpec) error { // Check if this is allowed client := conn.RemoteAddr().(*net.TCPAddr) - if !s.config.Rules.AllowAssociate(dest.IP, dest.Port, client.IP, client.Port) { + if !s.config.Rules.AllowAssociate(realDest.IP, realDest.Port, client.IP, client.Port) { if err := sendReply(conn, ruleFailure, dest); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } @@ -185,7 +194,7 @@ func (s *Server) handleAssociate(conn conn, bufConn io.Reader, dest *AddrSpec) e } // TODO: Support associate - if err := sendReply(conn, commandNotSupported, nil); err != nil { + if err := sendReply(conn, commandNotSupported, dest); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return nil diff --git a/socks5.go b/socks5.go index 9b30fc6..0df5a90 100644 --- a/socks5.go +++ b/socks5.go @@ -25,6 +25,11 @@ type Config struct { // various commands. If not provided, PermitAll is used. Rules RuleSet + // Rewriter can be used to transparently rewrite addresses. + // This is invoked before the RuleSet is invoked. + // Defaults to NoRewrite. + Rewriter AddressRewriter + // BindIP is used for bind or udp associate BindIP net.IP }