wireguard-negotiator/cmd/server.go

370 lines
8.1 KiB
Go

package cmd
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"os/exec"
"os/signal"
"syscall"
"github.com/serverwentdown/wireguard-negotiator/lib"
"github.com/urfave/cli/v2"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"gopkg.in/ini.v1"
)
var ErrNoAddressesFound = fmt.Errorf("No address found on the interface")
var CmdServer = &cli.Command{
Name: "server",
Usage: "Start the wireguard-negotiator server",
Flags: []cli.Flag{
&cli.StringFlag{
Name: "interface",
Aliases: []string{"i"},
Value: "wg0",
Usage: "An existing WireGuard interface to manage",
},
&cli.StringFlag{
Name: "config",
Aliases: []string{"c"},
Value: "",
DefaultText: "/etc/wireguard/<interface>.conf",
Usage: "Path to the existing WireGuard configuration file. WARNING: wireguard-negotiator will remove any comments in the file",
},
&cli.StringFlag{
Name: "endpoint",
Aliases: []string{"e"},
Value: "",
Required: true,
Usage: "Set the endpoint address",
},
&cli.StringFlag{
Name: "listen",
Aliases: []string{"l"},
Value: ":8080",
Usage: "Listen on this address",
},
&cli.BoolFlag{
Name: "interactive",
Aliases: []string{"I"},
Usage: "Enable interactive prompt before accepting new peers",
},
&cli.BoolFlag{
Name: "bin",
Aliases: []string{"B"},
Usage: "Serve the current wireguard-negotiator binary file upon GET request to /",
},
},
Action: runServer,
}
type request struct {
publicKey string
ip net.IP
}
func runServer(ctx *cli.Context) error {
inter := ctx.String("interface")
config := ctx.String("config")
if !ctx.IsSet("config") {
config = "/etc/wireguard/" + inter + ".conf"
}
endpoint := ctx.String("endpoint")
listen := ctx.String("listen")
interactive := ctx.Bool("interactive")
// Obtain the network interface
interf, err := net.InterfaceByName(inter)
if err != nil {
return err
}
// Obtain the server's public key
serverPublicKey, err := configReadInterfacePublicKey(config)
if err != nil {
return err
}
// TODO: Define this allocation method
// TODO: Include allocation behaviour in README
terribleCounterThatShouldNotExist := 1
interfAddrs, err := interf.Addrs()
if err != nil {
return err
}
// Obtain interface address for use in allocation
var interfIPNet *net.IPNet
if len(interfAddrs) < 1 {
return ErrNoAddressesFound
}
_, interfIPNet, err = net.ParseCIDR(interfAddrs[0].String())
// Set up interactive stuff
lineReader := bufio.NewReader(os.Stdin)
if !interactive {
lineReader = nil
}
addQueue := make(chan request, 0)
go adder(addQueue, inter, config)
gateQueue := make(chan request, 0)
go gater(gateQueue, addQueue, lineReader)
// TODO: Rate limiting
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "GET":
bin, err := os.Executable()
if err != nil {
w.WriteHeader(500)
return
}
file, err := os.Open(bin)
if err != nil {
w.WriteHeader(500)
return
}
_, err = io.Copy(w, file)
if err != nil {
log.Println("WARNING: Write binary executable to response failed")
return
}
default:
w.WriteHeader(405)
}
})
http.HandleFunc("/request", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "POST":
publicKey := r.PostFormValue("PublicKey")
// TODO: Ensure public key is new
// TODO: Validate public key
if len(publicKey) == 0 {
w.WriteHeader(400)
return
}
// Assign an IP address
terribleCounterThatShouldNotExist += 1
ip := incrementIP(interfIPNet.IP, terribleCounterThatShouldNotExist)
if !interfIPNet.Contains(ip) {
log.Println("WARNING: Ran out of addresses to allocate")
w.WriteHeader(500)
return
}
// Enqueue request into the gate
req := request{
ip: ip,
publicKey: publicKey,
}
// Wait for flush of configuration
gateQueue <- req
// Produce configuration to client
ipNet := &net.IPNet{
IP: ip,
Mask: interfIPNet.Mask,
}
netIPNet := &net.IPNet{
IP: interfIPNet.IP.Mask(interfIPNet.Mask),
Mask: interfIPNet.Mask,
}
resp := lib.PeerConfigResponse{
[]string{ipNet.String()},
[]string{netIPNet.String()},
serverPublicKey,
endpoint,
25,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
default:
w.WriteHeader(405)
}
})
server := &http.Server{
Addr: listen,
Handler: http.DefaultServeMux,
}
// Shutdown notifier
go func() {
sigint := make(chan os.Signal, 1)
signal.Notify(sigint, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
<-sigint
close(gateQueue)
close(addQueue)
if err := server.Shutdown(context.Background()); err != nil {
log.Printf("Server shutdown error: %v\n", err)
}
}()
log.Printf("Server listening on %v\n", listen)
return server.ListenAndServe()
}
func adder(queue chan request, inter string, config string) {
// Write requests to config and add peer
for {
select {
case req, ok := <-queue:
if !ok {
break
}
err := configAddPeer(config, req)
if err != nil {
log.Println(err)
continue
}
err = interAddPeer(inter, req, config)
if err != nil {
log.Println(err)
continue
}
}
}
}
func gater(queue chan request, result chan request, lineReader *bufio.Reader) {
// Receive requests and prompt the admin
for {
select {
case req, ok := <-queue:
if !ok {
return
}
fmt.Println(req.ip.String(), req.publicKey)
done := false
allowed := false
if lineReader == nil {
done = true
allowed = true
}
for !done {
fmt.Print("Allow? (y/n) ")
line, err := lineReader.ReadString('\n')
if err != nil {
log.Println(err)
return
}
switch line[:len(line)-1] {
case "y", "yes":
done = true
allowed = true
case "n", "no":
done = true
allowed = false
}
}
if allowed {
result <- req
}
}
}
}
func configAddPeer(config string, req request) error {
// For every request, open the config file again and rewrite it. Acceptable
// because this happens infrequently
// Preferably in the future, treat the configuration as a database
// For now, append to the config file
cfg := ini.Empty()
sec, _ := cfg.NewSection("Peer")
publicKey := sec.Key("PublicKey")
// TODO: Validation is needed
publicKey.SetValue(req.publicKey)
allowedIPs := sec.Key("AllowedIPs")
allowedHost := ipToIPNetWithHostMask(req.ip)
allowedIPs.SetValue((&allowedHost).String())
f, err := os.OpenFile(config, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return fmt.Errorf("opening %s failed: %w", config, err)
}
_, err = cfg.WriteTo(f)
if err != nil {
return fmt.Errorf("writing to %s failed: %w", config, err)
}
return nil
}
func configReadInterfacePublicKey(config string) (string, error) {
cfg, err := ini.Load(config)
if err != nil {
return "", fmt.Errorf("read interface public key failed: %w", err)
}
b64PrivateKey := cfg.Section("Interface").Key("PrivateKey").String()
wgPrivateKey, err := wgtypes.ParseKey(b64PrivateKey)
if err != nil {
return "", fmt.Errorf("read interface public key failed: %w", err)
}
wgPublicKey := wgPrivateKey.PublicKey()
return wgPublicKey.String(), nil
}
func interAddPeer(inter string, req request, config string) error {
// For every request, dynamically add the peer to the interface
// For now, simply run one fixed command to reread from the config file
cmd := exec.Command("wg", "setconf", inter, config)
err := cmd.Run()
if err != nil {
return fmt.Errorf("wq setconf failed: %w", err)
}
return nil
}
func ipToIPNetWithHostMask(ip net.IP) net.IPNet {
if ip4 := ip.To4(); ip4 != nil {
return net.IPNet{
IP: ip,
Mask: net.CIDRMask(32, 32),
}
}
return net.IPNet{
IP: ip,
Mask: net.CIDRMask(128, 128),
}
}
func incrementIP(ip net.IP, inc int) net.IP {
result := make(net.IP, len(ip))
copy(result, ip)
for i := len(ip) - 1; i >= 0; i-- {
remainder := inc % 256
overflow := int(result[i])+remainder > 255
result[i] += byte(remainder)
if overflow {
inc += 256
}
inc /= 256
}
return result
}