Skip to content

Commit

Permalink
Reduce complexity
Browse files Browse the repository at this point in the history
  • Loading branch information
lixmal committed Dec 20, 2024
1 parent 35cf423 commit 10a4793
Showing 1 changed file with 111 additions and 90 deletions.
201 changes: 111 additions & 90 deletions client/server/debug_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,6 @@ func collectNFTablesFromCommand() (string, error) {

// collectNFTablesFromNetlink collects rules using netlink library
func collectNFTablesFromNetlink() (string, error) {
var builder strings.Builder

conn, err := nftables.New()
if err != nil {
return "", fmt.Errorf("create nftables connection: %w", err)
Expand All @@ -168,69 +166,29 @@ func collectNFTablesFromNetlink() (string, error) {
return "", fmt.Errorf("list tables: %w", err)
}

// Sort tables by family for consistent output
sort.Slice(tables, func(i, j int) bool {
if tables[i].Family != tables[j].Family {
return tables[i].Family < tables[j].Family
}
return tables[i].Name < tables[j].Name
})
sortTables(tables)
return formatTables(conn, tables), nil
}

func formatTables(conn *nftables.Conn, tables []*nftables.Table) string {
var builder strings.Builder

for _, table := range tables {
builder.WriteString(fmt.Sprintf("table %s %s {\n", formatFamily(table.Family), table.Name))

chains, err := conn.ListChains()
chains, err := getAndSortTableChains(conn, table)
if err != nil {
log.Warnf("Failed to list chains for table %s: %v", table.Name, err)
continue
}

// Filter and sort chains for this table
var tableChains []*nftables.Chain
// Format chains
for _, chain := range chains {
if chain.Table.Name == table.Name && chain.Table.Family == table.Family {
tableChains = append(tableChains, chain)
}
formatChain(conn, table, chain, &builder)
}
sort.Slice(tableChains, func(i, j int) bool {
return tableChains[i].Name < tableChains[j].Name
})

for _, chain := range tableChains {
builder.WriteString(fmt.Sprintf("\tchain %s {\n", chain.Name))
if chain.Type != "" {
var policy string
if chain.Policy != nil {
policy = fmt.Sprintf("; policy %s", formatPolicy(*chain.Policy))
}
builder.WriteString(fmt.Sprintf("\t\ttype %s hook %s priority %d%s\n",
formatChainType(chain.Type),
formatChainHook(chain.Hooknum),
chain.Priority,
policy))
}

rules, err := conn.GetRules(table, chain)
if err != nil {
log.Warnf("Failed to get rules for chain %s: %v", chain.Name, err)
continue
}

// Sort rules by position for consistent output
sort.Slice(rules, func(i, j int) bool {
return rules[i].Position < rules[j].Position
})

for _, rule := range rules {
builder.WriteString(formatRule(rule))
}

builder.WriteString("\t}\n")
}

// Add sets if any exist
sets, err := conn.GetSets(table)
if err != nil {
// Format sets
if sets, err := conn.GetSets(table); err != nil {
log.Warnf("Failed to get sets for table %s: %v", table.Name, err)
} else if len(sets) > 0 {
builder.WriteString("\n")
Expand All @@ -242,7 +200,66 @@ func collectNFTablesFromNetlink() (string, error) {
builder.WriteString("}\n")
}

return builder.String(), nil
return builder.String()
}

func getAndSortTableChains(conn *nftables.Conn, table *nftables.Table) ([]*nftables.Chain, error) {
chains, err := conn.ListChains()
if err != nil {
return nil, err
}

var tableChains []*nftables.Chain
for _, chain := range chains {
if chain.Table.Name == table.Name && chain.Table.Family == table.Family {
tableChains = append(tableChains, chain)
}
}

sort.Slice(tableChains, func(i, j int) bool {
return tableChains[i].Name < tableChains[j].Name
})

return tableChains, nil
}

func formatChain(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, builder *strings.Builder) {
builder.WriteString(fmt.Sprintf("\tchain %s {\n", chain.Name))

if chain.Type != "" {
var policy string
if chain.Policy != nil {
policy = fmt.Sprintf("; policy %s", formatPolicy(*chain.Policy))
}
builder.WriteString(fmt.Sprintf("\t\ttype %s hook %s priority %d%s\n",
formatChainType(chain.Type),
formatChainHook(chain.Hooknum),
chain.Priority,
policy))
}

rules, err := conn.GetRules(table, chain)
if err != nil {
log.Warnf("Failed to get rules for chain %s: %v", chain.Name, err)
} else {
sort.Slice(rules, func(i, j int) bool {
return rules[i].Position < rules[j].Position
})
for _, rule := range rules {
builder.WriteString(formatRule(rule))
}
}

builder.WriteString("\t}\n")
}

func sortTables(tables []*nftables.Table) {
sort.Slice(tables, func(i, j int) bool {
if tables[i].Family != tables[j].Family {
return tables[i].Family < tables[j].Family
}
return tables[i].Name < tables[j].Name
})
}

func formatFamily(family nftables.TableFamily) string {
Expand Down Expand Up @@ -308,57 +325,61 @@ func formatPolicy(policy nftables.ChainPolicy) string {
}
}

// formatRule formats a rule in nft-like syntax
func formatRule(rule *nftables.Rule) string {
var builder strings.Builder
builder.WriteString("\t\t")

// Process expressions in sequence
for i := 0; i < len(rule.Exprs); i++ {
if i > 0 {
builder.WriteString(" ")
}
i = formatExprSequence(&builder, rule.Exprs, i)
}

exp := rule.Exprs[i]

// Look ahead for special sequences
if meta, ok := exp.(*expr.Meta); ok && i+1 < len(rule.Exprs) {
if cmp, ok := rule.Exprs[i+1].(*expr.Cmp); ok {
// Meta + Cmp sequence
switch meta.Key {
case expr.MetaKeyIIFNAME:
name := strings.TrimRight(string(cmp.Data), "\x00")
builder.WriteString(fmt.Sprintf("iifname %s %q", formatCmpOp(cmp.Op), name))
case expr.MetaKeyOIFNAME:
name := strings.TrimRight(string(cmp.Data), "\x00")
builder.WriteString(fmt.Sprintf("oifname %s %q", formatCmpOp(cmp.Op), name))
case expr.MetaKeyMARK:
if len(cmp.Data) == 4 {
val := binary.BigEndian.Uint32(cmp.Data)
builder.WriteString(fmt.Sprintf("meta mark %s 0x%x", formatCmpOp(cmp.Op), val))
}
default:
builder.WriteString(formatExpr(exp))
}
i++ // Skip the next expression since we handled it
continue
}
}
builder.WriteString("\n")
return builder.String()
}

func formatExprSequence(builder *strings.Builder, exprs []expr.Any, i int) int {
curr := exprs[i]

// Look ahead for Payload + Cmp sequences
if payload, ok := exp.(*expr.Payload); ok && i+1 < len(rule.Exprs) {
if cmp, ok := rule.Exprs[i+1].(*expr.Cmp); ok {
builder.WriteString(formatPayloadWithCmp(payload, cmp))
i++ // Skip the next expression
continue
// Handle Meta + Cmp sequence
if meta, ok := curr.(*expr.Meta); ok && i+1 < len(exprs) {
if cmp, ok := exprs[i+1].(*expr.Cmp); ok {
if formatted := formatMetaWithCmp(meta, cmp); formatted != "" {
builder.WriteString(formatted)
return i + 1
}
}
}

builder.WriteString(formatExpr(exp))
// Handle Payload + Cmp sequence
if payload, ok := curr.(*expr.Payload); ok && i+1 < len(exprs) {
if cmp, ok := exprs[i+1].(*expr.Cmp); ok {
builder.WriteString(formatPayloadWithCmp(payload, cmp))
return i + 1
}
}

builder.WriteString("\n")
return builder.String()
builder.WriteString(formatExpr(curr))
return i
}

func formatMetaWithCmp(meta *expr.Meta, cmp *expr.Cmp) string {
switch meta.Key {
case expr.MetaKeyIIFNAME:
name := strings.TrimRight(string(cmp.Data), "\x00")
return fmt.Sprintf("iifname %s %q", formatCmpOp(cmp.Op), name)
case expr.MetaKeyOIFNAME:
name := strings.TrimRight(string(cmp.Data), "\x00")
return fmt.Sprintf("oifname %s %q", formatCmpOp(cmp.Op), name)
case expr.MetaKeyMARK:
if len(cmp.Data) == 4 {
val := binary.BigEndian.Uint32(cmp.Data)
return fmt.Sprintf("meta mark %s 0x%x", formatCmpOp(cmp.Op), val)
}
}
return ""
}

func formatPayloadWithCmp(p *expr.Payload, cmp *expr.Cmp) string {
Expand Down

0 comments on commit 10a4793

Please sign in to comment.