diff --git a/lib/wgconfig.go b/lib/wgconfig.go index bef2b9f..3c73179 100644 --- a/lib/wgconfig.go +++ b/lib/wgconfig.go @@ -3,6 +3,7 @@ package lib import ( "bufio" + "bytes" "fmt" "io" "net" @@ -34,15 +35,31 @@ var ( ErrUnknownSection = fmt.Errorf("unknown section") ErrUnknownKey = fmt.Errorf("unknown key") - ErrValueParse = fmt.Errorf("value parse failed") + ErrValueParse = fmt.Errorf("value parse failed") + ErrPersistentKeepaliveRange = fmt.Errorf("persistent keepalive interval is neither 0/off nor 1-65535") ) -// ReadConfig is yet another INI-like configuration file parser, but for WireGuard -func ReadConfig(r io.Reader) (wgtypes.Config, error) { +type EndpointMap map[string]string + +func (e EndpointMap) insert(udpAddr net.UDPAddr, endpoint string) { + e[udpAddr.String()] = endpoint +} + +func (e EndpointMap) revert(udpAddr net.UDPAddr) string { + initial, ok := e[udpAddr.String()] + if !ok { + return udpAddr.String() + } + return initial +} + +// ReadConfig is yet another INI-like configuration file parser, but for WireGuard Config +func ReadConfig(r io.Reader) (wgtypes.Config, EndpointMap, error) { scanner := bufio.NewScanner(r) config := wgtypes.Config{ReplacePeers: true} section := SectionNone + endpointMap := make(EndpointMap) for scanner.Scan() { text := scanner.Text() @@ -59,7 +76,7 @@ func ReadConfig(r io.Reader) (wgtypes.Config, error) { ReplaceAllowedIPs: true, }) case len(s) > 0: - return config, fmt.Errorf("%w: %v", ErrUnknownSection, s) + return config, endpointMap, fmt.Errorf("%w: %v", ErrUnknownSection, s) } if len(k) == 0 { @@ -71,30 +88,25 @@ func ReadConfig(r io.Reader) (wgtypes.Config, error) { case SectionDevice: switch { case insensetiveMatch(k, "ListenPort"): - listenPort, err := strconv.ParseInt(v, 0, 0) + port, err := parsePort(v) if err != nil { - return config, fmt.Errorf("%w: %w: %v=%v", ErrValueParse, err, k, v) + return config, endpointMap, fmt.Errorf("%w: %w: %v=%v", ErrValueParse, err, k, v) } - listenPortInt := int(listenPort) - config.ListenPort = &listenPortInt + config.ListenPort = &port case insensetiveMatch(k, "FwMark"): - fwMarkInt := 0 - if !insensetiveMatch(v, "off") { - fwMark, err := strconv.ParseInt(v, 0, 0) - if err != nil { - return config, fmt.Errorf("%w: %w: %v=%v", ErrValueParse, err, k, v) - } - fwMarkInt = int(fwMark) + fwMark, err := parseFwMark(v) + if err != nil { + return config, endpointMap, fmt.Errorf("%w: %w: %v=%v", ErrValueParse, err, k, v) } - config.FirewallMark = &fwMarkInt + config.FirewallMark = &fwMark case insensetiveMatch(k, "PrivateKey"): key, err := wgtypes.ParseKey(v) if err != nil { - return config, fmt.Errorf("%w: %w: %v=%v", ErrValueParse, err, k, v) + return config, endpointMap, fmt.Errorf("%w: %w: %v=%v", ErrValueParse, err, k, v) } config.PrivateKey = &key default: - return config, fmt.Errorf("%w: %v: %v", ErrUnknownKey, sectionNames[section], k) + return config, endpointMap, fmt.Errorf("%w: %v: %v", ErrUnknownKey, sectionNames[section], k) } case SectionPeer: peer := &config.Peers[len(config.Peers)-1] @@ -102,44 +114,72 @@ func ReadConfig(r io.Reader) (wgtypes.Config, error) { case insensetiveMatch(k, "Endpoint"): endpoint, err := net.ResolveUDPAddr("udp", v) if err != nil { - return config, fmt.Errorf("%w: %w: %v=%v", ErrValueParse, err, k, v) + return config, endpointMap, fmt.Errorf("%w: %w: %v=%v", ErrValueParse, err, k, v) } + endpointMap.insert(*endpoint, v) peer.Endpoint = endpoint case insensetiveMatch(k, "PublicKey"): key, err := wgtypes.ParseKey(v) if err != nil { - return config, fmt.Errorf("%w: %w: %v=%v", ErrValueParse, err, k, v) + return config, endpointMap, fmt.Errorf("%w: %w: %v=%v", ErrValueParse, err, k, v) } peer.PublicKey = key case insensetiveMatch(k, "AllowedIPs"): allowedIPs, err := parseAllowedIPs(v) if err != nil { - return config, fmt.Errorf("%w: %w: %v=%v", ErrValueParse, err, k, v) + return config, endpointMap, fmt.Errorf("%w: %w: %v=%v", ErrValueParse, err, k, v) } peer.AllowedIPs = allowedIPs case insensetiveMatch(k, "PersistentKeepalive"): - persistentKeepalive := int64(0) - var err error - if !insensetiveMatch(v, "off") { - persistentKeepalive, err = strconv.ParseInt(v, 0, 64) - if err != nil { - return config, fmt.Errorf("%w: %w: %v=%v", ErrValueParse, err, k, v) - } + persistentKeepalive, err := parsePersistentKeepalive(v) + if err != nil { + return config, endpointMap, fmt.Errorf("%w: %w: %v=%v", ErrValueParse, err, k, v) } - if persistentKeepalive < 0 || persistentKeepalive > 65535 { - return config, fmt.Errorf("%w: Persistent keepalive interval is neither 0/off nor 1-65535: %v=%v", ErrValueParse, k, v) - } - persistentKeepaliveDuration := time.Duration(persistentKeepalive * int64(time.Second)) - peer.PersistentKeepaliveInterval = &persistentKeepaliveDuration + peer.PersistentKeepaliveInterval = &persistentKeepalive case insensetiveMatch(k, "PresharedKey"): - + key, err := wgtypes.ParseKey(v) + if err != nil { + return config, endpointMap, fmt.Errorf("%w: %w: %v=%v", ErrValueParse, err, k, v) + } + peer.PresharedKey = &key default: - return config, fmt.Errorf("%w: %v: %v", ErrUnknownKey, sectionNames[section], k) + return config, endpointMap, fmt.Errorf("%w: %v: %v", ErrUnknownKey, sectionNames[section], k) } } } - return config, nil + return config, endpointMap, nil +} + +// WriteConfig writes out WireGuard Device configuration into a buffer +func WriteConfig(w io.Writer, config wgtypes.Device, endpointMap EndpointMap) { + var emptyKey [wgtypes.KeyLen]byte + + writeConfigLine(w, formatSection("Interface")) + + writeConfigLine(w, formatLineKeyValue("PrivateKey", config.PrivateKey.String())) + if config.ListenPort > 0 { + writeConfigLine(w, formatLineKeyValue("ListenPort", formatPort(config.ListenPort))) + } + if config.FirewallMark > 0 { + writeConfigLine(w, formatLineKeyValue("FwMark", formatFwMark(config.FirewallMark))) + } + + for _, peer := range config.Peers { + writeConfigLine(w, "") + writeConfigLine(w, formatSection("Peer")) + writeConfigLine(w, formatLineKeyValue("PublicKey", peer.PublicKey.String())) + if !bytes.Equal(peer.PresharedKey[:], emptyKey[:]) { + writeConfigLine(w, formatLineKeyValue("PresharedKey", peer.PresharedKey.String())) + } + writeConfigLine(w, formatLineKeyValue("AllowedIPs", formatAllowedIPs(peer.AllowedIPs))) + if peer.PersistentKeepaliveInterval > 0 { + writeConfigLine(w, formatLineKeyValue("PersistentKeepalive", formatPersistentKeepalive(peer.PersistentKeepaliveInterval))) + } + if peer.Endpoint != nil { + writeConfigLine(w, formatLineKeyValue("Endpoint", endpointMap.revert(*peer.Endpoint))) + } + } } func readConfigLine(text string) (line, comments string) { @@ -156,6 +196,9 @@ func readConfigLine(text string) (line, comments string) { comments = strings.TrimSpace(comments) return } +func writeConfigLine(w io.Writer, line string) (int, error) { + return io.WriteString(w, line+"\n") +} func parseLine(line string) (section, key, value string) { if len(line) < 1 { @@ -173,21 +216,77 @@ func parseLine(line string) (section, key, value string) { return "", "", "" } +func formatSection(section string) string { + return "[" + section + "]" +} +func formatLineKeyValue(key, value string) string { + return key + " = " + value +} func insensetiveMatch(a string, b string) bool { return strings.ToLower(a) == strings.ToLower(b) } +func parsePort(s string) (int, error) { + port, err := strconv.ParseInt(s, 0, 0) + return int(port), err +} +func formatPort(port int) string { + return strconv.FormatInt(int64(port), 10) +} + +func parseFwMark(s string) (int, error) { + if insensetiveMatch(s, "off") { + return 0, nil + } + fwMark, err := strconv.ParseInt(s, 0, 0) + return int(fwMark), err +} +func formatFwMark(fwMark int) string { + if fwMark == 0 { + return "off" + } + return strconv.FormatInt(int64(fwMark), 10) +} + func parseAllowedIPs(s string) ([]net.IPNet, error) { - parsedIPs := make([]net.IPNet, 0) stringIPs := strings.Split(s, ",") - for _, stringIP := range stringIPs { + parsedIPs := make([]net.IPNet, len(stringIPs)) + for i, stringIP := range stringIPs { stringIP := strings.TrimSpace(stringIP) _, parsedIP, err := net.ParseCIDR(stringIP) if err != nil { return parsedIPs, err } - parsedIPs = append(parsedIPs, *parsedIP) + parsedIPs[i] = *parsedIP } return parsedIPs, nil } +func formatAllowedIPs(allowedIPs []net.IPNet) string { + stringIPs := make([]string, len(allowedIPs)) + for i, allowedIP := range allowedIPs { + stringIPs[i] = allowedIP.String() + } + return strings.Join(stringIPs, ", ") +} + +func parsePersistentKeepalive(s string) (time.Duration, error) { + if insensetiveMatch(s, "off") { + return time.Duration(0), nil + } + persistentKeepalive, err := strconv.ParseInt(s, 0, 64) + if err != nil { + return time.Duration(0), err + } + if persistentKeepalive < 0 || persistentKeepalive > 65535 { + return time.Duration(0), ErrPersistentKeepaliveRange + } + + return time.Duration(persistentKeepalive * int64(time.Second)), err +} +func formatPersistentKeepalive(persistentKeepalive time.Duration) string { + if int64(persistentKeepalive) == 0 { + return "off" + } + return strconv.FormatInt(int64(persistentKeepalive/time.Second), 10) +} diff --git a/lib/wgconfig_test.go b/lib/wgconfig_test.go index a997d66..1dc2d26 100644 --- a/lib/wgconfig_test.go +++ b/lib/wgconfig_test.go @@ -1,7 +1,9 @@ package lib import ( + "errors" "net" + "strconv" "strings" "testing" "time" @@ -19,19 +21,21 @@ ListenPort = 3333 PrivateKey = MITUgapB4QfRFF54ITXL3TaiYiSsVYkchqfjAXjxM10= [Peer] PublicKey = pjFx72IjbMh84SH1nq8Qfbl7HD5mSScHXCV1eISR7lk= -AllowedIPs = 192.168.10.2/32, 2001:470:ed5d:a::2/128 +AllowedIPs = 192.168.10.2/32, 2001:470:ed5d:a::2/128 PersistentKeepalive = 80 [Peer] -AllowedIPs = 192.168.10.40/32, 2001:470:ed5d:a::28/128 +AllowedIPs = 192.168.10.40/32 , 2001:470:ed5d:a::28/128 PublicKey = wXU+vSTdEoIwSi+Tmv35SCOFg17wCAwnmYxeQPpbzDg= +PresharedKey = wXU+vSTdEoIwSi+Tmv35SCOFg17wCAwnmYxeQPpbzDg= +Endpoint = example.com:4444 ` func TestReadConfig1(t *testing.T) { buf := strings.NewReader(testGoodConfig1) - got, err := ReadConfig(buf) + got, gotEndpointMap, err := ReadConfig(buf) if err != nil { - t.Fatalf("config read failed: %w", err) + t.Fatalf("config read failed: %v", err) } wantPrivateKey, _ := wgtypes.ParseKey("MITUgapB4QfRFF54ITXL3TaiYiSsVYkchqfjAXjxM10=") @@ -43,6 +47,8 @@ func TestReadConfig1(t *testing.T) { wantPeer2PublicKey, _ := wgtypes.ParseKey("wXU+vSTdEoIwSi+Tmv35SCOFg17wCAwnmYxeQPpbzDg=") _, wantPeer2AllowedIP1, _ := net.ParseCIDR("192.168.10.40/32") _, wantPeer2AllowedIP2, _ := net.ParseCIDR("2001:470:ed5d:a::28/128") + wantPeer2PresharedKey, _ := wgtypes.ParseKey("wXU+vSTdEoIwSi+Tmv35SCOFg17wCAwnmYxeQPpbzDg=") + wantPeer2Endpoint, _ := net.ResolveUDPAddr("udp", "example.com:4444") want := wgtypes.Config{ PrivateKey: &wantPrivateKey, @@ -65,6 +71,56 @@ func TestReadConfig1(t *testing.T) { *wantPeer2AllowedIP1, *wantPeer2AllowedIP2, }, + Endpoint: wantPeer2Endpoint, + PresharedKey: &wantPeer2PresharedKey, + }, + }, + } + + wantEndpointMap := EndpointMap{} + wantEndpointMap.insert(*wantPeer2Endpoint, "example.com:4444") + + if diff := cmp.Diff(want, got); diff != "" { + t.Fatalf("returned config is not what is wanted: \n%s", diff) + } + if diff := cmp.Diff(wantEndpointMap, gotEndpointMap); diff != "" { + t.Fatalf("returned endpointMap is not what is wanted: \n%s", diff) + } +} + +const testGoodConfig2 = ` +[Interface] +PrivateKey = MITUgapB4QfRFF54ITXL3TaiYiSsVYkchqfjAXjxM10= + +[Peer] +PublicKey = pjFx72IjbMh84SH1nq8Qfbl7HD5mSScHXCV1eISR7lk= + +[Peer] +PublicKey = wXU+vSTdEoIwSi+Tmv35SCOFg17wCAwnmYxeQPpbzDg= +` + +func TestReadConfig2(t *testing.T) { + buf := strings.NewReader(testGoodConfig2) + got, _, err := ReadConfig(buf) + if err != nil { + t.Fatalf("config read failed: %v", err) + } + + wantPrivateKey, _ := wgtypes.ParseKey("MITUgapB4QfRFF54ITXL3TaiYiSsVYkchqfjAXjxM10=") + wantPeer1PublicKey, _ := wgtypes.ParseKey("pjFx72IjbMh84SH1nq8Qfbl7HD5mSScHXCV1eISR7lk=") + wantPeer2PublicKey, _ := wgtypes.ParseKey("wXU+vSTdEoIwSi+Tmv35SCOFg17wCAwnmYxeQPpbzDg=") + + want := wgtypes.Config{ + PrivateKey: &wantPrivateKey, + ReplacePeers: true, + Peers: []wgtypes.PeerConfig{ + wgtypes.PeerConfig{ + PublicKey: wantPeer1PublicKey, + ReplaceAllowedIPs: true, + }, + wgtypes.PeerConfig{ + PublicKey: wantPeer2PublicKey, + ReplaceAllowedIPs: true, }, }, } @@ -73,3 +129,117 @@ func TestReadConfig1(t *testing.T) { t.Fatalf("returned config is not what is wanted: \n%s", diff) } } + +const testWantConfig1 = `[Interface] +PrivateKey = MITUgapB4QfRFF54ITXL3TaiYiSsVYkchqfjAXjxM10= +ListenPort = 3333 + +[Peer] +PublicKey = pjFx72IjbMh84SH1nq8Qfbl7HD5mSScHXCV1eISR7lk= +AllowedIPs = 192.168.10.2/32, 2001:470:ed5d:a::2/128 +PersistentKeepalive = 80 +Endpoint = example.com:4444 + +[Peer] +PublicKey = wXU+vSTdEoIwSi+Tmv35SCOFg17wCAwnmYxeQPpbzDg= +PresharedKey = wXU+vSTdEoIwSi+Tmv35SCOFg17wCAwnmYxeQPpbzDg= +AllowedIPs = 192.168.10.40/32, 2001:470:ed5d:a::28/128 +` + +func TestWriteConfig1(t *testing.T) { + var buf strings.Builder + + wantPrivateKey, _ := wgtypes.ParseKey("MITUgapB4QfRFF54ITXL3TaiYiSsVYkchqfjAXjxM10=") + wantListenPort := 3333 + wantPeer1PublicKey, _ := wgtypes.ParseKey("pjFx72IjbMh84SH1nq8Qfbl7HD5mSScHXCV1eISR7lk=") + _, wantPeer1AllowedIP1, _ := net.ParseCIDR("192.168.10.2/32") + _, wantPeer1AllowedIP2, _ := net.ParseCIDR("2001:470:ed5d:a::2/128") + wantPeer1PersistentKeepalive, _ := time.ParseDuration("80s") + wantPeer2PublicKey, _ := wgtypes.ParseKey("wXU+vSTdEoIwSi+Tmv35SCOFg17wCAwnmYxeQPpbzDg=") + _, wantPeer2AllowedIP1, _ := net.ParseCIDR("192.168.10.40/32") + _, wantPeer2AllowedIP2, _ := net.ParseCIDR("2001:470:ed5d:a::28/128") + wantPeer2PresharedKey, _ := wgtypes.ParseKey("wXU+vSTdEoIwSi+Tmv35SCOFg17wCAwnmYxeQPpbzDg=") + wantPeer2Endpoint, _ := net.ResolveUDPAddr("udp", "example.com:4444") + + config := wgtypes.Device{ + PrivateKey: wantPrivateKey, + ListenPort: wantListenPort, + Peers: []wgtypes.Peer{ + wgtypes.Peer{ + PublicKey: wantPeer1PublicKey, + AllowedIPs: []net.IPNet{ + *wantPeer1AllowedIP1, + *wantPeer1AllowedIP2, + }, + Endpoint: wantPeer2Endpoint, + PersistentKeepaliveInterval: wantPeer1PersistentKeepalive, + }, + wgtypes.Peer{ + PublicKey: wantPeer2PublicKey, + AllowedIPs: []net.IPNet{ + *wantPeer2AllowedIP1, + *wantPeer2AllowedIP2, + }, + PresharedKey: wantPeer2PresharedKey, + }, + }, + } + + endpointMap := EndpointMap{} + endpointMap.insert(*wantPeer2Endpoint, "example.com:4444") + + WriteConfig(&buf, config, endpointMap) + + if diff := cmp.Diff(testWantConfig1, buf.String()); diff != "" { + t.Fatalf("returned config is not what is wanted: \n%s", diff) + } +} + +func TestPersistentKeepalive(t *testing.T) { + parseWant, _ := time.ParseDuration("10s") + parseGot, err := parsePersistentKeepalive("10") + if err != nil { + t.Fatalf("parsed error %v, want %v", err, parseWant) + } + if parseWant != parseGot { + t.Fatalf("parsed %v, want %v", parseGot, parseWant) + } + + parseWant, _ = time.ParseDuration("0s") + parseGot, err = parsePersistentKeepalive("off") + if err != nil { + t.Fatalf("parsed error %v, want %v", err, parseWant) + } + if parseWant != parseGot { + t.Fatalf("parsed %v, want %v", parseGot, parseWant) + } + + var reasonErr *strconv.NumError + + _, err = parsePersistentKeepalive("10e") + if !errors.As(err, &reasonErr) || reasonErr.Err != strconv.ErrSyntax { + t.Fatalf("parsed error %v, want error %v", reasonErr.Err, strconv.ErrSyntax) + } + + _, err = parsePersistentKeepalive("1000000s") + if !errors.As(err, &reasonErr) || reasonErr.Err != strconv.ErrSyntax { + t.Fatalf("parsed error %v, want error %v", reasonErr.Err, strconv.ErrSyntax) + } + + _, err = parsePersistentKeepalive("1000000") + if !errors.Is(err, ErrPersistentKeepaliveRange) { + t.Fatalf("parsed error %v, want error %v", err, ErrPersistentKeepaliveRange) + } + + formatWant := "off" + formatGot := formatPersistentKeepalive(time.Duration(0)) + if formatWant != formatGot { + t.Fatalf("format %v, want %v", formatGot, formatWant) + } + + formatWant = "11" + formatGot = formatPersistentKeepalive(time.Duration(11 * int64(time.Second))) + if formatWant != formatGot { + t.Fatalf("format %v, want %v", formatGot, formatWant) + } +}