diff --git a/request.go b/request.go index 08f9452..d66813b 100644 --- a/request.go +++ b/request.go @@ -7,6 +7,7 @@ import ( "log" "net" "strings" + "time" ) const ( @@ -42,6 +43,11 @@ type addrSpec struct { port int } +type conn interface { + Write([]byte) (int, error) + RemoteAddr() net.Addr +} + func (a *addrSpec) String() string { if a.fqdn != "" { return fmt.Sprintf("%s (%s):%d", a.fqdn, a.ip, a.port) @@ -50,7 +56,7 @@ func (a *addrSpec) String() string { } // handleRequest is used for request processing after authentication -func (s *Server) handleRequest(conn net.Conn, bufConn io.Reader) error { +func (s *Server) handleRequest(conn conn, bufConn io.Reader) error { // Read the version byte header := []byte{0, 0, 0} if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil { @@ -102,7 +108,7 @@ func (s *Server) handleRequest(conn net.Conn, bufConn io.Reader) error { } // handleConnect is used to handle a connect command -func (s *Server) handleConnect(conn net.Conn, bufConn io.Reader, dest *addrSpec) error { +func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest *addrSpec) error { // Check if this is allowed client := conn.RemoteAddr().(*net.TCPAddr) if !s.config.Rules.AllowConnect(dest.ip, dest.port, client.IP, client.Port) { @@ -148,7 +154,7 @@ func (s *Server) handleConnect(conn net.Conn, bufConn io.Reader, dest *addrSpec) } // handleBind is used to handle a connect command -func (s *Server) handleBind(conn net.Conn, bufConn io.Reader, dest *addrSpec) error { +func (s *Server) handleBind(conn conn, bufConn io.Reader, dest *addrSpec) error { // Check if this is allowed client := conn.RemoteAddr().(*net.TCPAddr) if !s.config.Rules.AllowBind(dest.ip, dest.port, client.IP, client.Port) { @@ -166,7 +172,7 @@ func (s *Server) handleBind(conn net.Conn, bufConn io.Reader, dest *addrSpec) er } // handleAssociate is used to handle a connect command -func (s *Server) handleAssociate(conn net.Conn, bufConn io.Reader, dest *addrSpec) error { +func (s *Server) handleAssociate(conn conn, bufConn io.Reader, dest *addrSpec) error { // Check if this is allowed client := conn.RemoteAddr().(*net.TCPAddr) if !s.config.Rules.AllowAssociate(dest.ip, dest.port, client.IP, client.Port) { @@ -277,9 +283,15 @@ func sendReply(w io.Writer, resp uint8, addr *addrSpec) error { // proxy is used to suffle data from src to destination, and sends errors // down a dedicated channel -func proxy(name string, dst io.WriteCloser, src io.Reader, errCh chan error) { - defer dst.Close() +func proxy(name string, dst io.Writer, src io.Reader, errCh chan error) { + // Copy n, err := io.Copy(dst, src) - errCh <- err + + // Log, and sleep. This is jank but allows the otherside + // to finish a pending copy log.Printf("[DEBUG] Copied %d bytes for %s", n, name) + time.Sleep(10 * time.Millisecond) + + // Send any errors + errCh <- err } diff --git a/request_test.go b/request_test.go new file mode 100644 index 0000000..800d5ad --- /dev/null +++ b/request_test.go @@ -0,0 +1,86 @@ +package socks5 + +import ( + "bytes" + "encoding/binary" + "io" + "net" + "testing" +) + +type MockConn struct { + buf bytes.Buffer +} + +func (m *MockConn) Write(b []byte) (int, error) { + return m.buf.Write(b) +} + +func (m *MockConn) RemoteAddr() net.Addr { + return &net.TCPAddr{IP: []byte{127, 0, 0, 1}, Port: 65432} +} + +func TestRequest_Connect(t *testing.T) { + // Create a local listener + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } + go func() { + conn, err := l.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + buf := make([]byte, 4) + if _, err := io.ReadAtLeast(conn, buf, 4); err != nil { + t.Fatalf("err: %v", err) + } + + if !bytes.Equal(buf, []byte("ping")) { + t.Fatalf("bad: %v", buf) + } + conn.Write([]byte("pong")) + }() + lAddr := l.Addr().(*net.TCPAddr) + + // Make server + s := &Server{config: &Config{ + Rules: PermitAll(), + Resolver: DNSResolver{}, + }} + + // Create the connect request + req := bytes.NewBuffer(nil) + req.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) + + port := []byte{0, 0} + binary.BigEndian.PutUint16(port, uint16(lAddr.Port)) + req.Write(port) + + // Send a ping + req.Write([]byte("ping")) + + // Handle the request + resp := &MockConn{} + if err := s.handleRequest(resp, req); err != nil { + t.Fatalf("err: %v", err) + } + + // Verify response + out := resp.buf.Bytes() + expected := []byte{ + 5, + 0, + 0, + 1, + 127, 0, 0, 1, + port[0], + port[1], + 'p', 'o', 'n', 'g', + } + if !bytes.Equal(out, expected) { + t.Fatalf("bad: %v", out) + } +}