diff --git a/request.go b/request.go index 82f1c4e..50e65c6 100644 --- a/request.go +++ b/request.go @@ -35,12 +35,12 @@ var ( unrecognizedAddrType = fmt.Errorf("Unrecognized address type") ) -// addrSpec is used to return the target addrSpec +// AddrSpec is used to return the target AddrSpec // which may be specified as IPv4, IPv6, or a FQDN -type addrSpec struct { - fqdn string - ip net.IP - port int +type AddrSpec struct { + FQDN string + IP net.IP + Port int } type conn interface { @@ -48,11 +48,11 @@ type conn interface { RemoteAddr() net.Addr } -func (a *addrSpec) String() string { - if a.fqdn != "" { - return fmt.Sprintf("%s (%s):%d", a.fqdn, a.ip, a.port) +func (a *AddrSpec) String() string { + if a.FQDN != "" { + return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port) } - return fmt.Sprintf("%s:%d", a.ip, a.port) + return fmt.Sprintf("%s:%d", a.IP, a.Port) } // handleRequest is used for request processing after authentication @@ -80,17 +80,19 @@ func (s *Server) handleRequest(conn conn, bufConn io.Reader) error { } // Resolve the address if we have a FQDN - if dest.fqdn != "" { - addr, err := s.config.Resolver.Resolve(dest.fqdn) + if dest.FQDN != "" { + addr, err := s.config.Resolver.Resolve(dest.FQDN) if err != nil { if err := sendReply(conn, hostUnreachable, dest); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } - return fmt.Errorf("Failed to resolve destination '%v': %v", dest.fqdn, err) + return fmt.Errorf("Failed to resolve destination '%v': %v", dest.FQDN, err) } - dest.ip = addr + dest.IP = addr } + // Check if we have a rewriter + // Switch on the command switch header[1] { case connectCommand: @@ -108,10 +110,10 @@ func (s *Server) handleRequest(conn conn, bufConn io.Reader) error { } // handleConnect is used to handle a connect command -func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest *addrSpec) error { +func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest *AddrSpec) error { // Check if this is allowed client := conn.RemoteAddr().(*net.TCPAddr) - if !s.config.Rules.AllowConnect(dest.ip, dest.port, client.IP, client.Port) { + if !s.config.Rules.AllowConnect(dest.IP, dest.Port, client.IP, client.Port) { if err := sendReply(conn, ruleFailure, dest); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } @@ -119,7 +121,7 @@ func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest *addrSpec) err } // Attempt to connect - addr := net.TCPAddr{IP: dest.ip, Port: dest.port} + addr := net.TCPAddr{IP: dest.IP, Port: dest.Port} target, err := net.DialTCP("tcp", nil, &addr) if err != nil { msg := err.Error() @@ -154,10 +156,10 @@ func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest *addrSpec) err } // handleBind is used to handle a connect command -func (s *Server) handleBind(conn conn, bufConn io.Reader, dest *addrSpec) error { +func (s *Server) handleBind(conn conn, bufConn io.Reader, dest *AddrSpec) error { // Check if this is allowed client := conn.RemoteAddr().(*net.TCPAddr) - if !s.config.Rules.AllowBind(dest.ip, dest.port, client.IP, client.Port) { + if !s.config.Rules.AllowBind(dest.IP, dest.Port, client.IP, client.Port) { if err := sendReply(conn, ruleFailure, dest); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } @@ -172,10 +174,10 @@ func (s *Server) handleBind(conn conn, bufConn io.Reader, dest *addrSpec) error } // handleAssociate is used to handle a connect command -func (s *Server) handleAssociate(conn conn, bufConn io.Reader, dest *addrSpec) error { +func (s *Server) handleAssociate(conn conn, bufConn io.Reader, dest *AddrSpec) error { // Check if this is allowed client := conn.RemoteAddr().(*net.TCPAddr) - if !s.config.Rules.AllowAssociate(dest.ip, dest.port, client.IP, client.Port) { + if !s.config.Rules.AllowAssociate(dest.IP, dest.Port, client.IP, client.Port) { if err := sendReply(conn, ruleFailure, dest); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } @@ -189,10 +191,10 @@ func (s *Server) handleAssociate(conn conn, bufConn io.Reader, dest *addrSpec) e return nil } -// readAddrSpec is used to read addrSpec. +// readAddrSpec is used to read AddrSpec. // Expects an address type byte, follwed by the address and port -func readAddrSpec(r io.Reader) (*addrSpec, error) { - d := &addrSpec{} +func readAddrSpec(r io.Reader) (*AddrSpec, error) { + d := &AddrSpec{} // Get the address type addrType := []byte{0} @@ -207,14 +209,14 @@ func readAddrSpec(r io.Reader) (*addrSpec, error) { if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil { return nil, err } - d.ip = net.IP(addr) + d.IP = net.IP(addr) case ipv6Address: addr := make([]byte, 16) if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil { return nil, err } - d.ip = net.IP(addr) + d.IP = net.IP(addr) case fqdnAddress: if _, err := r.Read(addrType); err != nil { @@ -225,7 +227,7 @@ func readAddrSpec(r io.Reader) (*addrSpec, error) { if _, err := io.ReadAtLeast(r, fqdn, addrLen); err != nil { return nil, err } - d.fqdn = string(fqdn) + d.FQDN = string(fqdn) default: return nil, unrecognizedAddrType @@ -236,13 +238,13 @@ func readAddrSpec(r io.Reader) (*addrSpec, error) { if _, err := io.ReadAtLeast(r, port, 2); err != nil { return nil, err } - d.port = int(binary.BigEndian.Uint16(port)) + d.Port = int(binary.BigEndian.Uint16(port)) return d, nil } // sendReply is used to send a reply message -func sendReply(w io.Writer, resp uint8, addr *addrSpec) error { +func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error { // Format the address var addrType uint8 var addrBody []byte @@ -252,20 +254,20 @@ func sendReply(w io.Writer, resp uint8, addr *addrSpec) error { addrType = 0 addrBody = nil - case addr.fqdn != "": + case addr.FQDN != "": addrType = fqdnAddress - addrBody = append([]byte{byte(len(addr.fqdn))}, addr.fqdn...) - addrPort = uint16(addr.port) + addrBody = append([]byte{byte(len(addr.FQDN))}, addr.FQDN...) + addrPort = uint16(addr.Port) - case addr.ip.To4() != nil: + case addr.IP.To4() != nil: addrType = ipv4Address - addrBody = []byte(addr.ip.To4()) - addrPort = uint16(addr.port) + addrBody = []byte(addr.IP.To4()) + addrPort = uint16(addr.Port) - case addr.ip.To16() != nil: + case addr.IP.To16() != nil: addrType = ipv6Address - addrBody = []byte(addr.ip.To16()) - addrPort = uint16(addr.port) + addrBody = []byte(addr.IP.To16()) + addrPort = uint16(addr.Port) default: return fmt.Errorf("Failed to format address: %v", addr)