diff --git a/request.go b/request.go index 9fe3f12..c5b8613 100644 --- a/request.go +++ b/request.go @@ -40,8 +40,15 @@ type addrSpec struct { port int } +func (a *addrSpec) String() string { + if a.fqdn != "" { + return fmt.Sprintf("%s (%s):%d", a.fqdn, a.ip, a.port) + } + return fmt.Sprintf("%s:%d", a.ip, a.port) +} + // handleRequest is used for request processing after authentication -func (s *Server) handleRequest(conn io.Writer, bufConn io.Reader) error { +func (s *Server) handleRequest(conn net.Conn, bufConn io.Reader) error { // Read the version byte header := []byte{0, 0, 0} if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil { @@ -64,6 +71,18 @@ func (s *Server) handleRequest(conn io.Writer, bufConn io.Reader) error { return fmt.Errorf("Failed to read destination address: %v", err) } + // Resolve the address if we have a FQDN + if dest.fqdn != "" { + addr, err := s.config.Resolver.Resolve(dest.fqdn) + if err != nil { + if err := sendReply(conn, hostUnreachable, nil); err != nil { + return fmt.Errorf("Failed to send reply: %v", err) + } + return fmt.Errorf("Failed to resolve destination '%v': %v", dest.fqdn, err) + } + dest.ip = addr + } + // Switch on the command switch header[1] { case connectCommand: @@ -78,17 +97,44 @@ func (s *Server) handleRequest(conn io.Writer, bufConn io.Reader) error { } // handleConnect is used to handle a connect command -func (s *Server) handleConnect(conn io.Writer, bufConn io.Reader, dest *addrSpec) error { +func (s *Server) handleConnect(conn net.Conn, bufConn io.Reader, dest *addrSpec) error { + // Check if this is allowed + client := conn.RemoteAddr().(*net.TCPAddr) + if !s.config.Rules.AllowConnect(dest.ip, dest.port, client.IP, client.Port) { + if err := sendReply(conn, ruleFailure, dest); err != nil { + return fmt.Errorf("Failed to send reply: %v", err) + } + return fmt.Errorf("Connect to %v blocked by rules", dest) + } + return nil } // handleBind is used to handle a connect command -func (s *Server) handleBind(conn io.Writer, bufConn io.Reader, dest *addrSpec) error { +func (s *Server) handleBind(conn net.Conn, bufConn io.Reader, dest *addrSpec) error { + // Check if this is allowed + client := conn.RemoteAddr().(*net.TCPAddr) + if !s.config.Rules.AllowBind(dest.ip, dest.port, client.IP, client.Port) { + if err := sendReply(conn, ruleFailure, dest); err != nil { + return fmt.Errorf("Failed to send reply: %v", err) + } + return fmt.Errorf("Bind to %v blocked by rules", dest) + } + return nil } // handleAssociate is used to handle a connect command -func (s *Server) handleAssociate(conn io.Writer, bufConn io.Reader, dest *addrSpec) error { +func (s *Server) handleAssociate(conn net.Conn, bufConn io.Reader, dest *addrSpec) error { + // Check if this is allowed + client := conn.RemoteAddr().(*net.TCPAddr) + if !s.config.Rules.AllowAssociate(dest.ip, dest.port, client.IP, client.Port) { + if err := sendReply(conn, ruleFailure, dest); err != nil { + return fmt.Errorf("Failed to send reply: %v", err) + } + return fmt.Errorf("Associate to %v blocked by rules", dest) + } + return nil }