Add SSH connection option
parent
bb031be2f1
commit
089bc9c6e4
6
go.mod
6
go.mod
|
@ -1,6 +1,6 @@
|
|||
module "github.com/productionwentdown/forward"
|
||||
module github.com/productionwentdown/forward
|
||||
|
||||
require (
|
||||
"github.com/judwhite/go-svc" v1.0.0
|
||||
"golang.org/x/sys" v0.0.0-20180322165403-91ee8cde4354
|
||||
github.com/judwhite/go-svc v1.0.0
|
||||
golang.org/x/sys v0.0.0-20180322165403-91ee8cde4354 // indirect
|
||||
)
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
github.com/judwhite/go-svc v1.0.0/go.mod h1:EeMSAFO3mLgEQfcvnZ50JDG0O1uQlagpAbMS6talrXE=
|
||||
golang.org/x/sys v0.0.0-20180322165403-91ee8cde4354/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
66
server.go
66
server.go
|
@ -1,6 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"io"
|
||||
"log"
|
||||
|
@ -9,13 +10,16 @@ import (
|
|||
|
||||
var listen string
|
||||
var connect string
|
||||
var connectSSH string
|
||||
|
||||
var ln net.Listener
|
||||
var conn *net.TCPAddr
|
||||
var connSSH *net.TCPAddr
|
||||
|
||||
func setup() {
|
||||
flag.StringVar(&listen, "listen", ":8000", "listen on ip and port")
|
||||
flag.StringVar(&connect, "connect", "", "forward to ip and port")
|
||||
flag.StringVar(&listen, "listen", ":8000", "listen on address")
|
||||
flag.StringVar(&connect, "connect", "", "forward to address")
|
||||
flag.StringVar(&connectSSH, "ssh", "", "if set, will do basic introspection to forward SSH traffic to this address")
|
||||
flag.Parse()
|
||||
|
||||
var err error
|
||||
|
@ -27,6 +31,12 @@ func setup() {
|
|||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// check and parse SSH address
|
||||
connSSH, _ = net.ResolveTCPAddr("tcp", connectSSH)
|
||||
if connectSSH == "" {
|
||||
connSSH = nil
|
||||
}
|
||||
|
||||
// listen on address
|
||||
ln, err = net.Listen("tcp", listen)
|
||||
if err != nil {
|
||||
|
@ -36,6 +46,9 @@ func setup() {
|
|||
|
||||
log.Printf("listening on %v", ln.Addr())
|
||||
log.Printf("will connect to %v", conn)
|
||||
if connSSH != nil {
|
||||
log.Printf("will connect SSH to %v", connSSH)
|
||||
}
|
||||
}
|
||||
|
||||
func serve() {
|
||||
|
@ -48,19 +61,48 @@ func serve() {
|
|||
}
|
||||
|
||||
log.Printf("connection %v from %v", i, c.RemoteAddr())
|
||||
|
||||
cn, err := net.DialTCP("tcp", nil, conn)
|
||||
if err != nil {
|
||||
c.Close()
|
||||
log.Print(err)
|
||||
continue
|
||||
}
|
||||
|
||||
go pipe(c, cn, i)
|
||||
go pipe(cn, c, i)
|
||||
go handle(c, i)
|
||||
}
|
||||
}
|
||||
|
||||
var magic = []byte{'S', 'S', 'H', '-'}
|
||||
|
||||
var magicLen = len(magic)
|
||||
|
||||
func handle(c net.Conn, count int) {
|
||||
// read first four characters
|
||||
readMagic := make([]byte, magicLen, magicLen)
|
||||
n, err := c.Read(readMagic)
|
||||
if n != magicLen {
|
||||
log.Printf("warning! could not read header")
|
||||
return
|
||||
}
|
||||
opError, ok := err.(*net.OpError)
|
||||
if err != nil && (!ok || opError.Op != "readfrom") {
|
||||
log.Printf("warning! %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
connTo := conn
|
||||
// if the header looks like SSH, forward to SSH connection
|
||||
if bytes.Equal(readMagic, magic) {
|
||||
connTo = connSSH
|
||||
}
|
||||
|
||||
cn, err := net.DialTCP("tcp", nil, connTo)
|
||||
if err != nil {
|
||||
c.Close()
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
// write the first four characters
|
||||
cn.Write(readMagic)
|
||||
|
||||
go pipe(c, cn, count)
|
||||
go pipe(cn, c, count)
|
||||
}
|
||||
|
||||
func pipe(w io.WriteCloser, r io.ReadCloser, count int) {
|
||||
n, err := io.Copy(w, r)
|
||||
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2016 Jud White
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -1,62 +0,0 @@
|
|||
package test
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Equal asserts two parameters are equal by using reflect.DeepEqual.
|
||||
func Equal(t *testing.T, expected, actual interface{}) {
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
t.Logf("\033[31m%s:%d:\n\n\t %#v (expected)\n\n\t!= %#v (actual)\033[39m\n\n",
|
||||
filepath.Base(file), line, expected, actual)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
// NotEqual asserts two parameters are not equal by using reflect.DeepEqual.
|
||||
func NotEqual(t *testing.T, expected, actual interface{}) {
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
t.Logf("\033[31m%s:%d:\n\n\tvalue should not equal %#v\033[39m\n\n",
|
||||
filepath.Base(file), line, actual)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
// Nil asserts the parameter is nil.
|
||||
func Nil(t *testing.T, object interface{}) {
|
||||
if !isNil(object) {
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
t.Logf("\033[31m%s:%d:\n\n\t <nil> (expected)\n\n\t!= %#v (actual)\033[39m\n\n",
|
||||
filepath.Base(file), line, object)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
// NotNil asserts the parameter is not nil.
|
||||
func NotNil(t *testing.T, object interface{}) {
|
||||
if isNil(object) {
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
t.Logf("\033[31m%s:%d:\n\n\tExpected value not to be <nil>\033[39m\n\n",
|
||||
filepath.Base(file), line)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func isNil(object interface{}) bool {
|
||||
if object == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
value := reflect.ValueOf(object)
|
||||
kind := value.Kind()
|
||||
if kind >= reflect.Chan && kind <= reflect.Slice && value.IsNil() {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
|
@ -1,36 +0,0 @@
|
|||
package svc
|
||||
|
||||
type mockProgram struct {
|
||||
start func() error
|
||||
stop func() error
|
||||
init func(Environment) error
|
||||
}
|
||||
|
||||
func (p *mockProgram) Start() error {
|
||||
return p.start()
|
||||
}
|
||||
|
||||
func (p *mockProgram) Stop() error {
|
||||
return p.stop()
|
||||
}
|
||||
|
||||
func (p *mockProgram) Init(wse Environment) error {
|
||||
return p.init(wse)
|
||||
}
|
||||
|
||||
func makeProgram(startCalled, stopCalled, initCalled *int) *mockProgram {
|
||||
return &mockProgram{
|
||||
start: func() error {
|
||||
*startCalled++
|
||||
return nil
|
||||
},
|
||||
stop: func() error {
|
||||
*stopCalled++
|
||||
return nil
|
||||
},
|
||||
init: func(wse Environment) error {
|
||||
*initCalled++
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
|
@ -1,68 +0,0 @@
|
|||
// +build !windows
|
||||
|
||||
package svc
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/judwhite/go-svc/svc/internal/test"
|
||||
)
|
||||
|
||||
func TestDefaultSignalHandling(t *testing.T) {
|
||||
signals := []os.Signal{syscall.SIGINT, syscall.SIGTERM} // default signals handled
|
||||
for _, signal := range signals {
|
||||
testSignalNotify(t, signal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserDefinedSignalHandling(t *testing.T) {
|
||||
signals := []os.Signal{syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP}
|
||||
for _, signal := range signals {
|
||||
testSignalNotify(t, signal, signals...)
|
||||
}
|
||||
}
|
||||
|
||||
func testSignalNotify(t *testing.T, signal os.Signal, sig ...os.Signal) {
|
||||
// arrange
|
||||
|
||||
// sigChan is the chan we'll send to here. if a signal matches a registered signal
|
||||
// type in the Run function (in svc_other.go) the signal will be delegated to the
|
||||
// channel passed to signalNotify, which is created in the Run function in svc_other.go.
|
||||
// shortly: we send here and the Run function gets it if it matches the filter.
|
||||
sigChan := make(chan os.Signal)
|
||||
|
||||
var startCalled, stopCalled, initCalled int
|
||||
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
|
||||
|
||||
signalNotify = func(c chan<- os.Signal, sig ...os.Signal) {
|
||||
if c == nil {
|
||||
panic("os/signal: Notify using nil channel")
|
||||
}
|
||||
|
||||
go func() {
|
||||
for val := range sigChan {
|
||||
for _, registeredSig := range sig {
|
||||
if val == registeredSig {
|
||||
c <- val
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
go func() {
|
||||
sigChan <- signal
|
||||
}()
|
||||
|
||||
// act
|
||||
if err := Run(prg, sig...); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// assert
|
||||
test.Equal(t, 1, startCalled)
|
||||
test.Equal(t, 1, stopCalled)
|
||||
test.Equal(t, 1, initCalled)
|
||||
}
|
|
@ -1,438 +0,0 @@
|
|||
// +build windows
|
||||
|
||||
package svc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/judwhite/go-svc/svc/internal/test"
|
||||
wsvc "golang.org/x/sys/windows/svc"
|
||||
)
|
||||
|
||||
func setupWinServiceTest(wsf *mockWinServiceFuncs) {
|
||||
// wsfWrapper allows signalNotify, svcIsInteractive, and svcRun to be set once.
|
||||
// Inidivual test functions set "wsf" to add behavior.
|
||||
wsfWrapper := &mockWinServiceFuncs{
|
||||
signalNotify: func(c chan<- os.Signal, sig ...os.Signal) {
|
||||
if c == nil {
|
||||
panic("os/signal: Notify using nil channel")
|
||||
}
|
||||
|
||||
if wsf.signalNotify != nil {
|
||||
wsf.signalNotify(c, sig...)
|
||||
} else {
|
||||
wsf1 := *wsf
|
||||
go func() {
|
||||
for val := range wsf1.sigChan {
|
||||
for _, registeredSig := range sig {
|
||||
if val == registeredSig {
|
||||
c <- val
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
},
|
||||
svcIsInteractive: func() (bool, error) {
|
||||
return wsf.svcIsInteractive()
|
||||
},
|
||||
svcRun: func(name string, handler wsvc.Handler) error {
|
||||
return wsf.svcRun(name, handler)
|
||||
},
|
||||
}
|
||||
|
||||
signalNotify = wsfWrapper.signalNotify
|
||||
svcIsAnInteractiveSession = wsfWrapper.svcIsInteractive
|
||||
svcRun = wsfWrapper.svcRun
|
||||
}
|
||||
|
||||
type mockWinServiceFuncs struct {
|
||||
signalNotify func(chan<- os.Signal, ...os.Signal)
|
||||
svcIsInteractive func() (bool, error)
|
||||
sigChan chan os.Signal
|
||||
svcRun func(string, wsvc.Handler) error
|
||||
ws *windowsService
|
||||
executeReturnedBool bool
|
||||
executeReturnedUInt32 uint32
|
||||
changes []wsvc.Status
|
||||
}
|
||||
|
||||
func setWindowsServiceFuncs(isInteractive bool, onRunningSendCmd *wsvc.Cmd) (*mockWinServiceFuncs, chan<- wsvc.ChangeRequest) {
|
||||
changeRequestChan := make(chan wsvc.ChangeRequest, 4)
|
||||
changesChan := make(chan wsvc.Status)
|
||||
done := make(chan struct{})
|
||||
|
||||
var wsf *mockWinServiceFuncs
|
||||
wsf = &mockWinServiceFuncs{
|
||||
sigChan: make(chan os.Signal),
|
||||
svcIsInteractive: func() (bool, error) {
|
||||
return isInteractive, nil
|
||||
},
|
||||
svcRun: func(name string, handler wsvc.Handler) error {
|
||||
wsf.ws = handler.(*windowsService)
|
||||
wsf.executeReturnedBool, wsf.executeReturnedUInt32 = handler.Execute(nil, changeRequestChan, changesChan)
|
||||
done <- struct{}{}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
var currentState wsvc.State
|
||||
|
||||
go func() {
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case change := <-changesChan:
|
||||
wsf.changes = append(wsf.changes, change)
|
||||
currentState = change.State
|
||||
|
||||
if change.State == wsvc.Running && onRunningSendCmd != nil {
|
||||
changeRequestChan <- wsvc.ChangeRequest{
|
||||
Cmd: *onRunningSendCmd,
|
||||
CurrentStatus: wsvc.Status{State: currentState},
|
||||
}
|
||||
}
|
||||
case <-done:
|
||||
break loop
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
setupWinServiceTest(wsf)
|
||||
|
||||
return wsf, changeRequestChan
|
||||
}
|
||||
|
||||
func TestWinService_RunWindowsService_NonInteractive(t *testing.T) {
|
||||
for _, svcCmd := range []wsvc.Cmd{wsvc.Stop, wsvc.Shutdown} {
|
||||
testRunWindowsServiceNonInteractive(t, svcCmd)
|
||||
}
|
||||
}
|
||||
|
||||
func testRunWindowsServiceNonInteractive(t *testing.T, svcCmd wsvc.Cmd) {
|
||||
// arrange
|
||||
var startCalled, stopCalled, initCalled int
|
||||
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
|
||||
|
||||
wsf, _ := setWindowsServiceFuncs(false, &svcCmd)
|
||||
|
||||
// act
|
||||
if err := Run(prg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// assert
|
||||
changes := wsf.changes
|
||||
|
||||
test.Equal(t, 1, startCalled)
|
||||
test.Equal(t, 1, stopCalled)
|
||||
test.Equal(t, 1, initCalled)
|
||||
|
||||
test.Equal(t, 3, len(changes))
|
||||
test.Equal(t, wsvc.StartPending, changes[0].State)
|
||||
test.Equal(t, wsvc.Running, changes[1].State)
|
||||
test.Equal(t, wsvc.StopPending, changes[2].State)
|
||||
|
||||
test.Equal(t, false, wsf.executeReturnedBool)
|
||||
test.Equal(t, uint32(0), wsf.executeReturnedUInt32)
|
||||
|
||||
test.Nil(t, wsf.ws.getError())
|
||||
}
|
||||
|
||||
func TestRunWindowsServiceNonInteractive_StartError(t *testing.T) {
|
||||
// arrange
|
||||
var startCalled, stopCalled, initCalled int
|
||||
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
|
||||
prg.start = func() error {
|
||||
startCalled++
|
||||
return errors.New("start error")
|
||||
}
|
||||
|
||||
svcStop := wsvc.Stop
|
||||
wsf, _ := setWindowsServiceFuncs(false, &svcStop)
|
||||
|
||||
// act
|
||||
err := Run(prg)
|
||||
|
||||
// assert
|
||||
test.Equal(t, "start error", err.Error())
|
||||
|
||||
changes := wsf.changes
|
||||
|
||||
test.Equal(t, 1, startCalled)
|
||||
test.Equal(t, 0, stopCalled)
|
||||
test.Equal(t, 1, initCalled)
|
||||
|
||||
test.Equal(t, 1, len(changes))
|
||||
test.Equal(t, wsvc.StartPending, changes[0].State)
|
||||
|
||||
test.Equal(t, true, wsf.executeReturnedBool)
|
||||
test.Equal(t, uint32(1), wsf.executeReturnedUInt32)
|
||||
|
||||
test.Equal(t, "start error", wsf.ws.getError().Error())
|
||||
}
|
||||
|
||||
func TestRunWindowsServiceInteractive_StartError(t *testing.T) {
|
||||
// arrange
|
||||
var startCalled, stopCalled, initCalled int
|
||||
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
|
||||
prg.start = func() error {
|
||||
startCalled++
|
||||
return errors.New("start error")
|
||||
}
|
||||
|
||||
wsf, _ := setWindowsServiceFuncs(true, nil)
|
||||
|
||||
// act
|
||||
err := Run(prg)
|
||||
|
||||
// assert
|
||||
test.Equal(t, "start error", err.Error())
|
||||
|
||||
changes := wsf.changes
|
||||
|
||||
test.Equal(t, 1, startCalled)
|
||||
test.Equal(t, 0, stopCalled)
|
||||
test.Equal(t, 1, initCalled)
|
||||
|
||||
test.Equal(t, 0, len(changes))
|
||||
}
|
||||
|
||||
func TestRunWindowsService_BeforeStartError(t *testing.T) {
|
||||
// arrange
|
||||
var startCalled, stopCalled, initCalled int
|
||||
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
|
||||
prg.init = func(Environment) error {
|
||||
initCalled++
|
||||
return errors.New("before start error")
|
||||
}
|
||||
|
||||
wsf, _ := setWindowsServiceFuncs(false, nil)
|
||||
|
||||
// act
|
||||
err := Run(prg)
|
||||
|
||||
// assert
|
||||
test.Equal(t, "before start error", err.Error())
|
||||
|
||||
changes := wsf.changes
|
||||
|
||||
test.Equal(t, 0, startCalled)
|
||||
test.Equal(t, 0, stopCalled)
|
||||
test.Equal(t, 1, initCalled)
|
||||
|
||||
test.Equal(t, 0, len(changes))
|
||||
}
|
||||
|
||||
func TestRunWindowsService_IsAnInteractiveSessionError(t *testing.T) {
|
||||
// arrange
|
||||
var startCalled, stopCalled, initCalled int
|
||||
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
|
||||
|
||||
wsf, _ := setWindowsServiceFuncs(false, nil)
|
||||
wsf.svcIsInteractive = func() (bool, error) {
|
||||
return false, errors.New("IsAnInteractiveSession error")
|
||||
}
|
||||
|
||||
// act
|
||||
err := Run(prg)
|
||||
|
||||
// assert
|
||||
test.Equal(t, "IsAnInteractiveSession error", err.Error())
|
||||
|
||||
changes := wsf.changes
|
||||
|
||||
test.Equal(t, 0, startCalled)
|
||||
test.Equal(t, 0, stopCalled)
|
||||
test.Equal(t, 0, initCalled)
|
||||
|
||||
test.Equal(t, 0, len(changes))
|
||||
}
|
||||
|
||||
func TestRunWindowsServiceNonInteractive_RunError(t *testing.T) {
|
||||
// arrange
|
||||
var startCalled, stopCalled, initCalled int
|
||||
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
|
||||
|
||||
svcStop := wsvc.Stop
|
||||
wsf, _ := setWindowsServiceFuncs(false, &svcStop)
|
||||
wsf.svcRun = func(name string, handler wsvc.Handler) error {
|
||||
wsf.ws = handler.(*windowsService)
|
||||
return errors.New("wsvc.Run error")
|
||||
}
|
||||
|
||||
// act
|
||||
err := Run(prg)
|
||||
|
||||
// assert
|
||||
test.Equal(t, "wsvc.Run error", err.Error())
|
||||
|
||||
changes := wsf.changes
|
||||
|
||||
test.Equal(t, 0, startCalled)
|
||||
test.Equal(t, 0, stopCalled)
|
||||
test.Equal(t, 1, initCalled)
|
||||
|
||||
test.Equal(t, 0, len(changes))
|
||||
|
||||
test.Nil(t, wsf.ws.getError())
|
||||
}
|
||||
|
||||
func TestRunWindowsServiceNonInteractive_Interrogate(t *testing.T) {
|
||||
// arrange
|
||||
var startCalled, stopCalled, initCalled int
|
||||
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
|
||||
|
||||
wsf, changeRequest := setWindowsServiceFuncs(false, nil)
|
||||
|
||||
time.AfterFunc(50*time.Millisecond, func() {
|
||||
// ignored, PausePending won't be in changes slice
|
||||
// make sure we don't panic/err on unexpected values
|
||||
changeRequest <- wsvc.ChangeRequest{
|
||||
Cmd: wsvc.Pause,
|
||||
CurrentStatus: wsvc.Status{State: wsvc.PausePending},
|
||||
}
|
||||
})
|
||||
|
||||
time.AfterFunc(100*time.Millisecond, func() {
|
||||
// handled, Paused will be in changes slice
|
||||
changeRequest <- wsvc.ChangeRequest{
|
||||
Cmd: wsvc.Interrogate,
|
||||
CurrentStatus: wsvc.Status{State: wsvc.Paused},
|
||||
}
|
||||
})
|
||||
|
||||
time.AfterFunc(200*time.Millisecond, func() {
|
||||
// handled, but CurrentStatus overridden with StopPending;
|
||||
// ContinuePending won't be in changes slice
|
||||
changeRequest <- wsvc.ChangeRequest{
|
||||
Cmd: wsvc.Stop,
|
||||
CurrentStatus: wsvc.Status{State: wsvc.ContinuePending},
|
||||
}
|
||||
})
|
||||
|
||||
// act
|
||||
if err := Run(prg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// assert
|
||||
changes := wsf.changes
|
||||
|
||||
test.Equal(t, 1, startCalled)
|
||||
test.Equal(t, 1, stopCalled)
|
||||
test.Equal(t, 1, initCalled)
|
||||
|
||||
test.Equal(t, 4, len(changes))
|
||||
test.Equal(t, wsvc.StartPending, changes[0].State)
|
||||
test.Equal(t, wsvc.Running, changes[1].State)
|
||||
test.Equal(t, wsvc.Paused, changes[2].State)
|
||||
test.Equal(t, wsvc.StopPending, changes[3].State)
|
||||
|
||||
test.Equal(t, false, wsf.executeReturnedBool)
|
||||
test.Equal(t, uint32(0), wsf.executeReturnedUInt32)
|
||||
|
||||
test.Nil(t, wsf.ws.getError())
|
||||
}
|
||||
|
||||
func TestRunWindowsServiceInteractive_StopError(t *testing.T) {
|
||||
// arrange
|
||||
var startCalled, stopCalled, initCalled int
|
||||
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
|
||||
prg.stop = func() error {
|
||||
stopCalled++
|
||||
return errors.New("stop error")
|
||||
}
|
||||
|
||||
wsf, _ := setWindowsServiceFuncs(true, nil)
|
||||
|
||||
go func() {
|
||||
wsf.sigChan <- os.Interrupt
|
||||
}()
|
||||
|
||||
// act
|
||||
err := Run(prg)
|
||||
|
||||
// assert
|
||||
test.Equal(t, "stop error", err.Error())
|
||||
test.Equal(t, 1, startCalled)
|
||||
test.Equal(t, 1, stopCalled)
|
||||
test.Equal(t, 1, initCalled)
|
||||
test.Equal(t, 0, len(wsf.changes))
|
||||
}
|
||||
|
||||
func TestRunWindowsServiceNonInteractive_StopError(t *testing.T) {
|
||||
// arrange
|
||||
var startCalled, stopCalled, initCalled int
|
||||
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
|
||||
prg.stop = func() error {
|
||||
stopCalled++
|
||||
return errors.New("stop error")
|
||||
}
|
||||
|
||||
shutdownCmd := wsvc.Shutdown
|
||||
wsf, _ := setWindowsServiceFuncs(false, &shutdownCmd)
|
||||
|
||||
// act
|
||||
err := Run(prg)
|
||||
|
||||
// assert
|
||||
changes := wsf.changes
|
||||
|
||||
test.Equal(t, "stop error", err.Error())
|
||||
|
||||
test.Equal(t, 1, startCalled)
|
||||
test.Equal(t, 1, stopCalled)
|
||||
test.Equal(t, 1, initCalled)
|
||||
|
||||
test.Equal(t, 3, len(changes))
|
||||
test.Equal(t, wsvc.StartPending, changes[0].State)
|
||||
test.Equal(t, wsvc.Running, changes[1].State)
|
||||
test.Equal(t, wsvc.StopPending, changes[2].State)
|
||||
|
||||
test.Equal(t, true, wsf.executeReturnedBool)
|
||||
test.Equal(t, uint32(2), wsf.executeReturnedUInt32)
|
||||
|
||||
test.Equal(t, "stop error", wsf.ws.getError().Error())
|
||||
}
|
||||
|
||||
func TestDefaultSignalHandling(t *testing.T) {
|
||||
signals := []os.Signal{syscall.SIGINT} // default signal handled
|
||||
for _, signal := range signals {
|
||||
testSignalNotify(t, signal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserDefinedSignalHandling(t *testing.T) {
|
||||
signals := []os.Signal{syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP}
|
||||
for _, signal := range signals {
|
||||
testSignalNotify(t, signal, signals...)
|
||||
}
|
||||
}
|
||||
|
||||
func testSignalNotify(t *testing.T, signal os.Signal, sig ...os.Signal) {
|
||||
// arrange
|
||||
var startCalled, stopCalled, initCalled int
|
||||
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
|
||||
|
||||
wsf, _ := setWindowsServiceFuncs(true, nil)
|
||||
|
||||
go func() {
|
||||
wsf.sigChan <- signal
|
||||
}()
|
||||
|
||||
// act
|
||||
if err := Run(prg, sig...); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// assert
|
||||
test.Equal(t, 1, startCalled)
|
||||
test.Equal(t, 1, stopCalled)
|
||||
test.Equal(t, 1, initCalled)
|
||||
test.Equal(t, 0, len(wsf.changes))
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
# This source code refers to The Go Authors for copyright purposes.
|
||||
# The master list of authors is in the main Go distribution,
|
||||
# visible at http://tip.golang.org/AUTHORS.
|
|
@ -0,0 +1,3 @@
|
|||
# This source code was written by the Go contributors.
|
||||
# The master list of contributors is in the main Go distribution,
|
||||
# visible at http://tip.golang.org/CONTRIBUTORS.
|
|
@ -0,0 +1,27 @@
|
|||
Copyright (c) 2009 The Go Authors. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,22 @@
|
|||
Additional IP Rights Grant (Patents)
|
||||
|
||||
"This implementation" means the copyrightable works distributed by
|
||||
Google as part of the Go project.
|
||||
|
||||
Google hereby grants to You a perpetual, worldwide, non-exclusive,
|
||||
no-charge, royalty-free, irrevocable (except as stated in this section)
|
||||
patent license to make, have made, use, offer to sell, sell, import,
|
||||
transfer and otherwise run, modify and propagate the contents of this
|
||||
implementation of Go, where such license applies only to those patent
|
||||
claims, both currently owned or controlled by Google and acquired in
|
||||
the future, licensable by Google that are necessarily infringed by this
|
||||
implementation of Go. This grant does not include claims that would be
|
||||
infringed only as a consequence of further modification of this
|
||||
implementation. If you or your agent or exclusive licensee institute or
|
||||
order or agree to the institution of patent litigation against any
|
||||
entity (including a cross-claim or counterclaim in a lawsuit) alleging
|
||||
that this implementation of Go or any code incorporated within this
|
||||
implementation of Go constitutes direct or contributory patent
|
||||
infringement, or inducement of patent infringement, then any patent
|
||||
rights granted to you under this License for this implementation of Go
|
||||
shall terminate as of the date such litigation is filed.
|
|
@ -1,139 +0,0 @@
|
|||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build windows
|
||||
|
||||
package mgr
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unicode/utf16"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const (
|
||||
// Service start types.
|
||||
StartManual = windows.SERVICE_DEMAND_START // the service must be started manually
|
||||
StartAutomatic = windows.SERVICE_AUTO_START // the service will start by itself whenever the computer reboots
|
||||
StartDisabled = windows.SERVICE_DISABLED // the service cannot be started
|
||||
|
||||
// The severity of the error, and action taken,
|
||||
// if this service fails to start.
|
||||
ErrorCritical = windows.SERVICE_ERROR_CRITICAL
|
||||
ErrorIgnore = windows.SERVICE_ERROR_IGNORE
|
||||
ErrorNormal = windows.SERVICE_ERROR_NORMAL
|
||||
ErrorSevere = windows.SERVICE_ERROR_SEVERE
|
||||
)
|
||||
|
||||
// TODO(brainman): Password is not returned by windows.QueryServiceConfig, not sure how to get it.
|
||||
|
||||
type Config struct {
|
||||
ServiceType uint32
|
||||
StartType uint32
|
||||
ErrorControl uint32
|
||||
BinaryPathName string // fully qualified path to the service binary file, can also include arguments for an auto-start service
|
||||
LoadOrderGroup string
|
||||
TagId uint32
|
||||
Dependencies []string
|
||||
ServiceStartName string // name of the account under which the service should run
|
||||
DisplayName string
|
||||
Password string
|
||||
Description string
|
||||
}
|
||||
|
||||
func toString(p *uint16) string {
|
||||
if p == nil {
|
||||
return ""
|
||||
}
|
||||
return syscall.UTF16ToString((*[4096]uint16)(unsafe.Pointer(p))[:])
|
||||
}
|
||||
|
||||
func toStringSlice(ps *uint16) []string {
|
||||
if ps == nil {
|
||||
return nil
|
||||
}
|
||||
r := make([]string, 0)
|
||||
for from, i, p := 0, 0, (*[1 << 24]uint16)(unsafe.Pointer(ps)); true; i++ {
|
||||
if p[i] == 0 {
|
||||
// empty string marks the end
|
||||
if i <= from {
|
||||
break
|
||||
}
|
||||
r = append(r, string(utf16.Decode(p[from:i])))
|
||||
from = i + 1
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Config retrieves service s configuration paramteres.
|
||||
func (s *Service) Config() (Config, error) {
|
||||
var p *windows.QUERY_SERVICE_CONFIG
|
||||
n := uint32(1024)
|
||||
for {
|
||||
b := make([]byte, n)
|
||||
p = (*windows.QUERY_SERVICE_CONFIG)(unsafe.Pointer(&b[0]))
|
||||
err := windows.QueryServiceConfig(s.Handle, p, n, &n)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if err.(syscall.Errno) != syscall.ERROR_INSUFFICIENT_BUFFER {
|
||||
return Config{}, err
|
||||
}
|
||||
if n <= uint32(len(b)) {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
var p2 *windows.SERVICE_DESCRIPTION
|
||||
n = uint32(1024)
|
||||
for {
|
||||
b := make([]byte, n)
|
||||
p2 = (*windows.SERVICE_DESCRIPTION)(unsafe.Pointer(&b[0]))
|
||||
err := windows.QueryServiceConfig2(s.Handle,
|
||||
windows.SERVICE_CONFIG_DESCRIPTION, &b[0], n, &n)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if err.(syscall.Errno) != syscall.ERROR_INSUFFICIENT_BUFFER {
|
||||
return Config{}, err
|
||||
}
|
||||
if n <= uint32(len(b)) {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return Config{
|
||||
ServiceType: p.ServiceType,
|
||||
StartType: p.StartType,
|
||||
ErrorControl: p.ErrorControl,
|
||||
BinaryPathName: toString(p.BinaryPathName),
|
||||
LoadOrderGroup: toString(p.LoadOrderGroup),
|
||||
TagId: p.TagId,
|
||||
Dependencies: toStringSlice(p.Dependencies),
|
||||
ServiceStartName: toString(p.ServiceStartName),
|
||||
DisplayName: toString(p.DisplayName),
|
||||
Description: toString(p2.Description),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func updateDescription(handle windows.Handle, desc string) error {
|
||||
d := windows.SERVICE_DESCRIPTION{toPtr(desc)}
|
||||
return windows.ChangeServiceConfig2(handle,
|
||||
windows.SERVICE_CONFIG_DESCRIPTION, (*byte)(unsafe.Pointer(&d)))
|
||||
}
|
||||
|
||||
// UpdateConfig updates service s configuration parameters.
|
||||
func (s *Service) UpdateConfig(c Config) error {
|
||||
err := windows.ChangeServiceConfig(s.Handle, c.ServiceType, c.StartType,
|
||||
c.ErrorControl, toPtr(c.BinaryPathName), toPtr(c.LoadOrderGroup),
|
||||
nil, toStringBlock(c.Dependencies), toPtr(c.ServiceStartName),
|
||||
toPtr(c.Password), toPtr(c.DisplayName))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return updateDescription(s.Handle, c.Description)
|
||||
}
|
|
@ -1,162 +0,0 @@
|
|||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build windows
|
||||
|
||||
// Package mgr can be used to manage Windows service programs.
|
||||
// It can be used to install and remove them. It can also start,
|
||||
// stop and pause them. The package can query / change current
|
||||
// service state and config parameters.
|
||||
//
|
||||
package mgr
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unicode/utf16"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// Mgr is used to manage Windows service.
|
||||
type Mgr struct {
|
||||
Handle windows.Handle
|
||||
}
|
||||
|
||||
// Connect establishes a connection to the service control manager.
|
||||
func Connect() (*Mgr, error) {
|
||||
return ConnectRemote("")
|
||||
}
|
||||
|
||||
// ConnectRemote establishes a connection to the
|
||||
// service control manager on computer named host.
|
||||
func ConnectRemote(host string) (*Mgr, error) {
|
||||
var s *uint16
|
||||
if host != "" {
|
||||
s = syscall.StringToUTF16Ptr(host)
|
||||
}
|
||||
h, err := windows.OpenSCManager(s, nil, windows.SC_MANAGER_ALL_ACCESS)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Mgr{Handle: h}, nil
|
||||
}
|
||||
|
||||
// Disconnect closes connection to the service control manager m.
|
||||
func (m *Mgr) Disconnect() error {
|
||||
return windows.CloseServiceHandle(m.Handle)
|
||||
}
|
||||
|
||||
func toPtr(s string) *uint16 {
|
||||
if len(s) == 0 {
|
||||
return nil
|
||||
}
|
||||
return syscall.StringToUTF16Ptr(s)
|
||||
}
|
||||
|
||||
// toStringBlock terminates strings in ss with 0, and then
|
||||
// concatenates them together. It also adds extra 0 at the end.
|
||||
func toStringBlock(ss []string) *uint16 {
|
||||
if len(ss) == 0 {
|
||||
return nil
|
||||
}
|
||||
t := ""
|
||||
for _, s := range ss {
|
||||
if s != "" {
|
||||
t += s + "\x00"
|
||||
}
|
||||
}
|
||||
if t == "" {
|
||||
return nil
|
||||
}
|
||||
t += "\x00"
|
||||
return &utf16.Encode([]rune(t))[0]
|
||||
}
|
||||
|
||||
// CreateService installs new service name on the system.
|
||||
// The service will be executed by running exepath binary.
|
||||
// Use config c to specify service parameters.
|
||||
// Any args will be passed as command-line arguments when
|
||||
// the service is started; these arguments are distinct from
|
||||
// the arguments passed to Service.Start or via the "Start
|
||||
// parameters" field in the service's Properties dialog box.
|
||||
func (m *Mgr) CreateService(name, exepath string, c Config, args ...string) (*Service, error) {
|
||||
if c.StartType == 0 {
|
||||
c.StartType = StartManual
|
||||
}
|
||||
if c.ErrorControl == 0 {
|
||||
c.ErrorControl = ErrorNormal
|
||||
}
|
||||
if c.ServiceType == 0 {
|
||||
c.ServiceType = windows.SERVICE_WIN32_OWN_PROCESS
|
||||
}
|
||||
s := syscall.EscapeArg(exepath)
|
||||
for _, v := range args {
|
||||
s += " " + syscall.EscapeArg(v)
|
||||
}
|
||||
h, err := windows.CreateService(m.Handle, toPtr(name), toPtr(c.DisplayName),
|
||||
windows.SERVICE_ALL_ACCESS, c.ServiceType,
|
||||
c.StartType, c.ErrorControl, toPtr(s), toPtr(c.LoadOrderGroup),
|
||||
nil, toStringBlock(c.Dependencies), toPtr(c.ServiceStartName), toPtr(c.Password))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if c.Description != "" {
|
||||
err = updateDescription(h, c.Description)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &Service{Name: name, Handle: h}, nil
|
||||
}
|
||||
|
||||
// OpenService retrieves access to service name, so it can
|
||||
// be interrogated and controlled.
|
||||
func (m *Mgr) OpenService(name string) (*Service, error) {
|
||||
h, err := windows.OpenService(m.Handle, syscall.StringToUTF16Ptr(name), windows.SERVICE_ALL_ACCESS)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Service{Name: name, Handle: h}, nil
|
||||
}
|
||||
|
||||
// ListServices enumerates services in the specified
|
||||
// service control manager database m.
|
||||
// If the caller does not have the SERVICE_QUERY_STATUS
|
||||
// access right to a service, the service is silently
|
||||
// omitted from the list of services returned.
|
||||
func (m *Mgr) ListServices() ([]string, error) {
|
||||
var err error
|
||||
var bytesNeeded, servicesReturned uint32
|
||||
var buf []byte
|
||||
for {
|
||||
var p *byte
|
||||
if len(buf) > 0 {
|
||||
p = &buf[0]
|
||||
}
|
||||
err = windows.EnumServicesStatusEx(m.Handle, windows.SC_ENUM_PROCESS_INFO,
|
||||
windows.SERVICE_WIN32, windows.SERVICE_STATE_ALL,
|
||||
p, uint32(len(buf)), &bytesNeeded, &servicesReturned, nil, nil)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if err != syscall.ERROR_MORE_DATA {
|
||||
return nil, err
|
||||
}
|
||||
if bytesNeeded <= uint32(len(buf)) {
|
||||
return nil, err
|
||||
}
|
||||
buf = make([]byte, bytesNeeded)
|
||||
}
|
||||
if servicesReturned == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
services := (*[1 << 20]windows.ENUM_SERVICE_STATUS_PROCESS)(unsafe.Pointer(&buf[0]))[:servicesReturned]
|
||||
var names []string
|
||||
for _, s := range services {
|
||||
name := syscall.UTF16ToString((*[1 << 20]uint16)(unsafe.Pointer(s.ServiceName))[:])
|
||||
names = append(names, name)
|
||||
}
|
||||
return names, nil
|
||||
}
|
|
@ -1,169 +0,0 @@
|
|||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build windows
|
||||
|
||||
package mgr_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/windows/svc/mgr"
|
||||
)
|
||||
|
||||
func TestOpenLanManServer(t *testing.T) {
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
if errno, ok := err.(syscall.Errno); ok && errno == syscall.ERROR_ACCESS_DENIED {
|
||||
t.Skip("Skipping test: we don't have rights to manage services.")
|
||||
}
|
||||
t.Fatalf("SCM connection failed: %s", err)
|
||||
}
|
||||
defer m.Disconnect()
|
||||
|
||||
s, err := m.OpenService("LanmanServer")
|
||||
if err != nil {
|
||||
t.Fatalf("OpenService(lanmanserver) failed: %s", err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
_, err = s.Config()
|
||||
if err != nil {
|
||||
t.Fatalf("Config failed: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func install(t *testing.T, m *mgr.Mgr, name, exepath string, c mgr.Config) {
|
||||
// Sometimes it takes a while for the service to get
|
||||
// removed after previous test run.
|
||||
for i := 0; ; i++ {
|
||||
s, err := m.OpenService(name)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
s.Close()
|
||||
|
||||
if i > 10 {
|
||||
t.Fatalf("service %s already exists", name)
|
||||
}
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
}
|
||||
|
||||
s, err := m.CreateService(name, exepath, c)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateService(%s) failed: %v", name, err)
|
||||
}
|
||||
defer s.Close()
|
||||
}
|
||||
|
||||
func depString(d []string) string {
|
||||
if len(d) == 0 {
|
||||
return ""
|
||||
}
|
||||
for i := range d {
|
||||
d[i] = strings.ToLower(d[i])
|
||||
}
|
||||
ss := sort.StringSlice(d)
|
||||
ss.Sort()
|
||||
return strings.Join([]string(ss), " ")
|
||||
}
|
||||
|
||||
func testConfig(t *testing.T, s *mgr.Service, should mgr.Config) mgr.Config {
|
||||
is, err := s.Config()
|
||||
if err != nil {
|
||||
t.Fatalf("Config failed: %s", err)
|
||||
}
|
||||
if should.DisplayName != is.DisplayName {
|
||||
t.Fatalf("config mismatch: DisplayName is %q, but should have %q", is.DisplayName, should.DisplayName)
|
||||
}
|
||||
if should.StartType != is.StartType {
|
||||
t.Fatalf("config mismatch: StartType is %v, but should have %v", is.StartType, should.StartType)
|
||||
}
|
||||
if should.Description != is.Description {
|
||||
t.Fatalf("config mismatch: Description is %q, but should have %q", is.Description, should.Description)
|
||||
}
|
||||
if depString(should.Dependencies) != depString(is.Dependencies) {
|
||||
t.Fatalf("config mismatch: Dependencies is %v, but should have %v", is.Dependencies, should.Dependencies)
|
||||
}
|
||||
return is
|
||||
}
|
||||
|
||||
func remove(t *testing.T, s *mgr.Service) {
|
||||
err := s.Delete()
|
||||
if err != nil {
|
||||
t.Fatalf("Delete failed: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMyService(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode - it modifies system services")
|
||||
}
|
||||
|
||||
const name = "myservice"
|
||||
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
if errno, ok := err.(syscall.Errno); ok && errno == syscall.ERROR_ACCESS_DENIED {
|
||||
t.Skip("Skipping test: we don't have rights to manage services.")
|
||||
}
|
||||
t.Fatalf("SCM connection failed: %s", err)
|
||||
}
|
||||
defer m.Disconnect()
|
||||
|
||||
c := mgr.Config{
|
||||
StartType: mgr.StartDisabled,
|
||||
DisplayName: "my service",
|
||||
Description: "my service is just a test",
|
||||
Dependencies: []string{"LanmanServer", "W32Time"},
|
||||
}
|
||||
|
||||
exename := os.Args[0]
|
||||
exepath, err := filepath.Abs(exename)
|
||||
if err != nil {
|
||||
t.Fatalf("filepath.Abs(%s) failed: %s", exename, err)
|
||||
}
|
||||
|
||||
install(t, m, name, exepath, c)
|
||||
|
||||
s, err := m.OpenService(name)
|
||||
if err != nil {
|
||||
t.Fatalf("service %s is not installed", name)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
c.BinaryPathName = exepath
|
||||
c = testConfig(t, s, c)
|
||||
|
||||
c.StartType = mgr.StartManual
|
||||
err = s.UpdateConfig(c)
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateConfig failed: %v", err)
|
||||
}
|
||||
|
||||
testConfig(t, s, c)
|
||||
|
||||
svcnames, err := m.ListServices()
|
||||
if err != nil {
|
||||
t.Fatalf("ListServices failed: %v", err)
|
||||
}
|
||||
var myserviceIsInstalled bool
|
||||
for _, sn := range svcnames {
|
||||
if sn == name {
|
||||
myserviceIsInstalled = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !myserviceIsInstalled {
|
||||
t.Errorf("ListServices failed to find %q service", name)
|
||||
}
|
||||
|
||||
remove(t, s)
|
||||
}
|
|
@ -1,72 +0,0 @@
|
|||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build windows
|
||||
|
||||
package mgr
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/svc"
|
||||
)
|
||||
|
||||
// TODO(brainman): Use EnumDependentServices to enumerate dependent services.
|
||||
|
||||
// Service is used to access Windows service.
|
||||
type Service struct {
|
||||
Name string
|
||||
Handle windows.Handle
|
||||
}
|
||||
|
||||
// Delete marks service s for deletion from the service control manager database.
|
||||
func (s *Service) Delete() error {
|
||||
return windows.DeleteService(s.Handle)
|
||||
}
|
||||
|
||||
// Close relinquish access to the service s.
|
||||
func (s *Service) Close() error {
|
||||
return windows.CloseServiceHandle(s.Handle)
|
||||
}
|
||||
|
||||
// Start starts service s.
|
||||
// args will be passed to svc.Handler.Execute.
|
||||
func (s *Service) Start(args ...string) error {
|
||||
var p **uint16
|
||||
if len(args) > 0 {
|
||||
vs := make([]*uint16, len(args))
|
||||
for i := range vs {
|
||||
vs[i] = syscall.StringToUTF16Ptr(args[i])
|
||||
}
|
||||
p = &vs[0]
|
||||
}
|
||||
return windows.StartService(s.Handle, uint32(len(args)), p)
|
||||
}
|
||||
|
||||
// Control sends state change request c to the servce s.
|
||||
func (s *Service) Control(c svc.Cmd) (svc.Status, error) {
|
||||
var t windows.SERVICE_STATUS
|
||||
err := windows.ControlService(s.Handle, uint32(c), &t)
|
||||
if err != nil {
|
||||
return svc.Status{}, err
|
||||
}
|
||||
return svc.Status{
|
||||
State: svc.State(t.CurrentState),
|
||||
Accepts: svc.Accepted(t.ControlsAccepted),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Query returns current status of service s.
|
||||
func (s *Service) Query() (svc.Status, error) {
|
||||
var t windows.SERVICE_STATUS
|
||||
err := windows.QueryServiceStatus(s.Handle, &t)
|
||||
if err != nil {
|
||||
return svc.Status{}, err
|
||||
}
|
||||
return svc.Status{
|
||||
State: svc.State(t.CurrentState),
|
||||
Accepts: svc.Accepted(t.ControlsAccepted),
|
||||
}, nil
|
||||
}
|
|
@ -1,118 +0,0 @@
|
|||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build windows
|
||||
|
||||
package svc_test
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/windows/svc"
|
||||
"golang.org/x/sys/windows/svc/mgr"
|
||||
)
|
||||
|
||||
func getState(t *testing.T, s *mgr.Service) svc.State {
|
||||
status, err := s.Query()
|
||||
if err != nil {
|
||||
t.Fatalf("Query(%s) failed: %s", s.Name, err)
|
||||
}
|
||||
return status.State
|
||||
}
|
||||
|
||||
func testState(t *testing.T, s *mgr.Service, want svc.State) {
|
||||
have := getState(t, s)
|
||||
if have != want {
|
||||
t.Fatalf("%s state is=%d want=%d", s.Name, have, want)
|
||||
}
|
||||
}
|
||||
|
||||
func waitState(t *testing.T, s *mgr.Service, want svc.State) {
|
||||
for i := 0; ; i++ {
|
||||
have := getState(t, s)
|
||||
if have == want {
|
||||
return
|
||||
}
|
||||
if i > 10 {
|
||||
t.Fatalf("%s state is=%d, waiting timeout", s.Name, have)
|
||||
}
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExample(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode - it modifies system services")
|
||||
}
|
||||
|
||||
const name = "myservice"
|
||||
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("SCM connection failed: %s", err)
|
||||
}
|
||||
defer m.Disconnect()
|
||||
|
||||
dir, err := ioutil.TempDir("", "svc")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
exepath := filepath.Join(dir, "a.exe")
|
||||
o, err := exec.Command("go", "build", "-o", exepath, "golang.org/x/sys/windows/svc/example").CombinedOutput()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to build service program: %v\n%v", err, string(o))
|
||||
}
|
||||
|
||||
s, err := m.OpenService(name)
|
||||
if err == nil {
|
||||
err = s.Delete()
|
||||
if err != nil {
|
||||
s.Close()
|
||||
t.Fatalf("Delete failed: %s", err)
|
||||
}
|
||||
s.Close()
|
||||
}
|
||||
s, err = m.CreateService(name, exepath, mgr.Config{DisplayName: "my service"}, "is", "auto-started")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateService(%s) failed: %v", name, err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
testState(t, s, svc.Stopped)
|
||||
err = s.Start("is", "manual-started")
|
||||
if err != nil {
|
||||
t.Fatalf("Start(%s) failed: %s", s.Name, err)
|
||||
}
|
||||
waitState(t, s, svc.Running)
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// testing deadlock from issues 4.
|
||||
_, err = s.Control(svc.Interrogate)
|
||||
if err != nil {
|
||||
t.Fatalf("Control(%s) failed: %s", s.Name, err)
|
||||
}
|
||||
_, err = s.Control(svc.Interrogate)
|
||||
if err != nil {
|
||||
t.Fatalf("Control(%s) failed: %s", s.Name, err)
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
_, err = s.Control(svc.Stop)
|
||||
if err != nil {
|
||||
t.Fatalf("Control(%s) failed: %s", s.Name, err)
|
||||
}
|
||||
waitState(t, s, svc.Stopped)
|
||||
|
||||
err = s.Delete()
|
||||
if err != nil {
|
||||
t.Fatalf("Delete failed: %s", err)
|
||||
}
|
||||
}
|
|
@ -1,53 +0,0 @@
|
|||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build windows
|
||||
|
||||
package windows_test
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func testSetGetenv(t *testing.T, key, value string) {
|
||||
err := windows.Setenv(key, value)
|
||||
if err != nil {
|
||||
t.Fatalf("Setenv failed to set %q: %v", value, err)
|
||||
}
|
||||
newvalue, found := windows.Getenv(key)
|
||||
if !found {
|
||||
t.Fatalf("Getenv failed to find %v variable (want value %q)", key, value)
|
||||
}
|
||||
if newvalue != value {
|
||||
t.Fatalf("Getenv(%v) = %q; want %q", key, newvalue, value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnv(t *testing.T) {
|
||||
testSetGetenv(t, "TESTENV", "AVALUE")
|
||||
// make sure TESTENV gets set to "", not deleted
|
||||
testSetGetenv(t, "TESTENV", "")
|
||||
}
|
||||
|
||||
func TestGetProcAddressByOrdinal(t *testing.T) {
|
||||
// Attempt calling shlwapi.dll:IsOS, resolving it by ordinal, as
|
||||
// suggested in
|
||||
// https://msdn.microsoft.com/en-us/library/windows/desktop/bb773795.aspx
|
||||
h, err := windows.LoadLibrary("shlwapi.dll")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load shlwapi.dll: %s", err)
|
||||
}
|
||||
procIsOS, err := windows.GetProcAddressByOrdinal(h, 437)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not find shlwapi.dll:IsOS by ordinal: %s", err)
|
||||
}
|
||||
const OS_NT = 1
|
||||
r, _, _ := syscall.Syscall(procIsOS, 1, OS_NT, 0, 0)
|
||||
if r == 0 {
|
||||
t.Error("shlwapi.dll:IsOS(OS_NT) returned 0, expected non-zero value")
|
||||
}
|
||||
}
|
|
@ -1,107 +0,0 @@
|
|||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package windows_test
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"testing"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func TestWin32finddata(t *testing.T) {
|
||||
dir, err := ioutil.TempDir("", "go-build")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
path := filepath.Join(dir, "long_name.and_extension")
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create %v: %v", path, err)
|
||||
}
|
||||
f.Close()
|
||||
|
||||
type X struct {
|
||||
fd windows.Win32finddata
|
||||
got byte
|
||||
pad [10]byte // to protect ourselves
|
||||
|
||||
}
|
||||
var want byte = 2 // it is unlikely to have this character in the filename
|
||||
x := X{got: want}
|
||||
|
||||
pathp, _ := windows.UTF16PtrFromString(path)
|
||||
h, err := windows.FindFirstFile(pathp, &(x.fd))
|
||||
if err != nil {
|
||||
t.Fatalf("FindFirstFile failed: %v", err)
|
||||
}
|
||||
err = windows.FindClose(h)
|
||||
if err != nil {
|
||||
t.Fatalf("FindClose failed: %v", err)
|
||||
}
|
||||
|
||||
if x.got != want {
|
||||
t.Fatalf("memory corruption: want=%d got=%d", want, x.got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatMessage(t *testing.T) {
|
||||
dll := windows.MustLoadDLL("pdh.dll")
|
||||
|
||||
pdhOpenQuery := func(datasrc *uint16, userdata uint32, query *windows.Handle) (errno uintptr) {
|
||||
r0, _, _ := syscall.Syscall(dll.MustFindProc("PdhOpenQueryW").Addr(), 3, uintptr(unsafe.Pointer(datasrc)), uintptr(userdata), uintptr(unsafe.Pointer(query)))
|
||||
return r0
|
||||
}
|
||||
|
||||
pdhCloseQuery := func(query windows.Handle) (errno uintptr) {
|
||||
r0, _, _ := syscall.Syscall(dll.MustFindProc("PdhCloseQuery").Addr(), 1, uintptr(query), 0, 0)
|
||||
return r0
|
||||
}
|
||||
|
||||
var q windows.Handle
|
||||
name, err := windows.UTF16PtrFromString("no_such_source")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
errno := pdhOpenQuery(name, 0, &q)
|
||||
if errno == 0 {
|
||||
pdhCloseQuery(q)
|
||||
t.Fatal("PdhOpenQuery succeeded, but expected to fail.")
|
||||
}
|
||||
|
||||
const flags uint32 = syscall.FORMAT_MESSAGE_FROM_HMODULE | syscall.FORMAT_MESSAGE_ARGUMENT_ARRAY | syscall.FORMAT_MESSAGE_IGNORE_INSERTS
|
||||
buf := make([]uint16, 300)
|
||||
_, err = windows.FormatMessage(flags, uintptr(dll.Handle), uint32(errno), 0, buf, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("FormatMessage for handle=%x and errno=%x failed: %v", dll.Handle, errno, err)
|
||||
}
|
||||
}
|
||||
|
||||
func abort(funcname string, err error) {
|
||||
panic(funcname + " failed: " + err.Error())
|
||||
}
|
||||
|
||||
func ExampleLoadLibrary() {
|
||||
h, err := windows.LoadLibrary("kernel32.dll")
|
||||
if err != nil {
|
||||
abort("LoadLibrary", err)
|
||||
}
|
||||
defer windows.FreeLibrary(h)
|
||||
proc, err := windows.GetProcAddress(h, "GetVersion")
|
||||
if err != nil {
|
||||
abort("GetProcAddress", err)
|
||||
}
|
||||
r, _, _ := syscall.Syscall(uintptr(proc), 0, 0, 0, 0)
|
||||
major := byte(r)
|
||||
minor := uint8(r >> 8)
|
||||
build := uint16(r >> 16)
|
||||
print("windows version ", major, ".", minor, " (Build ", build, ")\n")
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
# github.com/judwhite/go-svc v1.0.0
|
||||
github.com/judwhite/go-svc/svc
|
||||
# golang.org/x/sys v0.0.0-20180322165403-91ee8cde4354
|
||||
golang.org/x/sys/windows/svc
|
||||
golang.org/x/sys/windows
|
|
@ -1,4 +0,0 @@
|
|||
MODULE VERSION
|
||||
github.com/productionwentdown/forward -
|
||||
github.com/judwhite/go-svc v1.0.0
|
||||
golang.org/x/sys v0.0.0-20180322165403-91ee8cde4354
|
Loading…
Reference in New Issue