diff --git a/client/server/debug_linux.go b/client/server/debug_linux.go index 740a8fe97f1..60bc4056167 100644 --- a/client/server/debug_linux.go +++ b/client/server/debug_linux.go @@ -156,8 +156,6 @@ func collectNFTablesFromCommand() (string, error) { // collectNFTablesFromNetlink collects rules using netlink library func collectNFTablesFromNetlink() (string, error) { - var builder strings.Builder - conn, err := nftables.New() if err != nil { return "", fmt.Errorf("create nftables connection: %w", err) @@ -168,69 +166,29 @@ func collectNFTablesFromNetlink() (string, error) { return "", fmt.Errorf("list tables: %w", err) } - // Sort tables by family for consistent output - sort.Slice(tables, func(i, j int) bool { - if tables[i].Family != tables[j].Family { - return tables[i].Family < tables[j].Family - } - return tables[i].Name < tables[j].Name - }) + sortTables(tables) + return formatTables(conn, tables), nil +} + +func formatTables(conn *nftables.Conn, tables []*nftables.Table) string { + var builder strings.Builder for _, table := range tables { builder.WriteString(fmt.Sprintf("table %s %s {\n", formatFamily(table.Family), table.Name)) - chains, err := conn.ListChains() + chains, err := getAndSortTableChains(conn, table) if err != nil { log.Warnf("Failed to list chains for table %s: %v", table.Name, err) continue } - // Filter and sort chains for this table - var tableChains []*nftables.Chain + // Format chains for _, chain := range chains { - if chain.Table.Name == table.Name && chain.Table.Family == table.Family { - tableChains = append(tableChains, chain) - } + formatChain(conn, table, chain, &builder) } - sort.Slice(tableChains, func(i, j int) bool { - return tableChains[i].Name < tableChains[j].Name - }) - - for _, chain := range tableChains { - builder.WriteString(fmt.Sprintf("\tchain %s {\n", chain.Name)) - if chain.Type != "" { - var policy string - if chain.Policy != nil { - policy = fmt.Sprintf("; policy %s", formatPolicy(*chain.Policy)) - } - builder.WriteString(fmt.Sprintf("\t\ttype %s hook %s priority %d%s\n", - formatChainType(chain.Type), - formatChainHook(chain.Hooknum), - chain.Priority, - policy)) - } - rules, err := conn.GetRules(table, chain) - if err != nil { - log.Warnf("Failed to get rules for chain %s: %v", chain.Name, err) - continue - } - - // Sort rules by position for consistent output - sort.Slice(rules, func(i, j int) bool { - return rules[i].Position < rules[j].Position - }) - - for _, rule := range rules { - builder.WriteString(formatRule(rule)) - } - - builder.WriteString("\t}\n") - } - - // Add sets if any exist - sets, err := conn.GetSets(table) - if err != nil { + // Format sets + if sets, err := conn.GetSets(table); err != nil { log.Warnf("Failed to get sets for table %s: %v", table.Name, err) } else if len(sets) > 0 { builder.WriteString("\n") @@ -242,7 +200,66 @@ func collectNFTablesFromNetlink() (string, error) { builder.WriteString("}\n") } - return builder.String(), nil + return builder.String() +} + +func getAndSortTableChains(conn *nftables.Conn, table *nftables.Table) ([]*nftables.Chain, error) { + chains, err := conn.ListChains() + if err != nil { + return nil, err + } + + var tableChains []*nftables.Chain + for _, chain := range chains { + if chain.Table.Name == table.Name && chain.Table.Family == table.Family { + tableChains = append(tableChains, chain) + } + } + + sort.Slice(tableChains, func(i, j int) bool { + return tableChains[i].Name < tableChains[j].Name + }) + + return tableChains, nil +} + +func formatChain(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, builder *strings.Builder) { + builder.WriteString(fmt.Sprintf("\tchain %s {\n", chain.Name)) + + if chain.Type != "" { + var policy string + if chain.Policy != nil { + policy = fmt.Sprintf("; policy %s", formatPolicy(*chain.Policy)) + } + builder.WriteString(fmt.Sprintf("\t\ttype %s hook %s priority %d%s\n", + formatChainType(chain.Type), + formatChainHook(chain.Hooknum), + chain.Priority, + policy)) + } + + rules, err := conn.GetRules(table, chain) + if err != nil { + log.Warnf("Failed to get rules for chain %s: %v", chain.Name, err) + } else { + sort.Slice(rules, func(i, j int) bool { + return rules[i].Position < rules[j].Position + }) + for _, rule := range rules { + builder.WriteString(formatRule(rule)) + } + } + + builder.WriteString("\t}\n") +} + +func sortTables(tables []*nftables.Table) { + sort.Slice(tables, func(i, j int) bool { + if tables[i].Family != tables[j].Family { + return tables[i].Family < tables[j].Family + } + return tables[i].Name < tables[j].Name + }) } func formatFamily(family nftables.TableFamily) string { @@ -308,57 +325,61 @@ func formatPolicy(policy nftables.ChainPolicy) string { } } -// formatRule formats a rule in nft-like syntax func formatRule(rule *nftables.Rule) string { var builder strings.Builder builder.WriteString("\t\t") - // Process expressions in sequence for i := 0; i < len(rule.Exprs); i++ { if i > 0 { builder.WriteString(" ") } + i = formatExprSequence(&builder, rule.Exprs, i) + } - exp := rule.Exprs[i] - - // Look ahead for special sequences - if meta, ok := exp.(*expr.Meta); ok && i+1 < len(rule.Exprs) { - if cmp, ok := rule.Exprs[i+1].(*expr.Cmp); ok { - // Meta + Cmp sequence - switch meta.Key { - case expr.MetaKeyIIFNAME: - name := strings.TrimRight(string(cmp.Data), "\x00") - builder.WriteString(fmt.Sprintf("iifname %s %q", formatCmpOp(cmp.Op), name)) - case expr.MetaKeyOIFNAME: - name := strings.TrimRight(string(cmp.Data), "\x00") - builder.WriteString(fmt.Sprintf("oifname %s %q", formatCmpOp(cmp.Op), name)) - case expr.MetaKeyMARK: - if len(cmp.Data) == 4 { - val := binary.BigEndian.Uint32(cmp.Data) - builder.WriteString(fmt.Sprintf("meta mark %s 0x%x", formatCmpOp(cmp.Op), val)) - } - default: - builder.WriteString(formatExpr(exp)) - } - i++ // Skip the next expression since we handled it - continue - } - } + builder.WriteString("\n") + return builder.String() +} + +func formatExprSequence(builder *strings.Builder, exprs []expr.Any, i int) int { + curr := exprs[i] - // Look ahead for Payload + Cmp sequences - if payload, ok := exp.(*expr.Payload); ok && i+1 < len(rule.Exprs) { - if cmp, ok := rule.Exprs[i+1].(*expr.Cmp); ok { - builder.WriteString(formatPayloadWithCmp(payload, cmp)) - i++ // Skip the next expression - continue + // Handle Meta + Cmp sequence + if meta, ok := curr.(*expr.Meta); ok && i+1 < len(exprs) { + if cmp, ok := exprs[i+1].(*expr.Cmp); ok { + if formatted := formatMetaWithCmp(meta, cmp); formatted != "" { + builder.WriteString(formatted) + return i + 1 } } + } - builder.WriteString(formatExpr(exp)) + // Handle Payload + Cmp sequence + if payload, ok := curr.(*expr.Payload); ok && i+1 < len(exprs) { + if cmp, ok := exprs[i+1].(*expr.Cmp); ok { + builder.WriteString(formatPayloadWithCmp(payload, cmp)) + return i + 1 + } } - builder.WriteString("\n") - return builder.String() + builder.WriteString(formatExpr(curr)) + return i +} + +func formatMetaWithCmp(meta *expr.Meta, cmp *expr.Cmp) string { + switch meta.Key { + case expr.MetaKeyIIFNAME: + name := strings.TrimRight(string(cmp.Data), "\x00") + return fmt.Sprintf("iifname %s %q", formatCmpOp(cmp.Op), name) + case expr.MetaKeyOIFNAME: + name := strings.TrimRight(string(cmp.Data), "\x00") + return fmt.Sprintf("oifname %s %q", formatCmpOp(cmp.Op), name) + case expr.MetaKeyMARK: + if len(cmp.Data) == 4 { + val := binary.BigEndian.Uint32(cmp.Data) + return fmt.Sprintf("meta mark %s 0x%x", formatCmpOp(cmp.Op), val) + } + } + return "" } func formatPayloadWithCmp(p *expr.Payload, cmp *expr.Cmp) string {