1
0
Fork 0

Add SSH connection option

master
Ambrose Chua 2019-09-04 15:09:05 +08:00
parent bb031be2f1
commit 089bc9c6e4
Signed by: ambrose
GPG Key ID: B34FBE029276BA5D
21 changed files with 140 additions and 1443 deletions

6
go.mod
View File

@ -1,6 +1,6 @@
module "github.com/productionwentdown/forward"
module github.com/productionwentdown/forward
require (
"github.com/judwhite/go-svc" v1.0.0
"golang.org/x/sys" v0.0.0-20180322165403-91ee8cde4354
github.com/judwhite/go-svc v1.0.0
golang.org/x/sys v0.0.0-20180322165403-91ee8cde4354 // indirect
)

2
go.sum Normal file
View File

@ -0,0 +1,2 @@
github.com/judwhite/go-svc v1.0.0/go.mod h1:EeMSAFO3mLgEQfcvnZ50JDG0O1uQlagpAbMS6talrXE=
golang.org/x/sys v0.0.0-20180322165403-91ee8cde4354/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=

View File

@ -1,6 +1,7 @@
package main
import (
"bytes"
"flag"
"io"
"log"
@ -9,13 +10,16 @@ import (
var listen string
var connect string
var connectSSH string
var ln net.Listener
var conn *net.TCPAddr
var connSSH *net.TCPAddr
func setup() {
flag.StringVar(&listen, "listen", ":8000", "listen on ip and port")
flag.StringVar(&connect, "connect", "", "forward to ip and port")
flag.StringVar(&listen, "listen", ":8000", "listen on address")
flag.StringVar(&connect, "connect", "", "forward to address")
flag.StringVar(&connectSSH, "ssh", "", "if set, will do basic introspection to forward SSH traffic to this address")
flag.Parse()
var err error
@ -27,6 +31,12 @@ func setup() {
log.Fatal(err)
}
// check and parse SSH address
connSSH, _ = net.ResolveTCPAddr("tcp", connectSSH)
if connectSSH == "" {
connSSH = nil
}
// listen on address
ln, err = net.Listen("tcp", listen)
if err != nil {
@ -36,6 +46,9 @@ func setup() {
log.Printf("listening on %v", ln.Addr())
log.Printf("will connect to %v", conn)
if connSSH != nil {
log.Printf("will connect SSH to %v", connSSH)
}
}
func serve() {
@ -48,17 +61,46 @@ func serve() {
}
log.Printf("connection %v from %v", i, c.RemoteAddr())
go handle(c, i)
}
}
cn, err := net.DialTCP("tcp", nil, conn)
var magic = []byte{'S', 'S', 'H', '-'}
var magicLen = len(magic)
func handle(c net.Conn, count int) {
// read first four characters
readMagic := make([]byte, magicLen, magicLen)
n, err := c.Read(readMagic)
if n != magicLen {
log.Printf("warning! could not read header")
return
}
opError, ok := err.(*net.OpError)
if err != nil && (!ok || opError.Op != "readfrom") {
log.Printf("warning! %v", err)
return
}
connTo := conn
// if the header looks like SSH, forward to SSH connection
if bytes.Equal(readMagic, magic) {
connTo = connSSH
}
cn, err := net.DialTCP("tcp", nil, connTo)
if err != nil {
c.Close()
log.Print(err)
continue
return
}
go pipe(c, cn, i)
go pipe(cn, c, i)
}
// write the first four characters
cn.Write(readMagic)
go pipe(c, cn, count)
go pipe(cn, c, count)
}
func pipe(w io.WriteCloser, r io.ReadCloser, count int) {

21
vendor/github.com/judwhite/go-svc/LICENSE generated vendored Normal file
View File

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2016 Jud White
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,62 +0,0 @@
package test
import (
"path/filepath"
"reflect"
"runtime"
"testing"
)
// Equal asserts two parameters are equal by using reflect.DeepEqual.
func Equal(t *testing.T, expected, actual interface{}) {
if !reflect.DeepEqual(expected, actual) {
_, file, line, _ := runtime.Caller(1)
t.Logf("\033[31m%s:%d:\n\n\t %#v (expected)\n\n\t!= %#v (actual)\033[39m\n\n",
filepath.Base(file), line, expected, actual)
t.FailNow()
}
}
// NotEqual asserts two parameters are not equal by using reflect.DeepEqual.
func NotEqual(t *testing.T, expected, actual interface{}) {
if !reflect.DeepEqual(expected, actual) {
_, file, line, _ := runtime.Caller(1)
t.Logf("\033[31m%s:%d:\n\n\tvalue should not equal %#v\033[39m\n\n",
filepath.Base(file), line, actual)
t.FailNow()
}
}
// Nil asserts the parameter is nil.
func Nil(t *testing.T, object interface{}) {
if !isNil(object) {
_, file, line, _ := runtime.Caller(1)
t.Logf("\033[31m%s:%d:\n\n\t <nil> (expected)\n\n\t!= %#v (actual)\033[39m\n\n",
filepath.Base(file), line, object)
t.FailNow()
}
}
// NotNil asserts the parameter is not nil.
func NotNil(t *testing.T, object interface{}) {
if isNil(object) {
_, file, line, _ := runtime.Caller(1)
t.Logf("\033[31m%s:%d:\n\n\tExpected value not to be <nil>\033[39m\n\n",
filepath.Base(file), line)
t.FailNow()
}
}
func isNil(object interface{}) bool {
if object == nil {
return true
}
value := reflect.ValueOf(object)
kind := value.Kind()
if kind >= reflect.Chan && kind <= reflect.Slice && value.IsNil() {
return true
}
return false
}

View File

@ -1,36 +0,0 @@
package svc
type mockProgram struct {
start func() error
stop func() error
init func(Environment) error
}
func (p *mockProgram) Start() error {
return p.start()
}
func (p *mockProgram) Stop() error {
return p.stop()
}
func (p *mockProgram) Init(wse Environment) error {
return p.init(wse)
}
func makeProgram(startCalled, stopCalled, initCalled *int) *mockProgram {
return &mockProgram{
start: func() error {
*startCalled++
return nil
},
stop: func() error {
*stopCalled++
return nil
},
init: func(wse Environment) error {
*initCalled++
return nil
},
}
}

View File

@ -1,68 +0,0 @@
// +build !windows
package svc
import (
"os"
"syscall"
"testing"
"github.com/judwhite/go-svc/svc/internal/test"
)
func TestDefaultSignalHandling(t *testing.T) {
signals := []os.Signal{syscall.SIGINT, syscall.SIGTERM} // default signals handled
for _, signal := range signals {
testSignalNotify(t, signal)
}
}
func TestUserDefinedSignalHandling(t *testing.T) {
signals := []os.Signal{syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP}
for _, signal := range signals {
testSignalNotify(t, signal, signals...)
}
}
func testSignalNotify(t *testing.T, signal os.Signal, sig ...os.Signal) {
// arrange
// sigChan is the chan we'll send to here. if a signal matches a registered signal
// type in the Run function (in svc_other.go) the signal will be delegated to the
// channel passed to signalNotify, which is created in the Run function in svc_other.go.
// shortly: we send here and the Run function gets it if it matches the filter.
sigChan := make(chan os.Signal)
var startCalled, stopCalled, initCalled int
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
signalNotify = func(c chan<- os.Signal, sig ...os.Signal) {
if c == nil {
panic("os/signal: Notify using nil channel")
}
go func() {
for val := range sigChan {
for _, registeredSig := range sig {
if val == registeredSig {
c <- val
}
}
}
}()
}
go func() {
sigChan <- signal
}()
// act
if err := Run(prg, sig...); err != nil {
t.Fatal(err)
}
// assert
test.Equal(t, 1, startCalled)
test.Equal(t, 1, stopCalled)
test.Equal(t, 1, initCalled)
}

View File

@ -1,438 +0,0 @@
// +build windows
package svc
import (
"errors"
"os"
"syscall"
"testing"
"time"
"github.com/judwhite/go-svc/svc/internal/test"
wsvc "golang.org/x/sys/windows/svc"
)
func setupWinServiceTest(wsf *mockWinServiceFuncs) {
// wsfWrapper allows signalNotify, svcIsInteractive, and svcRun to be set once.
// Inidivual test functions set "wsf" to add behavior.
wsfWrapper := &mockWinServiceFuncs{
signalNotify: func(c chan<- os.Signal, sig ...os.Signal) {
if c == nil {
panic("os/signal: Notify using nil channel")
}
if wsf.signalNotify != nil {
wsf.signalNotify(c, sig...)
} else {
wsf1 := *wsf
go func() {
for val := range wsf1.sigChan {
for _, registeredSig := range sig {
if val == registeredSig {
c <- val
}
}
}
}()
}
},
svcIsInteractive: func() (bool, error) {
return wsf.svcIsInteractive()
},
svcRun: func(name string, handler wsvc.Handler) error {
return wsf.svcRun(name, handler)
},
}
signalNotify = wsfWrapper.signalNotify
svcIsAnInteractiveSession = wsfWrapper.svcIsInteractive
svcRun = wsfWrapper.svcRun
}
type mockWinServiceFuncs struct {
signalNotify func(chan<- os.Signal, ...os.Signal)
svcIsInteractive func() (bool, error)
sigChan chan os.Signal
svcRun func(string, wsvc.Handler) error
ws *windowsService
executeReturnedBool bool
executeReturnedUInt32 uint32
changes []wsvc.Status
}
func setWindowsServiceFuncs(isInteractive bool, onRunningSendCmd *wsvc.Cmd) (*mockWinServiceFuncs, chan<- wsvc.ChangeRequest) {
changeRequestChan := make(chan wsvc.ChangeRequest, 4)
changesChan := make(chan wsvc.Status)
done := make(chan struct{})
var wsf *mockWinServiceFuncs
wsf = &mockWinServiceFuncs{
sigChan: make(chan os.Signal),
svcIsInteractive: func() (bool, error) {
return isInteractive, nil
},
svcRun: func(name string, handler wsvc.Handler) error {
wsf.ws = handler.(*windowsService)
wsf.executeReturnedBool, wsf.executeReturnedUInt32 = handler.Execute(nil, changeRequestChan, changesChan)
done <- struct{}{}
return nil
},
}
var currentState wsvc.State
go func() {
loop:
for {
select {
case change := <-changesChan:
wsf.changes = append(wsf.changes, change)
currentState = change.State
if change.State == wsvc.Running && onRunningSendCmd != nil {
changeRequestChan <- wsvc.ChangeRequest{
Cmd: *onRunningSendCmd,
CurrentStatus: wsvc.Status{State: currentState},
}
}
case <-done:
break loop
}
}
}()
setupWinServiceTest(wsf)
return wsf, changeRequestChan
}
func TestWinService_RunWindowsService_NonInteractive(t *testing.T) {
for _, svcCmd := range []wsvc.Cmd{wsvc.Stop, wsvc.Shutdown} {
testRunWindowsServiceNonInteractive(t, svcCmd)
}
}
func testRunWindowsServiceNonInteractive(t *testing.T, svcCmd wsvc.Cmd) {
// arrange
var startCalled, stopCalled, initCalled int
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
wsf, _ := setWindowsServiceFuncs(false, &svcCmd)
// act
if err := Run(prg); err != nil {
t.Fatal(err)
}
// assert
changes := wsf.changes
test.Equal(t, 1, startCalled)
test.Equal(t, 1, stopCalled)
test.Equal(t, 1, initCalled)
test.Equal(t, 3, len(changes))
test.Equal(t, wsvc.StartPending, changes[0].State)
test.Equal(t, wsvc.Running, changes[1].State)
test.Equal(t, wsvc.StopPending, changes[2].State)
test.Equal(t, false, wsf.executeReturnedBool)
test.Equal(t, uint32(0), wsf.executeReturnedUInt32)
test.Nil(t, wsf.ws.getError())
}
func TestRunWindowsServiceNonInteractive_StartError(t *testing.T) {
// arrange
var startCalled, stopCalled, initCalled int
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
prg.start = func() error {
startCalled++
return errors.New("start error")
}
svcStop := wsvc.Stop
wsf, _ := setWindowsServiceFuncs(false, &svcStop)
// act
err := Run(prg)
// assert
test.Equal(t, "start error", err.Error())
changes := wsf.changes
test.Equal(t, 1, startCalled)
test.Equal(t, 0, stopCalled)
test.Equal(t, 1, initCalled)
test.Equal(t, 1, len(changes))
test.Equal(t, wsvc.StartPending, changes[0].State)
test.Equal(t, true, wsf.executeReturnedBool)
test.Equal(t, uint32(1), wsf.executeReturnedUInt32)
test.Equal(t, "start error", wsf.ws.getError().Error())
}
func TestRunWindowsServiceInteractive_StartError(t *testing.T) {
// arrange
var startCalled, stopCalled, initCalled int
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
prg.start = func() error {
startCalled++
return errors.New("start error")
}
wsf, _ := setWindowsServiceFuncs(true, nil)
// act
err := Run(prg)
// assert
test.Equal(t, "start error", err.Error())
changes := wsf.changes
test.Equal(t, 1, startCalled)
test.Equal(t, 0, stopCalled)
test.Equal(t, 1, initCalled)
test.Equal(t, 0, len(changes))
}
func TestRunWindowsService_BeforeStartError(t *testing.T) {
// arrange
var startCalled, stopCalled, initCalled int
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
prg.init = func(Environment) error {
initCalled++
return errors.New("before start error")
}
wsf, _ := setWindowsServiceFuncs(false, nil)
// act
err := Run(prg)
// assert
test.Equal(t, "before start error", err.Error())
changes := wsf.changes
test.Equal(t, 0, startCalled)
test.Equal(t, 0, stopCalled)
test.Equal(t, 1, initCalled)
test.Equal(t, 0, len(changes))
}
func TestRunWindowsService_IsAnInteractiveSessionError(t *testing.T) {
// arrange
var startCalled, stopCalled, initCalled int
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
wsf, _ := setWindowsServiceFuncs(false, nil)
wsf.svcIsInteractive = func() (bool, error) {
return false, errors.New("IsAnInteractiveSession error")
}
// act
err := Run(prg)
// assert
test.Equal(t, "IsAnInteractiveSession error", err.Error())
changes := wsf.changes
test.Equal(t, 0, startCalled)
test.Equal(t, 0, stopCalled)
test.Equal(t, 0, initCalled)
test.Equal(t, 0, len(changes))
}
func TestRunWindowsServiceNonInteractive_RunError(t *testing.T) {
// arrange
var startCalled, stopCalled, initCalled int
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
svcStop := wsvc.Stop
wsf, _ := setWindowsServiceFuncs(false, &svcStop)
wsf.svcRun = func(name string, handler wsvc.Handler) error {
wsf.ws = handler.(*windowsService)
return errors.New("wsvc.Run error")
}
// act
err := Run(prg)
// assert
test.Equal(t, "wsvc.Run error", err.Error())
changes := wsf.changes
test.Equal(t, 0, startCalled)
test.Equal(t, 0, stopCalled)
test.Equal(t, 1, initCalled)
test.Equal(t, 0, len(changes))
test.Nil(t, wsf.ws.getError())
}
func TestRunWindowsServiceNonInteractive_Interrogate(t *testing.T) {
// arrange
var startCalled, stopCalled, initCalled int
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
wsf, changeRequest := setWindowsServiceFuncs(false, nil)
time.AfterFunc(50*time.Millisecond, func() {
// ignored, PausePending won't be in changes slice
// make sure we don't panic/err on unexpected values
changeRequest <- wsvc.ChangeRequest{
Cmd: wsvc.Pause,
CurrentStatus: wsvc.Status{State: wsvc.PausePending},
}
})
time.AfterFunc(100*time.Millisecond, func() {
// handled, Paused will be in changes slice
changeRequest <- wsvc.ChangeRequest{
Cmd: wsvc.Interrogate,
CurrentStatus: wsvc.Status{State: wsvc.Paused},
}
})
time.AfterFunc(200*time.Millisecond, func() {
// handled, but CurrentStatus overridden with StopPending;
// ContinuePending won't be in changes slice
changeRequest <- wsvc.ChangeRequest{
Cmd: wsvc.Stop,
CurrentStatus: wsvc.Status{State: wsvc.ContinuePending},
}
})
// act
if err := Run(prg); err != nil {
t.Fatal(err)
}
// assert
changes := wsf.changes
test.Equal(t, 1, startCalled)
test.Equal(t, 1, stopCalled)
test.Equal(t, 1, initCalled)
test.Equal(t, 4, len(changes))
test.Equal(t, wsvc.StartPending, changes[0].State)
test.Equal(t, wsvc.Running, changes[1].State)
test.Equal(t, wsvc.Paused, changes[2].State)
test.Equal(t, wsvc.StopPending, changes[3].State)
test.Equal(t, false, wsf.executeReturnedBool)
test.Equal(t, uint32(0), wsf.executeReturnedUInt32)
test.Nil(t, wsf.ws.getError())
}
func TestRunWindowsServiceInteractive_StopError(t *testing.T) {
// arrange
var startCalled, stopCalled, initCalled int
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
prg.stop = func() error {
stopCalled++
return errors.New("stop error")
}
wsf, _ := setWindowsServiceFuncs(true, nil)
go func() {
wsf.sigChan <- os.Interrupt
}()
// act
err := Run(prg)
// assert
test.Equal(t, "stop error", err.Error())
test.Equal(t, 1, startCalled)
test.Equal(t, 1, stopCalled)
test.Equal(t, 1, initCalled)
test.Equal(t, 0, len(wsf.changes))
}
func TestRunWindowsServiceNonInteractive_StopError(t *testing.T) {
// arrange
var startCalled, stopCalled, initCalled int
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
prg.stop = func() error {
stopCalled++
return errors.New("stop error")
}
shutdownCmd := wsvc.Shutdown
wsf, _ := setWindowsServiceFuncs(false, &shutdownCmd)
// act
err := Run(prg)
// assert
changes := wsf.changes
test.Equal(t, "stop error", err.Error())
test.Equal(t, 1, startCalled)
test.Equal(t, 1, stopCalled)
test.Equal(t, 1, initCalled)
test.Equal(t, 3, len(changes))
test.Equal(t, wsvc.StartPending, changes[0].State)
test.Equal(t, wsvc.Running, changes[1].State)
test.Equal(t, wsvc.StopPending, changes[2].State)
test.Equal(t, true, wsf.executeReturnedBool)
test.Equal(t, uint32(2), wsf.executeReturnedUInt32)
test.Equal(t, "stop error", wsf.ws.getError().Error())
}
func TestDefaultSignalHandling(t *testing.T) {
signals := []os.Signal{syscall.SIGINT} // default signal handled
for _, signal := range signals {
testSignalNotify(t, signal)
}
}
func TestUserDefinedSignalHandling(t *testing.T) {
signals := []os.Signal{syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP}
for _, signal := range signals {
testSignalNotify(t, signal, signals...)
}
}
func testSignalNotify(t *testing.T, signal os.Signal, sig ...os.Signal) {
// arrange
var startCalled, stopCalled, initCalled int
prg := makeProgram(&startCalled, &stopCalled, &initCalled)
wsf, _ := setWindowsServiceFuncs(true, nil)
go func() {
wsf.sigChan <- signal
}()
// act
if err := Run(prg, sig...); err != nil {
t.Fatal(err)
}
// assert
test.Equal(t, 1, startCalled)
test.Equal(t, 1, stopCalled)
test.Equal(t, 1, initCalled)
test.Equal(t, 0, len(wsf.changes))
}

3
vendor/golang.org/x/sys/AUTHORS generated vendored Normal file
View File

@ -0,0 +1,3 @@
# This source code refers to The Go Authors for copyright purposes.
# The master list of authors is in the main Go distribution,
# visible at http://tip.golang.org/AUTHORS.

3
vendor/golang.org/x/sys/CONTRIBUTORS generated vendored Normal file
View File

@ -0,0 +1,3 @@
# This source code was written by the Go contributors.
# The master list of contributors is in the main Go distribution,
# visible at http://tip.golang.org/CONTRIBUTORS.

27
vendor/golang.org/x/sys/LICENSE generated vendored Normal file
View File

@ -0,0 +1,27 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

22
vendor/golang.org/x/sys/PATENTS generated vendored Normal file
View File

@ -0,0 +1,22 @@
Additional IP Rights Grant (Patents)
"This implementation" means the copyrightable works distributed by
Google as part of the Go project.
Google hereby grants to You a perpetual, worldwide, non-exclusive,
no-charge, royalty-free, irrevocable (except as stated in this section)
patent license to make, have made, use, offer to sell, sell, import,
transfer and otherwise run, modify and propagate the contents of this
implementation of Go, where such license applies only to those patent
claims, both currently owned or controlled by Google and acquired in
the future, licensable by Google that are necessarily infringed by this
implementation of Go. This grant does not include claims that would be
infringed only as a consequence of further modification of this
implementation. If you or your agent or exclusive licensee institute or
order or agree to the institution of patent litigation against any
entity (including a cross-claim or counterclaim in a lawsuit) alleging
that this implementation of Go or any code incorporated within this
implementation of Go constitutes direct or contributory patent
infringement, or inducement of patent infringement, then any patent
rights granted to you under this License for this implementation of Go
shall terminate as of the date such litigation is filed.

View File

@ -1,139 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build windows
package mgr
import (
"syscall"
"unicode/utf16"
"unsafe"
"golang.org/x/sys/windows"
)
const (
// Service start types.
StartManual = windows.SERVICE_DEMAND_START // the service must be started manually
StartAutomatic = windows.SERVICE_AUTO_START // the service will start by itself whenever the computer reboots
StartDisabled = windows.SERVICE_DISABLED // the service cannot be started
// The severity of the error, and action taken,
// if this service fails to start.
ErrorCritical = windows.SERVICE_ERROR_CRITICAL
ErrorIgnore = windows.SERVICE_ERROR_IGNORE
ErrorNormal = windows.SERVICE_ERROR_NORMAL
ErrorSevere = windows.SERVICE_ERROR_SEVERE
)
// TODO(brainman): Password is not returned by windows.QueryServiceConfig, not sure how to get it.
type Config struct {
ServiceType uint32
StartType uint32
ErrorControl uint32
BinaryPathName string // fully qualified path to the service binary file, can also include arguments for an auto-start service
LoadOrderGroup string
TagId uint32
Dependencies []string
ServiceStartName string // name of the account under which the service should run
DisplayName string
Password string
Description string
}
func toString(p *uint16) string {
if p == nil {
return ""
}
return syscall.UTF16ToString((*[4096]uint16)(unsafe.Pointer(p))[:])
}
func toStringSlice(ps *uint16) []string {
if ps == nil {
return nil
}
r := make([]string, 0)
for from, i, p := 0, 0, (*[1 << 24]uint16)(unsafe.Pointer(ps)); true; i++ {
if p[i] == 0 {
// empty string marks the end
if i <= from {
break
}
r = append(r, string(utf16.Decode(p[from:i])))
from = i + 1
}
}
return r
}
// Config retrieves service s configuration paramteres.
func (s *Service) Config() (Config, error) {
var p *windows.QUERY_SERVICE_CONFIG
n := uint32(1024)
for {
b := make([]byte, n)
p = (*windows.QUERY_SERVICE_CONFIG)(unsafe.Pointer(&b[0]))
err := windows.QueryServiceConfig(s.Handle, p, n, &n)
if err == nil {
break
}
if err.(syscall.Errno) != syscall.ERROR_INSUFFICIENT_BUFFER {
return Config{}, err
}
if n <= uint32(len(b)) {
return Config{}, err
}
}
var p2 *windows.SERVICE_DESCRIPTION
n = uint32(1024)
for {
b := make([]byte, n)
p2 = (*windows.SERVICE_DESCRIPTION)(unsafe.Pointer(&b[0]))
err := windows.QueryServiceConfig2(s.Handle,
windows.SERVICE_CONFIG_DESCRIPTION, &b[0], n, &n)
if err == nil {
break
}
if err.(syscall.Errno) != syscall.ERROR_INSUFFICIENT_BUFFER {
return Config{}, err
}
if n <= uint32(len(b)) {
return Config{}, err
}
}
return Config{
ServiceType: p.ServiceType,
StartType: p.StartType,
ErrorControl: p.ErrorControl,
BinaryPathName: toString(p.BinaryPathName),
LoadOrderGroup: toString(p.LoadOrderGroup),
TagId: p.TagId,
Dependencies: toStringSlice(p.Dependencies),
ServiceStartName: toString(p.ServiceStartName),
DisplayName: toString(p.DisplayName),
Description: toString(p2.Description),
}, nil
}
func updateDescription(handle windows.Handle, desc string) error {
d := windows.SERVICE_DESCRIPTION{toPtr(desc)}
return windows.ChangeServiceConfig2(handle,
windows.SERVICE_CONFIG_DESCRIPTION, (*byte)(unsafe.Pointer(&d)))
}
// UpdateConfig updates service s configuration parameters.
func (s *Service) UpdateConfig(c Config) error {
err := windows.ChangeServiceConfig(s.Handle, c.ServiceType, c.StartType,
c.ErrorControl, toPtr(c.BinaryPathName), toPtr(c.LoadOrderGroup),
nil, toStringBlock(c.Dependencies), toPtr(c.ServiceStartName),
toPtr(c.Password), toPtr(c.DisplayName))
if err != nil {
return err
}
return updateDescription(s.Handle, c.Description)
}

View File

@ -1,162 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build windows
// Package mgr can be used to manage Windows service programs.
// It can be used to install and remove them. It can also start,
// stop and pause them. The package can query / change current
// service state and config parameters.
//
package mgr
import (
"syscall"
"unicode/utf16"
"unsafe"
"golang.org/x/sys/windows"
)
// Mgr is used to manage Windows service.
type Mgr struct {
Handle windows.Handle
}
// Connect establishes a connection to the service control manager.
func Connect() (*Mgr, error) {
return ConnectRemote("")
}
// ConnectRemote establishes a connection to the
// service control manager on computer named host.
func ConnectRemote(host string) (*Mgr, error) {
var s *uint16
if host != "" {
s = syscall.StringToUTF16Ptr(host)
}
h, err := windows.OpenSCManager(s, nil, windows.SC_MANAGER_ALL_ACCESS)
if err != nil {
return nil, err
}
return &Mgr{Handle: h}, nil
}
// Disconnect closes connection to the service control manager m.
func (m *Mgr) Disconnect() error {
return windows.CloseServiceHandle(m.Handle)
}
func toPtr(s string) *uint16 {
if len(s) == 0 {
return nil
}
return syscall.StringToUTF16Ptr(s)
}
// toStringBlock terminates strings in ss with 0, and then
// concatenates them together. It also adds extra 0 at the end.
func toStringBlock(ss []string) *uint16 {
if len(ss) == 0 {
return nil
}
t := ""
for _, s := range ss {
if s != "" {
t += s + "\x00"
}
}
if t == "" {
return nil
}
t += "\x00"
return &utf16.Encode([]rune(t))[0]
}
// CreateService installs new service name on the system.
// The service will be executed by running exepath binary.
// Use config c to specify service parameters.
// Any args will be passed as command-line arguments when
// the service is started; these arguments are distinct from
// the arguments passed to Service.Start or via the "Start
// parameters" field in the service's Properties dialog box.
func (m *Mgr) CreateService(name, exepath string, c Config, args ...string) (*Service, error) {
if c.StartType == 0 {
c.StartType = StartManual
}
if c.ErrorControl == 0 {
c.ErrorControl = ErrorNormal
}
if c.ServiceType == 0 {
c.ServiceType = windows.SERVICE_WIN32_OWN_PROCESS
}
s := syscall.EscapeArg(exepath)
for _, v := range args {
s += " " + syscall.EscapeArg(v)
}
h, err := windows.CreateService(m.Handle, toPtr(name), toPtr(c.DisplayName),
windows.SERVICE_ALL_ACCESS, c.ServiceType,
c.StartType, c.ErrorControl, toPtr(s), toPtr(c.LoadOrderGroup),
nil, toStringBlock(c.Dependencies), toPtr(c.ServiceStartName), toPtr(c.Password))
if err != nil {
return nil, err
}
if c.Description != "" {
err = updateDescription(h, c.Description)
if err != nil {
return nil, err
}
}
return &Service{Name: name, Handle: h}, nil
}
// OpenService retrieves access to service name, so it can
// be interrogated and controlled.
func (m *Mgr) OpenService(name string) (*Service, error) {
h, err := windows.OpenService(m.Handle, syscall.StringToUTF16Ptr(name), windows.SERVICE_ALL_ACCESS)
if err != nil {
return nil, err
}
return &Service{Name: name, Handle: h}, nil
}
// ListServices enumerates services in the specified
// service control manager database m.
// If the caller does not have the SERVICE_QUERY_STATUS
// access right to a service, the service is silently
// omitted from the list of services returned.
func (m *Mgr) ListServices() ([]string, error) {
var err error
var bytesNeeded, servicesReturned uint32
var buf []byte
for {
var p *byte
if len(buf) > 0 {
p = &buf[0]
}
err = windows.EnumServicesStatusEx(m.Handle, windows.SC_ENUM_PROCESS_INFO,
windows.SERVICE_WIN32, windows.SERVICE_STATE_ALL,
p, uint32(len(buf)), &bytesNeeded, &servicesReturned, nil, nil)
if err == nil {
break
}
if err != syscall.ERROR_MORE_DATA {
return nil, err
}
if bytesNeeded <= uint32(len(buf)) {
return nil, err
}
buf = make([]byte, bytesNeeded)
}
if servicesReturned == 0 {
return nil, nil
}
services := (*[1 << 20]windows.ENUM_SERVICE_STATUS_PROCESS)(unsafe.Pointer(&buf[0]))[:servicesReturned]
var names []string
for _, s := range services {
name := syscall.UTF16ToString((*[1 << 20]uint16)(unsafe.Pointer(s.ServiceName))[:])
names = append(names, name)
}
return names, nil
}

View File

@ -1,169 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build windows
package mgr_test
import (
"os"
"path/filepath"
"sort"
"strings"
"syscall"
"testing"
"time"
"golang.org/x/sys/windows/svc/mgr"
)
func TestOpenLanManServer(t *testing.T) {
m, err := mgr.Connect()
if err != nil {
if errno, ok := err.(syscall.Errno); ok && errno == syscall.ERROR_ACCESS_DENIED {
t.Skip("Skipping test: we don't have rights to manage services.")
}
t.Fatalf("SCM connection failed: %s", err)
}
defer m.Disconnect()
s, err := m.OpenService("LanmanServer")
if err != nil {
t.Fatalf("OpenService(lanmanserver) failed: %s", err)
}
defer s.Close()
_, err = s.Config()
if err != nil {
t.Fatalf("Config failed: %s", err)
}
}
func install(t *testing.T, m *mgr.Mgr, name, exepath string, c mgr.Config) {
// Sometimes it takes a while for the service to get
// removed after previous test run.
for i := 0; ; i++ {
s, err := m.OpenService(name)
if err != nil {
break
}
s.Close()
if i > 10 {
t.Fatalf("service %s already exists", name)
}
time.Sleep(300 * time.Millisecond)
}
s, err := m.CreateService(name, exepath, c)
if err != nil {
t.Fatalf("CreateService(%s) failed: %v", name, err)
}
defer s.Close()
}
func depString(d []string) string {
if len(d) == 0 {
return ""
}
for i := range d {
d[i] = strings.ToLower(d[i])
}
ss := sort.StringSlice(d)
ss.Sort()
return strings.Join([]string(ss), " ")
}
func testConfig(t *testing.T, s *mgr.Service, should mgr.Config) mgr.Config {
is, err := s.Config()
if err != nil {
t.Fatalf("Config failed: %s", err)
}
if should.DisplayName != is.DisplayName {
t.Fatalf("config mismatch: DisplayName is %q, but should have %q", is.DisplayName, should.DisplayName)
}
if should.StartType != is.StartType {
t.Fatalf("config mismatch: StartType is %v, but should have %v", is.StartType, should.StartType)
}
if should.Description != is.Description {
t.Fatalf("config mismatch: Description is %q, but should have %q", is.Description, should.Description)
}
if depString(should.Dependencies) != depString(is.Dependencies) {
t.Fatalf("config mismatch: Dependencies is %v, but should have %v", is.Dependencies, should.Dependencies)
}
return is
}
func remove(t *testing.T, s *mgr.Service) {
err := s.Delete()
if err != nil {
t.Fatalf("Delete failed: %s", err)
}
}
func TestMyService(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode - it modifies system services")
}
const name = "myservice"
m, err := mgr.Connect()
if err != nil {
if errno, ok := err.(syscall.Errno); ok && errno == syscall.ERROR_ACCESS_DENIED {
t.Skip("Skipping test: we don't have rights to manage services.")
}
t.Fatalf("SCM connection failed: %s", err)
}
defer m.Disconnect()
c := mgr.Config{
StartType: mgr.StartDisabled,
DisplayName: "my service",
Description: "my service is just a test",
Dependencies: []string{"LanmanServer", "W32Time"},
}
exename := os.Args[0]
exepath, err := filepath.Abs(exename)
if err != nil {
t.Fatalf("filepath.Abs(%s) failed: %s", exename, err)
}
install(t, m, name, exepath, c)
s, err := m.OpenService(name)
if err != nil {
t.Fatalf("service %s is not installed", name)
}
defer s.Close()
c.BinaryPathName = exepath
c = testConfig(t, s, c)
c.StartType = mgr.StartManual
err = s.UpdateConfig(c)
if err != nil {
t.Fatalf("UpdateConfig failed: %v", err)
}
testConfig(t, s, c)
svcnames, err := m.ListServices()
if err != nil {
t.Fatalf("ListServices failed: %v", err)
}
var myserviceIsInstalled bool
for _, sn := range svcnames {
if sn == name {
myserviceIsInstalled = true
break
}
}
if !myserviceIsInstalled {
t.Errorf("ListServices failed to find %q service", name)
}
remove(t, s)
}

View File

@ -1,72 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build windows
package mgr
import (
"syscall"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
)
// TODO(brainman): Use EnumDependentServices to enumerate dependent services.
// Service is used to access Windows service.
type Service struct {
Name string
Handle windows.Handle
}
// Delete marks service s for deletion from the service control manager database.
func (s *Service) Delete() error {
return windows.DeleteService(s.Handle)
}
// Close relinquish access to the service s.
func (s *Service) Close() error {
return windows.CloseServiceHandle(s.Handle)
}
// Start starts service s.
// args will be passed to svc.Handler.Execute.
func (s *Service) Start(args ...string) error {
var p **uint16
if len(args) > 0 {
vs := make([]*uint16, len(args))
for i := range vs {
vs[i] = syscall.StringToUTF16Ptr(args[i])
}
p = &vs[0]
}
return windows.StartService(s.Handle, uint32(len(args)), p)
}
// Control sends state change request c to the servce s.
func (s *Service) Control(c svc.Cmd) (svc.Status, error) {
var t windows.SERVICE_STATUS
err := windows.ControlService(s.Handle, uint32(c), &t)
if err != nil {
return svc.Status{}, err
}
return svc.Status{
State: svc.State(t.CurrentState),
Accepts: svc.Accepted(t.ControlsAccepted),
}, nil
}
// Query returns current status of service s.
func (s *Service) Query() (svc.Status, error) {
var t windows.SERVICE_STATUS
err := windows.QueryServiceStatus(s.Handle, &t)
if err != nil {
return svc.Status{}, err
}
return svc.Status{
State: svc.State(t.CurrentState),
Accepts: svc.Accepted(t.ControlsAccepted),
}, nil
}

View File

@ -1,118 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build windows
package svc_test
import (
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"testing"
"time"
"golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/mgr"
)
func getState(t *testing.T, s *mgr.Service) svc.State {
status, err := s.Query()
if err != nil {
t.Fatalf("Query(%s) failed: %s", s.Name, err)
}
return status.State
}
func testState(t *testing.T, s *mgr.Service, want svc.State) {
have := getState(t, s)
if have != want {
t.Fatalf("%s state is=%d want=%d", s.Name, have, want)
}
}
func waitState(t *testing.T, s *mgr.Service, want svc.State) {
for i := 0; ; i++ {
have := getState(t, s)
if have == want {
return
}
if i > 10 {
t.Fatalf("%s state is=%d, waiting timeout", s.Name, have)
}
time.Sleep(300 * time.Millisecond)
}
}
func TestExample(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode - it modifies system services")
}
const name = "myservice"
m, err := mgr.Connect()
if err != nil {
t.Fatalf("SCM connection failed: %s", err)
}
defer m.Disconnect()
dir, err := ioutil.TempDir("", "svc")
if err != nil {
t.Fatalf("failed to create temp directory: %v", err)
}
defer os.RemoveAll(dir)
exepath := filepath.Join(dir, "a.exe")
o, err := exec.Command("go", "build", "-o", exepath, "golang.org/x/sys/windows/svc/example").CombinedOutput()
if err != nil {
t.Fatalf("failed to build service program: %v\n%v", err, string(o))
}
s, err := m.OpenService(name)
if err == nil {
err = s.Delete()
if err != nil {
s.Close()
t.Fatalf("Delete failed: %s", err)
}
s.Close()
}
s, err = m.CreateService(name, exepath, mgr.Config{DisplayName: "my service"}, "is", "auto-started")
if err != nil {
t.Fatalf("CreateService(%s) failed: %v", name, err)
}
defer s.Close()
testState(t, s, svc.Stopped)
err = s.Start("is", "manual-started")
if err != nil {
t.Fatalf("Start(%s) failed: %s", s.Name, err)
}
waitState(t, s, svc.Running)
time.Sleep(1 * time.Second)
// testing deadlock from issues 4.
_, err = s.Control(svc.Interrogate)
if err != nil {
t.Fatalf("Control(%s) failed: %s", s.Name, err)
}
_, err = s.Control(svc.Interrogate)
if err != nil {
t.Fatalf("Control(%s) failed: %s", s.Name, err)
}
time.Sleep(1 * time.Second)
_, err = s.Control(svc.Stop)
if err != nil {
t.Fatalf("Control(%s) failed: %s", s.Name, err)
}
waitState(t, s, svc.Stopped)
err = s.Delete()
if err != nil {
t.Fatalf("Delete failed: %s", err)
}
}

View File

@ -1,53 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build windows
package windows_test
import (
"syscall"
"testing"
"golang.org/x/sys/windows"
)
func testSetGetenv(t *testing.T, key, value string) {
err := windows.Setenv(key, value)
if err != nil {
t.Fatalf("Setenv failed to set %q: %v", value, err)
}
newvalue, found := windows.Getenv(key)
if !found {
t.Fatalf("Getenv failed to find %v variable (want value %q)", key, value)
}
if newvalue != value {
t.Fatalf("Getenv(%v) = %q; want %q", key, newvalue, value)
}
}
func TestEnv(t *testing.T) {
testSetGetenv(t, "TESTENV", "AVALUE")
// make sure TESTENV gets set to "", not deleted
testSetGetenv(t, "TESTENV", "")
}
func TestGetProcAddressByOrdinal(t *testing.T) {
// Attempt calling shlwapi.dll:IsOS, resolving it by ordinal, as
// suggested in
// https://msdn.microsoft.com/en-us/library/windows/desktop/bb773795.aspx
h, err := windows.LoadLibrary("shlwapi.dll")
if err != nil {
t.Fatalf("Failed to load shlwapi.dll: %s", err)
}
procIsOS, err := windows.GetProcAddressByOrdinal(h, 437)
if err != nil {
t.Fatalf("Could not find shlwapi.dll:IsOS by ordinal: %s", err)
}
const OS_NT = 1
r, _, _ := syscall.Syscall(procIsOS, 1, OS_NT, 0, 0)
if r == 0 {
t.Error("shlwapi.dll:IsOS(OS_NT) returned 0, expected non-zero value")
}
}

View File

@ -1,107 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package windows_test
import (
"io/ioutil"
"os"
"path/filepath"
"syscall"
"testing"
"unsafe"
"golang.org/x/sys/windows"
)
func TestWin32finddata(t *testing.T) {
dir, err := ioutil.TempDir("", "go-build")
if err != nil {
t.Fatalf("failed to create temp directory: %v", err)
}
defer os.RemoveAll(dir)
path := filepath.Join(dir, "long_name.and_extension")
f, err := os.Create(path)
if err != nil {
t.Fatalf("failed to create %v: %v", path, err)
}
f.Close()
type X struct {
fd windows.Win32finddata
got byte
pad [10]byte // to protect ourselves
}
var want byte = 2 // it is unlikely to have this character in the filename
x := X{got: want}
pathp, _ := windows.UTF16PtrFromString(path)
h, err := windows.FindFirstFile(pathp, &(x.fd))
if err != nil {
t.Fatalf("FindFirstFile failed: %v", err)
}
err = windows.FindClose(h)
if err != nil {
t.Fatalf("FindClose failed: %v", err)
}
if x.got != want {
t.Fatalf("memory corruption: want=%d got=%d", want, x.got)
}
}
func TestFormatMessage(t *testing.T) {
dll := windows.MustLoadDLL("pdh.dll")
pdhOpenQuery := func(datasrc *uint16, userdata uint32, query *windows.Handle) (errno uintptr) {
r0, _, _ := syscall.Syscall(dll.MustFindProc("PdhOpenQueryW").Addr(), 3, uintptr(unsafe.Pointer(datasrc)), uintptr(userdata), uintptr(unsafe.Pointer(query)))
return r0
}
pdhCloseQuery := func(query windows.Handle) (errno uintptr) {
r0, _, _ := syscall.Syscall(dll.MustFindProc("PdhCloseQuery").Addr(), 1, uintptr(query), 0, 0)
return r0
}
var q windows.Handle
name, err := windows.UTF16PtrFromString("no_such_source")
if err != nil {
t.Fatal(err)
}
errno := pdhOpenQuery(name, 0, &q)
if errno == 0 {
pdhCloseQuery(q)
t.Fatal("PdhOpenQuery succeeded, but expected to fail.")
}
const flags uint32 = syscall.FORMAT_MESSAGE_FROM_HMODULE | syscall.FORMAT_MESSAGE_ARGUMENT_ARRAY | syscall.FORMAT_MESSAGE_IGNORE_INSERTS
buf := make([]uint16, 300)
_, err = windows.FormatMessage(flags, uintptr(dll.Handle), uint32(errno), 0, buf, nil)
if err != nil {
t.Fatalf("FormatMessage for handle=%x and errno=%x failed: %v", dll.Handle, errno, err)
}
}
func abort(funcname string, err error) {
panic(funcname + " failed: " + err.Error())
}
func ExampleLoadLibrary() {
h, err := windows.LoadLibrary("kernel32.dll")
if err != nil {
abort("LoadLibrary", err)
}
defer windows.FreeLibrary(h)
proc, err := windows.GetProcAddress(h, "GetVersion")
if err != nil {
abort("GetProcAddress", err)
}
r, _, _ := syscall.Syscall(uintptr(proc), 0, 0, 0, 0)
major := byte(r)
minor := uint8(r >> 8)
build := uint16(r >> 16)
print("windows version ", major, ".", minor, " (Build ", build, ")\n")
}

5
vendor/modules.txt vendored Normal file
View File

@ -0,0 +1,5 @@
# github.com/judwhite/go-svc v1.0.0
github.com/judwhite/go-svc/svc
# golang.org/x/sys v0.0.0-20180322165403-91ee8cde4354
golang.org/x/sys/windows/svc
golang.org/x/sys/windows

4
vendor/vgo.list vendored
View File

@ -1,4 +0,0 @@
MODULE VERSION
github.com/productionwentdown/forward -
github.com/judwhite/go-svc v1.0.0
golang.org/x/sys v0.0.0-20180322165403-91ee8cde4354