From 85e909eb88f6fb7a884f615cef60850e591592ce Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Fri, 24 Jan 2014 14:31:58 -0800 Subject: [PATCH] Fixing bind port return value --- request.go | 26 ++++++++++++++------------ request_test.go | 10 +++++++--- socks5_test.go | 7 +++++-- 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/request.go b/request.go index 3b0f0da..c349d5d 100644 --- a/request.go +++ b/request.go @@ -1,7 +1,6 @@ package socks5 import ( - "encoding/binary" "fmt" "io" "log" @@ -88,7 +87,7 @@ func (s *Server) handleRequest(conn conn, bufConn io.Reader) error { if dest.FQDN != "" { addr, err := s.config.Resolver.Resolve(dest.FQDN) if err != nil { - if err := sendReply(conn, hostUnreachable, dest); err != nil { + if err := sendReply(conn, hostUnreachable, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return fmt.Errorf("Failed to resolve destination '%v': %v", dest.FQDN, err) @@ -111,7 +110,7 @@ func (s *Server) handleRequest(conn conn, bufConn io.Reader) error { case associateCommand: return s.handleAssociate(conn, bufConn, dest, realDest) default: - if err := sendReply(conn, commandNotSupported, dest); err != nil { + if err := sendReply(conn, commandNotSupported, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return fmt.Errorf("Unsupported command: %v", header[1]) @@ -123,7 +122,7 @@ func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest, realDest *Add // Check if this is allowed client := conn.RemoteAddr().(*net.TCPAddr) if !s.config.Rules.AllowConnect(realDest.IP, realDest.Port, client.IP, client.Port) { - if err := sendReply(conn, ruleFailure, dest); err != nil { + if err := sendReply(conn, ruleFailure, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return fmt.Errorf("Connect to %v blocked by rules", dest) @@ -140,7 +139,7 @@ func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest, realDest *Add } else if strings.Contains(msg, "network is unreachable") { resp = networkUnreachable } - if err := sendReply(conn, resp, dest); err != nil { + if err := sendReply(conn, resp, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return fmt.Errorf("Connect to %v failed: %v", dest, err) @@ -148,7 +147,9 @@ func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest, realDest *Add defer target.Close() // Send success - if err := sendReply(conn, successReply, dest); err != nil { + local := target.LocalAddr().(*net.TCPAddr) + bind := AddrSpec{IP: local.IP, Port: local.Port} + if err := sendReply(conn, successReply, &bind); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } @@ -169,14 +170,14 @@ func (s *Server) handleBind(conn conn, bufConn io.Reader, dest, realDest *AddrSp // Check if this is allowed client := conn.RemoteAddr().(*net.TCPAddr) if !s.config.Rules.AllowBind(realDest.IP, realDest.Port, client.IP, client.Port) { - if err := sendReply(conn, ruleFailure, dest); err != nil { + if err := sendReply(conn, ruleFailure, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return fmt.Errorf("Bind to %v blocked by rules", dest) } // TODO: Support bind - if err := sendReply(conn, commandNotSupported, dest); err != nil { + if err := sendReply(conn, commandNotSupported, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return nil @@ -187,14 +188,14 @@ func (s *Server) handleAssociate(conn conn, bufConn io.Reader, dest, realDest *A // Check if this is allowed client := conn.RemoteAddr().(*net.TCPAddr) if !s.config.Rules.AllowAssociate(realDest.IP, realDest.Port, client.IP, client.Port) { - if err := sendReply(conn, ruleFailure, dest); err != nil { + if err := sendReply(conn, ruleFailure, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return fmt.Errorf("Associate to %v blocked by rules", dest) } // TODO: Support associate - if err := sendReply(conn, commandNotSupported, dest); err != nil { + if err := sendReply(conn, commandNotSupported, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return nil @@ -247,7 +248,7 @@ func readAddrSpec(r io.Reader) (*AddrSpec, error) { if _, err := io.ReadAtLeast(r, port, 2); err != nil { return nil, err } - d.Port = int(binary.BigEndian.Uint16(port)) + d.Port = (int(port[0]) << 8) | int(port[1]) return d, nil } @@ -290,7 +291,8 @@ func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error { msg[2] = 0 // Reserved msg[3] = addrType copy(msg[4:], addrBody) - binary.BigEndian.PutUint16(msg[4+len(addrBody):], uint16(addrPort)) + msg[4+len(addrBody)] = byte(addrPort >> 8) + msg[4+len(addrBody)+1] = byte(addrPort & 0xff) // Send the message _, err := w.Write(msg) diff --git a/request_test.go b/request_test.go index 800d5ad..820bfac 100644 --- a/request_test.go +++ b/request_test.go @@ -76,11 +76,15 @@ func TestRequest_Connect(t *testing.T) { 0, 1, 127, 0, 0, 1, - port[0], - port[1], + 0, 0, 'p', 'o', 'n', 'g', } + + // Ignore the port for both + out[8] = 0 + out[9] = 0 + if !bytes.Equal(out, expected) { - t.Fatalf("bad: %v", out) + t.Fatalf("bad: %v %v", out, expected) } } diff --git a/socks5_test.go b/socks5_test.go index 5f32e28..4fa724f 100644 --- a/socks5_test.go +++ b/socks5_test.go @@ -86,8 +86,7 @@ func TestSOCKS5_Connect(t *testing.T) { 0, 1, 127, 0, 0, 1, - port[0], - port[1], + 0, 0, 'p', 'o', 'n', 'g', } out := make([]byte, len(expected)) @@ -97,6 +96,10 @@ func TestSOCKS5_Connect(t *testing.T) { t.Fatalf("err: %v", err) } + // Ignore the port + out[12] = 0 + out[13] = 0 + if !bytes.Equal(out, expected) { t.Fatalf("bad: %v", out) }