1
0
Fork 0
forward/server.go

135 lines
2.5 KiB
Go
Raw Permalink Normal View History

2018-03-26 19:48:14 +08:00
package main
import (
2019-09-04 15:09:05 +08:00
"bytes"
2018-03-26 19:48:14 +08:00
"flag"
"io"
"log"
"net"
)
var listen string
var connect string
2019-09-04 15:09:05 +08:00
var connectSSH string
2018-03-26 19:48:14 +08:00
var ln net.Listener
var conn *net.TCPAddr
2019-09-04 15:09:05 +08:00
var connSSH *net.TCPAddr
2018-03-26 19:48:14 +08:00
func setup() {
2019-09-04 15:09:05 +08:00
flag.StringVar(&listen, "listen", ":8000", "listen on address")
flag.StringVar(&connect, "connect", "", "forward to address")
flag.StringVar(&connectSSH, "ssh", "", "if set, will do basic introspection to forward SSH traffic to this address")
2018-03-26 19:48:14 +08:00
flag.Parse()
var err error
// check and parse address
conn, err = net.ResolveTCPAddr("tcp", connect)
if err != nil {
flag.PrintDefaults()
log.Fatal(err)
}
2019-09-04 15:09:05 +08:00
// check and parse SSH address
connSSH, _ = net.ResolveTCPAddr("tcp", connectSSH)
if connectSSH == "" {
connSSH = nil
}
2018-03-26 19:48:14 +08:00
// listen on address
ln, err = net.Listen("tcp", listen)
if err != nil {
flag.PrintDefaults()
log.Fatal(err)
}
log.Printf("listening on %v", ln.Addr())
log.Printf("will connect to %v", conn)
2019-09-04 15:09:05 +08:00
if connSSH != nil {
log.Printf("will connect SSH to %v", connSSH)
}
2018-03-26 19:48:14 +08:00
}
func serve() {
for i := 0; ; i++ {
// accept new connection
c, err := ln.Accept()
if err != nil {
log.Print(err)
break
}
log.Printf("connection %v from %v", i, c.RemoteAddr())
2019-09-04 15:09:05 +08:00
go handle(c, i)
}
}
2018-03-26 19:48:14 +08:00
2019-09-04 15:09:05 +08:00
var magic = []byte{'S', 'S', 'H', '-'}
var magicLen = len(magic)
func handle(c net.Conn, count int) {
2019-09-04 15:16:02 +08:00
if connSSH != nil {
2018-03-26 19:48:14 +08:00
2019-09-04 15:16:02 +08:00
// read first four characters
readMagic := make([]byte, magicLen, magicLen)
n, err := c.Read(readMagic)
if n != magicLen {
log.Printf("warning! could not read header")
return
}
opError, ok := err.(*net.OpError)
if err != nil && (!ok || opError.Op != "readfrom") {
log.Printf("warning! %v", err)
return
}
2019-09-04 15:09:05 +08:00
2019-09-04 15:16:02 +08:00
connTo := conn
// if the header looks like SSH, forward to SSH connection
if bytes.Equal(readMagic, magic) {
connTo = connSSH
}
2019-09-04 15:09:05 +08:00
2019-09-04 15:16:02 +08:00
cn, err := net.DialTCP("tcp", nil, connTo)
if err != nil {
c.Close()
log.Print(err)
return
}
// write the first four characters
cn.Write(readMagic)
go pipe(c, cn, count)
go pipe(cn, c, count)
} else {
2019-09-04 15:09:05 +08:00
2019-09-04 15:16:02 +08:00
cn, err := net.DialTCP("tcp", nil, conn)
if err != nil {
c.Close()
log.Print(err)
return
}
go pipe(c, cn, count)
go pipe(cn, c, count)
}
2018-03-26 19:48:14 +08:00
}
func pipe(w io.WriteCloser, r io.ReadCloser, count int) {
n, err := io.Copy(w, r)
r.Close()
w.Close()
log.Printf("connection %v closed, %v bytes", count, n)
opError, ok := err.(*net.OpError)
if err != nil && (!ok || opError.Op != "readfrom") {
log.Printf("warning! %v", err)
}
}