You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

118 lines
2.2 KiB

  1. package main
  2. import (
  3. "bytes"
  4. "flag"
  5. "io"
  6. "log"
  7. "net"
  8. )
  9. var listen string
  10. var connect string
  11. var connectSSH string
  12. var ln net.Listener
  13. var conn *net.TCPAddr
  14. var connSSH *net.TCPAddr
  15. func setup() {
  16. flag.StringVar(&listen, "listen", ":8000", "listen on address")
  17. flag.StringVar(&connect, "connect", "", "forward to address")
  18. flag.StringVar(&connectSSH, "ssh", "", "if set, will do basic introspection to forward SSH traffic to this address")
  19. flag.Parse()
  20. var err error
  21. // check and parse address
  22. conn, err = net.ResolveTCPAddr("tcp", connect)
  23. if err != nil {
  24. flag.PrintDefaults()
  25. log.Fatal(err)
  26. }
  27. // check and parse SSH address
  28. connSSH, _ = net.ResolveTCPAddr("tcp", connectSSH)
  29. if connectSSH == "" {
  30. connSSH = nil
  31. }
  32. // listen on address
  33. ln, err = net.Listen("tcp", listen)
  34. if err != nil {
  35. flag.PrintDefaults()
  36. log.Fatal(err)
  37. }
  38. log.Printf("listening on %v", ln.Addr())
  39. log.Printf("will connect to %v", conn)
  40. if connSSH != nil {
  41. log.Printf("will connect SSH to %v", connSSH)
  42. }
  43. }
  44. func serve() {
  45. for i := 0; ; i++ {
  46. // accept new connection
  47. c, err := ln.Accept()
  48. if err != nil {
  49. log.Print(err)
  50. break
  51. }
  52. log.Printf("connection %v from %v", i, c.RemoteAddr())
  53. go handle(c, i)
  54. }
  55. }
  56. var magic = []byte{'S', 'S', 'H', '-'}
  57. var magicLen = len(magic)
  58. func handle(c net.Conn, count int) {
  59. // read first four characters
  60. readMagic := make([]byte, magicLen, magicLen)
  61. n, err := c.Read(readMagic)
  62. if n != magicLen {
  63. log.Printf("warning! could not read header")
  64. return
  65. }
  66. opError, ok := err.(*net.OpError)
  67. if err != nil && (!ok || opError.Op != "readfrom") {
  68. log.Printf("warning! %v", err)
  69. return
  70. }
  71. connTo := conn
  72. // if the header looks like SSH, forward to SSH connection
  73. if bytes.Equal(readMagic, magic) {
  74. connTo = connSSH
  75. }
  76. cn, err := net.DialTCP("tcp", nil, connTo)
  77. if err != nil {
  78. c.Close()
  79. log.Print(err)
  80. return
  81. }
  82. // write the first four characters
  83. cn.Write(readMagic)
  84. go pipe(c, cn, count)
  85. go pipe(cn, c, count)
  86. }
  87. func pipe(w io.WriteCloser, r io.ReadCloser, count int) {
  88. n, err := io.Copy(w, r)
  89. r.Close()
  90. w.Close()
  91. log.Printf("connection %v closed, %v bytes", count, n)
  92. opError, ok := err.(*net.OpError)
  93. if err != nil && (!ok || opError.Op != "readfrom") {
  94. log.Printf("warning! %v", err)
  95. }
  96. }