370 lines
8.1 KiB
Go
370 lines
8.1 KiB
Go
package cmd
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
"os/signal"
|
|
"syscall"
|
|
|
|
"github.com/serverwentdown/wireguard-negotiator/lib"
|
|
"github.com/urfave/cli/v2"
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
"gopkg.in/ini.v1"
|
|
)
|
|
|
|
var ErrNoAddressesFound = fmt.Errorf("No address found on the interface")
|
|
|
|
var CmdServer = &cli.Command{
|
|
Name: "server",
|
|
Usage: "Start the wireguard-negotiator server",
|
|
Flags: []cli.Flag{
|
|
&cli.StringFlag{
|
|
Name: "interface",
|
|
Aliases: []string{"i"},
|
|
Value: "wg0",
|
|
Usage: "An existing WireGuard interface to manage",
|
|
},
|
|
&cli.StringFlag{
|
|
Name: "config",
|
|
Aliases: []string{"c"},
|
|
Value: "",
|
|
DefaultText: "/etc/wireguard/<interface>.conf",
|
|
Usage: "Path to the existing WireGuard configuration file. WARNING: wireguard-negotiator will remove any comments in the file",
|
|
},
|
|
&cli.StringFlag{
|
|
Name: "endpoint",
|
|
Aliases: []string{"e"},
|
|
Value: "",
|
|
Required: true,
|
|
Usage: "Set the endpoint address",
|
|
},
|
|
&cli.StringFlag{
|
|
Name: "listen",
|
|
Aliases: []string{"l"},
|
|
Value: ":8080",
|
|
Usage: "Listen on this address",
|
|
},
|
|
&cli.BoolFlag{
|
|
Name: "interactive",
|
|
Aliases: []string{"I"},
|
|
Usage: "Enable interactive prompt before accepting new peers",
|
|
},
|
|
&cli.BoolFlag{
|
|
Name: "bin",
|
|
Aliases: []string{"B"},
|
|
Usage: "Serve the current wireguard-negotiator binary file upon GET request to /",
|
|
},
|
|
},
|
|
Action: runServer,
|
|
}
|
|
|
|
type request struct {
|
|
publicKey string
|
|
ip net.IP
|
|
}
|
|
|
|
func runServer(ctx *cli.Context) error {
|
|
inter := ctx.String("interface")
|
|
config := ctx.String("config")
|
|
if !ctx.IsSet("config") {
|
|
config = "/etc/wireguard/" + inter + ".conf"
|
|
}
|
|
endpoint := ctx.String("endpoint")
|
|
listen := ctx.String("listen")
|
|
interactive := ctx.Bool("interactive")
|
|
|
|
// Obtain the network interface
|
|
interf, err := net.InterfaceByName(inter)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Obtain the server's public key
|
|
serverPublicKey, err := configReadInterfacePublicKey(config)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// TODO: Define this allocation method
|
|
// TODO: Include allocation behaviour in README
|
|
terribleCounterThatShouldNotExist := 1
|
|
interfAddrs, err := interf.Addrs()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Obtain interface address for use in allocation
|
|
var interfIPNet *net.IPNet
|
|
if len(interfAddrs) < 1 {
|
|
return ErrNoAddressesFound
|
|
}
|
|
_, interfIPNet, err = net.ParseCIDR(interfAddrs[0].String())
|
|
|
|
// Set up interactive stuff
|
|
lineReader := bufio.NewReader(os.Stdin)
|
|
if !interactive {
|
|
lineReader = nil
|
|
}
|
|
|
|
addQueue := make(chan request, 0)
|
|
go adder(addQueue, inter, config)
|
|
|
|
gateQueue := make(chan request, 0)
|
|
go gater(gateQueue, addQueue, lineReader)
|
|
|
|
// TODO: Rate limiting
|
|
|
|
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.Method {
|
|
case "GET":
|
|
bin, err := os.Executable()
|
|
if err != nil {
|
|
w.WriteHeader(500)
|
|
return
|
|
}
|
|
file, err := os.Open(bin)
|
|
if err != nil {
|
|
w.WriteHeader(500)
|
|
return
|
|
}
|
|
_, err = io.Copy(w, file)
|
|
if err != nil {
|
|
log.Println("WARNING: Write binary executable to response failed")
|
|
return
|
|
}
|
|
default:
|
|
w.WriteHeader(405)
|
|
}
|
|
})
|
|
http.HandleFunc("/request", func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.Method {
|
|
case "POST":
|
|
publicKey := r.PostFormValue("PublicKey")
|
|
// TODO: Ensure public key is new
|
|
// TODO: Validate public key
|
|
if len(publicKey) == 0 {
|
|
w.WriteHeader(400)
|
|
return
|
|
}
|
|
|
|
// Assign an IP address
|
|
terribleCounterThatShouldNotExist += 1
|
|
ip := incrementIP(interfIPNet.IP, terribleCounterThatShouldNotExist)
|
|
if !interfIPNet.Contains(ip) {
|
|
log.Println("WARNING: Ran out of addresses to allocate")
|
|
w.WriteHeader(500)
|
|
return
|
|
}
|
|
|
|
// Enqueue request into the gate
|
|
req := request{
|
|
ip: ip,
|
|
publicKey: publicKey,
|
|
}
|
|
|
|
// Wait for flush of configuration
|
|
gateQueue <- req
|
|
|
|
// Produce configuration to client
|
|
ipNet := &net.IPNet{
|
|
IP: ip,
|
|
Mask: interfIPNet.Mask,
|
|
}
|
|
netIPNet := &net.IPNet{
|
|
IP: interfIPNet.IP.Mask(interfIPNet.Mask),
|
|
Mask: interfIPNet.Mask,
|
|
}
|
|
resp := lib.PeerConfigResponse{
|
|
[]string{ipNet.String()},
|
|
[]string{netIPNet.String()},
|
|
serverPublicKey,
|
|
endpoint,
|
|
25,
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(resp)
|
|
default:
|
|
w.WriteHeader(405)
|
|
}
|
|
})
|
|
|
|
server := &http.Server{
|
|
Addr: listen,
|
|
Handler: http.DefaultServeMux,
|
|
}
|
|
|
|
// Shutdown notifier
|
|
go func() {
|
|
sigint := make(chan os.Signal, 1)
|
|
signal.Notify(sigint, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
|
|
<-sigint
|
|
close(gateQueue)
|
|
close(addQueue)
|
|
if err := server.Shutdown(context.Background()); err != nil {
|
|
log.Printf("Server shutdown error: %v\n", err)
|
|
}
|
|
}()
|
|
|
|
log.Printf("Server listening on %v\n", listen)
|
|
|
|
return server.ListenAndServe()
|
|
}
|
|
|
|
func adder(queue chan request, inter string, config string) {
|
|
// Write requests to config and add peer
|
|
for {
|
|
select {
|
|
case req, ok := <-queue:
|
|
if !ok {
|
|
break
|
|
}
|
|
err := configAddPeer(config, req)
|
|
if err != nil {
|
|
log.Println(err)
|
|
continue
|
|
}
|
|
err = interAddPeer(inter, req, config)
|
|
if err != nil {
|
|
log.Println(err)
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func gater(queue chan request, result chan request, lineReader *bufio.Reader) {
|
|
// Receive requests and prompt the admin
|
|
for {
|
|
select {
|
|
case req, ok := <-queue:
|
|
if !ok {
|
|
return
|
|
}
|
|
fmt.Println(req.ip.String(), req.publicKey)
|
|
|
|
done := false
|
|
allowed := false
|
|
|
|
if lineReader == nil {
|
|
done = true
|
|
allowed = true
|
|
}
|
|
|
|
for !done {
|
|
fmt.Print("Allow? (y/n) ")
|
|
line, err := lineReader.ReadString('\n')
|
|
if err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
|
|
switch line[:len(line)-1] {
|
|
case "y", "yes":
|
|
done = true
|
|
allowed = true
|
|
case "n", "no":
|
|
done = true
|
|
allowed = false
|
|
}
|
|
}
|
|
|
|
if allowed {
|
|
result <- req
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func configAddPeer(config string, req request) error {
|
|
// For every request, open the config file again and rewrite it. Acceptable
|
|
// because this happens infrequently
|
|
|
|
// Preferably in the future, treat the configuration as a database
|
|
|
|
// For now, append to the config file
|
|
cfg := ini.Empty()
|
|
sec, _ := cfg.NewSection("Peer")
|
|
publicKey := sec.Key("PublicKey")
|
|
// TODO: Validation is needed
|
|
publicKey.SetValue(req.publicKey)
|
|
allowedIPs := sec.Key("AllowedIPs")
|
|
allowedHost := ipToIPNetWithHostMask(req.ip)
|
|
allowedIPs.SetValue((&allowedHost).String())
|
|
|
|
f, err := os.OpenFile(config, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
|
if err != nil {
|
|
return fmt.Errorf("opening %s failed: %w", config, err)
|
|
}
|
|
_, err = cfg.WriteTo(f)
|
|
if err != nil {
|
|
return fmt.Errorf("writing to %s failed: %w", config, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func configReadInterfacePublicKey(config string) (string, error) {
|
|
cfg, err := ini.Load(config)
|
|
if err != nil {
|
|
return "", fmt.Errorf("read interface public key failed: %w", err)
|
|
}
|
|
|
|
b64PrivateKey := cfg.Section("Interface").Key("PrivateKey").String()
|
|
wgPrivateKey, err := wgtypes.ParseKey(b64PrivateKey)
|
|
if err != nil {
|
|
return "", fmt.Errorf("read interface public key failed: %w", err)
|
|
}
|
|
wgPublicKey := wgPrivateKey.PublicKey()
|
|
return wgPublicKey.String(), nil
|
|
}
|
|
|
|
func interAddPeer(inter string, req request, config string) error {
|
|
// For every request, dynamically add the peer to the interface
|
|
|
|
// For now, simply run one fixed command to reread from the config file
|
|
cmd := exec.Command("wg", "setconf", inter, config)
|
|
err := cmd.Run()
|
|
if err != nil {
|
|
return fmt.Errorf("wq setconf failed: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ipToIPNetWithHostMask(ip net.IP) net.IPNet {
|
|
if ip4 := ip.To4(); ip4 != nil {
|
|
return net.IPNet{
|
|
IP: ip,
|
|
Mask: net.CIDRMask(32, 32),
|
|
}
|
|
}
|
|
return net.IPNet{
|
|
IP: ip,
|
|
Mask: net.CIDRMask(128, 128),
|
|
}
|
|
}
|
|
|
|
func incrementIP(ip net.IP, inc int) net.IP {
|
|
result := make(net.IP, len(ip))
|
|
copy(result, ip)
|
|
|
|
for i := len(ip) - 1; i >= 0; i-- {
|
|
remainder := inc % 256
|
|
overflow := int(result[i])+remainder > 255
|
|
|
|
result[i] += byte(remainder)
|
|
if overflow {
|
|
inc += 256
|
|
}
|
|
inc /= 256
|
|
}
|
|
return result
|
|
}
|