1
0
Fork 0

Testing the CONNECT handling

logger
Armon Dadgar 2014-01-23 14:07:39 -08:00
parent 5437f80e57
commit 9aca0ed614
2 changed files with 105 additions and 7 deletions

View File

@ -7,6 +7,7 @@ import (
"log"
"net"
"strings"
"time"
)
const (
@ -42,6 +43,11 @@ type addrSpec struct {
port int
}
type conn interface {
Write([]byte) (int, error)
RemoteAddr() net.Addr
}
func (a *addrSpec) String() string {
if a.fqdn != "" {
return fmt.Sprintf("%s (%s):%d", a.fqdn, a.ip, a.port)
@ -50,7 +56,7 @@ func (a *addrSpec) String() string {
}
// handleRequest is used for request processing after authentication
func (s *Server) handleRequest(conn net.Conn, bufConn io.Reader) error {
func (s *Server) handleRequest(conn conn, bufConn io.Reader) error {
// Read the version byte
header := []byte{0, 0, 0}
if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil {
@ -102,7 +108,7 @@ func (s *Server) handleRequest(conn net.Conn, bufConn io.Reader) error {
}
// handleConnect is used to handle a connect command
func (s *Server) handleConnect(conn net.Conn, bufConn io.Reader, dest *addrSpec) error {
func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest *addrSpec) error {
// Check if this is allowed
client := conn.RemoteAddr().(*net.TCPAddr)
if !s.config.Rules.AllowConnect(dest.ip, dest.port, client.IP, client.Port) {
@ -148,7 +154,7 @@ func (s *Server) handleConnect(conn net.Conn, bufConn io.Reader, dest *addrSpec)
}
// handleBind is used to handle a connect command
func (s *Server) handleBind(conn net.Conn, bufConn io.Reader, dest *addrSpec) error {
func (s *Server) handleBind(conn conn, bufConn io.Reader, dest *addrSpec) error {
// Check if this is allowed
client := conn.RemoteAddr().(*net.TCPAddr)
if !s.config.Rules.AllowBind(dest.ip, dest.port, client.IP, client.Port) {
@ -166,7 +172,7 @@ func (s *Server) handleBind(conn net.Conn, bufConn io.Reader, dest *addrSpec) er
}
// handleAssociate is used to handle a connect command
func (s *Server) handleAssociate(conn net.Conn, bufConn io.Reader, dest *addrSpec) error {
func (s *Server) handleAssociate(conn conn, bufConn io.Reader, dest *addrSpec) error {
// Check if this is allowed
client := conn.RemoteAddr().(*net.TCPAddr)
if !s.config.Rules.AllowAssociate(dest.ip, dest.port, client.IP, client.Port) {
@ -277,9 +283,15 @@ func sendReply(w io.Writer, resp uint8, addr *addrSpec) error {
// proxy is used to suffle data from src to destination, and sends errors
// down a dedicated channel
func proxy(name string, dst io.WriteCloser, src io.Reader, errCh chan error) {
defer dst.Close()
func proxy(name string, dst io.Writer, src io.Reader, errCh chan error) {
// Copy
n, err := io.Copy(dst, src)
errCh <- err
// Log, and sleep. This is jank but allows the otherside
// to finish a pending copy
log.Printf("[DEBUG] Copied %d bytes for %s", n, name)
time.Sleep(10 * time.Millisecond)
// Send any errors
errCh <- err
}

86
request_test.go Normal file
View File

@ -0,0 +1,86 @@
package socks5
import (
"bytes"
"encoding/binary"
"io"
"net"
"testing"
)
type MockConn struct {
buf bytes.Buffer
}
func (m *MockConn) Write(b []byte) (int, error) {
return m.buf.Write(b)
}
func (m *MockConn) RemoteAddr() net.Addr {
return &net.TCPAddr{IP: []byte{127, 0, 0, 1}, Port: 65432}
}
func TestRequest_Connect(t *testing.T) {
// Create a local listener
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
go func() {
conn, err := l.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
buf := make([]byte, 4)
if _, err := io.ReadAtLeast(conn, buf, 4); err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(buf, []byte("ping")) {
t.Fatalf("bad: %v", buf)
}
conn.Write([]byte("pong"))
}()
lAddr := l.Addr().(*net.TCPAddr)
// Make server
s := &Server{config: &Config{
Rules: PermitAll(),
Resolver: DNSResolver{},
}}
// Create the connect request
req := bytes.NewBuffer(nil)
req.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
port := []byte{0, 0}
binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
req.Write(port)
// Send a ping
req.Write([]byte("ping"))
// Handle the request
resp := &MockConn{}
if err := s.handleRequest(resp, req); err != nil {
t.Fatalf("err: %v", err)
}
// Verify response
out := resp.buf.Bytes()
expected := []byte{
5,
0,
0,
1,
127, 0, 0, 1,
port[0],
port[1],
'p', 'o', 'n', 'g',
}
if !bytes.Equal(out, expected) {
t.Fatalf("bad: %v", out)
}
}