293 lines
7.8 KiB
Go
293 lines
7.8 KiB
Go
// TODO: Should split between encoder and decoder
|
|
package lib
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
)
|
|
|
|
const (
|
|
CommentChar = "#"
|
|
AssignmentChar = "="
|
|
)
|
|
|
|
const (
|
|
SectionNone = iota
|
|
SectionDevice
|
|
SectionPeer
|
|
)
|
|
|
|
var sectionNames = []string{
|
|
"None",
|
|
"Device",
|
|
"Peer",
|
|
}
|
|
|
|
var (
|
|
ErrUnknownSection = fmt.Errorf("unknown section")
|
|
ErrUnknownKey = fmt.Errorf("unknown key")
|
|
|
|
ErrValueParse = fmt.Errorf("value parse failed")
|
|
ErrPersistentKeepaliveRange = fmt.Errorf("persistent keepalive interval is neither 0/off nor 1-65535")
|
|
)
|
|
|
|
type EndpointMap map[string]string
|
|
|
|
func (e EndpointMap) insert(udpAddr net.UDPAddr, endpoint string) {
|
|
e[udpAddr.String()] = endpoint
|
|
}
|
|
|
|
func (e EndpointMap) revert(udpAddr net.UDPAddr) string {
|
|
initial, ok := e[udpAddr.String()]
|
|
if !ok {
|
|
return udpAddr.String()
|
|
}
|
|
return initial
|
|
}
|
|
|
|
// ReadConfig is yet another INI-like configuration file parser, but for WireGuard Config
|
|
func ReadConfig(r io.Reader) (wgtypes.Config, EndpointMap, error) {
|
|
scanner := bufio.NewScanner(r)
|
|
|
|
config := wgtypes.Config{ReplacePeers: true}
|
|
section := SectionNone
|
|
endpointMap := make(EndpointMap)
|
|
|
|
for scanner.Scan() {
|
|
text := scanner.Text()
|
|
line, _ := readConfigLine(text)
|
|
|
|
s, k, v := parseLine(line)
|
|
|
|
switch {
|
|
case insensetiveMatch(s, "Interface"):
|
|
section = SectionDevice
|
|
case insensetiveMatch(s, "Peer"):
|
|
section = SectionPeer
|
|
config.Peers = append(config.Peers, wgtypes.PeerConfig{
|
|
ReplaceAllowedIPs: true,
|
|
})
|
|
case len(s) > 0:
|
|
return config, endpointMap, fmt.Errorf("%w: %v", ErrUnknownSection, s)
|
|
}
|
|
|
|
if len(k) == 0 {
|
|
continue
|
|
}
|
|
|
|
// TODO: break out parsers into functions
|
|
switch section {
|
|
case SectionDevice:
|
|
switch {
|
|
case insensetiveMatch(k, "ListenPort"):
|
|
port, err := parsePort(v)
|
|
if err != nil {
|
|
return config, endpointMap, fmt.Errorf("%w: %w: %v=%v", ErrValueParse, err, k, v)
|
|
}
|
|
config.ListenPort = &port
|
|
case insensetiveMatch(k, "FwMark"):
|
|
fwMark, err := parseFwMark(v)
|
|
if err != nil {
|
|
return config, endpointMap, fmt.Errorf("%w: %w: %v=%v", ErrValueParse, err, k, v)
|
|
}
|
|
config.FirewallMark = &fwMark
|
|
case insensetiveMatch(k, "PrivateKey"):
|
|
key, err := wgtypes.ParseKey(v)
|
|
if err != nil {
|
|
return config, endpointMap, fmt.Errorf("%w: %w: %v=%v", ErrValueParse, err, k, v)
|
|
}
|
|
config.PrivateKey = &key
|
|
default:
|
|
return config, endpointMap, fmt.Errorf("%w: %v: %v", ErrUnknownKey, sectionNames[section], k)
|
|
}
|
|
case SectionPeer:
|
|
peer := &config.Peers[len(config.Peers)-1]
|
|
switch {
|
|
case insensetiveMatch(k, "Endpoint"):
|
|
endpoint, err := net.ResolveUDPAddr("udp", v)
|
|
if err != nil {
|
|
return config, endpointMap, fmt.Errorf("%w: %w: %v=%v", ErrValueParse, err, k, v)
|
|
}
|
|
endpointMap.insert(*endpoint, v)
|
|
peer.Endpoint = endpoint
|
|
case insensetiveMatch(k, "PublicKey"):
|
|
key, err := wgtypes.ParseKey(v)
|
|
if err != nil {
|
|
return config, endpointMap, fmt.Errorf("%w: %w: %v=%v", ErrValueParse, err, k, v)
|
|
}
|
|
peer.PublicKey = key
|
|
case insensetiveMatch(k, "AllowedIPs"):
|
|
allowedIPs, err := parseAllowedIPs(v)
|
|
if err != nil {
|
|
return config, endpointMap, fmt.Errorf("%w: %w: %v=%v", ErrValueParse, err, k, v)
|
|
}
|
|
peer.AllowedIPs = allowedIPs
|
|
case insensetiveMatch(k, "PersistentKeepalive"):
|
|
persistentKeepalive, err := parsePersistentKeepalive(v)
|
|
if err != nil {
|
|
return config, endpointMap, fmt.Errorf("%w: %w: %v=%v", ErrValueParse, err, k, v)
|
|
}
|
|
peer.PersistentKeepaliveInterval = &persistentKeepalive
|
|
case insensetiveMatch(k, "PresharedKey"):
|
|
key, err := wgtypes.ParseKey(v)
|
|
if err != nil {
|
|
return config, endpointMap, fmt.Errorf("%w: %w: %v=%v", ErrValueParse, err, k, v)
|
|
}
|
|
peer.PresharedKey = &key
|
|
default:
|
|
return config, endpointMap, fmt.Errorf("%w: %v: %v", ErrUnknownKey, sectionNames[section], k)
|
|
}
|
|
}
|
|
}
|
|
|
|
return config, endpointMap, nil
|
|
}
|
|
|
|
// WriteConfig writes out WireGuard Device configuration into a buffer
|
|
func WriteConfig(w io.Writer, config wgtypes.Device, endpointMap EndpointMap) {
|
|
var emptyKey [wgtypes.KeyLen]byte
|
|
|
|
writeConfigLine(w, formatSection("Interface"))
|
|
|
|
writeConfigLine(w, formatLineKeyValue("PrivateKey", config.PrivateKey.String()))
|
|
if config.ListenPort > 0 {
|
|
writeConfigLine(w, formatLineKeyValue("ListenPort", formatPort(config.ListenPort)))
|
|
}
|
|
if config.FirewallMark > 0 {
|
|
writeConfigLine(w, formatLineKeyValue("FwMark", formatFwMark(config.FirewallMark)))
|
|
}
|
|
|
|
for _, peer := range config.Peers {
|
|
writeConfigLine(w, "")
|
|
writeConfigLine(w, formatSection("Peer"))
|
|
writeConfigLine(w, formatLineKeyValue("PublicKey", peer.PublicKey.String()))
|
|
if !bytes.Equal(peer.PresharedKey[:], emptyKey[:]) {
|
|
writeConfigLine(w, formatLineKeyValue("PresharedKey", peer.PresharedKey.String()))
|
|
}
|
|
writeConfigLine(w, formatLineKeyValue("AllowedIPs", formatAllowedIPs(peer.AllowedIPs)))
|
|
if peer.PersistentKeepaliveInterval > 0 {
|
|
writeConfigLine(w, formatLineKeyValue("PersistentKeepalive", formatPersistentKeepalive(peer.PersistentKeepaliveInterval)))
|
|
}
|
|
if peer.Endpoint != nil {
|
|
writeConfigLine(w, formatLineKeyValue("Endpoint", endpointMap.revert(*peer.Endpoint)))
|
|
}
|
|
}
|
|
}
|
|
|
|
func readConfigLine(text string) (line, comments string) {
|
|
line = text
|
|
comments = ""
|
|
|
|
comment := strings.Index(line, CommentChar)
|
|
if comment >= 0 {
|
|
line = text[:comment]
|
|
comments = text[comment+1:]
|
|
}
|
|
|
|
line = strings.TrimSpace(line)
|
|
comments = strings.TrimSpace(comments)
|
|
return
|
|
}
|
|
func writeConfigLine(w io.Writer, line string) (int, error) {
|
|
return io.WriteString(w, line+"\n")
|
|
}
|
|
|
|
func parseLine(line string) (section, key, value string) {
|
|
if len(line) < 1 {
|
|
return "", "", ""
|
|
}
|
|
|
|
if line[0] == '[' && line[len(line)-1] == ']' {
|
|
return line[1 : len(line)-1], "", ""
|
|
}
|
|
|
|
assign := strings.Index(line, AssignmentChar)
|
|
if assign >= 0 {
|
|
return "", strings.TrimSpace(line[:assign]), strings.TrimSpace(line[assign+1:])
|
|
}
|
|
|
|
return "", "", ""
|
|
}
|
|
func formatSection(section string) string {
|
|
return "[" + section + "]"
|
|
}
|
|
func formatLineKeyValue(key, value string) string {
|
|
return key + " = " + value
|
|
}
|
|
|
|
func insensetiveMatch(a string, b string) bool {
|
|
return strings.ToLower(a) == strings.ToLower(b)
|
|
}
|
|
|
|
func parsePort(s string) (int, error) {
|
|
port, err := strconv.ParseInt(s, 0, 0)
|
|
return int(port), err
|
|
}
|
|
func formatPort(port int) string {
|
|
return strconv.FormatInt(int64(port), 10)
|
|
}
|
|
|
|
func parseFwMark(s string) (int, error) {
|
|
if insensetiveMatch(s, "off") {
|
|
return 0, nil
|
|
}
|
|
fwMark, err := strconv.ParseInt(s, 0, 0)
|
|
return int(fwMark), err
|
|
}
|
|
func formatFwMark(fwMark int) string {
|
|
if fwMark == 0 {
|
|
return "off"
|
|
}
|
|
return strconv.FormatInt(int64(fwMark), 10)
|
|
}
|
|
|
|
func parseAllowedIPs(s string) ([]net.IPNet, error) {
|
|
stringIPs := strings.Split(s, ",")
|
|
parsedIPs := make([]net.IPNet, len(stringIPs))
|
|
for i, stringIP := range stringIPs {
|
|
stringIP := strings.TrimSpace(stringIP)
|
|
_, parsedIP, err := net.ParseCIDR(stringIP)
|
|
if err != nil {
|
|
return parsedIPs, err
|
|
}
|
|
parsedIPs[i] = *parsedIP
|
|
}
|
|
return parsedIPs, nil
|
|
}
|
|
func formatAllowedIPs(allowedIPs []net.IPNet) string {
|
|
stringIPs := make([]string, len(allowedIPs))
|
|
for i, allowedIP := range allowedIPs {
|
|
stringIPs[i] = allowedIP.String()
|
|
}
|
|
return strings.Join(stringIPs, ", ")
|
|
}
|
|
|
|
func parsePersistentKeepalive(s string) (time.Duration, error) {
|
|
if insensetiveMatch(s, "off") {
|
|
return time.Duration(0), nil
|
|
}
|
|
persistentKeepalive, err := strconv.ParseInt(s, 0, 64)
|
|
if err != nil {
|
|
return time.Duration(0), err
|
|
}
|
|
if persistentKeepalive < 0 || persistentKeepalive > 65535 {
|
|
return time.Duration(0), ErrPersistentKeepaliveRange
|
|
}
|
|
|
|
return time.Duration(persistentKeepalive * int64(time.Second)), err
|
|
}
|
|
func formatPersistentKeepalive(persistentKeepalive time.Duration) string {
|
|
if int64(persistentKeepalive) == 0 {
|
|
return "off"
|
|
}
|
|
return strconv.FormatInt(int64(persistentKeepalive/time.Second), 10)
|
|
}
|