Add SSH connection option
parent
bb031be2f1
commit
089bc9c6e4
6
go.mod
6
go.mod
|
@ -1,6 +1,6 @@
|
||||||
module "github.com/productionwentdown/forward"
|
module github.com/productionwentdown/forward
|
||||||
|
|
||||||
require (
|
require (
|
||||||
"github.com/judwhite/go-svc" v1.0.0
|
github.com/judwhite/go-svc v1.0.0
|
||||||
"golang.org/x/sys" v0.0.0-20180322165403-91ee8cde4354
|
golang.org/x/sys v0.0.0-20180322165403-91ee8cde4354 // indirect
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
github.com/judwhite/go-svc v1.0.0/go.mod h1:EeMSAFO3mLgEQfcvnZ50JDG0O1uQlagpAbMS6talrXE=
|
||||||
|
golang.org/x/sys v0.0.0-20180322165403-91ee8cde4354/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
66
server.go
66
server.go
|
@ -1,6 +1,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"flag"
|
"flag"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
@ -9,13 +10,16 @@ import (
|
||||||
|
|
||||||
var listen string
|
var listen string
|
||||||
var connect string
|
var connect string
|
||||||
|
var connectSSH string
|
||||||
|
|
||||||
var ln net.Listener
|
var ln net.Listener
|
||||||
var conn *net.TCPAddr
|
var conn *net.TCPAddr
|
||||||
|
var connSSH *net.TCPAddr
|
||||||
|
|
||||||
func setup() {
|
func setup() {
|
||||||
flag.StringVar(&listen, "listen", ":8000", "listen on ip and port")
|
flag.StringVar(&listen, "listen", ":8000", "listen on address")
|
||||||
flag.StringVar(&connect, "connect", "", "forward to ip and port")
|
flag.StringVar(&connect, "connect", "", "forward to address")
|
||||||
|
flag.StringVar(&connectSSH, "ssh", "", "if set, will do basic introspection to forward SSH traffic to this address")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
@ -27,6 +31,12 @@ func setup() {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check and parse SSH address
|
||||||
|
connSSH, _ = net.ResolveTCPAddr("tcp", connectSSH)
|
||||||
|
if connectSSH == "" {
|
||||||
|
connSSH = nil
|
||||||
|
}
|
||||||
|
|
||||||
// listen on address
|
// listen on address
|
||||||
ln, err = net.Listen("tcp", listen)
|
ln, err = net.Listen("tcp", listen)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -36,6 +46,9 @@ func setup() {
|
||||||
|
|
||||||
log.Printf("listening on %v", ln.Addr())
|
log.Printf("listening on %v", ln.Addr())
|
||||||
log.Printf("will connect to %v", conn)
|
log.Printf("will connect to %v", conn)
|
||||||
|
if connSSH != nil {
|
||||||
|
log.Printf("will connect SSH to %v", connSSH)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func serve() {
|
func serve() {
|
||||||
|
@ -48,19 +61,48 @@ func serve() {
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("connection %v from %v", i, c.RemoteAddr())
|
log.Printf("connection %v from %v", i, c.RemoteAddr())
|
||||||
|
go handle(c, i)
|
||||||
cn, err := net.DialTCP("tcp", nil, conn)
|
|
||||||
if err != nil {
|
|
||||||
c.Close()
|
|
||||||
log.Print(err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
go pipe(c, cn, i)
|
|
||||||
go pipe(cn, c, i)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var magic = []byte{'S', 'S', 'H', '-'}
|
||||||
|
|
||||||
|
var magicLen = len(magic)
|
||||||
|
|
||||||
|
func handle(c net.Conn, count int) {
|
||||||
|
// read first four characters
|
||||||
|
readMagic := make([]byte, magicLen, magicLen)
|
||||||
|
n, err := c.Read(readMagic)
|
||||||
|
if n != magicLen {
|
||||||
|
log.Printf("warning! could not read header")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
opError, ok := err.(*net.OpError)
|
||||||
|
if err != nil && (!ok || opError.Op != "readfrom") {
|
||||||
|
log.Printf("warning! %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
connTo := conn
|
||||||
|
// if the header looks like SSH, forward to SSH connection
|
||||||
|
if bytes.Equal(readMagic, magic) {
|
||||||
|
connTo = connSSH
|
||||||
|
}
|
||||||
|
|
||||||
|
cn, err := net.DialTCP("tcp", nil, connTo)
|
||||||
|
if err != nil {
|
||||||
|
c.Close()
|
||||||
|
log.Print(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// write the first four characters
|
||||||
|
cn.Write(readMagic)
|
||||||
|
|
||||||
|
go pipe(c, cn, count)
|
||||||
|
go pipe(cn, c, count)
|
||||||
|
}
|
||||||
|
|
||||||
func pipe(w io.WriteCloser, r io.ReadCloser, count int) {
|
func pipe(w io.WriteCloser, r io.ReadCloser, count int) {
|
||||||
n, err := io.Copy(w, r)
|
n, err := io.Copy(w, r)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2016 Jud White
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
|
@ -1,62 +0,0 @@
|
||||||
package test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"path/filepath"
|
|
||||||
"reflect"
|
|
||||||
"runtime"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Equal asserts two parameters are equal by using reflect.DeepEqual.
|
|
||||||
func Equal(t *testing.T, expected, actual interface{}) {
|
|
||||||
if !reflect.DeepEqual(expected, actual) {
|
|
||||||
_, file, line, _ := runtime.Caller(1)
|
|
||||||
t.Logf("\033[31m%s:%d:\n\n\t %#v (expected)\n\n\t!= %#v (actual)\033[39m\n\n",
|
|
||||||
filepath.Base(file), line, expected, actual)
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NotEqual asserts two parameters are not equal by using reflect.DeepEqual.
|
|
||||||
func NotEqual(t *testing.T, expected, actual interface{}) {
|
|
||||||
if !reflect.DeepEqual(expected, actual) {
|
|
||||||
_, file, line, _ := runtime.Caller(1)
|
|
||||||
t.Logf("\033[31m%s:%d:\n\n\tvalue should not equal %#v\033[39m\n\n",
|
|
||||||
filepath.Base(file), line, actual)
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Nil asserts the parameter is nil.
|
|
||||||
func Nil(t *testing.T, object interface{}) {
|
|
||||||
if !isNil(object) {
|
|
||||||
_, file, line, _ := runtime.Caller(1)
|
|
||||||
t.Logf("\033[31m%s:%d:\n\n\t <nil> (expected)\n\n\t!= %#v (actual)\033[39m\n\n",
|
|
||||||
filepath.Base(file), line, object)
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NotNil asserts the parameter is not nil.
|
|
||||||
func NotNil(t *testing.T, object interface{}) {
|
|
||||||
if isNil(object) {
|
|
||||||
_, file, line, _ := runtime.Caller(1)
|
|
||||||
t.Logf("\033[31m%s:%d:\n\n\tExpected value not to be <nil>\033[39m\n\n",
|
|
||||||
filepath.Base(file), line)
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func isNil(object interface{}) bool {
|
|
||||||
if object == nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
value := reflect.ValueOf(object)
|
|
||||||
kind := value.Kind()
|
|
||||||
if kind >= reflect.Chan && kind <= reflect.Slice && value.IsNil() {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
|
@ -1,36 +0,0 @@
|
||||||
package svc
|
|
||||||
|
|
||||||
type mockProgram struct {
|
|
||||||
start func() error
|
|
||||||
stop func() error
|
|
||||||
init func(Environment) error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *mockProgram) Start() error {
|
|
||||||
return p.start()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *mockProgram) Stop() error {
|
|
||||||
return p.stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *mockProgram) Init(wse Environment) error {
|
|
||||||
return p.init(wse)
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeProgram(startCalled, stopCalled, initCalled *int) *mockProgram {
|
|
||||||
return &mockProgram{
|
|
||||||
start: func() error {
|
|
||||||
*startCalled++
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
stop: func() error {
|
|
||||||
*stopCalled++
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
init: func(wse Environment) error {
|
|
||||||
*initCalled++
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,68 +0,0 @@
|
||||||
// +build !windows
|
|
||||||
|
|
||||||
package svc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"syscall"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/judwhite/go-svc/svc/internal/test"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestDefaultSignalHandling(t *testing.T) {
|
|
||||||
signals := []os.Signal{syscall.SIGINT, syscall.SIGTERM} // default signals handled
|
|
||||||
for _, signal := range signals {
|
|
||||||
testSignalNotify(t, signal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUserDefinedSignalHandling(t *testing.T) {
|
|
||||||
signals := []os.Signal{syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP}
|
|
||||||
for _, signal := range signals {
|
|
||||||
testSignalNotify(t, signal, signals...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func testSignalNotify(t *testing.T, signal os.Signal, sig ...os.Signal) {
|
|
||||||
// arrange
|
|
||||||
|
|
||||||
// sigChan is the chan we'll send to here. if a signal matches a registered signal
|
|
||||||
// type in the Run function (in svc_other.go) the signal will be delegated to the
|
|
||||||
// channel passed to signalNotify, which is created in the Run function in svc_other.go.
|
|
||||||
// shortly: we send here and the Run function gets it if it matches the filter.
|
|
||||||
sigChan := make(chan os.Signal)
|
|
||||||
|
|
||||||
var startCalled, stopCalled, initCalled int
|
|
||||||
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
|
|
||||||
|
|
||||||
signalNotify = func(c chan<- os.Signal, sig ...os.Signal) {
|
|
||||||
if c == nil {
|
|
||||||
panic("os/signal: Notify using nil channel")
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for val := range sigChan {
|
|
||||||
for _, registeredSig := range sig {
|
|
||||||
if val == registeredSig {
|
|
||||||
c <- val
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
sigChan <- signal
|
|
||||||
}()
|
|
||||||
|
|
||||||
// act
|
|
||||||
if err := Run(prg, sig...); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// assert
|
|
||||||
test.Equal(t, 1, startCalled)
|
|
||||||
test.Equal(t, 1, stopCalled)
|
|
||||||
test.Equal(t, 1, initCalled)
|
|
||||||
}
|
|
|
@ -1,438 +0,0 @@
|
||||||
// +build windows
|
|
||||||
|
|
||||||
package svc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"os"
|
|
||||||
"syscall"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/judwhite/go-svc/svc/internal/test"
|
|
||||||
wsvc "golang.org/x/sys/windows/svc"
|
|
||||||
)
|
|
||||||
|
|
||||||
func setupWinServiceTest(wsf *mockWinServiceFuncs) {
|
|
||||||
// wsfWrapper allows signalNotify, svcIsInteractive, and svcRun to be set once.
|
|
||||||
// Inidivual test functions set "wsf" to add behavior.
|
|
||||||
wsfWrapper := &mockWinServiceFuncs{
|
|
||||||
signalNotify: func(c chan<- os.Signal, sig ...os.Signal) {
|
|
||||||
if c == nil {
|
|
||||||
panic("os/signal: Notify using nil channel")
|
|
||||||
}
|
|
||||||
|
|
||||||
if wsf.signalNotify != nil {
|
|
||||||
wsf.signalNotify(c, sig...)
|
|
||||||
} else {
|
|
||||||
wsf1 := *wsf
|
|
||||||
go func() {
|
|
||||||
for val := range wsf1.sigChan {
|
|
||||||
for _, registeredSig := range sig {
|
|
||||||
if val == registeredSig {
|
|
||||||
c <- val
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
},
|
|
||||||
svcIsInteractive: func() (bool, error) {
|
|
||||||
return wsf.svcIsInteractive()
|
|
||||||
},
|
|
||||||
svcRun: func(name string, handler wsvc.Handler) error {
|
|
||||||
return wsf.svcRun(name, handler)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
signalNotify = wsfWrapper.signalNotify
|
|
||||||
svcIsAnInteractiveSession = wsfWrapper.svcIsInteractive
|
|
||||||
svcRun = wsfWrapper.svcRun
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockWinServiceFuncs struct {
|
|
||||||
signalNotify func(chan<- os.Signal, ...os.Signal)
|
|
||||||
svcIsInteractive func() (bool, error)
|
|
||||||
sigChan chan os.Signal
|
|
||||||
svcRun func(string, wsvc.Handler) error
|
|
||||||
ws *windowsService
|
|
||||||
executeReturnedBool bool
|
|
||||||
executeReturnedUInt32 uint32
|
|
||||||
changes []wsvc.Status
|
|
||||||
}
|
|
||||||
|
|
||||||
func setWindowsServiceFuncs(isInteractive bool, onRunningSendCmd *wsvc.Cmd) (*mockWinServiceFuncs, chan<- wsvc.ChangeRequest) {
|
|
||||||
changeRequestChan := make(chan wsvc.ChangeRequest, 4)
|
|
||||||
changesChan := make(chan wsvc.Status)
|
|
||||||
done := make(chan struct{})
|
|
||||||
|
|
||||||
var wsf *mockWinServiceFuncs
|
|
||||||
wsf = &mockWinServiceFuncs{
|
|
||||||
sigChan: make(chan os.Signal),
|
|
||||||
svcIsInteractive: func() (bool, error) {
|
|
||||||
return isInteractive, nil
|
|
||||||
},
|
|
||||||
svcRun: func(name string, handler wsvc.Handler) error {
|
|
||||||
wsf.ws = handler.(*windowsService)
|
|
||||||
wsf.executeReturnedBool, wsf.executeReturnedUInt32 = handler.Execute(nil, changeRequestChan, changesChan)
|
|
||||||
done <- struct{}{}
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var currentState wsvc.State
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
loop:
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case change := <-changesChan:
|
|
||||||
wsf.changes = append(wsf.changes, change)
|
|
||||||
currentState = change.State
|
|
||||||
|
|
||||||
if change.State == wsvc.Running && onRunningSendCmd != nil {
|
|
||||||
changeRequestChan <- wsvc.ChangeRequest{
|
|
||||||
Cmd: *onRunningSendCmd,
|
|
||||||
CurrentStatus: wsvc.Status{State: currentState},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case <-done:
|
|
||||||
break loop
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
setupWinServiceTest(wsf)
|
|
||||||
|
|
||||||
return wsf, changeRequestChan
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWinService_RunWindowsService_NonInteractive(t *testing.T) {
|
|
||||||
for _, svcCmd := range []wsvc.Cmd{wsvc.Stop, wsvc.Shutdown} {
|
|
||||||
testRunWindowsServiceNonInteractive(t, svcCmd)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func testRunWindowsServiceNonInteractive(t *testing.T, svcCmd wsvc.Cmd) {
|
|
||||||
// arrange
|
|
||||||
var startCalled, stopCalled, initCalled int
|
|
||||||
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
|
|
||||||
|
|
||||||
wsf, _ := setWindowsServiceFuncs(false, &svcCmd)
|
|
||||||
|
|
||||||
// act
|
|
||||||
if err := Run(prg); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// assert
|
|
||||||
changes := wsf.changes
|
|
||||||
|
|
||||||
test.Equal(t, 1, startCalled)
|
|
||||||
test.Equal(t, 1, stopCalled)
|
|
||||||
test.Equal(t, 1, initCalled)
|
|
||||||
|
|
||||||
test.Equal(t, 3, len(changes))
|
|
||||||
test.Equal(t, wsvc.StartPending, changes[0].State)
|
|
||||||
test.Equal(t, wsvc.Running, changes[1].State)
|
|
||||||
test.Equal(t, wsvc.StopPending, changes[2].State)
|
|
||||||
|
|
||||||
test.Equal(t, false, wsf.executeReturnedBool)
|
|
||||||
test.Equal(t, uint32(0), wsf.executeReturnedUInt32)
|
|
||||||
|
|
||||||
test.Nil(t, wsf.ws.getError())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunWindowsServiceNonInteractive_StartError(t *testing.T) {
|
|
||||||
// arrange
|
|
||||||
var startCalled, stopCalled, initCalled int
|
|
||||||
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
|
|
||||||
prg.start = func() error {
|
|
||||||
startCalled++
|
|
||||||
return errors.New("start error")
|
|
||||||
}
|
|
||||||
|
|
||||||
svcStop := wsvc.Stop
|
|
||||||
wsf, _ := setWindowsServiceFuncs(false, &svcStop)
|
|
||||||
|
|
||||||
// act
|
|
||||||
err := Run(prg)
|
|
||||||
|
|
||||||
// assert
|
|
||||||
test.Equal(t, "start error", err.Error())
|
|
||||||
|
|
||||||
changes := wsf.changes
|
|
||||||
|
|
||||||
test.Equal(t, 1, startCalled)
|
|
||||||
test.Equal(t, 0, stopCalled)
|
|
||||||
test.Equal(t, 1, initCalled)
|
|
||||||
|
|
||||||
test.Equal(t, 1, len(changes))
|
|
||||||
test.Equal(t, wsvc.StartPending, changes[0].State)
|
|
||||||
|
|
||||||
test.Equal(t, true, wsf.executeReturnedBool)
|
|
||||||
test.Equal(t, uint32(1), wsf.executeReturnedUInt32)
|
|
||||||
|
|
||||||
test.Equal(t, "start error", wsf.ws.getError().Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunWindowsServiceInteractive_StartError(t *testing.T) {
|
|
||||||
// arrange
|
|
||||||
var startCalled, stopCalled, initCalled int
|
|
||||||
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
|
|
||||||
prg.start = func() error {
|
|
||||||
startCalled++
|
|
||||||
return errors.New("start error")
|
|
||||||
}
|
|
||||||
|
|
||||||
wsf, _ := setWindowsServiceFuncs(true, nil)
|
|
||||||
|
|
||||||
// act
|
|
||||||
err := Run(prg)
|
|
||||||
|
|
||||||
// assert
|
|
||||||
test.Equal(t, "start error", err.Error())
|
|
||||||
|
|
||||||
changes := wsf.changes
|
|
||||||
|
|
||||||
test.Equal(t, 1, startCalled)
|
|
||||||
test.Equal(t, 0, stopCalled)
|
|
||||||
test.Equal(t, 1, initCalled)
|
|
||||||
|
|
||||||
test.Equal(t, 0, len(changes))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunWindowsService_BeforeStartError(t *testing.T) {
|
|
||||||
// arrange
|
|
||||||
var startCalled, stopCalled, initCalled int
|
|
||||||
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
|
|
||||||
prg.init = func(Environment) error {
|
|
||||||
initCalled++
|
|
||||||
return errors.New("before start error")
|
|
||||||
}
|
|
||||||
|
|
||||||
wsf, _ := setWindowsServiceFuncs(false, nil)
|
|
||||||
|
|
||||||
// act
|
|
||||||
err := Run(prg)
|
|
||||||
|
|
||||||
// assert
|
|
||||||
test.Equal(t, "before start error", err.Error())
|
|
||||||
|
|
||||||
changes := wsf.changes
|
|
||||||
|
|
||||||
test.Equal(t, 0, startCalled)
|
|
||||||
test.Equal(t, 0, stopCalled)
|
|
||||||
test.Equal(t, 1, initCalled)
|
|
||||||
|
|
||||||
test.Equal(t, 0, len(changes))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunWindowsService_IsAnInteractiveSessionError(t *testing.T) {
|
|
||||||
// arrange
|
|
||||||
var startCalled, stopCalled, initCalled int
|
|
||||||
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
|
|
||||||
|
|
||||||
wsf, _ := setWindowsServiceFuncs(false, nil)
|
|
||||||
wsf.svcIsInteractive = func() (bool, error) {
|
|
||||||
return false, errors.New("IsAnInteractiveSession error")
|
|
||||||
}
|
|
||||||
|
|
||||||
// act
|
|
||||||
err := Run(prg)
|
|
||||||
|
|
||||||
// assert
|
|
||||||
test.Equal(t, "IsAnInteractiveSession error", err.Error())
|
|
||||||
|
|
||||||
changes := wsf.changes
|
|
||||||
|
|
||||||
test.Equal(t, 0, startCalled)
|
|
||||||
test.Equal(t, 0, stopCalled)
|
|
||||||
test.Equal(t, 0, initCalled)
|
|
||||||
|
|
||||||
test.Equal(t, 0, len(changes))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunWindowsServiceNonInteractive_RunError(t *testing.T) {
|
|
||||||
// arrange
|
|
||||||
var startCalled, stopCalled, initCalled int
|
|
||||||
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
|
|
||||||
|
|
||||||
svcStop := wsvc.Stop
|
|
||||||
wsf, _ := setWindowsServiceFuncs(false, &svcStop)
|
|
||||||
wsf.svcRun = func(name string, handler wsvc.Handler) error {
|
|
||||||
wsf.ws = handler.(*windowsService)
|
|
||||||
return errors.New("wsvc.Run error")
|
|
||||||
}
|
|
||||||
|
|
||||||
// act
|
|
||||||
err := Run(prg)
|
|
||||||
|
|
||||||
// assert
|
|
||||||
test.Equal(t, "wsvc.Run error", err.Error())
|
|
||||||
|
|
||||||
changes := wsf.changes
|
|
||||||
|
|
||||||
test.Equal(t, 0, startCalled)
|
|
||||||
test.Equal(t, 0, stopCalled)
|
|
||||||
test.Equal(t, 1, initCalled)
|
|
||||||
|
|
||||||
test.Equal(t, 0, len(changes))
|
|
||||||
|
|
||||||
test.Nil(t, wsf.ws.getError())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunWindowsServiceNonInteractive_Interrogate(t *testing.T) {
|
|
||||||
// arrange
|
|
||||||
var startCalled, stopCalled, initCalled int
|
|
||||||
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
|
|
||||||
|
|
||||||
wsf, changeRequest := setWindowsServiceFuncs(false, nil)
|
|
||||||
|
|
||||||
time.AfterFunc(50*time.Millisecond, func() {
|
|
||||||
// ignored, PausePending won't be in changes slice
|
|
||||||
// make sure we don't panic/err on unexpected values
|
|
||||||
changeRequest <- wsvc.ChangeRequest{
|
|
||||||
Cmd: wsvc.Pause,
|
|
||||||
CurrentStatus: wsvc.Status{State: wsvc.PausePending},
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
time.AfterFunc(100*time.Millisecond, func() {
|
|
||||||
// handled, Paused will be in changes slice
|
|
||||||
changeRequest <- wsvc.ChangeRequest{
|
|
||||||
Cmd: wsvc.Interrogate,
|
|
||||||
CurrentStatus: wsvc.Status{State: wsvc.Paused},
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
time.AfterFunc(200*time.Millisecond, func() {
|
|
||||||
// handled, but CurrentStatus overridden with StopPending;
|
|
||||||
// ContinuePending won't be in changes slice
|
|
||||||
changeRequest <- wsvc.ChangeRequest{
|
|
||||||
Cmd: wsvc.Stop,
|
|
||||||
CurrentStatus: wsvc.Status{State: wsvc.ContinuePending},
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// act
|
|
||||||
if err := Run(prg); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// assert
|
|
||||||
changes := wsf.changes
|
|
||||||
|
|
||||||
test.Equal(t, 1, startCalled)
|
|
||||||
test.Equal(t, 1, stopCalled)
|
|
||||||
test.Equal(t, 1, initCalled)
|
|
||||||
|
|
||||||
test.Equal(t, 4, len(changes))
|
|
||||||
test.Equal(t, wsvc.StartPending, changes[0].State)
|
|
||||||
test.Equal(t, wsvc.Running, changes[1].State)
|
|
||||||
test.Equal(t, wsvc.Paused, changes[2].State)
|
|
||||||
test.Equal(t, wsvc.StopPending, changes[3].State)
|
|
||||||
|
|
||||||
test.Equal(t, false, wsf.executeReturnedBool)
|
|
||||||
test.Equal(t, uint32(0), wsf.executeReturnedUInt32)
|
|
||||||
|
|
||||||
test.Nil(t, wsf.ws.getError())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunWindowsServiceInteractive_StopError(t *testing.T) {
|
|
||||||
// arrange
|
|
||||||
var startCalled, stopCalled, initCalled int
|
|
||||||
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
|
|
||||||
prg.stop = func() error {
|
|
||||||
stopCalled++
|
|
||||||
return errors.New("stop error")
|
|
||||||
}
|
|
||||||
|
|
||||||
wsf, _ := setWindowsServiceFuncs(true, nil)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
wsf.sigChan <- os.Interrupt
|
|
||||||
}()
|
|
||||||
|
|
||||||
// act
|
|
||||||
err := Run(prg)
|
|
||||||
|
|
||||||
// assert
|
|
||||||
test.Equal(t, "stop error", err.Error())
|
|
||||||
test.Equal(t, 1, startCalled)
|
|
||||||
test.Equal(t, 1, stopCalled)
|
|
||||||
test.Equal(t, 1, initCalled)
|
|
||||||
test.Equal(t, 0, len(wsf.changes))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunWindowsServiceNonInteractive_StopError(t *testing.T) {
|
|
||||||
// arrange
|
|
||||||
var startCalled, stopCalled, initCalled int
|
|
||||||
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
|
|
||||||
prg.stop = func() error {
|
|
||||||
stopCalled++
|
|
||||||
return errors.New("stop error")
|
|
||||||
}
|
|
||||||
|
|
||||||
shutdownCmd := wsvc.Shutdown
|
|
||||||
wsf, _ := setWindowsServiceFuncs(false, &shutdownCmd)
|
|
||||||
|
|
||||||
// act
|
|
||||||
err := Run(prg)
|
|
||||||
|
|
||||||
// assert
|
|
||||||
changes := wsf.changes
|
|
||||||
|
|
||||||
test.Equal(t, "stop error", err.Error())
|
|
||||||
|
|
||||||
test.Equal(t, 1, startCalled)
|
|
||||||
test.Equal(t, 1, stopCalled)
|
|
||||||
test.Equal(t, 1, initCalled)
|
|
||||||
|
|
||||||
test.Equal(t, 3, len(changes))
|
|
||||||
test.Equal(t, wsvc.StartPending, changes[0].State)
|
|
||||||
test.Equal(t, wsvc.Running, changes[1].State)
|
|
||||||
test.Equal(t, wsvc.StopPending, changes[2].State)
|
|
||||||
|
|
||||||
test.Equal(t, true, wsf.executeReturnedBool)
|
|
||||||
test.Equal(t, uint32(2), wsf.executeReturnedUInt32)
|
|
||||||
|
|
||||||
test.Equal(t, "stop error", wsf.ws.getError().Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultSignalHandling(t *testing.T) {
|
|
||||||
signals := []os.Signal{syscall.SIGINT} // default signal handled
|
|
||||||
for _, signal := range signals {
|
|
||||||
testSignalNotify(t, signal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUserDefinedSignalHandling(t *testing.T) {
|
|
||||||
signals := []os.Signal{syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP}
|
|
||||||
for _, signal := range signals {
|
|
||||||
testSignalNotify(t, signal, signals...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func testSignalNotify(t *testing.T, signal os.Signal, sig ...os.Signal) {
|
|
||||||
// arrange
|
|
||||||
var startCalled, stopCalled, initCalled int
|
|
||||||
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
|
|
||||||
|
|
||||||
wsf, _ := setWindowsServiceFuncs(true, nil)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
wsf.sigChan <- signal
|
|
||||||
}()
|
|
||||||
|
|
||||||
// act
|
|
||||||
if err := Run(prg, sig...); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// assert
|
|
||||||
test.Equal(t, 1, startCalled)
|
|
||||||
test.Equal(t, 1, stopCalled)
|
|
||||||
test.Equal(t, 1, initCalled)
|
|
||||||
test.Equal(t, 0, len(wsf.changes))
|
|
||||||
}
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
# This source code refers to The Go Authors for copyright purposes.
|
||||||
|
# The master list of authors is in the main Go distribution,
|
||||||
|
# visible at http://tip.golang.org/AUTHORS.
|
|
@ -0,0 +1,3 @@
|
||||||
|
# This source code was written by the Go contributors.
|
||||||
|
# The master list of contributors is in the main Go distribution,
|
||||||
|
# visible at http://tip.golang.org/CONTRIBUTORS.
|
|
@ -0,0 +1,27 @@
|
||||||
|
Copyright (c) 2009 The Go Authors. All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are
|
||||||
|
met:
|
||||||
|
|
||||||
|
* Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
* Redistributions in binary form must reproduce the above
|
||||||
|
copyright notice, this list of conditions and the following disclaimer
|
||||||
|
in the documentation and/or other materials provided with the
|
||||||
|
distribution.
|
||||||
|
* Neither the name of Google Inc. nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,22 @@
|
||||||
|
Additional IP Rights Grant (Patents)
|
||||||
|
|
||||||
|
"This implementation" means the copyrightable works distributed by
|
||||||
|
Google as part of the Go project.
|
||||||
|
|
||||||
|
Google hereby grants to You a perpetual, worldwide, non-exclusive,
|
||||||
|
no-charge, royalty-free, irrevocable (except as stated in this section)
|
||||||
|
patent license to make, have made, use, offer to sell, sell, import,
|
||||||
|
transfer and otherwise run, modify and propagate the contents of this
|
||||||
|
implementation of Go, where such license applies only to those patent
|
||||||
|
claims, both currently owned or controlled by Google and acquired in
|
||||||
|
the future, licensable by Google that are necessarily infringed by this
|
||||||
|
implementation of Go. This grant does not include claims that would be
|
||||||
|
infringed only as a consequence of further modification of this
|
||||||
|
implementation. If you or your agent or exclusive licensee institute or
|
||||||
|
order or agree to the institution of patent litigation against any
|
||||||
|
entity (including a cross-claim or counterclaim in a lawsuit) alleging
|
||||||
|
that this implementation of Go or any code incorporated within this
|
||||||
|
implementation of Go constitutes direct or contributory patent
|
||||||
|
infringement, or inducement of patent infringement, then any patent
|
||||||
|
rights granted to you under this License for this implementation of Go
|
||||||
|
shall terminate as of the date such litigation is filed.
|
|
@ -1,139 +0,0 @@
|
||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// +build windows
|
|
||||||
|
|
||||||
package mgr
|
|
||||||
|
|
||||||
import (
|
|
||||||
"syscall"
|
|
||||||
"unicode/utf16"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// Service start types.
|
|
||||||
StartManual = windows.SERVICE_DEMAND_START // the service must be started manually
|
|
||||||
StartAutomatic = windows.SERVICE_AUTO_START // the service will start by itself whenever the computer reboots
|
|
||||||
StartDisabled = windows.SERVICE_DISABLED // the service cannot be started
|
|
||||||
|
|
||||||
// The severity of the error, and action taken,
|
|
||||||
// if this service fails to start.
|
|
||||||
ErrorCritical = windows.SERVICE_ERROR_CRITICAL
|
|
||||||
ErrorIgnore = windows.SERVICE_ERROR_IGNORE
|
|
||||||
ErrorNormal = windows.SERVICE_ERROR_NORMAL
|
|
||||||
ErrorSevere = windows.SERVICE_ERROR_SEVERE
|
|
||||||
)
|
|
||||||
|
|
||||||
// TODO(brainman): Password is not returned by windows.QueryServiceConfig, not sure how to get it.
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
ServiceType uint32
|
|
||||||
StartType uint32
|
|
||||||
ErrorControl uint32
|
|
||||||
BinaryPathName string // fully qualified path to the service binary file, can also include arguments for an auto-start service
|
|
||||||
LoadOrderGroup string
|
|
||||||
TagId uint32
|
|
||||||
Dependencies []string
|
|
||||||
ServiceStartName string // name of the account under which the service should run
|
|
||||||
DisplayName string
|
|
||||||
Password string
|
|
||||||
Description string
|
|
||||||
}
|
|
||||||
|
|
||||||
func toString(p *uint16) string {
|
|
||||||
if p == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return syscall.UTF16ToString((*[4096]uint16)(unsafe.Pointer(p))[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
func toStringSlice(ps *uint16) []string {
|
|
||||||
if ps == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
r := make([]string, 0)
|
|
||||||
for from, i, p := 0, 0, (*[1 << 24]uint16)(unsafe.Pointer(ps)); true; i++ {
|
|
||||||
if p[i] == 0 {
|
|
||||||
// empty string marks the end
|
|
||||||
if i <= from {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
r = append(r, string(utf16.Decode(p[from:i])))
|
|
||||||
from = i + 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// Config retrieves service s configuration paramteres.
|
|
||||||
func (s *Service) Config() (Config, error) {
|
|
||||||
var p *windows.QUERY_SERVICE_CONFIG
|
|
||||||
n := uint32(1024)
|
|
||||||
for {
|
|
||||||
b := make([]byte, n)
|
|
||||||
p = (*windows.QUERY_SERVICE_CONFIG)(unsafe.Pointer(&b[0]))
|
|
||||||
err := windows.QueryServiceConfig(s.Handle, p, n, &n)
|
|
||||||
if err == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if err.(syscall.Errno) != syscall.ERROR_INSUFFICIENT_BUFFER {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
if n <= uint32(len(b)) {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var p2 *windows.SERVICE_DESCRIPTION
|
|
||||||
n = uint32(1024)
|
|
||||||
for {
|
|
||||||
b := make([]byte, n)
|
|
||||||
p2 = (*windows.SERVICE_DESCRIPTION)(unsafe.Pointer(&b[0]))
|
|
||||||
err := windows.QueryServiceConfig2(s.Handle,
|
|
||||||
windows.SERVICE_CONFIG_DESCRIPTION, &b[0], n, &n)
|
|
||||||
if err == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if err.(syscall.Errno) != syscall.ERROR_INSUFFICIENT_BUFFER {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
if n <= uint32(len(b)) {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return Config{
|
|
||||||
ServiceType: p.ServiceType,
|
|
||||||
StartType: p.StartType,
|
|
||||||
ErrorControl: p.ErrorControl,
|
|
||||||
BinaryPathName: toString(p.BinaryPathName),
|
|
||||||
LoadOrderGroup: toString(p.LoadOrderGroup),
|
|
||||||
TagId: p.TagId,
|
|
||||||
Dependencies: toStringSlice(p.Dependencies),
|
|
||||||
ServiceStartName: toString(p.ServiceStartName),
|
|
||||||
DisplayName: toString(p.DisplayName),
|
|
||||||
Description: toString(p2.Description),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateDescription(handle windows.Handle, desc string) error {
|
|
||||||
d := windows.SERVICE_DESCRIPTION{toPtr(desc)}
|
|
||||||
return windows.ChangeServiceConfig2(handle,
|
|
||||||
windows.SERVICE_CONFIG_DESCRIPTION, (*byte)(unsafe.Pointer(&d)))
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateConfig updates service s configuration parameters.
|
|
||||||
func (s *Service) UpdateConfig(c Config) error {
|
|
||||||
err := windows.ChangeServiceConfig(s.Handle, c.ServiceType, c.StartType,
|
|
||||||
c.ErrorControl, toPtr(c.BinaryPathName), toPtr(c.LoadOrderGroup),
|
|
||||||
nil, toStringBlock(c.Dependencies), toPtr(c.ServiceStartName),
|
|
||||||
toPtr(c.Password), toPtr(c.DisplayName))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return updateDescription(s.Handle, c.Description)
|
|
||||||
}
|
|
|
@ -1,162 +0,0 @@
|
||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// +build windows
|
|
||||||
|
|
||||||
// Package mgr can be used to manage Windows service programs.
|
|
||||||
// It can be used to install and remove them. It can also start,
|
|
||||||
// stop and pause them. The package can query / change current
|
|
||||||
// service state and config parameters.
|
|
||||||
//
|
|
||||||
package mgr
|
|
||||||
|
|
||||||
import (
|
|
||||||
"syscall"
|
|
||||||
"unicode/utf16"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Mgr is used to manage Windows service.
|
|
||||||
type Mgr struct {
|
|
||||||
Handle windows.Handle
|
|
||||||
}
|
|
||||||
|
|
||||||
// Connect establishes a connection to the service control manager.
|
|
||||||
func Connect() (*Mgr, error) {
|
|
||||||
return ConnectRemote("")
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConnectRemote establishes a connection to the
|
|
||||||
// service control manager on computer named host.
|
|
||||||
func ConnectRemote(host string) (*Mgr, error) {
|
|
||||||
var s *uint16
|
|
||||||
if host != "" {
|
|
||||||
s = syscall.StringToUTF16Ptr(host)
|
|
||||||
}
|
|
||||||
h, err := windows.OpenSCManager(s, nil, windows.SC_MANAGER_ALL_ACCESS)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &Mgr{Handle: h}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Disconnect closes connection to the service control manager m.
|
|
||||||
func (m *Mgr) Disconnect() error {
|
|
||||||
return windows.CloseServiceHandle(m.Handle)
|
|
||||||
}
|
|
||||||
|
|
||||||
func toPtr(s string) *uint16 {
|
|
||||||
if len(s) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return syscall.StringToUTF16Ptr(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
// toStringBlock terminates strings in ss with 0, and then
|
|
||||||
// concatenates them together. It also adds extra 0 at the end.
|
|
||||||
func toStringBlock(ss []string) *uint16 {
|
|
||||||
if len(ss) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
t := ""
|
|
||||||
for _, s := range ss {
|
|
||||||
if s != "" {
|
|
||||||
t += s + "\x00"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if t == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
t += "\x00"
|
|
||||||
return &utf16.Encode([]rune(t))[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateService installs new service name on the system.
|
|
||||||
// The service will be executed by running exepath binary.
|
|
||||||
// Use config c to specify service parameters.
|
|
||||||
// Any args will be passed as command-line arguments when
|
|
||||||
// the service is started; these arguments are distinct from
|
|
||||||
// the arguments passed to Service.Start or via the "Start
|
|
||||||
// parameters" field in the service's Properties dialog box.
|
|
||||||
func (m *Mgr) CreateService(name, exepath string, c Config, args ...string) (*Service, error) {
|
|
||||||
if c.StartType == 0 {
|
|
||||||
c.StartType = StartManual
|
|
||||||
}
|
|
||||||
if c.ErrorControl == 0 {
|
|
||||||
c.ErrorControl = ErrorNormal
|
|
||||||
}
|
|
||||||
if c.ServiceType == 0 {
|
|
||||||
c.ServiceType = windows.SERVICE_WIN32_OWN_PROCESS
|
|
||||||
}
|
|
||||||
s := syscall.EscapeArg(exepath)
|
|
||||||
for _, v := range args {
|
|
||||||
s += " " + syscall.EscapeArg(v)
|
|
||||||
}
|
|
||||||
h, err := windows.CreateService(m.Handle, toPtr(name), toPtr(c.DisplayName),
|
|
||||||
windows.SERVICE_ALL_ACCESS, c.ServiceType,
|
|
||||||
c.StartType, c.ErrorControl, toPtr(s), toPtr(c.LoadOrderGroup),
|
|
||||||
nil, toStringBlock(c.Dependencies), toPtr(c.ServiceStartName), toPtr(c.Password))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if c.Description != "" {
|
|
||||||
err = updateDescription(h, c.Description)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &Service{Name: name, Handle: h}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// OpenService retrieves access to service name, so it can
|
|
||||||
// be interrogated and controlled.
|
|
||||||
func (m *Mgr) OpenService(name string) (*Service, error) {
|
|
||||||
h, err := windows.OpenService(m.Handle, syscall.StringToUTF16Ptr(name), windows.SERVICE_ALL_ACCESS)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &Service{Name: name, Handle: h}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListServices enumerates services in the specified
|
|
||||||
// service control manager database m.
|
|
||||||
// If the caller does not have the SERVICE_QUERY_STATUS
|
|
||||||
// access right to a service, the service is silently
|
|
||||||
// omitted from the list of services returned.
|
|
||||||
func (m *Mgr) ListServices() ([]string, error) {
|
|
||||||
var err error
|
|
||||||
var bytesNeeded, servicesReturned uint32
|
|
||||||
var buf []byte
|
|
||||||
for {
|
|
||||||
var p *byte
|
|
||||||
if len(buf) > 0 {
|
|
||||||
p = &buf[0]
|
|
||||||
}
|
|
||||||
err = windows.EnumServicesStatusEx(m.Handle, windows.SC_ENUM_PROCESS_INFO,
|
|
||||||
windows.SERVICE_WIN32, windows.SERVICE_STATE_ALL,
|
|
||||||
p, uint32(len(buf)), &bytesNeeded, &servicesReturned, nil, nil)
|
|
||||||
if err == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if err != syscall.ERROR_MORE_DATA {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if bytesNeeded <= uint32(len(buf)) {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
buf = make([]byte, bytesNeeded)
|
|
||||||
}
|
|
||||||
if servicesReturned == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
services := (*[1 << 20]windows.ENUM_SERVICE_STATUS_PROCESS)(unsafe.Pointer(&buf[0]))[:servicesReturned]
|
|
||||||
var names []string
|
|
||||||
for _, s := range services {
|
|
||||||
name := syscall.UTF16ToString((*[1 << 20]uint16)(unsafe.Pointer(s.ServiceName))[:])
|
|
||||||
names = append(names, name)
|
|
||||||
}
|
|
||||||
return names, nil
|
|
||||||
}
|
|
|
@ -1,169 +0,0 @@
|
||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// +build windows
|
|
||||||
|
|
||||||
package mgr_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
"syscall"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows/svc/mgr"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestOpenLanManServer(t *testing.T) {
|
|
||||||
m, err := mgr.Connect()
|
|
||||||
if err != nil {
|
|
||||||
if errno, ok := err.(syscall.Errno); ok && errno == syscall.ERROR_ACCESS_DENIED {
|
|
||||||
t.Skip("Skipping test: we don't have rights to manage services.")
|
|
||||||
}
|
|
||||||
t.Fatalf("SCM connection failed: %s", err)
|
|
||||||
}
|
|
||||||
defer m.Disconnect()
|
|
||||||
|
|
||||||
s, err := m.OpenService("LanmanServer")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("OpenService(lanmanserver) failed: %s", err)
|
|
||||||
}
|
|
||||||
defer s.Close()
|
|
||||||
|
|
||||||
_, err = s.Config()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Config failed: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func install(t *testing.T, m *mgr.Mgr, name, exepath string, c mgr.Config) {
|
|
||||||
// Sometimes it takes a while for the service to get
|
|
||||||
// removed after previous test run.
|
|
||||||
for i := 0; ; i++ {
|
|
||||||
s, err := m.OpenService(name)
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
s.Close()
|
|
||||||
|
|
||||||
if i > 10 {
|
|
||||||
t.Fatalf("service %s already exists", name)
|
|
||||||
}
|
|
||||||
time.Sleep(300 * time.Millisecond)
|
|
||||||
}
|
|
||||||
|
|
||||||
s, err := m.CreateService(name, exepath, c)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("CreateService(%s) failed: %v", name, err)
|
|
||||||
}
|
|
||||||
defer s.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func depString(d []string) string {
|
|
||||||
if len(d) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
for i := range d {
|
|
||||||
d[i] = strings.ToLower(d[i])
|
|
||||||
}
|
|
||||||
ss := sort.StringSlice(d)
|
|
||||||
ss.Sort()
|
|
||||||
return strings.Join([]string(ss), " ")
|
|
||||||
}
|
|
||||||
|
|
||||||
func testConfig(t *testing.T, s *mgr.Service, should mgr.Config) mgr.Config {
|
|
||||||
is, err := s.Config()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Config failed: %s", err)
|
|
||||||
}
|
|
||||||
if should.DisplayName != is.DisplayName {
|
|
||||||
t.Fatalf("config mismatch: DisplayName is %q, but should have %q", is.DisplayName, should.DisplayName)
|
|
||||||
}
|
|
||||||
if should.StartType != is.StartType {
|
|
||||||
t.Fatalf("config mismatch: StartType is %v, but should have %v", is.StartType, should.StartType)
|
|
||||||
}
|
|
||||||
if should.Description != is.Description {
|
|
||||||
t.Fatalf("config mismatch: Description is %q, but should have %q", is.Description, should.Description)
|
|
||||||
}
|
|
||||||
if depString(should.Dependencies) != depString(is.Dependencies) {
|
|
||||||
t.Fatalf("config mismatch: Dependencies is %v, but should have %v", is.Dependencies, should.Dependencies)
|
|
||||||
}
|
|
||||||
return is
|
|
||||||
}
|
|
||||||
|
|
||||||
func remove(t *testing.T, s *mgr.Service) {
|
|
||||||
err := s.Delete()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Delete failed: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMyService(t *testing.T) {
|
|
||||||
if testing.Short() {
|
|
||||||
t.Skip("skipping test in short mode - it modifies system services")
|
|
||||||
}
|
|
||||||
|
|
||||||
const name = "myservice"
|
|
||||||
|
|
||||||
m, err := mgr.Connect()
|
|
||||||
if err != nil {
|
|
||||||
if errno, ok := err.(syscall.Errno); ok && errno == syscall.ERROR_ACCESS_DENIED {
|
|
||||||
t.Skip("Skipping test: we don't have rights to manage services.")
|
|
||||||
}
|
|
||||||
t.Fatalf("SCM connection failed: %s", err)
|
|
||||||
}
|
|
||||||
defer m.Disconnect()
|
|
||||||
|
|
||||||
c := mgr.Config{
|
|
||||||
StartType: mgr.StartDisabled,
|
|
||||||
DisplayName: "my service",
|
|
||||||
Description: "my service is just a test",
|
|
||||||
Dependencies: []string{"LanmanServer", "W32Time"},
|
|
||||||
}
|
|
||||||
|
|
||||||
exename := os.Args[0]
|
|
||||||
exepath, err := filepath.Abs(exename)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("filepath.Abs(%s) failed: %s", exename, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
install(t, m, name, exepath, c)
|
|
||||||
|
|
||||||
s, err := m.OpenService(name)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("service %s is not installed", name)
|
|
||||||
}
|
|
||||||
defer s.Close()
|
|
||||||
|
|
||||||
c.BinaryPathName = exepath
|
|
||||||
c = testConfig(t, s, c)
|
|
||||||
|
|
||||||
c.StartType = mgr.StartManual
|
|
||||||
err = s.UpdateConfig(c)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("UpdateConfig failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
testConfig(t, s, c)
|
|
||||||
|
|
||||||
svcnames, err := m.ListServices()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ListServices failed: %v", err)
|
|
||||||
}
|
|
||||||
var myserviceIsInstalled bool
|
|
||||||
for _, sn := range svcnames {
|
|
||||||
if sn == name {
|
|
||||||
myserviceIsInstalled = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !myserviceIsInstalled {
|
|
||||||
t.Errorf("ListServices failed to find %q service", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
remove(t, s)
|
|
||||||
}
|
|
|
@ -1,72 +0,0 @@
|
||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// +build windows
|
|
||||||
|
|
||||||
package mgr
|
|
||||||
|
|
||||||
import (
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
"golang.org/x/sys/windows/svc"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TODO(brainman): Use EnumDependentServices to enumerate dependent services.
|
|
||||||
|
|
||||||
// Service is used to access Windows service.
|
|
||||||
type Service struct {
|
|
||||||
Name string
|
|
||||||
Handle windows.Handle
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete marks service s for deletion from the service control manager database.
|
|
||||||
func (s *Service) Delete() error {
|
|
||||||
return windows.DeleteService(s.Handle)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close relinquish access to the service s.
|
|
||||||
func (s *Service) Close() error {
|
|
||||||
return windows.CloseServiceHandle(s.Handle)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start starts service s.
|
|
||||||
// args will be passed to svc.Handler.Execute.
|
|
||||||
func (s *Service) Start(args ...string) error {
|
|
||||||
var p **uint16
|
|
||||||
if len(args) > 0 {
|
|
||||||
vs := make([]*uint16, len(args))
|
|
||||||
for i := range vs {
|
|
||||||
vs[i] = syscall.StringToUTF16Ptr(args[i])
|
|
||||||
}
|
|
||||||
p = &vs[0]
|
|
||||||
}
|
|
||||||
return windows.StartService(s.Handle, uint32(len(args)), p)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Control sends state change request c to the servce s.
|
|
||||||
func (s *Service) Control(c svc.Cmd) (svc.Status, error) {
|
|
||||||
var t windows.SERVICE_STATUS
|
|
||||||
err := windows.ControlService(s.Handle, uint32(c), &t)
|
|
||||||
if err != nil {
|
|
||||||
return svc.Status{}, err
|
|
||||||
}
|
|
||||||
return svc.Status{
|
|
||||||
State: svc.State(t.CurrentState),
|
|
||||||
Accepts: svc.Accepted(t.ControlsAccepted),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query returns current status of service s.
|
|
||||||
func (s *Service) Query() (svc.Status, error) {
|
|
||||||
var t windows.SERVICE_STATUS
|
|
||||||
err := windows.QueryServiceStatus(s.Handle, &t)
|
|
||||||
if err != nil {
|
|
||||||
return svc.Status{}, err
|
|
||||||
}
|
|
||||||
return svc.Status{
|
|
||||||
State: svc.State(t.CurrentState),
|
|
||||||
Accepts: svc.Accepted(t.ControlsAccepted),
|
|
||||||
}, nil
|
|
||||||
}
|
|
|
@ -1,118 +0,0 @@
|
||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// +build windows
|
|
||||||
|
|
||||||
package svc_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io/ioutil"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows/svc"
|
|
||||||
"golang.org/x/sys/windows/svc/mgr"
|
|
||||||
)
|
|
||||||
|
|
||||||
func getState(t *testing.T, s *mgr.Service) svc.State {
|
|
||||||
status, err := s.Query()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query(%s) failed: %s", s.Name, err)
|
|
||||||
}
|
|
||||||
return status.State
|
|
||||||
}
|
|
||||||
|
|
||||||
func testState(t *testing.T, s *mgr.Service, want svc.State) {
|
|
||||||
have := getState(t, s)
|
|
||||||
if have != want {
|
|
||||||
t.Fatalf("%s state is=%d want=%d", s.Name, have, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func waitState(t *testing.T, s *mgr.Service, want svc.State) {
|
|
||||||
for i := 0; ; i++ {
|
|
||||||
have := getState(t, s)
|
|
||||||
if have == want {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if i > 10 {
|
|
||||||
t.Fatalf("%s state is=%d, waiting timeout", s.Name, have)
|
|
||||||
}
|
|
||||||
time.Sleep(300 * time.Millisecond)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExample(t *testing.T) {
|
|
||||||
if testing.Short() {
|
|
||||||
t.Skip("skipping test in short mode - it modifies system services")
|
|
||||||
}
|
|
||||||
|
|
||||||
const name = "myservice"
|
|
||||||
|
|
||||||
m, err := mgr.Connect()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("SCM connection failed: %s", err)
|
|
||||||
}
|
|
||||||
defer m.Disconnect()
|
|
||||||
|
|
||||||
dir, err := ioutil.TempDir("", "svc")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to create temp directory: %v", err)
|
|
||||||
}
|
|
||||||
defer os.RemoveAll(dir)
|
|
||||||
|
|
||||||
exepath := filepath.Join(dir, "a.exe")
|
|
||||||
o, err := exec.Command("go", "build", "-o", exepath, "golang.org/x/sys/windows/svc/example").CombinedOutput()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to build service program: %v\n%v", err, string(o))
|
|
||||||
}
|
|
||||||
|
|
||||||
s, err := m.OpenService(name)
|
|
||||||
if err == nil {
|
|
||||||
err = s.Delete()
|
|
||||||
if err != nil {
|
|
||||||
s.Close()
|
|
||||||
t.Fatalf("Delete failed: %s", err)
|
|
||||||
}
|
|
||||||
s.Close()
|
|
||||||
}
|
|
||||||
s, err = m.CreateService(name, exepath, mgr.Config{DisplayName: "my service"}, "is", "auto-started")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("CreateService(%s) failed: %v", name, err)
|
|
||||||
}
|
|
||||||
defer s.Close()
|
|
||||||
|
|
||||||
testState(t, s, svc.Stopped)
|
|
||||||
err = s.Start("is", "manual-started")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Start(%s) failed: %s", s.Name, err)
|
|
||||||
}
|
|
||||||
waitState(t, s, svc.Running)
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
|
|
||||||
// testing deadlock from issues 4.
|
|
||||||
_, err = s.Control(svc.Interrogate)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Control(%s) failed: %s", s.Name, err)
|
|
||||||
}
|
|
||||||
_, err = s.Control(svc.Interrogate)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Control(%s) failed: %s", s.Name, err)
|
|
||||||
}
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
|
|
||||||
_, err = s.Control(svc.Stop)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Control(%s) failed: %s", s.Name, err)
|
|
||||||
}
|
|
||||||
waitState(t, s, svc.Stopped)
|
|
||||||
|
|
||||||
err = s.Delete()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Delete failed: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,53 +0,0 @@
|
||||||
// Copyright 2013 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// +build windows
|
|
||||||
|
|
||||||
package windows_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"syscall"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
func testSetGetenv(t *testing.T, key, value string) {
|
|
||||||
err := windows.Setenv(key, value)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Setenv failed to set %q: %v", value, err)
|
|
||||||
}
|
|
||||||
newvalue, found := windows.Getenv(key)
|
|
||||||
if !found {
|
|
||||||
t.Fatalf("Getenv failed to find %v variable (want value %q)", key, value)
|
|
||||||
}
|
|
||||||
if newvalue != value {
|
|
||||||
t.Fatalf("Getenv(%v) = %q; want %q", key, newvalue, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEnv(t *testing.T) {
|
|
||||||
testSetGetenv(t, "TESTENV", "AVALUE")
|
|
||||||
// make sure TESTENV gets set to "", not deleted
|
|
||||||
testSetGetenv(t, "TESTENV", "")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetProcAddressByOrdinal(t *testing.T) {
|
|
||||||
// Attempt calling shlwapi.dll:IsOS, resolving it by ordinal, as
|
|
||||||
// suggested in
|
|
||||||
// https://msdn.microsoft.com/en-us/library/windows/desktop/bb773795.aspx
|
|
||||||
h, err := windows.LoadLibrary("shlwapi.dll")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to load shlwapi.dll: %s", err)
|
|
||||||
}
|
|
||||||
procIsOS, err := windows.GetProcAddressByOrdinal(h, 437)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Could not find shlwapi.dll:IsOS by ordinal: %s", err)
|
|
||||||
}
|
|
||||||
const OS_NT = 1
|
|
||||||
r, _, _ := syscall.Syscall(procIsOS, 1, OS_NT, 0, 0)
|
|
||||||
if r == 0 {
|
|
||||||
t.Error("shlwapi.dll:IsOS(OS_NT) returned 0, expected non-zero value")
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,107 +0,0 @@
|
||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package windows_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io/ioutil"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"syscall"
|
|
||||||
"testing"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestWin32finddata(t *testing.T) {
|
|
||||||
dir, err := ioutil.TempDir("", "go-build")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to create temp directory: %v", err)
|
|
||||||
}
|
|
||||||
defer os.RemoveAll(dir)
|
|
||||||
|
|
||||||
path := filepath.Join(dir, "long_name.and_extension")
|
|
||||||
f, err := os.Create(path)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to create %v: %v", path, err)
|
|
||||||
}
|
|
||||||
f.Close()
|
|
||||||
|
|
||||||
type X struct {
|
|
||||||
fd windows.Win32finddata
|
|
||||||
got byte
|
|
||||||
pad [10]byte // to protect ourselves
|
|
||||||
|
|
||||||
}
|
|
||||||
var want byte = 2 // it is unlikely to have this character in the filename
|
|
||||||
x := X{got: want}
|
|
||||||
|
|
||||||
pathp, _ := windows.UTF16PtrFromString(path)
|
|
||||||
h, err := windows.FindFirstFile(pathp, &(x.fd))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("FindFirstFile failed: %v", err)
|
|
||||||
}
|
|
||||||
err = windows.FindClose(h)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("FindClose failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if x.got != want {
|
|
||||||
t.Fatalf("memory corruption: want=%d got=%d", want, x.got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFormatMessage(t *testing.T) {
|
|
||||||
dll := windows.MustLoadDLL("pdh.dll")
|
|
||||||
|
|
||||||
pdhOpenQuery := func(datasrc *uint16, userdata uint32, query *windows.Handle) (errno uintptr) {
|
|
||||||
r0, _, _ := syscall.Syscall(dll.MustFindProc("PdhOpenQueryW").Addr(), 3, uintptr(unsafe.Pointer(datasrc)), uintptr(userdata), uintptr(unsafe.Pointer(query)))
|
|
||||||
return r0
|
|
||||||
}
|
|
||||||
|
|
||||||
pdhCloseQuery := func(query windows.Handle) (errno uintptr) {
|
|
||||||
r0, _, _ := syscall.Syscall(dll.MustFindProc("PdhCloseQuery").Addr(), 1, uintptr(query), 0, 0)
|
|
||||||
return r0
|
|
||||||
}
|
|
||||||
|
|
||||||
var q windows.Handle
|
|
||||||
name, err := windows.UTF16PtrFromString("no_such_source")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
errno := pdhOpenQuery(name, 0, &q)
|
|
||||||
if errno == 0 {
|
|
||||||
pdhCloseQuery(q)
|
|
||||||
t.Fatal("PdhOpenQuery succeeded, but expected to fail.")
|
|
||||||
}
|
|
||||||
|
|
||||||
const flags uint32 = syscall.FORMAT_MESSAGE_FROM_HMODULE | syscall.FORMAT_MESSAGE_ARGUMENT_ARRAY | syscall.FORMAT_MESSAGE_IGNORE_INSERTS
|
|
||||||
buf := make([]uint16, 300)
|
|
||||||
_, err = windows.FormatMessage(flags, uintptr(dll.Handle), uint32(errno), 0, buf, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("FormatMessage for handle=%x and errno=%x failed: %v", dll.Handle, errno, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func abort(funcname string, err error) {
|
|
||||||
panic(funcname + " failed: " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
func ExampleLoadLibrary() {
|
|
||||||
h, err := windows.LoadLibrary("kernel32.dll")
|
|
||||||
if err != nil {
|
|
||||||
abort("LoadLibrary", err)
|
|
||||||
}
|
|
||||||
defer windows.FreeLibrary(h)
|
|
||||||
proc, err := windows.GetProcAddress(h, "GetVersion")
|
|
||||||
if err != nil {
|
|
||||||
abort("GetProcAddress", err)
|
|
||||||
}
|
|
||||||
r, _, _ := syscall.Syscall(uintptr(proc), 0, 0, 0, 0)
|
|
||||||
major := byte(r)
|
|
||||||
minor := uint8(r >> 8)
|
|
||||||
build := uint16(r >> 16)
|
|
||||||
print("windows version ", major, ".", minor, " (Build ", build, ")\n")
|
|
||||||
}
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
# github.com/judwhite/go-svc v1.0.0
|
||||||
|
github.com/judwhite/go-svc/svc
|
||||||
|
# golang.org/x/sys v0.0.0-20180322165403-91ee8cde4354
|
||||||
|
golang.org/x/sys/windows/svc
|
||||||
|
golang.org/x/sys/windows
|
|
@ -1,4 +0,0 @@
|
||||||
MODULE VERSION
|
|
||||||
github.com/productionwentdown/forward -
|
|
||||||
github.com/judwhite/go-svc v1.0.0
|
|
||||||
golang.org/x/sys v0.0.0-20180322165403-91ee8cde4354
|
|
Loading…
Reference in New Issue