package socks5 import ( "bytes" "fmt" "io" "log" "math/big" "net" "strconv" "strings" "golang.org/x/net/context" ) const ( ConnectCommand = uint8(1) BindCommand = uint8(2) AssociateCommand = uint8(3) ipv4Address = uint8(1) fqdnAddress = uint8(3) ipv6Address = uint8(4) ) const ( successReply uint8 = iota serverFailure ruleFailure networkUnreachable hostUnreachable connectionRefused ttlExpired commandNotSupported addrTypeNotSupported ) var ( unrecognizedAddrType = fmt.Errorf("Unrecognized address type") ) // AddressRewriter is used to rewrite a destination transparently type AddressRewriter interface { Rewrite(ctx context.Context, request *Request) (context.Context, *AddrSpec) } // AddrSpec is used to return the target AddrSpec // which may be specified as IPv4, IPv6, or a FQDN type AddrSpec struct { FQDN string IP net.IP Port int } func (a *AddrSpec) String() string { if a.FQDN != "" { return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port) } return fmt.Sprintf("%s:%d", a.IP, a.Port) } // Address returns a string suitable to dial; prefer returning IP-based // address, fallback to FQDN func (a AddrSpec) Address() string { if 0 != len(a.IP) { return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port)) } return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port)) } // A Request represents request received by a server type Request struct { // Protocol version Version uint8 // Requested command Command uint8 // AuthContext provided during negotiation AuthContext *AuthContext // AddrSpec of the the network that sent the request RemoteAddr *AddrSpec // AddrSpec of the desired destination DestAddr *AddrSpec // AddrSpec of the actual destination (might be affected by rewrite) realDestAddr *AddrSpec bufConn io.Reader } type conn interface { Write([]byte) (int, error) RemoteAddr() net.Addr } // NewRequest creates a new Request from the tcp connection func NewRequest(bufConn io.Reader) (*Request, error) { // Read the version byte header := []byte{0, 0, 0} if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil { return nil, fmt.Errorf("Failed to get command version: %v", err) } // Ensure we are compatible if header[0] != socks5Version { return nil, fmt.Errorf("Unsupported command version: %v", header[0]) } // Read in the destination address dest, err := readAddrSpec(bufConn) if err != nil { return nil, err } request := &Request{ Version: socks5Version, Command: header[1], DestAddr: dest, bufConn: bufConn, } return request, nil } // handleRequest is used for request processing after authentication func (s *Server) handleRequest(req *Request, conn conn) error { ctx := context.Background() // Resolve the address if we have a FQDN dest := req.DestAddr if dest.FQDN != "" { ctx_, addr, err := s.config.Resolver.Resolve(ctx, dest.FQDN) if err != nil { if err := sendReply(conn, hostUnreachable, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return fmt.Errorf("Failed to resolve destination '%v': %v", dest.FQDN, err) } ctx = ctx_ dest.IP = addr } // Apply any address rewrites req.realDestAddr = req.DestAddr if s.config.Rewriter != nil { ctx, req.realDestAddr = s.config.Rewriter.Rewrite(ctx, req) } // Switch on the command switch req.Command { case ConnectCommand: return s.handleConnect(ctx, conn, req) case BindCommand: return s.handleBind(ctx, conn, req) case AssociateCommand: return s.handleAssociate(ctx, conn, req) default: if err := sendReply(conn, commandNotSupported, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return fmt.Errorf("Unsupported command: %v", req.Command) } } // handleConnect is used to handle a connect command func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) error { // Check if this is allowed if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { if err := sendReply(conn, ruleFailure, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr) } else { ctx = ctx_ } // Attempt to connect dial := s.config.Dial if dial == nil { dial = func(ctx context.Context, net_, addr string) (net.Conn, error) { return net.Dial(net_, addr) } } target, err := dial(ctx, "tcp", req.realDestAddr.Address()) if err != nil { msg := err.Error() resp := hostUnreachable if strings.Contains(msg, "refused") { resp = connectionRefused } else if strings.Contains(msg, "network is unreachable") { resp = networkUnreachable } if err := sendReply(conn, resp, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return fmt.Errorf("Connect to %v failed: %v", req.DestAddr, err) } defer target.Close() // Send success local := target.LocalAddr().(*net.TCPAddr) bind := AddrSpec{IP: local.IP, Port: local.Port} if err := sendReply(conn, successReply, &bind); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } // Start proxying errCh := make(chan error, 2) go proxy(target, req.bufConn, errCh) go proxy(conn, target, errCh) // Wait for i := 0; i < 2; i++ { e := <-errCh if e != nil { // return from this function closes target (and conn). return e } } return nil } // handleBind is used to handle a connect command func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error { // Check if this is allowed if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { if err := sendReply(conn, ruleFailure, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr) } else { ctx = ctx_ } // TODO: Support bind if err := sendReply(conn, commandNotSupported, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return nil } // handleAssociate is used to handle a connect command func (s *Server) handleAssociate(ctx context.Context, conn conn, req *Request) error { // Check if this is allowed if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { if err := sendReply(conn, ruleFailure, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr) } else { ctx = ctx_ } // TODO: Support associate if err := sendReply(conn, commandNotSupported, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return nil } // readAddrSpec is used to read AddrSpec. // Expects an address type byte, follwed by the address and port func readAddrSpec(r io.Reader) (*AddrSpec, error) { d := &AddrSpec{} // Get the address type addrType := []byte{0} if _, err := r.Read(addrType); err != nil { return nil, err } // Handle on a per type basis switch addrType[0] { case ipv4Address: addr := make([]byte, 4) if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil { return nil, err } d.IP = net.IP(addr) case ipv6Address: addr := make([]byte, 16) if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil { return nil, err } d.IP = net.IP(addr) case fqdnAddress: if _, err := r.Read(addrType); err != nil { return nil, err } addrLen := int(addrType[0]) fqdn := make([]byte, addrLen) if _, err := io.ReadAtLeast(r, fqdn, addrLen); err != nil { return nil, err } d.FQDN = string(fqdn) default: return nil, unrecognizedAddrType } // Read the port port := []byte{0, 0} if _, err := io.ReadAtLeast(r, port, 2); err != nil { return nil, err } d.Port = (int(port[0]) << 8) | int(port[1]) return d, nil } // sendReply is used to send a reply message func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error { // Format the address var addrType uint8 var addrBody []byte var addrPort uint16 switch { case addr == nil: addrType = ipv4Address addrBody = []byte{0, 0, 0, 0} addrPort = 0 case addr.FQDN != "": addrType = fqdnAddress addrBody = append([]byte{byte(len(addr.FQDN))}, addr.FQDN...) addrPort = uint16(addr.Port) case addr.IP.To4() != nil: addrType = ipv4Address addrBody = []byte(addr.IP.To4()) addrPort = uint16(addr.Port) case addr.IP.To16() != nil: addrType = ipv6Address addrBody = []byte(addr.IP.To16()) addrPort = uint16(addr.Port) default: return fmt.Errorf("Failed to format address: %v", addr) } // Format the message msg := make([]byte, 6+len(addrBody)) msg[0] = socks5Version msg[1] = resp msg[2] = 0 // Reserved msg[3] = addrType copy(msg[4:], addrBody) msg[4+len(addrBody)] = byte(addrPort >> 8) msg[4+len(addrBody)+1] = byte(addrPort & 0xff) // Send the message _, err := w.Write(msg) return err } type closeWriter interface { CloseWrite() error } // proxy is used to suffle data from src to destination, and sends errors // down a dedicated channel func proxy(dst io.Writer, src io.Reader, errCh chan error) { //_, err := io.Copy(dst, src) _, err := copyAndModify(dst, src) if tcpConn, ok := dst.(closeWriter); ok { tcpConn.CloseWrite() } errCh <- err } func copyAndModify(dst io.Writer, src io.Reader) (written int64, err error) { // based on copyBuffer in io.go buf := make([]byte, 32*1024) for { nr, er := src.Read(buf) if written == 0 && nr > 0 { if buf[0] == 0x16 && buf[5] == 0x01 { // && buf[1] == 0x03 && buf[2] == 0x01) { modifyHello(buf) } } if nr > 0 { nw, ew := dst.Write(buf[0:nr]) if nw > 0 { written += int64(nw) } if ew != nil { err = ew break } if nr != nw { err = io.ErrShortWrite break } } if er != nil { if er != io.EOF { err = er } break } } return written, err } func modifyHello(buf []byte) { // 1 byte content type // 2 byte tls version // 2 byte tls frame length frameLength := big.NewInt(0) frameLength.SetBytes(buf[3:5]) // 1 byte handshake type // 3 byte handshake length handshakeLength := big.NewInt(0) handshakeLength.SetBytes(buf[6:9]) // 2 byte version // 32 byte random // 1 byte session id length sessionLength := big.NewInt(0) sessionLength.SetBytes(buf[43:44]) // session id // 2 byte cipher suites length (in bytes) suitesLength := big.NewInt(0) suitesLength.SetBytes(buf[44+sessionLength.Uint64() : 44+sessionLength.Uint64()+2]) // cipher suites // 1 byte compression methods length // 1 byte compression method // 2 byte extensions length extensionsLength := big.NewInt(0) extensionsLength.SetBytes(buf[48+sessionLength.Uint64()+suitesLength.Uint64() : 48+sessionLength.Uint64()+suitesLength.Uint64()+2]) // extensions: extensionsStart := 50 + sessionLength.Uint64() + suitesLength.Uint64() extensionsEnd := extensionsStart + extensionsLength.Uint64() cur := extensionsStart for cur+4 < extensionsEnd { if buf[cur] != 0x00 || buf[cur+1] != 0x00 { extensionLength := big.NewInt(0) extensionLength.SetBytes(buf[cur+2 : cur+4]) cur += 4 + extensionLength.Uint64() continue } // yay is sni header! extensionLength := big.NewInt(0) extensionLength.SetBytes(buf[cur+2 : cur+4]) // 2 byte sni list length listLength := big.NewInt(0) listLength.SetBytes(buf[cur+4 : cur+6]) var list []string listStart := cur + 6 listEnd := cur + 6 + listLength.Uint64() lcur := listStart for lcur+3 < listEnd { // 1 byte type // 2 byte length nameLength := big.NewInt(0) nameLength.SetBytes(buf[lcur+1 : lcur+3]) //buf[lcur+3] = 'x' // name var name bytes.Buffer name.Write(buf[lcur+3 : lcur+3+nameLength.Uint64()]) name.WriteByte(0) // append to list list = append(list, name.String()) lcur += 3 + nameLength.Uint64() } log.Println(list) break } // 2 byte extension type // 2 byte extension length // extension data }