From f6449feb5549ecc834cc062f7f178c94621c5f80 Mon Sep 17 00:00:00 2001 From: Ambrose Chua Date: Mon, 26 Mar 2018 19:48:14 +0800 Subject: [PATCH] Support Windows service with go-svc --- go.mod | 6 ++++ main.go | 100 ++++++++++++++++++++++++++---------------------------- server.go | 76 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 131 insertions(+), 51 deletions(-) create mode 100644 go.mod create mode 100644 server.go diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..3d86675 --- /dev/null +++ b/go.mod @@ -0,0 +1,6 @@ +module "github.com/productionwentdown/forward" + +require ( + "github.com/judwhite/go-svc" v1.0.0 + "golang.org/x/sys" v0.0.0-20180322165403-91ee8cde4354 +) diff --git a/main.go b/main.go index 3d8e2cc..260e5eb 100644 --- a/main.go +++ b/main.go @@ -1,68 +1,66 @@ package main import ( - "flag" - "io" "log" - "net" + "os" + "path/filepath" + "sync" + + "github.com/judwhite/go-svc/svc" ) -var listen string -var connect string +type program struct { + logFile *os.File + wg sync.WaitGroup +} func main() { - flag.StringVar(&listen, "listen", ":8000", "listen on ip and port") - flag.StringVar(&connect, "connect", "", "forward to ip and port") - flag.Parse() + prg := &program{} - // check and parse address - conn, err := net.ResolveTCPAddr("tcp", connect) - if err != nil { - flag.PrintDefaults() + if err := svc.Run(prg); err != nil { log.Fatal(err) } - - // listen on address - ln, err := net.Listen("tcp", listen) - if err != nil { - flag.PrintDefaults() - log.Fatal(err) - } - - log.Printf("listening on %v", ln.Addr()) - log.Printf("will connect to %v", conn) - - for i := 0; ; i++ { - // accept new connection - c, err := ln.Accept() - if err != nil { - log.Fatal(err) - } - - log.Printf("connection %v from %v", i, c.RemoteAddr()) - - cn, err := net.DialTCP("tcp", nil, conn) - if err != nil { - c.Close() - log.Print(err) - continue - } - - go pipe(c, cn, i) - go pipe(cn, c, i) - } } -func pipe(w io.WriteCloser, r io.ReadCloser, count int) { - n, err := io.Copy(w, r) +func (p *program) Init(env svc.Environment) error { + if env.IsWindowsService() { + dir, err := filepath.Abs(filepath.Dir(os.Args[0])) + if err != nil { + return err + } + logPath := filepath.Join(dir, "forward.log") - r.Close() - w.Close() + f, err := os.OpenFile(logPath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0644) + if err != nil { + return err + } - log.Printf("connection %v closed, %v bytes", count, n) - - opError, ok := err.(*net.OpError) - if err != nil && (!ok || opError.Op != "readfrom") { - log.Printf("warning! %v", err) + p.logFile = f + log.SetOutput(f) } + + setup() + + return nil +} + +func (p *program) Start() error { + p.wg.Add(1) + go func() { + serve() + p.wg.Done() + }() + + log.Print("started") + return nil +} + +func (p *program) Stop() error { + log.Print("stopping...") + + ln.Close() + p.wg.Wait() + + log.Print("stopped") + return nil } diff --git a/server.go b/server.go new file mode 100644 index 0000000..e68e190 --- /dev/null +++ b/server.go @@ -0,0 +1,76 @@ +package main + +import ( + "flag" + "io" + "log" + "net" +) + +var listen string +var connect string + +var ln net.Listener +var conn *net.TCPAddr + +func setup() { + flag.StringVar(&listen, "listen", ":8000", "listen on ip and port") + flag.StringVar(&connect, "connect", "", "forward to ip and port") + flag.Parse() + + var err error + + // check and parse address + conn, err = net.ResolveTCPAddr("tcp", connect) + if err != nil { + flag.PrintDefaults() + log.Fatal(err) + } + + // listen on address + ln, err = net.Listen("tcp", listen) + if err != nil { + flag.PrintDefaults() + log.Fatal(err) + } + + log.Printf("listening on %v", ln.Addr()) + log.Printf("will connect to %v", conn) +} + +func serve() { + for i := 0; ; i++ { + // accept new connection + c, err := ln.Accept() + if err != nil { + log.Print(err) + break + } + + log.Printf("connection %v from %v", i, c.RemoteAddr()) + + cn, err := net.DialTCP("tcp", nil, conn) + if err != nil { + c.Close() + log.Print(err) + continue + } + + go pipe(c, cn, i) + go pipe(cn, c, i) + } +} + +func pipe(w io.WriteCloser, r io.ReadCloser, count int) { + n, err := io.Copy(w, r) + + r.Close() + w.Close() + + log.Printf("connection %v closed, %v bytes", count, n) + + opError, ok := err.(*net.OpError) + if err != nil && (!ok || opError.Op != "readfrom") { + log.Printf("warning! %v", err) + } +}