1
0
Fork 0

Add support for address rewriting

logger
Armon Dadgar 2014-01-23 16:55:08 -08:00
parent 86690d7131
commit c9813fbde2
2 changed files with 28 additions and 14 deletions

View File

@ -35,6 +35,11 @@ var (
unrecognizedAddrType = fmt.Errorf("Unrecognized address type") unrecognizedAddrType = fmt.Errorf("Unrecognized address type")
) )
// AddressRewriter is used to rewrite a destination transparently
type AddressRewriter interface {
Rewrite(addr *AddrSpec) *AddrSpec
}
// AddrSpec is used to return the target AddrSpec // AddrSpec is used to return the target AddrSpec
// which may be specified as IPv4, IPv6, or a FQDN // which may be specified as IPv4, IPv6, or a FQDN
type AddrSpec struct { type AddrSpec struct {
@ -91,18 +96,22 @@ func (s *Server) handleRequest(conn conn, bufConn io.Reader) error {
dest.IP = addr dest.IP = addr
} }
// Check if we have a rewriter // Apply any address rewrites
realDest := dest
if s.config.Rewriter != nil {
realDest = s.config.Rewriter.Rewrite(dest)
}
// Switch on the command // Switch on the command
switch header[1] { switch header[1] {
case connectCommand: case connectCommand:
return s.handleConnect(conn, bufConn, dest) return s.handleConnect(conn, bufConn, dest, realDest)
case bindCommand: case bindCommand:
return s.handleBind(conn, bufConn, dest) return s.handleBind(conn, bufConn, dest, realDest)
case associateCommand: case associateCommand:
return s.handleAssociate(conn, bufConn, dest) return s.handleAssociate(conn, bufConn, dest, realDest)
default: default:
if err := sendReply(conn, commandNotSupported, nil); err != nil { if err := sendReply(conn, commandNotSupported, dest); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("Failed to send reply: %v", err)
} }
return fmt.Errorf("Unsupported command: %v", header[1]) return fmt.Errorf("Unsupported command: %v", header[1])
@ -110,10 +119,10 @@ func (s *Server) handleRequest(conn conn, bufConn io.Reader) error {
} }
// handleConnect is used to handle a connect command // handleConnect is used to handle a connect command
func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest *AddrSpec) error { func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest, realDest *AddrSpec) error {
// Check if this is allowed // Check if this is allowed
client := conn.RemoteAddr().(*net.TCPAddr) client := conn.RemoteAddr().(*net.TCPAddr)
if !s.config.Rules.AllowConnect(dest.IP, dest.Port, client.IP, client.Port) { if !s.config.Rules.AllowConnect(realDest.IP, realDest.Port, client.IP, client.Port) {
if err := sendReply(conn, ruleFailure, dest); err != nil { if err := sendReply(conn, ruleFailure, dest); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("Failed to send reply: %v", err)
} }
@ -121,7 +130,7 @@ func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest *AddrSpec) err
} }
// Attempt to connect // Attempt to connect
addr := net.TCPAddr{IP: dest.IP, Port: dest.Port} addr := net.TCPAddr{IP: realDest.IP, Port: realDest.Port}
target, err := net.DialTCP("tcp", nil, &addr) target, err := net.DialTCP("tcp", nil, &addr)
if err != nil { if err != nil {
msg := err.Error() msg := err.Error()
@ -156,10 +165,10 @@ func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest *AddrSpec) err
} }
// handleBind is used to handle a connect command // handleBind is used to handle a connect command
func (s *Server) handleBind(conn conn, bufConn io.Reader, dest *AddrSpec) error { func (s *Server) handleBind(conn conn, bufConn io.Reader, dest, realDest *AddrSpec) error {
// Check if this is allowed // Check if this is allowed
client := conn.RemoteAddr().(*net.TCPAddr) client := conn.RemoteAddr().(*net.TCPAddr)
if !s.config.Rules.AllowBind(dest.IP, dest.Port, client.IP, client.Port) { if !s.config.Rules.AllowBind(realDest.IP, realDest.Port, client.IP, client.Port) {
if err := sendReply(conn, ruleFailure, dest); err != nil { if err := sendReply(conn, ruleFailure, dest); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("Failed to send reply: %v", err)
} }
@ -167,17 +176,17 @@ func (s *Server) handleBind(conn conn, bufConn io.Reader, dest *AddrSpec) error
} }
// TODO: Support bind // TODO: Support bind
if err := sendReply(conn, commandNotSupported, nil); err != nil { if err := sendReply(conn, commandNotSupported, dest); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("Failed to send reply: %v", err)
} }
return nil return nil
} }
// handleAssociate is used to handle a connect command // handleAssociate is used to handle a connect command
func (s *Server) handleAssociate(conn conn, bufConn io.Reader, dest *AddrSpec) error { func (s *Server) handleAssociate(conn conn, bufConn io.Reader, dest, realDest *AddrSpec) error {
// Check if this is allowed // Check if this is allowed
client := conn.RemoteAddr().(*net.TCPAddr) client := conn.RemoteAddr().(*net.TCPAddr)
if !s.config.Rules.AllowAssociate(dest.IP, dest.Port, client.IP, client.Port) { if !s.config.Rules.AllowAssociate(realDest.IP, realDest.Port, client.IP, client.Port) {
if err := sendReply(conn, ruleFailure, dest); err != nil { if err := sendReply(conn, ruleFailure, dest); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("Failed to send reply: %v", err)
} }
@ -185,7 +194,7 @@ func (s *Server) handleAssociate(conn conn, bufConn io.Reader, dest *AddrSpec) e
} }
// TODO: Support associate // TODO: Support associate
if err := sendReply(conn, commandNotSupported, nil); err != nil { if err := sendReply(conn, commandNotSupported, dest); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("Failed to send reply: %v", err)
} }
return nil return nil

View File

@ -25,6 +25,11 @@ type Config struct {
// various commands. If not provided, PermitAll is used. // various commands. If not provided, PermitAll is used.
Rules RuleSet Rules RuleSet
// Rewriter can be used to transparently rewrite addresses.
// This is invoked before the RuleSet is invoked.
// Defaults to NoRewrite.
Rewriter AddressRewriter
// BindIP is used for bind or udp associate // BindIP is used for bind or udp associate
BindIP net.IP BindIP net.IP
} }