diff --git a/request.go b/request.go index c349d5d..36dc739 100644 --- a/request.go +++ b/request.go @@ -19,7 +19,7 @@ const ( ) const ( - successReply uint8 = 0 + successReply uint8 = iota serverFailure ruleFailure networkUnreachable diff --git a/request_test.go b/request_test.go index 820bfac..ce1db9d 100644 --- a/request_test.go +++ b/request_test.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "io" "net" + "strings" "testing" ) @@ -88,3 +89,67 @@ func TestRequest_Connect(t *testing.T) { t.Fatalf("bad: %v %v", out, expected) } } + +func TestRequest_Connect_RuleFail(t *testing.T) { + // Create a local listener + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } + go func() { + conn, err := l.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + buf := make([]byte, 4) + if _, err := io.ReadAtLeast(conn, buf, 4); err != nil { + t.Fatalf("err: %v", err) + } + + if !bytes.Equal(buf, []byte("ping")) { + t.Fatalf("bad: %v", buf) + } + conn.Write([]byte("pong")) + }() + lAddr := l.Addr().(*net.TCPAddr) + + // Make server + s := &Server{config: &Config{ + Rules: PermitNone(), + Resolver: DNSResolver{}, + }} + + // Create the connect request + req := bytes.NewBuffer(nil) + req.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) + + port := []byte{0, 0} + binary.BigEndian.PutUint16(port, uint16(lAddr.Port)) + req.Write(port) + + // Send a ping + req.Write([]byte("ping")) + + // Handle the request + resp := &MockConn{} + if err := s.handleRequest(resp, req); !strings.Contains(err.Error(), "blocked by rules") { + t.Fatalf("err: %v", err) + } + + // Verify response + out := resp.buf.Bytes() + expected := []byte{ + 5, + 2, + 0, + 1, + 0, 0, 0, 0, + 0, 0, + } + + if !bytes.Equal(out, expected) { + t.Fatalf("bad: %v %v", out, expected) + } +} diff --git a/ruleset.go b/ruleset.go index 18f22a0..1b19a9a 100644 --- a/ruleset.go +++ b/ruleset.go @@ -16,11 +16,16 @@ type RuleSet interface { AllowAssociate(dstIP net.IP, dstPort int, srcIP net.IP, srcPort int) bool } -// PermitAll is an returns a RuleSet which allows all types of connections +// PermitAll returns a RuleSet which allows all types of connections func PermitAll() RuleSet { return &PermitCommand{true, true, true} } +// PermitNone returns a RuleSet which disallows all types of connections +func PermitNone() RuleSet { + return &PermitCommand{false, false, false} +} + // PermitCommand is an implementation of the RuleSet which // enables filtering supported commands type PermitCommand struct {