diff --git a/filter/acls_parser.go b/filter/acls_parser.go index b01e641..1de6002 100644 --- a/filter/acls_parser.go +++ b/filter/acls_parser.go @@ -1,14 +1,20 @@ package filter import ( + "fmt" + "github.com/inverse-inc/wireguard-go/device" "strconv" "strings" ) -func AclsToRules(acls ...string) Rules { +func AclsToRules(logger *device.Logger, acls ...string) Rules { rules := []RuleFunc{} for _, acl := range acls { - rule := AclToRule(acl) + rule, err := AclToRule(acl) + if err != nil { + logger.Error.Println(err) + } + if rule != nil { rules = append(rules, rule) } @@ -17,19 +23,17 @@ func AclsToRules(acls ...string) Rules { return rules } -func AclsToRulesFilter(acls []string, pre, post RuleFunc) func([]byte) error { +func AclsToRulesFilter(logger *device.Logger, acls []string, pre, post RuleFunc) func([]byte) error { rules := Rules([]RuleFunc{}) if pre != nil { rules = append(rules, pre) } - aclRules := AclsToRules(acls...) + aclRules := AclsToRules(logger, acls...) rules = append(rules, aclRules...) if post != nil { rules = append(rules, post) - } - - if len(aclRules) == 0 { + } else if len(aclRules) == 0 { rules = append(rules, RulePermit) } @@ -104,14 +108,14 @@ func (r *Ipv4RuleData) AnyProto() bool { return r.protocol == allProtocols } -func AclToRule(acl string) RuleFunc { +func AclToRule(acl string) (RuleFunc, error) { tokens := strings.Fields(acl) if len(tokens) < 2 { - return nil + return nil, fmt.Errorf("Acl to short: '%s'", acl) } if len(tokens) == 2 { - return twoPartAcl(tokens) + return twoPartAcl(acl, tokens) } rule := NewIpv4RuleData() @@ -124,22 +128,22 @@ func AclToRule(acl string) RuleFunc { case "deny": rule.cmd = Deny default: - return nil + return nil, invalidAcl(acl) } rule.src.Network, rule.src.Mask, tokens, ok = getSource(tokens) if ok { - return singleIpRule(rule.src, rule.cmd, SRC_IP_OFFSET) + return singleIpRule(rule.src, rule.cmd, SRC_IP_OFFSET), nil } rule.protocol, tokens, ok = getProtocol(tokens) if !ok { - return nil + return nil, invalidAcl(acl) } rule.src.Network, rule.src.Mask, tokens, ok = getSource(tokens) if !ok { - return nil + return nil, invalidAcl(acl) } if rule.protocol == 6 || rule.protocol == 17 { @@ -148,7 +152,7 @@ func AclToRule(acl string) RuleFunc { rule.dst.Network, rule.dst.Mask, tokens, ok = getSource(tokens) if !ok { - return nil + return nil, invalidAcl(acl) } switch rule.protocol { @@ -161,47 +165,47 @@ func AclToRule(acl string) RuleFunc { if rule.AnyIp() { if rule.AnyProto() { if rule.cmd == Permit { - return PermitAllRule() + return PermitAllRule(), nil } - return DenyAllRule() + return DenyAllRule(), nil } if rule.protocol == 1 { - return icmpRuleRule(rule.icmpRule, rule.cmd) + return icmpRuleRule(rule.icmpRule, rule.cmd), nil } if rule.AnyPort() { - return protoRule(byte(rule.protocol), rule.cmd) + return protoRule(byte(rule.protocol), rule.cmd), nil } if rule.AnySrcPort() && !rule.AnyDstPort() { - return portProtoRule(rule.dstPort, byte(rule.protocol), rule.cmd, 2) + return portProtoRule(rule.dstPort, byte(rule.protocol), rule.cmd, 2), nil } if !rule.AnySrcPort() && rule.AnyDstPort() { - return portProtoRule(rule.srcPort, byte(rule.protocol), rule.cmd, 0) + return portProtoRule(rule.srcPort, byte(rule.protocol), rule.cmd, 0), nil } } if rule.AnyProto() { if rule.AnySrcIP() && !rule.AnyDstIP() { - return singleIpRule(rule.dst, rule.cmd, DST_IP_OFFSET) + return singleIpRule(rule.dst, rule.cmd, DST_IP_OFFSET), nil } if !rule.AnySrcIP() && rule.AnyDstIP() { - return singleIpRule(rule.src, rule.cmd, SRC_IP_OFFSET) + return singleIpRule(rule.src, rule.cmd, SRC_IP_OFFSET), nil } - return srcDstRule(rule.src, rule.dst, rule.cmd) + return srcDstRule(rule.src, rule.dst, rule.cmd), nil } if rule.AnyPort() { - return srcDstProtoRule(rule.src, rule.dst, byte(rule.protocol), rule.cmd) + return srcDstProtoRule(rule.src, rule.dst, byte(rule.protocol), rule.cmd), nil } - return srcDstProtoSrcPortDstPort(rule.src, rule.dst, byte(rule.protocol), rule.srcPort, rule.dstPort, rule.cmd) + return srcDstProtoSrcPortDstPort(rule.src, rule.dst, byte(rule.protocol), rule.srcPort, rule.dstPort, rule.cmd), nil } func getProtocol(oTokens []string) (int16, []string, bool) { @@ -461,15 +465,19 @@ func dtoi(s string) (n int, i int, ok bool) { return n, i, true } -func twoPartAcl(tokens []string) RuleFunc { +func invalidAcl(acl string) error { + return fmt.Errorf("Invalid Acl: '%s'", acl) +} + +func twoPartAcl(acl string, tokens []string) (RuleFunc, error) { if tokens[1] == "any" { switch tokens[0] { case "permit": - return RulePermit + return RulePermit, nil case "deny": - return RuleDeny + return RuleDeny, nil } } - return nil + return nil, invalidAcl(acl) } diff --git a/filter/errors.go b/filter/errors.go index 7266d2b..b67b929 100644 --- a/filter/errors.go +++ b/filter/errors.go @@ -1,7 +1,7 @@ package filter import ( - "errors" + "errors" ) var ErrDenyAll = errors.New("Deny All") diff --git a/filter/rule_test.go b/filter/rule_test.go index 1f050ea..174131d 100644 --- a/filter/rule_test.go +++ b/filter/rule_test.go @@ -1,9 +1,15 @@ package filter import ( + "github.com/inverse-inc/wireguard-go/device" "testing" ) +var logger = device.NewLogger( + device.LogLevelSilent, + "(Testing)", +) + var rulePackets = [][]byte{ []byte{69, 0, 0, 60, 203, 131, 64, 0, 64, 6, 99, 226, 192, 168, 69, 3, 192, 168, 68, 2, 153, 222, 17, 92, 31, 232, 147, 213, 0, 0, 0, 0, 160, 2, 107, 208, 25, 66, 0, 0, 2, 4, 5, 100, 4, 2, 8, 10, 0, 153, 88, 86, 0, 0, 0, 0, 1, 3, 3, 7}, []byte{69, 0, 0, 52, 203, 132, 64, 0, 64, 6, 99, 233, 192, 168, 69, 3, 192, 168, 69, 2, 153, 222, 17, 92, 31, 232, 147, 214, 181, 110, 17, 81, 128, 16, 0, 216, 252, 127, 0, 0, 1, 1, 8, 10, 0, 153, 88, 88, 0, 151, 238, 205}, @@ -128,33 +134,33 @@ func TestPermitDstIpRule(t *testing.T) { } func TestPermitAny(t *testing.T) { - rules := AclsToRules("permit any") + rules := AclsToRules(logger, "permit any") if !rules.PassDefaultDeny(rulePackets[0]) { t.Error("permit any failed") } - rules = AclsToRules("permit 0.0.0.0 255.255.255.255") + rules = AclsToRules(logger, "permit 0.0.0.0 255.255.255.255") if !rules.PassDefaultDeny(rulePackets[0]) { t.Error("permit any failed") } } func TestPermitHost(t *testing.T) { - rules := AclsToRules("permit host 192.168.69.3") + rules := AclsToRules(logger, "permit host 192.168.69.3") if !rules.PassDefaultDeny(rulePackets[0]) { t.Error("permit host 192.169.68.3 failed") } } func TestDenyAny(t *testing.T) { - rules := AclsToRules("deny any") + rules := AclsToRules(logger, "deny any") if rules.PassDefaultPermit(rulePackets[0]) { t.Error("deny any failed") } } func TestPermitSrcDstPortProto(t *testing.T) { - rules := AclsToRules("permit tcp any any eq 80") + rules := AclsToRules(logger, "permit tcp any any eq 80") if rules.PassDefaultDeny(rulePackets[0]) { t.Error("permit tcp any any eq 80 failed") } @@ -168,7 +174,7 @@ func TestSimpleHost(t *testing.T) { } for _, acl := range acls { - rules := AclsToRules(acl) + rules := AclsToRules(logger, acl) if !rules.PassDefaultDeny(rulePackets[0]) { t.Errorf("acl '%s' failed", acl) } @@ -181,7 +187,7 @@ func TestIcmpPermitAny(t *testing.T) { } for _, acl := range acls { - rules := AclsToRules(acl) + rules := AclsToRules(logger, acl) if !rules.PassDefaultDeny(icmpPacket1) { t.Errorf("acl '%s' failed", acl) } @@ -199,7 +205,7 @@ func TestIcmpPermit(t *testing.T) { } for _, acl := range acls { - rules := AclsToRules(acl) + rules := AclsToRules(logger, acl) if !rules.PassDefaultDeny(icmpPacket1) { t.Errorf("acl '%s' failed", acl) } @@ -207,7 +213,7 @@ func TestIcmpPermit(t *testing.T) { } for _, acl := range acls { - rules := AclsToRules(acl) + rules := AclsToRules(logger, acl) if rules.PassDefaultDeny(icmpPacket2) { t.Errorf("acl '%s' passed", acl) } @@ -222,7 +228,7 @@ func TestIcmpDeny(t *testing.T) { } for _, acl := range acls { - rules := AclsToRules(acl) + rules := AclsToRules(logger, acl) if rules.PassDefaultDeny(icmpPacket1) { t.Errorf("acl '%s' passed", acl) } @@ -230,7 +236,7 @@ func TestIcmpDeny(t *testing.T) { } for _, acl := range acls { - rules := AclsToRules(acl) + rules := AclsToRules(logger, acl) if !rules.PassDefaultPermit(icmpPacket2) { t.Errorf("acl '%s' deny", acl) } diff --git a/main_shared.go b/main_shared.go index 93807bc..7235a4c 100644 --- a/main_shared.go +++ b/main_shared.go @@ -155,7 +155,7 @@ func startInverse(interfaceName string, device *device.Device) { preFilter = filter.BuildRBACFilter(ztn.APIClientCtx, ztn.APIClient, logger) } - filter := filter.AclsToRulesFilter(profile.ACLs, preFilter, nil) + filter := filter.AclsToRulesFilter(logger, profile.ACLs, preFilter, nil) device.SetReceiveFilter(filter) for _, peerID := range profile.AllowedPeers {