Refactor handleRequest using new Request struct
parent
353e906a7b
commit
ae345732fa
95
request.go
95
request.go
|
@ -10,9 +10,9 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
connectCommand = uint8(1)
|
||||
bindCommand = uint8(2)
|
||||
associateCommand = uint8(3)
|
||||
ConnectCommand = uint8(1)
|
||||
BindCommand = uint8(2)
|
||||
AssociateCommand = uint8(3)
|
||||
ipv4Address = uint8(1)
|
||||
fqdnAddress = uint8(3)
|
||||
ipv6Address = uint8(4)
|
||||
|
@ -47,6 +47,13 @@ type AddrSpec struct {
|
|||
Port int
|
||||
}
|
||||
|
||||
func (a *AddrSpec) String() string {
|
||||
if a.FQDN != "" {
|
||||
return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port)
|
||||
}
|
||||
return fmt.Sprintf("%s:%d", a.IP, a.Port)
|
||||
}
|
||||
|
||||
// A Request represents request received by a server
|
||||
type Request struct {
|
||||
// Protocol version
|
||||
|
@ -69,38 +76,39 @@ type conn interface {
|
|||
RemoteAddr() net.Addr
|
||||
}
|
||||
|
||||
func (a *AddrSpec) String() string {
|
||||
if a.FQDN != "" {
|
||||
return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port)
|
||||
}
|
||||
return fmt.Sprintf("%s:%d", a.IP, a.Port)
|
||||
}
|
||||
|
||||
// handleRequest is used for request processing after authentication
|
||||
func (s *Server) handleRequest(conn conn, bufConn io.Reader) error {
|
||||
// NewRequest creates a new Request from the tcp connection
|
||||
func NewRequest(bufConn io.Reader) (*Request, error) {
|
||||
// Read the version byte
|
||||
header := []byte{0, 0, 0}
|
||||
if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil {
|
||||
return fmt.Errorf("Failed to get command version: %v", err)
|
||||
return nil, fmt.Errorf("Failed to get command version: %v", err)
|
||||
}
|
||||
|
||||
// Ensure we are compatible
|
||||
if header[0] != socks5Version {
|
||||
return fmt.Errorf("Unsupported command version: %v", header[0])
|
||||
return nil, fmt.Errorf("Unsupported command version: %v", header[0])
|
||||
}
|
||||
|
||||
// Read in the destination address
|
||||
dest, err := readAddrSpec(bufConn)
|
||||
if err != nil {
|
||||
if err == unrecognizedAddrType {
|
||||
if err := sendReply(conn, addrTypeNotSupported, nil); err != nil {
|
||||
return fmt.Errorf("Failed to send reply: %v", err)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("Failed to read destination address: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
request := &Request{
|
||||
Version: socks5Version,
|
||||
Command: header[1],
|
||||
DestAddr: dest,
|
||||
bufConn: bufConn,
|
||||
}
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
// handleRequest is used for request processing after authentication
|
||||
func (s *Server) handleRequest(req *Request, conn conn) error {
|
||||
// Resolve the address if we have a FQDN
|
||||
dest := req.DestAddr
|
||||
if dest.FQDN != "" {
|
||||
addr, err := s.config.Resolver.Resolve(dest.FQDN)
|
||||
if err != nil {
|
||||
|
@ -113,40 +121,39 @@ func (s *Server) handleRequest(conn conn, bufConn io.Reader) error {
|
|||
}
|
||||
|
||||
// Apply any address rewrites
|
||||
realDest := dest
|
||||
req.realDestAddr = req.DestAddr
|
||||
if s.config.Rewriter != nil {
|
||||
realDest = s.config.Rewriter.Rewrite(dest)
|
||||
req.realDestAddr = s.config.Rewriter.Rewrite(req)
|
||||
}
|
||||
|
||||
// Switch on the command
|
||||
switch header[1] {
|
||||
case connectCommand:
|
||||
return s.handleConnect(conn, bufConn, dest, realDest)
|
||||
case bindCommand:
|
||||
return s.handleBind(conn, bufConn, dest, realDest)
|
||||
case associateCommand:
|
||||
return s.handleAssociate(conn, bufConn, dest, realDest)
|
||||
switch req.Command {
|
||||
case ConnectCommand:
|
||||
return s.handleConnect(conn, req)
|
||||
case BindCommand:
|
||||
return s.handleBind(conn, req)
|
||||
case AssociateCommand:
|
||||
return s.handleAssociate(conn, req)
|
||||
default:
|
||||
if err := sendReply(conn, commandNotSupported, nil); err != nil {
|
||||
return fmt.Errorf("Failed to send reply: %v", err)
|
||||
}
|
||||
return fmt.Errorf("Unsupported command: %v", header[1])
|
||||
return fmt.Errorf("Unsupported command: %v", req.Command)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnect is used to handle a connect command
|
||||
func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest, realDest *AddrSpec) error {
|
||||
func (s *Server) handleConnect(conn conn, req *Request) error {
|
||||
// Check if this is allowed
|
||||
client := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if !s.config.Rules.AllowConnect(realDest.IP, realDest.Port, client.IP, client.Port) {
|
||||
if !s.config.Rules.Allow(req) {
|
||||
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
||||
return fmt.Errorf("Failed to send reply: %v", err)
|
||||
}
|
||||
return fmt.Errorf("Connect to %v blocked by rules", dest)
|
||||
return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr)
|
||||
}
|
||||
|
||||
// Attempt to connect
|
||||
addr := net.TCPAddr{IP: realDest.IP, Port: realDest.Port}
|
||||
addr := net.TCPAddr{IP: req.realDestAddr.IP, Port: req.realDestAddr.Port}
|
||||
target, err := net.DialTCP("tcp", nil, &addr)
|
||||
if err != nil {
|
||||
msg := err.Error()
|
||||
|
@ -159,7 +166,7 @@ func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest, realDest *Add
|
|||
if err := sendReply(conn, resp, nil); err != nil {
|
||||
return fmt.Errorf("Failed to send reply: %v", err)
|
||||
}
|
||||
return fmt.Errorf("Connect to %v failed: %v", dest, err)
|
||||
return fmt.Errorf("Connect to %v failed: %v", req.DestAddr, err)
|
||||
}
|
||||
defer target.Close()
|
||||
|
||||
|
@ -172,7 +179,7 @@ func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest, realDest *Add
|
|||
|
||||
// Start proxying
|
||||
errCh := make(chan error, 2)
|
||||
go proxy("target", target, bufConn, errCh, s.config.Logger)
|
||||
go proxy("target", target, req.bufConn, errCh, s.config.Logger)
|
||||
go proxy("client", conn, target, errCh, s.config.Logger)
|
||||
|
||||
// Wait
|
||||
|
@ -183,14 +190,13 @@ func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest, realDest *Add
|
|||
}
|
||||
|
||||
// handleBind is used to handle a connect command
|
||||
func (s *Server) handleBind(conn conn, bufConn io.Reader, dest, realDest *AddrSpec) error {
|
||||
func (s *Server) handleBind(conn conn, req *Request) error {
|
||||
// Check if this is allowed
|
||||
client := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if !s.config.Rules.AllowBind(realDest.IP, realDest.Port, client.IP, client.Port) {
|
||||
if !s.config.Rules.Allow(req) {
|
||||
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
||||
return fmt.Errorf("Failed to send reply: %v", err)
|
||||
}
|
||||
return fmt.Errorf("Bind to %v blocked by rules", dest)
|
||||
return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr)
|
||||
}
|
||||
|
||||
// TODO: Support bind
|
||||
|
@ -201,14 +207,13 @@ func (s *Server) handleBind(conn conn, bufConn io.Reader, dest, realDest *AddrSp
|
|||
}
|
||||
|
||||
// handleAssociate is used to handle a connect command
|
||||
func (s *Server) handleAssociate(conn conn, bufConn io.Reader, dest, realDest *AddrSpec) error {
|
||||
func (s *Server) handleAssociate(conn conn, req *Request) error {
|
||||
// Check if this is allowed
|
||||
client := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if !s.config.Rules.AllowAssociate(realDest.IP, realDest.Port, client.IP, client.Port) {
|
||||
if !s.config.Rules.Allow(req) {
|
||||
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
||||
return fmt.Errorf("Failed to send reply: %v", err)
|
||||
}
|
||||
return fmt.Errorf("Associate to %v blocked by rules", dest)
|
||||
return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr)
|
||||
}
|
||||
|
||||
// TODO: Support associate
|
||||
|
|
|
@ -56,19 +56,24 @@ func TestRequest_Connect(t *testing.T) {
|
|||
}}
|
||||
|
||||
// Create the connect request
|
||||
req := bytes.NewBuffer(nil)
|
||||
req.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
|
||||
buf := bytes.NewBuffer(nil)
|
||||
buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
|
||||
|
||||
port := []byte{0, 0}
|
||||
binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
|
||||
req.Write(port)
|
||||
buf.Write(port)
|
||||
|
||||
// Send a ping
|
||||
req.Write([]byte("ping"))
|
||||
buf.Write([]byte("ping"))
|
||||
|
||||
// Handle the request
|
||||
resp := &MockConn{}
|
||||
if err := s.handleRequest(resp, req); err != nil {
|
||||
req, err := NewRequest(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if err := s.handleRequest(req, resp); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -126,19 +131,24 @@ func TestRequest_Connect_RuleFail(t *testing.T) {
|
|||
}}
|
||||
|
||||
// Create the connect request
|
||||
req := bytes.NewBuffer(nil)
|
||||
req.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
|
||||
buf := bytes.NewBuffer(nil)
|
||||
buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
|
||||
|
||||
port := []byte{0, 0}
|
||||
binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
|
||||
req.Write(port)
|
||||
buf.Write(port)
|
||||
|
||||
// Send a ping
|
||||
req.Write([]byte("ping"))
|
||||
buf.Write([]byte("ping"))
|
||||
|
||||
// Handle the request
|
||||
resp := &MockConn{}
|
||||
if err := s.handleRequest(resp, req); !strings.Contains(err.Error(), "blocked by rules") {
|
||||
req, err := NewRequest(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if err := s.handleRequest(req, resp); !strings.Contains(err.Error(), "blocked by rules") {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue