From 7c425c11f866a7eeb157710422fd1f63184f36c5 Mon Sep 17 00:00:00 2001 From: Ryan Diers <39590744+radsec@users.noreply.github.com> Date: Fri, 29 Mar 2024 21:51:49 -0700 Subject: [PATCH] [Feature] Update to include latest Santa rule types of `TeamID` and `SigningID` support (#44) --- examples/sample-rules.csv | 2 +- examples/sample-rules2.json | 4 +- internal/cli/flags/client_mode.go | 6 + internal/cli/flags/rule_info.go | 20 +- internal/cli/flags/rule_type.go | 24 +- internal/cli/flags/target.go | 2 +- internal/cli/rule/rule-allow.go | 2 +- internal/cli/rule/rule-common.go | 37 ++- internal/cli/rule/rule-remove.go | 54 +++- internal/cli/rule/rule-update.go | 57 +++-- internal/cli/rules/rules-export.go | 10 +- internal/cli/rules/rules-import.go | 25 +- internal/cli/rules/rules.go | 11 +- internal/cli/santa_sensor/santa_fileinfo.go | 2 + internal/handlers/ruledownload/response.go | 10 +- pkg/model/feedrules/create.go | 16 +- pkg/model/feedrules/model.go | 14 ++ pkg/model/feedrules/query.go | 19 +- pkg/model/globalrules/add.go | 34 +-- pkg/model/globalrules/get.go | 17 +- pkg/model/globalrules/model.go | 4 +- pkg/model/globalrules/query.go | 19 +- pkg/model/globalrules/remove.go | 13 +- pkg/model/globalrules/update.go | 22 +- pkg/model/machinerules/add.go | 51 ++-- pkg/model/machinerules/get.go | 7 +- pkg/model/machinerules/model.go | 4 +- pkg/model/machinerules/services.go | 26 +- pkg/model/machinerules/services_test.go | 16 +- pkg/model/rules/constants.go | 2 + pkg/model/rules/primary_key.go | 16 +- pkg/model/rules/rule.go | 9 +- pkg/model/syncstate/get_test.go | 3 +- pkg/types/data_type.go | 58 ++++- pkg/types/data_type_test.go | 266 ++++++++------------ pkg/types/policy.go | 34 +-- pkg/types/policy_test.go | 38 ++- pkg/types/rule_type.go | 24 +- pkg/types/rule_type_test.go | 35 ++- 39 files changed, 598 insertions(+), 415 deletions(-) diff --git a/examples/sample-rules.csv b/examples/sample-rules.csv index 4e0a7e4..8eb6363 100644 --- a/examples/sample-rules.csv +++ b/examples/sample-rules.csv @@ -1,4 +1,4 @@ -sha256,type,policy,custom_msg,description +identifier,type,policy,custom_msg,description d84db96af8c2e60ac4c851a21ec460f6f84e0235beb17d24a78712b9b021ed57,CERTIFICATE,ALLOWLIST,,"Software Signing by Apple Inc." d292f56f78effeb715382f3578b3716309da04e31589b23b68c3750edd526660,CERTIFICATE,ALLOWLIST,,"Developer ID Application: Zoom Video Communications, Inc. (BJ4HAAB9B3)" 96f18e09d65445985c7df5df74ef152a0bc42e8934175a626180d9700c343e7b,CERTIFICATE,ALLOWLIST,,"Developer ID Application: Mozilla Corporation (43AQ936H96)" diff --git a/examples/sample-rules2.json b/examples/sample-rules2.json index 9621f8c..07525c0 100644 --- a/examples/sample-rules2.json +++ b/examples/sample-rules2.json @@ -1,13 +1,13 @@ [ { - "sha256": "d84db96af8c2e60ac4c851a21ec460f6f84e0235beb17d24a78712b9b021ed57", + "identifier": "d84db96af8c2e60ac4c851a21ec460f6f84e0235beb17d24a78712b9b021ed57", "type": "CERTIFICATE", "policy": "ALLOWLIST", "custom_msg": "", "description": "Software Signing by Apple Inc." }, { - "sha256": "345a8e098bd04794aaeefda8c9ef56a0bf3d3706d67d35bc0e23f11bb3bffce5", + "identifier": "345a8e098bd04794aaeefda8c9ef56a0bf3d3706d67d35bc0e23f11bb3bffce5", "type": "CERTIFICATE", "policy": "ALLOWLIST", "custom_msg": "", diff --git a/internal/cli/flags/client_mode.go b/internal/cli/flags/client_mode.go index c89a14e..9463631 100644 --- a/internal/cli/flags/client_mode.go +++ b/internal/cli/flags/client_mode.go @@ -6,6 +6,12 @@ import ( "github.com/airbnb/rudolph/pkg/types" ) +const ( + monitorMode = "monitor" + lockdownMode = "lockdown" + defaultClientMode = "monitor" +) + // configMode is a custom type for use as a CLI flag representing the type of config mode being applied type ClientMode types.ClientMode diff --git a/internal/cli/flags/rule_info.go b/internal/cli/flags/rule_info.go index e4c0a28..3538d9c 100644 --- a/internal/cli/flags/rule_info.go +++ b/internal/cli/flags/rule_info.go @@ -3,24 +3,24 @@ package flags import "github.com/spf13/cobra" type RuleInfoFlags struct { - RuleType *RuleType - SHA256 *string - FilePath *string + RuleType *RuleType + Identifier *string + FilePath *string } func (r *RuleInfoFlags) AddRuleInfoFlags(cmd *cobra.Command) { var ( - ruleTypeArg RuleType - sha256Arg string - filepathArg string + ruleTypeArg RuleType + identifierArg string + filepathArg string ) // Flag specifying the binary - cmd.Flags().StringVarP(&filepathArg, "filepath", "f", "", `The filepath of a binary. Provide exactly one of [--filepath|--sha]`) - cmd.Flags().StringVarP(&sha256Arg, "sha", "s", "", `The sha256 of a file`) + cmd.Flags().StringVarP(&filepathArg, "filepath", "f", "", `The filepath of a binary/application. Provide exactly one of [--filepath|--sha]`) + cmd.Flags().StringVarP(&identifierArg, "identifier", "i", "", `The Identifier/SHA256 for a file, application, teamID, or signingID`) // rule-type should be one of "binary" or "cert" ("bin" and "certificate" also work) - cmd.Flags().VarP(&ruleTypeArg, "rule-type", "t", `type of rule being applied. valid options are: "binary", "bin", "certificate", or "cert"`) + cmd.Flags().VarP(&ruleTypeArg, "rule-type", "t", `type of rule being applied. valid options are: "binary", "bin", "certificate", "cert", "teamid", "signingid"`) _ = cmd.MarkFlagRequired("rule-type") // If we want to make the `rule-type` flag optional with a default (say "binary"), @@ -30,6 +30,6 @@ func (r *RuleInfoFlags) AddRuleInfoFlags(cmd *cobra.Command) { // rule-policy is to specify the policy for edit commands r.RuleType = &ruleTypeArg - r.SHA256 = &sha256Arg + r.Identifier = &identifierArg r.FilePath = &filepathArg } diff --git a/internal/cli/flags/rule_type.go b/internal/cli/flags/rule_type.go index 33b44f1..e2f053e 100644 --- a/internal/cli/flags/rule_type.go +++ b/internal/cli/flags/rule_type.go @@ -2,18 +2,18 @@ package flags import ( "fmt" + "strings" "github.com/airbnb/rudolph/pkg/types" ) const ( - binType = "binary" - binTypeShort = "bin" - certType = "certificate" - certTypeShort = "cert" - monitorMode = "monitor" - lockdownMode = "lockdown" - defaultClientMode = "monitor" + binType = "binary" + binTypeShort = "bin" + certType = "certificate" + certTypeShort = "cert" + teamIDType = "teamid" + signingIDType = "signingid" ) // ruleType is a custom type for use as a CLI flag representing the type of rule being applied @@ -34,11 +34,15 @@ func (i *RuleType) AsRuleType() types.RuleType { } func (i *RuleType) Set(s string) error { - switch s { + switch strings.ToLower(s) { case binType, binTypeShort: *i = RuleType(types.RuleTypeBinary) case certType, certTypeShort: *i = RuleType(types.RuleTypeCertificate) + case teamIDType: + *i = RuleType(types.RuleTypeTeamID) + case signingIDType: + *i = RuleType(types.RuleTypeSigningID) default: return fmt.Errorf(`invalid rule type; must be "binary" or "cert"`) } @@ -56,6 +60,10 @@ func (i *RuleType) String() string { return binType case types.RuleTypeCertificate: return certType + case types.RuleTypeTeamID: + return teamIDType + case types.RuleTypeSigningID: + return signingIDType } // No default diff --git a/internal/cli/flags/target.go b/internal/cli/flags/target.go index 703ecee..bf10673 100644 --- a/internal/cli/flags/target.go +++ b/internal/cli/flags/target.go @@ -28,7 +28,7 @@ func (t *TargetFlags) AddTargetFlags(cmd *cobra.Command) { } func (t *TargetFlags) AddTargetFlagsRules(cmd *cobra.Command) { - cmd.Flags().BoolVarP(&t.IsGlobal, "global", "g", false, "Retrive rules that apply globally.") + cmd.Flags().BoolVarP(&t.IsGlobal, "global", "g", false, "Retrieve rules that apply globally.") cmd.Flags().StringVarP(&t.MachineID, "machine", "m", "", "Retrieve rules for a single machine. Omit to apply to the current machine.") } diff --git a/internal/cli/rule/rule-allow.go b/internal/cli/rule/rule-allow.go index dd8b902..24e98b5 100644 --- a/internal/cli/rule/rule-allow.go +++ b/internal/cli/rule/rule-allow.go @@ -13,7 +13,7 @@ func init() { rf := flags.RuleInfoFlags{} var ruleAllowCmd = &cobra.Command{ - Use: "allow [-f |-s ] -t [-m |--global]", + Use: "allow [-f |-i ] -t [-m |--global]", Short: "Create a rule that applies the Allowlist policy to the specified file", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { diff --git a/internal/cli/rule/rule-common.go b/internal/cli/rule/rule-common.go index 439e846..a7afeea 100644 --- a/internal/cli/rule/rule-common.go +++ b/internal/cli/rule/rule-common.go @@ -3,6 +3,7 @@ package rule import ( "bufio" "fmt" + "log" "os" "strings" "time" @@ -28,19 +29,23 @@ var ( ) func applyPolicyForPath(timeProvider clock.TimeProvider, client dynamodb.DynamoDBClient, policy types.Policy, tf flags.TargetFlags, rf flags.RuleInfoFlags) (err error) { - // Second, determine the rule type and sha256 + // Second, determine the rule type and identifier ruleType := (*rf.RuleType).AsRuleType() var description string - var sha256 string + var identifier string + if *rf.FilePath != "" { fileInfo, err := santa_sensor.RunSantaFileInfo(*rf.FilePath) if err != nil { return fmt.Errorf("encountered an error while attempting to get file information for %q", *rf.FilePath) } - - sha256 = fileInfo.SHA256 + identifier = fileInfo.SHA256 description = fmt.Sprintf("%s from %s", fileInfo.Path, tf.SelfMachineID) // FIXME (derek.wang) tf.SelfMachineID is Not initialized. - if ruleType == types.Certificate { + + switch ruleType { + case types.RuleTypeBinary: + break + case types.RuleTypeCertificate: if len(fileInfo.SigningChain) == 0 { return fmt.Errorf("NO SIGNING INFO FOUND FOR GIVEN BINARY") } @@ -51,12 +56,18 @@ func applyPolicyForPath(timeProvider clock.TimeProvider, client dynamodb.DynamoD return fmt.Errorf("NO CERTIFICATE NAME FOUND FOR GIVEN BINARY") } - sha256 = fileInfo.SigningChain[0].SHA256 + identifier = fileInfo.SigningChain[0].SHA256 description = fmt.Sprintf("%v, by %v (%v)", fileInfo.SigningChain[0].CommonName, fileInfo.SigningChain[0].Organization, fileInfo.SigningChain[0].OrganizationalUnit) + case types.RuleTypeTeamID: + identifier = fileInfo.TeamID + case types.RuleTypeSigningID: + identifier = fileInfo.SigningID + default: + log.Printf("error (recovered): encountered unknown ruleType: (%+v)", ruleType) + return fmt.Errorf("error (recovered): encountered unknown ruleType: (%+v)", ruleType) } - - } else if *rf.SHA256 != "" { - sha256 = *rf.SHA256 + } else if *rf.Identifier != "" { + identifier = *rf.Identifier } // TODO @@ -93,7 +104,7 @@ func applyPolicyForPath(timeProvider clock.TimeProvider, client dynamodb.DynamoD fmt.Println("Uploading the following rule:") fmt.Println(" MachineID: ", machineID, suffix) - fmt.Println(" SHA256: ", sha256) + fmt.Println(" Identifier/SHA256: ", identifier) fmt.Println(" Policy: ", policy, " (", string(policyDescription), ")") fmt.Println(" RuleType: ", ruleType, " (", string(ruleTypeDescription), ")") fmt.Println(" Description: ", description) @@ -105,13 +116,13 @@ func applyPolicyForPath(timeProvider clock.TimeProvider, client dynamodb.DynamoD reader := bufio.NewReader(os.Stdin) text, _ := reader.ReadString('\n') text = strings.Replace(text, "\n", "", -1) - if text == "ok" || text == "yes" { + if strings.ToLower(text) == "ok" || strings.ToLower(text) == "yes" { // Do rule creation if tf.IsGlobal { - err = globalrules.AddNewGlobalRule(timeProvider, client, sha256, ruleType, policy, description) + err = globalrules.AddNewGlobalRule(timeProvider, client, identifier, ruleType, policy, description) } else { expires := timeProvider.Now().Add(time.Hour * machinerules.MachineRuleDefaultExpirationHours).UTC() - err = machinerules.AddNewMachineRule(client, machineID, sha256, ruleType, policy, description, expires) + err = machinerules.AddNewMachineRule(client, machineID, identifier, ruleType, policy, description, expires) } if err != nil { return fmt.Errorf("could not upload rule to DynamoDB: %w", err) diff --git a/internal/cli/rule/rule-remove.go b/internal/cli/rule/rule-remove.go index 0160a9f..b5ed94a 100644 --- a/internal/cli/rule/rule-remove.go +++ b/internal/cli/rule/rule-remove.go @@ -1,7 +1,10 @@ package rule import ( + "bufio" "fmt" + "os" + "strings" "github.com/airbnb/rudolph/internal/cli/flags" "github.com/airbnb/rudolph/pkg/clock" @@ -18,9 +21,10 @@ func init() { tf := flags.TargetFlags{} var removeRuleCmd = &cobra.Command{ - Use: "remove ", + Use: `remove ex: 'TeamID#1234567'`, Aliases: []string{"delete"}, Short: "Removes/deletes a rule from the backing store", + Long: ` | # | 'TeamID#1234567'`, Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { region, _ := cmd.Flags().GetString("region") @@ -47,15 +51,51 @@ func init() { } func removeRule(globalRemover globalrules.RuleRemovalService, machineRuleRemover machinerules.RuleRemovalService, ruleName string, tf flags.TargetFlags) error { - if !tf.IsGlobal { - machineID, err := tf.GetMachineID() + // First, determine which machine to apply + var machineID string + if !tf.IsGlobal || tf.IsTargetSelf() { + var err error + machineID, err = tf.GetMachineID() if err != nil { - return fmt.Errorf("failed to get MachineID: %v", err) + return fmt.Errorf("failed to get MachineID: %w", err) } - return machineRuleRemover.RemoveMachineRule(machineID, ruleName) } - idempotencyKey := uuid.NewString() + fmt.Println("Removing the following rule:") + if machineID != "" { + fmt.Println(" MachineID: ", machineID) + } + fmt.Println(" Identifier/SHA256: ", ruleName) + fmt.Println("") + fmt.Println(`Apply changes? (Enter: "yes" or "ok")`) + fmt.Print("> ") + + // Read confirmation + reader := bufio.NewReader(os.Stdin) + text, _ := reader.ReadString('\n') + text = strings.Replace(text, "\n", "", -1) + if strings.ToLower(text) == "ok" || strings.ToLower(text) == "yes" { + // Do rule deletion + if !tf.IsGlobal { + machineID, err := tf.GetMachineID() + if err != nil { + return fmt.Errorf("failed to get MachineID: %v", err) + } + return machineRuleRemover.RemoveMachineRule(machineID, ruleName) + } + + idempotencyKey := uuid.NewString() + + err := globalRemover.RemoveGlobalRule(ruleName, idempotencyKey) + if err != nil { + return fmt.Errorf("failed to remove global rule: %v", err) + } + + fmt.Println("Successfully sent a rule to dynamodb") + } else { + fmt.Println("Well ok then") + } + fmt.Println("") - return globalRemover.RemoveGlobalRule(ruleName, idempotencyKey) + return nil } diff --git a/internal/cli/rule/rule-update.go b/internal/cli/rule/rule-update.go index 189cf13..0412cf8 100644 --- a/internal/cli/rule/rule-update.go +++ b/internal/cli/rule/rule-update.go @@ -4,6 +4,7 @@ import ( "bufio" "errors" "fmt" + "log" "os" "strings" @@ -70,16 +71,20 @@ func (rh *ruleHandler) updateRulePolicy(tf flags.TargetFlags, rf flags.RuleInfoF ruleType := rf.RuleType.AsRuleType() rulePolicy := ru.RulePolicy.AsRulePolicy() var description string - var sha256 string - if *rf.FilePath != "" { - fileInfo, err := santa_sensor.RunSantaFileInfo(*rf.FilePath) - if err != nil { - return fmt.Errorf("encountered an error while attempting to get file information for %q", *rf.FilePath) - } + var identifier string + + fileInfo, err := santa_sensor.RunSantaFileInfo(*rf.FilePath) + if err != nil { + return fmt.Errorf("encountered an error while attempting to get file information for %q", *rf.FilePath) + } - sha256 = fileInfo.SHA256 + if *rf.FilePath != "" { + identifier = fileInfo.SHA256 description = fmt.Sprintf("%s from %s", fileInfo.Path, tf.SelfMachineID) // FIXME (derek.wang) tf.SelfMachineID is Not initialized. - if ruleType == types.Certificate { + switch ruleType { + case types.RuleTypeBinary: + break + case types.RuleTypeCertificate: if len(fileInfo.SigningChain) == 0 { return fmt.Errorf("NO SIGNING INFO FOUND FOR GIVEN BINARY") } @@ -90,14 +95,28 @@ func (rh *ruleHandler) updateRulePolicy(tf flags.TargetFlags, rf flags.RuleInfoF return fmt.Errorf("NO CERTIFICATE NAME FOUND FOR GIVEN BINARY") } - sha256 = fileInfo.SigningChain[0].SHA256 + identifier = fileInfo.SigningChain[0].SHA256 description = fmt.Sprintf("%v, by %v (%v)", fileInfo.SigningChain[0].CommonName, fileInfo.SigningChain[0].Organization, fileInfo.SigningChain[0].OrganizationalUnit) + case types.RuleTypeTeamID: + identifier = fileInfo.TeamID + case types.RuleTypeSigningID: + identifier = fileInfo.SigningID + default: + log.Printf("error (recovered): encountered unknown ruleType: (%+v)", ruleType) + return fmt.Errorf("error (recovered): encountered unknown ruleType: (%+v)", ruleType) } - - } else if *rf.SHA256 != "" { - sha256 = *rf.SHA256 + } else if *rf.Identifier != "" { + identifier = *rf.Identifier } + // TODO + // Query if there is an existing rule: and show the before/after + // partitionKey := fmt.Sprintf("%s%s", store.MachineRulesPKPrefix, machineID) + // ruleName := + // if ruleType == types.RuleTypeTeamID { + // ruleName = fmt.Sprintf("%s%s", store.TeamRulesPKPrefix, identifier) + // } else if *rf.SHA256 != "" { + rulePolicyDescription, err := rulePolicy.MarshalText() if err != nil { return @@ -124,7 +143,7 @@ func (rh *ruleHandler) updateRulePolicy(tf flags.TargetFlags, rf flags.RuleInfoF // Query if there is an existing rule: and show the before/after if machineID == "(Global)" { - existingItem, err := globalrules.GetGlobalRuleByShaType(rh.dynamodbClient, sha256, ruleType) + existingItem, err := globalrules.GetGlobalRuleByShaType(rh.dynamodbClient, identifier, ruleType) if err != nil { return err } @@ -146,12 +165,12 @@ func (rh *ruleHandler) updateRulePolicy(tf flags.TargetFlags, rf flags.RuleInfoF fmt.Println("The current rule is rule:") fmt.Println(" MachineID: ", machineID, suffix) - fmt.Println(" SHA256: ", existingItem.SHA256) + fmt.Println(" Identifier/SHA256: ", existingItem.Identifier) fmt.Println(" Policy: ", existingItem.Policy, " (", string(rulePolicyDescription), ")") fmt.Println(" RuleType: ", existingItem.RuleType, " (", string(ruleTypeDescription), ")") fmt.Println(" Description: ", description) } else { - existingItem, err := machinerules.GetMachineRuleByShaType(rh.dynamodbClient, machineID, sha256, ruleType) + existingItem, err := machinerules.GetMachineRuleByShaType(rh.dynamodbClient, machineID, identifier, ruleType) if err != nil { return err } @@ -173,7 +192,7 @@ func (rh *ruleHandler) updateRulePolicy(tf flags.TargetFlags, rf flags.RuleInfoF fmt.Println("The current rule is rule:") fmt.Println(" MachineID: ", machineID, suffix) - fmt.Println(" SHA256: ", existingItem.SHA256) + fmt.Println(" Identifier/SHA256: ", existingItem.Identifier) fmt.Println(" Policy: ", existingItem.Policy, " (", string(rulePolicyDescription), ")") fmt.Println(" RuleType: ", existingItem.RuleType, " (", string(ruleTypeDescription), ")") fmt.Println(" Description: ", description) @@ -181,7 +200,7 @@ func (rh *ruleHandler) updateRulePolicy(tf flags.TargetFlags, rf flags.RuleInfoF fmt.Println("Updating the rule to the following:") fmt.Println(" MachineID: ", machineID, suffix) - fmt.Println(" SHA256: ", sha256) + fmt.Println(" Identifier/SHA256: ", identifier) fmt.Println(" Policy: ", rulePolicy, " (", string(rulePolicyDescription), ")") fmt.Println(" RuleType: ", ruleType, " (", string(ruleTypeDescription), ")") fmt.Println(" Description: ", description) @@ -196,9 +215,9 @@ func (rh *ruleHandler) updateRulePolicy(tf flags.TargetFlags, rf flags.RuleInfoF if text == "ok" || text == "yes" { // Do rule update if tf.IsGlobal { - err = rh.globalRuleUpdater.UpdateGlobalRule(sha256, ruleType, rulePolicy) + err = rh.globalRuleUpdater.UpdateGlobalRule(identifier, ruleType, rulePolicy) } else { - err = rh.machineRuleUpdater.UpdateMachineRulePolicy(machineID, sha256, ruleType, rulePolicy) + err = rh.machineRuleUpdater.UpdateMachineRulePolicy(machineID, identifier, ruleType, rulePolicy) } if err != nil { return fmt.Errorf("could not upload rule to dynamodb: %w", err) diff --git a/internal/cli/rules/rules-export.go b/internal/cli/rules/rules-export.go index 915fb5d..caa911b 100644 --- a/internal/cli/rules/rules-export.go +++ b/internal/cli/rules/rules-export.go @@ -3,7 +3,7 @@ package rules import ( "encoding/json" "fmt" - "io/ioutil" + "os" "github.com/spf13/cobra" @@ -57,7 +57,7 @@ func runExport( type fileRule struct { RuleType types.RuleType `json:"type"` Policy types.Policy `json:"policy"` - SHA256 string `json:"sha256"` + Identifier string `json:"identifier"` CustomMessage string `json:"custom_msg,omitempty"` Description string `json:"description"` } @@ -67,7 +67,7 @@ func runJsonExport(client dynamodb.QueryAPI, filename string) (err error) { fmt.Println("Querying rules from DynamoDB...") total, err := getRules(client, func(rule globalrules.GlobalRuleRow) (err error) { jsonRules = append(jsonRules, fileRule{ - SHA256: rule.SHA256, + Identifier: rule.Identifier, RuleType: rule.RuleType, Policy: rule.Policy, CustomMessage: rule.CustomMessage, @@ -83,7 +83,7 @@ func runJsonExport(client dynamodb.QueryAPI, filename string) (err error) { if err != nil { return } - err = ioutil.WriteFile(filename, jsondata, 0644) + err = os.WriteFile(filename, jsondata, 0644) if err != nil { return } @@ -124,7 +124,7 @@ func runCsvExport( return } record := []string{ - rule.SHA256, + rule.Identifier, string(ruleType), string(policy), rule.CustomMessage, diff --git a/internal/cli/rules/rules-import.go b/internal/cli/rules/rules-import.go index dcd77aa..620c257 100644 --- a/internal/cli/rules/rules-import.go +++ b/internal/cli/rules/rules-import.go @@ -172,9 +172,9 @@ func runCsvImport( // Start taking lines from the csv and shoveling them into the workers for line := range data { - sha256, ok := line["sha256"] + identifier, ok := line["identifier"] if !ok { - panic("no sha256") + panic("no identifier") } ruleTypeStr, ok := line["type"] if !ok { @@ -205,7 +205,7 @@ func runCsvImport( } rules <- fileRule{ - SHA256: sha256, + Identifier: identifier, RuleType: ruleType, Policy: policy, Description: description, @@ -232,14 +232,21 @@ func ddbWriter( var err error atomic.AddUint64(total, 1) - suffix := "" - if rule.RuleType == types.Certificate { + var suffix string + switch rule.RuleType { + case types.RuleTypeCertificate: suffix = " (Cert)" + case types.RuleTypeTeamID: + suffix = " (TeamID)" + case types.RuleTypeSigningID: + suffix = " (SigningID)" + default: + suffix = "" } if rule.Policy == types.RulePolicyRemove { - fmt.Printf(" Removing rule: [%s]\n", rule.SHA256) - sortkey := rudolphrules.RuleSortKeyFromTypeSHA(rule.SHA256, rule.RuleType) + fmt.Printf(" Removing rule: [%s]\n", rule.Identifier) + sortkey := rudolphrules.RuleSortKeyFromTypeIdentifier(rule.Identifier, rule.RuleType) err = globalrules.RemoveGlobalRule( timeProvider, client, @@ -249,11 +256,11 @@ func ddbWriter( ) } else { - fmt.Printf(" Writing rule: [%+v] %s%s\n", rule.Policy, rule.SHA256, suffix) + fmt.Printf(" Writing rule: [%+v] %s%s\n", rule.Policy, rule.Identifier, suffix) err = globalrules.AddNewGlobalRule( timeProvider, client, - rule.SHA256, + rule.Identifier, rule.RuleType, rule.Policy, rule.Description, diff --git a/internal/cli/rules/rules.go b/internal/cli/rules/rules.go index b9ad842..59a17c4 100644 --- a/internal/cli/rules/rules.go +++ b/internal/cli/rules/rules.go @@ -108,6 +108,10 @@ func renderRule(item modelrules.SantaRule) string { predicate = "binary" case types.RuleTypeCertificate: predicate = "certificate" + case types.RuleTypeTeamID: + predicate = "teamID" + case types.RuleTypeSigningID: + predicate = "signingID" default: predicate = "?" } @@ -130,5 +134,10 @@ func renderRule(item modelrules.SantaRule) string { verb = "?" } - return fmt.Sprintf("Rule %v %v (%s)", verb, predicate, item.SHA256) + // Support backwards compatibility with SHA256 + if item.SHA256 != "" && item.Identifier == "" { + item.Identifier = item.SHA256 + } + + return fmt.Sprintf("Rule %v %v (%s)", verb, predicate, item.Identifier) } diff --git a/internal/cli/santa_sensor/santa_fileinfo.go b/internal/cli/santa_sensor/santa_fileinfo.go index 58dff20..b957f6b 100644 --- a/internal/cli/santa_sensor/santa_fileinfo.go +++ b/internal/cli/santa_sensor/santa_fileinfo.go @@ -13,6 +13,8 @@ type santaFileInfo struct { Path string `json:"Path"` SHA256 string `json:"SHA-256"` SHA1 string `json:"SHA-1"` + TeamID string `json:"Team ID"` + SigningID string `json:"Signing ID"` BundleName string `json:"Bundle Name"` BundleVersion string `json:"Bundle Version"` BundleVersionStr string `json:"Bundle Version Str"` diff --git a/internal/handlers/ruledownload/response.go b/internal/handlers/ruledownload/response.go index b5b2734..f1b42ed 100644 --- a/internal/handlers/ruledownload/response.go +++ b/internal/handlers/ruledownload/response.go @@ -15,10 +15,12 @@ type RuledownloadResponse struct { // RuledownloadRule is a single rule returned in a RuledownloadResponse // It duck-types to/from the SantaRule struct type +// Documentation: https://santa.dev/development/sync-protocol.html#rules-objects type RuledownloadRule struct { RuleType types.RuleType `json:"rule_type"` Policy types.Policy `json:"policy"` - SHA256 string `json:"sha256"` + SHA256 string `json:"sha256,omitempty"` + Identifier string `json:"identifier"` CustomMessage string `json:"custom_msg,omitempty"` } @@ -29,6 +31,12 @@ func DDBRulesToResponseRules(rulesList []rules.SantaRule) (responseRules []Ruled for i, rule := range rulesList { responseRules[i] = RuledownloadRule(rule) + // responseRules[i] = RuledownloadRule{ + // RuleType: rule.RuleType, + // Policy: rule.Policy, + // Identifier: rule.Identifier, + // CustomMessage: rule.CustomMessage, + // } } return } diff --git a/pkg/model/feedrules/create.go b/pkg/model/feedrules/create.go index 6a77616..3480d88 100644 --- a/pkg/model/feedrules/create.go +++ b/pkg/model/feedrules/create.go @@ -1,25 +1,27 @@ package feedrules import ( - "fmt" - "github.com/airbnb/rudolph/pkg/clock" "github.com/airbnb/rudolph/pkg/dynamodb" "github.com/airbnb/rudolph/pkg/model/rules" ) func ConstructFeedRuleFromBaseRule(timeProvider clock.TimeProvider, rule rules.SantaRule) FeedRuleRow { + var identifier string + // Support backwards compatibility with legacy SHA256 identifier + if rule.SHA256 != "" && rule.Identifier == "" { + identifier = rule.SHA256 + } else { + identifier = rule.Identifier + } + return FeedRuleRow{ PrimaryKey: dynamodb.PrimaryKey{ PartitionKey: feedRulesPK, // With this sort key, all rules will be ordered by the date they are created, // but also uniqified by the specific binary. This means that you can seek all rules // over time, kind of like picking up diffs through VCS changes. - SortKey: fmt.Sprintf( - "%s#%s", - clock.RFC3339(timeProvider.Now()), - rules.RuleSortKeyFromTypeSHA(rule.SHA256, rule.RuleType), - ), + SortKey: feedRulesSK(timeProvider, identifier, rule.RuleType), }, SantaRule: rule, ExpiresAfter: GetSyncStateExpiresAfter(timeProvider), diff --git a/pkg/model/feedrules/model.go b/pkg/model/feedrules/model.go index a4e8d43..bd85467 100644 --- a/pkg/model/feedrules/model.go +++ b/pkg/model/feedrules/model.go @@ -1,6 +1,8 @@ package feedrules import ( + "fmt" + "github.com/airbnb/rudolph/pkg/clock" "github.com/airbnb/rudolph/pkg/dynamodb" "github.com/airbnb/rudolph/pkg/model/rules" @@ -26,3 +28,15 @@ func GetSyncStateExpiresAfter(timeProvider clock.TimeProvider) int64 { func GetDataType() types.DataType { return types.DataTypeRulesFeed } + +func feedRulesSK( + timeProvider clock.TimeProvider, + identifier string, + ruleType types.RuleType, +) string { + return fmt.Sprintf( + "%s#%s", + clock.RFC3339(timeProvider.Now()), + rules.RuleSortKeyFromTypeIdentifier(identifier, ruleType), + ) +} diff --git a/pkg/model/feedrules/query.go b/pkg/model/feedrules/query.go index 77bbeec..8a09fd9 100644 --- a/pkg/model/feedrules/query.go +++ b/pkg/model/feedrules/query.go @@ -15,11 +15,19 @@ import ( // GetPaginatedFeedRules returns zero or more rules on the feed, up to the limit // If there are more rules to paginate through, will return a lastEvaluatedKey that can be passed in as the // exclusiveStartKey in subsequent requests. Otherwise, lastEvaluatedKey is nil when there are no more items. -func GetPaginatedFeedRules(client dynamodb.QueryAPI, limit int, exclusiveStartKey *dynamodb.PrimaryKey) (items *[]FeedRuleRow, lastEvaluatedKey *dynamodb.PrimaryKey, err error) { +func GetPaginatedFeedRules( + client dynamodb.QueryAPI, + limit int, + exclusiveStartKey *dynamodb.PrimaryKey, +) ( + items *[]FeedRuleRow, + lastEvaluatedKey *dynamodb.PrimaryKey, + err error, +) { partitionKey := feedRulesPK if limit <= 0 { - err = errors.New("Invalid limit/batchsize specified") + err = errors.New("invalid limit/batchsize specified") return } @@ -67,5 +75,12 @@ func GetPaginatedFeedRules(client dynamodb.QueryAPI, limit int, exclusiveStartKe return } // log.Printf(" got %d items from query.", len(*items)) + + // To support legacy SHA256 types, we must transform the datasets before returning + for _, item := range *items { + if item.SHA256 != "" && item.Identifier == "" { + item.Identifier = item.SHA256 + } + } return } diff --git a/pkg/model/globalrules/add.go b/pkg/model/globalrules/add.go index 1a92d69..42c9c14 100644 --- a/pkg/model/globalrules/add.go +++ b/pkg/model/globalrules/add.go @@ -11,9 +11,16 @@ import ( awsdynamodbtypes "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) -func AddNewGlobalRule(time clock.TimeProvider, client dynamodb.TransactWriteItemsAPI, sha256 string, ruleType types.RuleType, policy types.Policy, description string) error { +func AddNewGlobalRule( + time clock.TimeProvider, + client dynamodb.TransactWriteItemsAPI, + identifier string, + ruleType types.RuleType, + policy types.Policy, + description string, +) error { // Input Validation - isValid, err := inputValidation(sha256, ruleType, policy, description) + isValid, err := ruleValidation(ruleType, policy) if err != nil { return err } @@ -24,13 +31,13 @@ func AddNewGlobalRule(time clock.TimeProvider, client dynamodb.TransactWriteItem rule := GlobalRuleRow{ PrimaryKey: dynamodb.PrimaryKey{ PartitionKey: globalRulesPK, - SortKey: globalRulesSK(sha256, ruleType), + SortKey: globalRulesSK(identifier, ruleType), }, Description: description, SantaRule: rules.SantaRule{ - RuleType: ruleType, - Policy: policy, - SHA256: sha256, + RuleType: ruleType, + Policy: policy, + Identifier: identifier, }, } @@ -54,17 +61,12 @@ func AddNewGlobalRule(time clock.TimeProvider, client dynamodb.TransactWriteItem return err } -func inputValidation(sha256 string, ruleType types.RuleType, policy types.Policy, description string) (bool, error) { - var err error - - // RuleSha256 validation - err = types.ValidateSha256(sha256) - if err != nil { - return false, err - } - +func ruleValidation( + ruleType types.RuleType, + policy types.Policy, +) (bool, error) { // RuleType validation - _, err = ruleType.MarshalText() + _, err := ruleType.MarshalText() if err != nil { return false, err } diff --git a/pkg/model/globalrules/get.go b/pkg/model/globalrules/get.go index d5c9ac7..cb7e68a 100644 --- a/pkg/model/globalrules/get.go +++ b/pkg/model/globalrules/get.go @@ -12,13 +12,22 @@ func GetGlobalRuleBySortKey(client dynamodb.GetItemAPI, ruleSortKey string) (*Gl return getItemAsGlobalRule(client, globalRulesPK, ruleSortKey) } +// @deprecated Use GetGlobalRuleByIdentifier func GetGlobalRuleByShaType(client dynamodb.GetItemAPI, sha256 string, ruleType types.RuleType) (*GlobalRuleRow, error) { + return GetGlobalRuleByIdentifier(client, sha256, ruleType) +} + +func GetGlobalRuleByIdentifier(client dynamodb.GetItemAPI, identifier string, ruleType types.RuleType) (*GlobalRuleRow, error) { pk := globalRulesPK - sk := globalRulesSK(sha256, ruleType) + sk := globalRulesSK(identifier, ruleType) return getItemAsGlobalRule(client, pk, sk) } -func getItemAsGlobalRule(client dynamodb.GetItemAPI, partitionKey string, sortKey string) (rule *GlobalRuleRow, err error) { +func getItemAsGlobalRule( + client dynamodb.GetItemAPI, + partitionKey string, + sortKey string, +) (rule *GlobalRuleRow, err error) { output, err := client.GetItem( dynamodb.PrimaryKey{ PartitionKey: partitionKey, @@ -42,5 +51,9 @@ func getItemAsGlobalRule(client dynamodb.GetItemAPI, partitionKey string, sortKe return } + if rule.SHA256 != "" && rule.Identifier == "" { + rule.Identifier = rule.SHA256 + } + return } diff --git a/pkg/model/globalrules/model.go b/pkg/model/globalrules/model.go index 91541ab..9e7ea4e 100644 --- a/pkg/model/globalrules/model.go +++ b/pkg/model/globalrules/model.go @@ -20,6 +20,6 @@ type updateRulePolicyRequest struct { Policy types.Policy `dynamodbav:"Policy"` } -func globalRulesSK(sha256 string, ruleType types.RuleType) string { - return rules.RuleSortKeyFromTypeSHA(sha256, ruleType) +func globalRulesSK(identifier string, ruleType types.RuleType) string { + return rules.RuleSortKeyFromTypeIdentifier(identifier, ruleType) } diff --git a/pkg/model/globalrules/query.go b/pkg/model/globalrules/query.go index 87ad816..f05f5f1 100644 --- a/pkg/model/globalrules/query.go +++ b/pkg/model/globalrules/query.go @@ -16,11 +16,19 @@ func PingDatabase(client dynamodb.QueryAPI) (err error) { return } -func GetPaginatedGlobalRules(client dynamodb.QueryAPI, limit int, exclusiveStartKey *dynamodb.PrimaryKey) (items *[]GlobalRuleRow, lastEvaluatedKey *dynamodb.PrimaryKey, err error) { +func GetPaginatedGlobalRules( + client dynamodb.QueryAPI, + limit int, + exclusiveStartKey *dynamodb.PrimaryKey, +) ( + items *[]GlobalRuleRow, + lastEvaluatedKey *dynamodb.PrimaryKey, + err error, +) { partitionKey := globalRulesPK if limit <= 0 { - err = errors.New("Invalid limit/batchsize specified") + err = errors.New("invalid limit/batchsize specified") return } @@ -66,5 +74,12 @@ func GetPaginatedGlobalRules(client dynamodb.QueryAPI, limit int, exclusiveStart err = fmt.Errorf("failed to unmarshal result from DynamoDB: %w", err) return } + + // To support legacy SHA256 types, we must transform the datasets before returning + for _, item := range *items { + if item.SHA256 != "" && item.Identifier == "" { + item.Identifier = item.SHA256 + } + } return } diff --git a/pkg/model/globalrules/remove.go b/pkg/model/globalrules/remove.go index 1e08ac0..53e9f14 100644 --- a/pkg/model/globalrules/remove.go +++ b/pkg/model/globalrules/remove.go @@ -1,7 +1,6 @@ package globalrules import ( - "errors" "fmt" "github.com/airbnb/rudolph/pkg/clock" @@ -15,13 +14,19 @@ import ( // RemoveGlobalRule will remove the rule from the global repository of rules. // It also creates a new rule entry in the feed that explicitly tells future syncs // to remove the rule too. -func RemoveGlobalRule(timeProvider clock.TimeProvider, getter dynamodb.GetItemAPI, transacter dynamodb.TransactWriteItemsAPI, ruleSortKey string, txnIdempotencyKey string) (err error) { +func RemoveGlobalRule( + timeProvider clock.TimeProvider, + getter dynamodb.GetItemAPI, + transacter dynamodb.TransactWriteItemsAPI, + ruleSortKey string, + txnIdempotencyKey string, +) error { rule, err := GetGlobalRuleBySortKey(getter, ruleSortKey) if err != nil { return fmt.Errorf("query to retrieve existing rule failed: %w", err) } if rule == nil { - return errors.New(fmt.Sprintf("no such rule with sk (%s) exists", ruleSortKey)) + return fmt.Errorf("no such rule with sk (%s) exists", ruleSortKey) } // Delete the global rule @@ -55,7 +60,7 @@ func RemoveGlobalRule(timeProvider clock.TimeProvider, getter dynamodb.GetItemAP return fmt.Errorf("transaction delete failed: %w", err) } - return + return nil } type RuleRemovalService interface { diff --git a/pkg/model/globalrules/update.go b/pkg/model/globalrules/update.go index f5d9ad2..c3c15b3 100644 --- a/pkg/model/globalrules/update.go +++ b/pkg/model/globalrules/update.go @@ -11,10 +11,16 @@ import ( awsdynamodbtypes "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) -func UpdateGlobalRule(time clock.TimeProvider, client dynamodb.TransactWriteItemsAPI, sha256 string, ruleType types.RuleType, rulePolicy types.Policy) (err error) { +func UpdateGlobalRule( + time clock.TimeProvider, + client dynamodb.TransactWriteItemsAPI, + identifier string, + ruleType types.RuleType, + rulePolicy types.Policy, +) error { // Get the PK/SK values pk := globalRulesPK - sk := globalRulesSK(sha256, ruleType) + sk := globalRulesSK(identifier, ruleType) primaryKey := dynamodb.PrimaryKey{ PartitionKey: pk, @@ -28,9 +34,9 @@ func UpdateGlobalRule(time clock.TimeProvider, client dynamodb.TransactWriteItem // UpdatedRule for the ruleFeed Update updatedRule := rules.SantaRule{ - RuleType: ruleType, - Policy: rulePolicy, - SHA256: sha256, + RuleType: ruleType, + Policy: rulePolicy, + Identifier: identifier, } updateFeedRuleItem := feedrules.ConstructFeedRuleFromBaseRule(time, updatedRule) @@ -56,8 +62,8 @@ func UpdateGlobalRule(time clock.TimeProvider, client dynamodb.TransactWriteItem // Send the TransactWriteRequest _, err = client.TransactWriteItems(transactItems, nil) if err != nil { - err = fmt.Errorf("failed to update global rule: %w", err) - return + return fmt.Errorf("failed to update global rule: %w", err) } - return + + return nil } diff --git a/pkg/model/machinerules/add.go b/pkg/model/machinerules/add.go index e6fe971..4f62565 100644 --- a/pkg/model/machinerules/add.go +++ b/pkg/model/machinerules/add.go @@ -9,9 +9,24 @@ import ( "github.com/airbnb/rudolph/pkg/types" ) -func AddNewMachineRule(client dynamodb.PutItemAPI, machineID string, sha256 string, ruleType types.RuleType, policy types.Policy, description string, expires time.Time) (err error) { +func AddNewMachineRule( + client dynamodb.PutItemAPI, + machineID string, + identifier string, + ruleType types.RuleType, + policy types.Policy, + description string, + expires time.Time, +) error { // Input Validation - isValid, err := inputValidation(machineID, sha256, ruleType, policy, description, expires) + isValid, err := ruleValidation( + machineID, + identifier, + ruleType, + policy, + description, + expires, + ) if err != nil { return err } @@ -22,32 +37,34 @@ func AddNewMachineRule(client dynamodb.PutItemAPI, machineID string, sha256 stri rule := MachineRuleRow{ PrimaryKey: dynamodb.PrimaryKey{ PartitionKey: machineRulePK(machineID), - SortKey: machineRuleSK(sha256, ruleType), + SortKey: machineRuleSK(identifier, ruleType), }, Description: description, SantaRule: rules.SantaRule{ - RuleType: ruleType, - Policy: policy, - SHA256: sha256, + RuleType: ruleType, + Policy: policy, + Identifier: identifier, }, ExpiresAfter: expires.Unix(), } _, err = client.PutItem(rule) - return -} - -func inputValidation(machineID, sha256 string, ruleType types.RuleType, policy types.Policy, description string, expires time.Time) (bool, error) { - var err error - - // MachineID validation - err = types.ValidateMachineID(machineID) if err != nil { - return false, err + return err } + return nil +} - // RuleSha256 validation - err = types.ValidateSha256(sha256) +func ruleValidation( + machineID, + identifier string, + ruleType types.RuleType, + policy types.Policy, + description string, + expires time.Time, +) (bool, error) { + // MachineID validation + err := types.ValidateMachineID(machineID) if err != nil { return false, err } diff --git a/pkg/model/machinerules/get.go b/pkg/model/machinerules/get.go index 55c602f..dc3d2dc 100644 --- a/pkg/model/machinerules/get.go +++ b/pkg/model/machinerules/get.go @@ -35,8 +35,13 @@ func getItemAsMachineRule(client dynamodb.GetItemAPI, partitionKey string, sortK return } +// @deprecated func GetMachineRuleByShaType(client dynamodb.GetItemAPI, machineID string, sha256 string, ruleType types.RuleType) (rule *MachineRuleRow, err error) { + return GetMachineRuleByIdentifierType(client, machineID, sha256, ruleType) +} + +func GetMachineRuleByIdentifierType(client dynamodb.GetItemAPI, machineID string, identifier string, ruleType types.RuleType) (rule *MachineRuleRow, err error) { pk := machineRulePK(machineID) - sk := machineRuleSK(sha256, ruleType) + sk := machineRuleSK(identifier, ruleType) return getItemAsMachineRule(client, pk, sk) } diff --git a/pkg/model/machinerules/model.go b/pkg/model/machinerules/model.go index 1cd7140..6e19f65 100644 --- a/pkg/model/machinerules/model.go +++ b/pkg/model/machinerules/model.go @@ -35,6 +35,6 @@ type updateRulePolicyRequest struct { func machineRulePK(machineID string) string { return fmt.Sprintf("%s%s", machineRulesPKPrefix, machineID) } -func machineRuleSK(sha256 string, ruleType types.RuleType) string { - return rules.RuleSortKeyFromTypeSHA(sha256, ruleType) +func machineRuleSK(identifier string, ruleType types.RuleType) string { + return rules.RuleSortKeyFromTypeIdentifier(identifier, ruleType) } diff --git a/pkg/model/machinerules/services.go b/pkg/model/machinerules/services.go index 613407f..2e8f23d 100644 --- a/pkg/model/machinerules/services.go +++ b/pkg/model/machinerules/services.go @@ -26,14 +26,12 @@ func (c ConcreteMachineRulesUpdater) UpdateMachineRulePolicy(machineID string, s return UpdateMachineRule(c.Updater, machineID, sha256, ruleType, rulePolicy, expires) } -// // This service exposes all machine rules access methods -// type MachineRulesService interface { - Get(machineId string, sha256 string, ruleType types.RuleType) (rule *MachineRuleRow, err error) - Add(machineId string, sha256 string, ruleType types.RuleType, policy types.Policy, description string, expires time.Time) error - Update(machineId string, sha256 string, ruleType types.RuleType, rulePolicy types.Policy, expires time.Time) error - Remove(machineId string, sha256 string, ruleType types.RuleType) error + Get(machineId string, identifier string, ruleType types.RuleType) (rule *MachineRuleRow, err error) + Add(machineId string, identifier string, ruleType types.RuleType, policy types.Policy, description string, expires time.Time) error + Update(machineId string, identifier string, ruleType types.RuleType, rulePolicy types.Policy, expires time.Time) error + Remove(machineId string, identifier string, ruleType types.RuleType) error RemoveBySortKey(machineId string, ruleSortKey string) error GetMachineRules(machineID string) (items *[]MachineRuleRow, err error) } @@ -48,20 +46,20 @@ func GetMachineRulesService(dynamodb dynamodb.DynamoDBClient) MachineRulesServic } } -func (s ConcreteMachineRulesService) Get(machineId string, sha256 string, ruleType types.RuleType) (rule *MachineRuleRow, err error) { - return getItemAsMachineRule(s.dynamodb, machineRulePK(machineId), machineRuleSK(sha256, ruleType)) +func (s ConcreteMachineRulesService) Get(machineId string, identifier string, ruleType types.RuleType) (rule *MachineRuleRow, err error) { + return getItemAsMachineRule(s.dynamodb, machineRulePK(machineId), machineRuleSK(identifier, ruleType)) } -func (s ConcreteMachineRulesService) Add(machineId string, sha256 string, ruleType types.RuleType, policy types.Policy, description string, expires time.Time) error { - return AddNewMachineRule(s.dynamodb, machineId, sha256, ruleType, policy, description, expires) +func (s ConcreteMachineRulesService) Add(machineId string, identifier string, ruleType types.RuleType, policy types.Policy, description string, expires time.Time) error { + return AddNewMachineRule(s.dynamodb, machineId, identifier, ruleType, policy, description, expires) } -func (s ConcreteMachineRulesService) Update(machineId string, sha256 string, ruleType types.RuleType, rulePolicy types.Policy, expires time.Time) error { - return UpdateMachineRule(s.dynamodb, machineId, sha256, ruleType, rulePolicy, expires) +func (s ConcreteMachineRulesService) Update(machineId string, identifier string, ruleType types.RuleType, rulePolicy types.Policy, expires time.Time) error { + return UpdateMachineRule(s.dynamodb, machineId, identifier, ruleType, rulePolicy, expires) } func (s ConcreteMachineRulesService) RemoveBySortKey(machineId string, ruleSortKey string) error { return RemoveMachineRule(s.dynamodb, s.dynamodb, machineId, ruleSortKey) } -func (s ConcreteMachineRulesService) Remove(machineId string, sha256 string, ruleType types.RuleType) error { - return RemoveMachineRule(s.dynamodb, s.dynamodb, machineId, machineRuleSK(sha256, ruleType)) +func (s ConcreteMachineRulesService) Remove(machineId string, identifier string, ruleType types.RuleType) error { + return RemoveMachineRule(s.dynamodb, s.dynamodb, machineId, machineRuleSK(identifier, ruleType)) } func (s ConcreteMachineRulesService) GetMachineRules(machineId string) (items *[]MachineRuleRow, err error) { return GetMachineRules(s.dynamodb, machineId) diff --git a/pkg/model/machinerules/services_test.go b/pkg/model/machinerules/services_test.go index 2162929..a8e9ff0 100644 --- a/pkg/model/machinerules/services_test.go +++ b/pkg/model/machinerules/services_test.go @@ -62,7 +62,7 @@ func (m *MockDynamodb) PutItem(item interface{}) (*awsdynamodb.PutItemOutput, er func Test_Service_Get(t *testing.T) { machineID := "858CBF28-5EAA-58A3-A155-BA5E81D5B5DD" - sha256 := "ed0a9ba83449b5966363e0c20fe7755defcb2d7136657d3880bb462a8d7a7025" + identifier := "ed0a9ba83449b5966363e0c20fe7755defcb2d7136657d3880bb462a8d7a7025" t.Run("GetItem returns no item", func(t *testing.T) { mocked := &MockDynamodb{} @@ -72,7 +72,7 @@ func Test_Service_Get(t *testing.T) { dynamodb: mocked, } - item, err := service.Get(machineID, sha256, types.Binary) + item, err := service.Get(machineID, identifier, types.Binary) assert.Empty(t, err) assert.Empty(t, item) }) @@ -90,6 +90,9 @@ func Test_Service_Get(t *testing.T) { "SHA256": &awsdynamodbtypes.AttributeValueMemberS{ Value: "ed0a9ba83449b5966363e0c20fe7755defcb2d7136657d3880bb462a8d7a7025", }, + "Identifier": &awsdynamodbtypes.AttributeValueMemberS{ + Value: "ed0a9ba83449b5966363e0c20fe7755defcb2d7136657d3880bb462a8d7a7025", + }, "Policy": &awsdynamodbtypes.AttributeValueMemberN{ Value: "1", }, @@ -100,9 +103,10 @@ func Test_Service_Get(t *testing.T) { dynamodb: mocked, } - item, err := service.Get(machineID, sha256, types.Binary) + item, err := service.Get(machineID, identifier, types.Binary) assert.Empty(t, err) + assert.Equal(t, item.Identifier, "ed0a9ba83449b5966363e0c20fe7755defcb2d7136657d3880bb462a8d7a7025") assert.Equal(t, item.SHA256, "ed0a9ba83449b5966363e0c20fe7755defcb2d7136657d3880bb462a8d7a7025") assert.Equal(t, item.Policy, types.Allowlist) }) @@ -111,7 +115,7 @@ func Test_Service_Get(t *testing.T) { func Test_Service_Add_OK(t *testing.T) { t.Run("PutItem works with no errors", func(t *testing.T) { machineID := "858CBF28-5EAA-58A3-A155-BA5E81D5B5DD" - sha256 := "ed0a9ba83449b5966363e0c20fe7755defcb2d7136657d3880bb462a8d7a7025" + identifier := "ed0a9ba83449b5966363e0c20fe7755defcb2d7136657d3880bb462a8d7a7025" ruleType := types.Binary description := "Description" policy := types.AllowlistCompiler @@ -123,14 +127,14 @@ func Test_Service_Add_OK(t *testing.T) { mocked := &MockDynamodb{} mocked.On("PutItem", mock.MatchedBy(func(item interface{}) bool { rule := item.(MachineRuleRow) - return rule.Description == description && rule.Policy == policy && rule.SHA256 == sha256 + return rule.Description == description && rule.Policy == policy && rule.Identifier == identifier })).Return(&awsdynamodb.PutItemOutput{}, nil) service := ConcreteMachineRulesService{ dynamodb: mocked, } - err := service.Add(machineID, sha256, ruleType, policy, description, expires) + err := service.Add(machineID, identifier, ruleType, policy, description, expires) assert.Empty(t, err) mocked.AssertCalled(t, "PutItem", mock.Anything) }) diff --git a/pkg/model/rules/constants.go b/pkg/model/rules/constants.go index 1f0ce73..b925021 100644 --- a/pkg/model/rules/constants.go +++ b/pkg/model/rules/constants.go @@ -3,4 +3,6 @@ package rules const ( binaryRuleSKPrefix = "Binary#" certificateRuleSKPrefix = "Cert#" + teamIDRuleSKPrefix = "TeamID#" + signingIDRuleSKPrefix = "SigningID#" ) diff --git a/pkg/model/rules/primary_key.go b/pkg/model/rules/primary_key.go index 5033421..a02690c 100644 --- a/pkg/model/rules/primary_key.go +++ b/pkg/model/rules/primary_key.go @@ -7,17 +7,21 @@ import ( "github.com/airbnb/rudolph/pkg/types" ) +// @deprecated func RuleSortKeyFromTypeSHA(sha256 string, ruleType types.RuleType) string { - if len(sha256) != 64 { - log.Printf("error (recovered): invalid sha256: (%s)", sha256) - return "" - } + return RuleSortKeyFromTypeIdentifier(sha256, ruleType) +} +func RuleSortKeyFromTypeIdentifier(identifier string, ruleType types.RuleType) string { switch ruleType { case types.RuleTypeBinary: - return fmt.Sprintf("%s%s", binaryRuleSKPrefix, sha256) + return fmt.Sprintf("%s%s", binaryRuleSKPrefix, identifier) case types.RuleTypeCertificate: - return fmt.Sprintf("%s%s", certificateRuleSKPrefix, sha256) + return fmt.Sprintf("%s%s", certificateRuleSKPrefix, identifier) + case types.RuleTypeTeamID: + return fmt.Sprintf("%s%s", teamIDRuleSKPrefix, identifier) + case types.RuleTypeSigningID: + return fmt.Sprintf("%s%s", signingIDRuleSKPrefix, identifier) default: log.Printf("error (recovered): encountered unknown ruleType: (%+v)", ruleType) return "" diff --git a/pkg/model/rules/rule.go b/pkg/model/rules/rule.go index 825fae7..17e6b7d 100644 --- a/pkg/model/rules/rule.go +++ b/pkg/model/rules/rule.go @@ -11,8 +11,9 @@ import ( // - feed rules: A timeline-based feed of rule "diffs" that endpoints can download // - machine rules: Rules intended only to be deployed to specific endpoints type SantaRule struct { - RuleType types.RuleType `dynamodbav:"RuleType"` - Policy types.Policy `dynamodbav:"Policy"` - SHA256 string `dynamodbav:"SHA256"` - CustomMessage string `dynamodbav:"CustomMessage,omitempty"` + RuleType types.RuleType `dynamodbav:"RuleType" json:"rule_type"` + Policy types.Policy `dynamodbav:"Policy" json:"policy"` + SHA256 string `dynamodbav:"SHA256,omitempty" json:"sha256,omitempty"` // @deprecated - Use Identifier instead + Identifier string `dynamodbav:"Identifier" json:"identifier"` + CustomMessage string `dynamodbav:"CustomMessage,omitempty" json:"custom_msg,omitempty"` } diff --git a/pkg/model/syncstate/get_test.go b/pkg/model/syncstate/get_test.go index 9e11d6d..f04d09a 100644 --- a/pkg/model/syncstate/get_test.go +++ b/pkg/model/syncstate/get_test.go @@ -52,7 +52,6 @@ func Test_GetIntendedConfig(t *testing.T) { } for _, test := range cases { - dataTypeAV, _ := test.expectedDataType.MarshalDynamoDBAttributeValue() dynamodb := getSyncState( func(key dynamodb.PrimaryKey, consistentRead bool) (*awsdynamodb.GetItemOutput, error) { if test.dbError { @@ -66,7 +65,7 @@ func Test_GetIntendedConfig(t *testing.T) { Item: map[string]awstypes.AttributeValue{ "CleanSync": &awstypes.AttributeValueMemberBOOL{Value: false}, "BatchSize": &awstypes.AttributeValueMemberN{Value: "31"}, - "DataType": dataTypeAV, + "DataType": &awstypes.AttributeValueMemberS{Value: string(test.expectedDataType)}, }, }, nil } diff --git a/pkg/types/data_type.go b/pkg/types/data_type.go index 81a2a10..4ea2531 100644 --- a/pkg/types/data_type.go +++ b/pkg/types/data_type.go @@ -4,8 +4,6 @@ import ( "fmt" awstypes "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" ) // DataType identifies the current DynamoDB data model @@ -25,22 +23,32 @@ func (dt *DataType) UnmarshalText(text []byte) error { case "SENSOR_DATA": fallthrough case "SENSORDATA": + fallthrough + case "SensorData": *dt = DataTypeSensorData case "RULES_FEED": fallthrough case "RULESFEED": + fallthrough + case "RulesFeed": *dt = DataTypeRulesFeed case "SYNC_STATE": fallthrough case "SYNCSTATE": + fallthrough + case "SyncState": *dt = DataTypeSyncState case "MACHINE_CONFIG": fallthrough case "MACHINECONFIG": - *dt = DataTypeGlobalConfig + fallthrough + case "MachineConfig": + *dt = DataTypeMachineConfig case "GLOBAL_CONFIG": fallthrough case "GLOBALCONFIG": + fallthrough + case "GlobalConfig": *dt = DataTypeGlobalConfig default: return fmt.Errorf("unknown data_type value %q", mode) @@ -52,15 +60,15 @@ func (dt *DataType) UnmarshalText(text []byte) error { func (dt DataType) MarshalText() ([]byte, error) { switch dt { case DataTypeSensorData: - return []byte("SENSORDATA"), nil + return []byte("SensorData"), nil case DataTypeSyncState: - return []byte("SYNCSTATE"), nil + return []byte("SyncState"), nil case DataTypeMachineConfig: - return []byte("MACHINECONFIG"), nil + return []byte("MachineConfig"), nil case DataTypeGlobalConfig: - return []byte("GLOBALCONFIG"), nil + return []byte("GlobalConfig"), nil case DataTypeRulesFeed: - return []byte("RULESFEED"), nil + return []byte("RulesFeed"), nil default: return nil, fmt.Errorf("unknown data_type %s", dt) } @@ -87,26 +95,56 @@ func (dt DataType) MarshalDynamoDBAttributeValue() (awstypes.AttributeValue, err } // UnmarshalDynamoDBAttributeValue implements the Unmarshaler interface -func (dt *DataType) UnmarshalDynamoDBAttributeValue(av *dynamodb.AttributeValue) error { - switch t := aws.StringValue(av.N); t { +func (dt *DataType) UnmarshalDynamoDBAttributeValue(av awstypes.AttributeValue) error { + v, ok := av.(*awstypes.AttributeValueMemberS) + if !ok { + return fmt.Errorf("unexpected data_type value type: %T", av) + } + + switch t := v.Value; t { case "1": fallthrough case "SENSOR_DATA": fallthrough case "SENSORDATA": + fallthrough + case "SensorData": *dt = DataTypeSensorData case "2": fallthrough case "SYNC_STATE": fallthrough case "SYNCSTATE": + fallthrough + case "SyncState": *dt = DataTypeSyncState case "3": fallthrough + case "GLOBAL_CONFIG": + fallthrough + case "GLOBALCONFIG": + fallthrough + case "GlobalConfig": + *dt = DataTypeGlobalConfig + case "4": + fallthrough case "MACHINE_CONFIG": + fallthrough + case "MACHINECONFIG": + fallthrough + case "MachineConfig": *dt = DataTypeMachineConfig + case "5": + fallthrough + case "RULES_FEED": + fallthrough + case "RULESFEED": + fallthrough + case "RulesFeed": + *dt = DataTypeRulesFeed default: return fmt.Errorf("unknown data_type value %q", t) } + return nil } diff --git a/pkg/types/data_type_test.go b/pkg/types/data_type_test.go index 530855f..9ecae6e 100644 --- a/pkg/types/data_type_test.go +++ b/pkg/types/data_type_test.go @@ -1,171 +1,115 @@ package types import ( - "errors" "testing" awstypes "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/stretchr/testify/assert" ) -// SensorData - Success Validation // -func TestTypes_DataType_Marshal_SensorData_Success(t *testing.T) { - dataType, err := DataType("SensorData").MarshalText() - assert.Empty(t, err) - assert.Equal(t, dataType, []byte("SENSORDATA")) -} - -func TestTypes_DataType_Unmarshal_DataTypeSensorData_Success(t *testing.T) { - dataType := DataTypeSensorData - err := dataType.UnmarshalText([]byte("SENSORDATA")) - assert.Empty(t, err) -} - -func TestTypes_DataType_Unmarshal_DataTypeSENSOR_STATE_Success(t *testing.T) { - dataType := DataTypeSensorData - err := dataType.UnmarshalText([]byte("SENSOR_DATA")) - assert.Empty(t, err) -} - -// SensorData - Failure Validation // -func TestTypes_DataType_Marshal_SensorData_Failure(t *testing.T) { - _, err := DataType("SensorDatas").MarshalText() - assert.NotEmpty(t, err) -} - -func TestTypes_DataType_Unmarshal_DataTypeSensorData_Failure(t *testing.T) { - dataType := DataTypeSensorData - err := dataType.UnmarshalText([]byte("SENSORDATAS")) - assert.NotEmpty(t, err) -} - -// SyncState - Success Validation // -func TestTypes_DataType_Marshal_SyncState_Success(t *testing.T) { - dataType, err := DataType("SyncState").MarshalText() - assert.Empty(t, err) - assert.Equal(t, dataType, []byte("SYNCSTATE")) -} - -func TestTypes_DataType_Unmarshal_SyncState_Success(t *testing.T) { - dataType := DataTypeSyncState - err := dataType.UnmarshalText([]byte("SYNCSTATE")) - assert.Empty(t, err) -} - -func TestTypes_DataType_Unmarshal_SYNC_STATE_Success(t *testing.T) { - dataType := DataTypeSyncState - err := dataType.UnmarshalText([]byte("SYNC_STATE")) - assert.Empty(t, err) -} - -// RulesFeed - Success Validation // -func TestTypes_DataType_Marshal_RulesFeed_Success(t *testing.T) { - dataType, err := DataType("RulesFeed").MarshalText() - assert.Empty(t, err) - assert.Equal(t, dataType, []byte("RULESFEED")) -} - -func TestTypes_DataType_Unmarshal_RulesFeed_Success(t *testing.T) { - dataType := DataTypeRulesFeed - err := dataType.UnmarshalText([]byte("RULESFEED")) - assert.Empty(t, err) -} - -func TestTypes_DataType_Unmarshal_RULES_FEED_Success(t *testing.T) { - dataType := DataTypeRulesFeed - err := dataType.UnmarshalText([]byte("RULES_FEED")) - assert.Empty(t, err) -} - -func TestTypes_DataType_Marshal_SyncState_Failure(t *testing.T) { - _, err := DataType("SYNCSTATE").MarshalText() - assert.NotEmpty(t, err) -} - -// RulesFeed - Failure Validation // -func TestTypes_DataType_Unmarshal_RulesFeed_Failure(t *testing.T) { - dataType := DataTypeRulesFeed - err := dataType.UnmarshalText([]byte("RULESFEEDS")) - assert.NotEmpty(t, err) -} - -func TestTypes_DataType_Marshal_RulesFeed_Failure(t *testing.T) { - _, err := DataType("RULESFEED").MarshalText() - assert.NotEmpty(t, err) -} - -// SyncState - Failure Validation // -func TestTypes_DataType_Unmarshal_SyncState_Failure(t *testing.T) { - dataType := DataTypeSensorData - err := dataType.UnmarshalText([]byte("SYNCSTATESS")) - assert.NotEmpty(t, err) -} - -// MachineConfig - Success Validation // -func TestTypes_DataType_Marshal_MachineConfig_Success(t *testing.T) { - dataType, err := DataType("MachineConfig").MarshalText() - assert.Empty(t, err) - assert.Equal(t, dataType, []byte("MACHINECONFIG")) -} - -func TestTypes_DataType_Unmarshal_MachineConfig_Success(t *testing.T) { - dataType := DataTypeMachineConfig - err := dataType.UnmarshalText([]byte("MACHINECONFIG")) - assert.Empty(t, err) -} - -func TestTypes_DataType_Unmarshal_MACHINE_CONFIG_Success(t *testing.T) { - dataType := DataTypeMachineConfig - err := dataType.UnmarshalText([]byte("MACHINE_CONFIG")) - assert.Empty(t, err) -} - -// SyncState - Failure Validation // -func TestTypes_DataType_Marshal_MachineConfig_Failure(t *testing.T) { - _, err := DataType("MACHINECONFIGS").MarshalText() - assert.NotEmpty(t, err) -} - -func TestTypes_DataType_Unmarshal_MachineConfig_Failure(t *testing.T) { - dataType := DataTypeSensorData - err := dataType.UnmarshalText([]byte("MACHINECONFIGS")) - assert.NotEmpty(t, err) -} - -// MarshalDynamoDBAttributeValue - Success Validation // -func TestTypes_DataType_MarshalDynamoDBAttributeValue_SensorState_Success(t *testing.T) { - var dataType DataType = DataTypeSensorData - - av, err := dataType.MarshalDynamoDBAttributeValue() - expectedAV := &awstypes.AttributeValueMemberS{Value: "SensorData"} - assert.Empty(t, err) - assert.Equal(t, av, expectedAV) -} - -func TestTypes_DataType_MarshalDynamoDBAttributeValue_SyncState_Success(t *testing.T) { - var dataType DataType = DataTypeSyncState - - av, err := dataType.MarshalDynamoDBAttributeValue() - expectedAV := &awstypes.AttributeValueMemberS{Value: "SyncState"} - assert.Empty(t, err) - assert.Equal(t, av, expectedAV) -} - -func TestTypes_DataType_MarshalDynamoDBAttributeValue_MachineConfig_Success(t *testing.T) { - var dataType DataType = DataTypeMachineConfig - - av, err := dataType.MarshalDynamoDBAttributeValue() - expectedAV := &awstypes.AttributeValueMemberS{Value: "MachineConfig"} - assert.Empty(t, err) - assert.Equal(t, av, expectedAV) -} - -// MarshalDynamoDBAttributeValue - Failure Validation // -func TestTypes_DataType_MarshalDynamoDBAttributeValue_MachineConfigs_Failure(t *testing.T) { - var dataType DataType = "MachineConfigs" - - _, err := dataType.MarshalDynamoDBAttributeValue() - assert.NotEmpty(t, err) - expectedErr := errors.New(`unknown data_type value "MachineConfigs"`) - assert.Equal(t, err.Error(), expectedErr.Error()) +func TestDataTypes_MarshalTest(t *testing.T) { + tests := []struct { + name string + dataType DataType + want []byte + wantErr bool + }{ + {"SensorData", DataTypeSensorData, []byte(DataTypeSensorData), false}, + {"SyncState", DataTypeSyncState, []byte(DataTypeSyncState), false}, + {"RulesFeed", DataTypeRulesFeed, []byte(DataTypeRulesFeed), false}, + {"MachineConfig", DataTypeMachineConfig, []byte(DataTypeMachineConfig), false}, + {"GlobalConfig", DataTypeGlobalConfig, []byte(DataTypeGlobalConfig), false}, + {"MISSPELLED", DataType(""), []byte(nil), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.dataType.MarshalText() + if (err != nil) != tt.wantErr { + t.Errorf("DataType.MarshalText() error = %v, wantErr %v", err, tt.wantErr) + return + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestDataTypes_UnmarshalText(t *testing.T) { + tests := []struct { + name string + text []byte + want DataType + wantErr bool + }{ + {"SensorData", []byte(DataTypeSensorData), DataTypeSensorData, false}, + {"SyncState", []byte(DataTypeSyncState), DataTypeSyncState, false}, + {"RulesFeed", []byte(DataTypeRulesFeed), DataTypeRulesFeed, false}, + {"MachineConfig", []byte(DataTypeMachineConfig), DataTypeMachineConfig, false}, + {"GlobalConfig", []byte(DataTypeGlobalConfig), DataTypeGlobalConfig, false}, + {"MISSPELLED", []byte(""), DataType(""), true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var dt DataType + err := dt.UnmarshalText(tt.text) + if (err != nil) != tt.wantErr { + t.Errorf("DataType.UnmarshalText() error = %v, wantErr %v", err, tt.wantErr) + return + } + assert.Equal(t, tt.want, dt) + }) + } +} + +func TestDataTypes_MarshalDynamoDBAttributeValue(t *testing.T) { + tests := []struct { + name string + dataType DataType + want awstypes.AttributeValue + wantErr bool + }{ + {"SensorData", DataTypeSensorData, &awstypes.AttributeValueMemberS{Value: string(DataTypeSensorData)}, false}, + {"SyncState", DataTypeSyncState, &awstypes.AttributeValueMemberS{Value: string(DataTypeSyncState)}, false}, + {"RulesFeed", DataTypeRulesFeed, &awstypes.AttributeValueMemberS{Value: string(DataTypeRulesFeed)}, false}, + {"MachineConfig", DataTypeMachineConfig, &awstypes.AttributeValueMemberS{Value: string(DataTypeMachineConfig)}, false}, + {"GlobalConfig", DataTypeGlobalConfig, &awstypes.AttributeValueMemberS{Value: string(DataTypeGlobalConfig)}, false}, + {"MISSPELLED", DataType(""), nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.dataType.MarshalDynamoDBAttributeValue() + if (err != nil) != tt.wantErr { + t.Errorf("DataType.MarshalDynamoDBAttributeValue() error = %v, wantErr %v", err, tt.wantErr) + return + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestDataTypes_UnmarshalDynamoDBAttributeValue(t *testing.T) { + tests := []struct { + name string + av awstypes.AttributeValue + want DataType + wantErr bool + }{ + {"SensorData", &awstypes.AttributeValueMemberS{Value: string(DataTypeSensorData)}, DataTypeSensorData, false}, + {"SyncState", &awstypes.AttributeValueMemberS{Value: string(DataTypeSyncState)}, DataTypeSyncState, false}, + {"RulesFeed", &awstypes.AttributeValueMemberS{Value: string(DataTypeRulesFeed)}, DataTypeRulesFeed, false}, + {"MachineConfig", &awstypes.AttributeValueMemberS{Value: string(DataTypeMachineConfig)}, DataTypeMachineConfig, false}, + {"GlobalConfig", &awstypes.AttributeValueMemberS{Value: string(DataTypeGlobalConfig)}, DataTypeGlobalConfig, false}, + {"MISSPELLED", nil, DataType(""), true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := DataType("") + err := got.UnmarshalDynamoDBAttributeValue(tt.av) + if (err != nil) != tt.wantErr { + t.Errorf("DataType.UnmarshalDynamoDBAttributeValue() error = %v, wantErr %v", err, tt.wantErr) + return + } + assert.Equal(t, tt.want, got) + }) + } } diff --git a/pkg/types/policy.go b/pkg/types/policy.go index 6d54dba..ef13633 100644 --- a/pkg/types/policy.go +++ b/pkg/types/policy.go @@ -3,8 +3,7 @@ package types import ( "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + awstypes "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) // Policy represents the Santa Rule Policy. @@ -79,7 +78,7 @@ func (p Policy) MarshalText() ([]byte, error) { } // MarshalDynamoDBAttributeValue for ddb -func (p Policy) MarshalDynamoDBAttributeValue(av *dynamodb.AttributeValue) error { +func (p Policy) MarshalDynamoDBAttributeValue() (awstypes.AttributeValue, error) { var s string switch p { case RulePolicyAllowlist: @@ -95,39 +94,32 @@ func (p Policy) MarshalDynamoDBAttributeValue(av *dynamodb.AttributeValue) error case RulePolicyAllowlistTransitive: s = "6" default: - return fmt.Errorf("unknown policy value %q", p) + return nil, fmt.Errorf("unknown policy value %q", p) } - // av.S = &s - av.N = &s - return nil + return &awstypes.AttributeValueMemberN{Value: s}, nil } // UnmarshalDynamoDBAttributeValue implements the Unmarshaler interface -func (p *Policy) UnmarshalDynamoDBAttributeValue(av *dynamodb.AttributeValue) error { - switch t := aws.StringValue(av.N); t { + +func (p *Policy) UnmarshalDynamoDBAttributeValue(av awstypes.AttributeValue) error { + // return attributevalue.Unmarshal(av, p) + v, ok := av.(*awstypes.AttributeValueMemberN) + if !ok { + return fmt.Errorf("unexpected policy value type %T", av) + } + + switch t := v.Value; t { case "1": - fallthrough - case "ALLOWLIST": *p = RulePolicyAllowlist case "2": - fallthrough - case "BLOCKLIST": *p = RulePolicyBlocklist case "3": - fallthrough - case "SILENT_BLOCKLIST": *p = RulePolicySilentBlocklist case "4": - fallthrough - case "REMOVE": *p = RulePolicyRemove case "5": - fallthrough - case "ALLOWLIST_COMPILER": *p = RulePolicyAllowlistCompiler case "6": - fallthrough - case "ALLOWLIST_TRANSITIVE": *p = RulePolicyAllowlistTransitive default: return fmt.Errorf("unknown policy value %q", t) diff --git a/pkg/types/policy_test.go b/pkg/types/policy_test.go index d0225b7..6111830 100644 --- a/pkg/types/policy_test.go +++ b/pkg/types/policy_test.go @@ -3,7 +3,7 @@ package types import ( "testing" - "github.com/aws/aws-sdk-go/service/dynamodb" + awstypes "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/stretchr/testify/assert" ) @@ -68,22 +68,21 @@ func TestPolicyTypes_MarshalDynamoDBAttributeValue(t *testing.T) { tests := []struct { name string policy Policy - want *dynamodb.AttributeValue + want awstypes.AttributeValue wantErr bool }{ - {"ALLOWLIST", RulePolicyAllowlist, new(dynamodb.AttributeValue).SetN("1"), false}, - {"BLOCKLIST", RulePolicyBlocklist, new(dynamodb.AttributeValue).SetN("2"), false}, - {"SILENT_BLOCKLIST", RulePolicySilentBlocklist, new(dynamodb.AttributeValue).SetN("3"), false}, - {"REMOVE", RulePolicyRemove, new(dynamodb.AttributeValue).SetN("4"), false}, - {"ALLOWLIST_COMPILER", RulePolicyAllowlistCompiler, new(dynamodb.AttributeValue).SetN("5"), false}, - {"ALLOWLIST_TRANSITIVE", RulePolicyAllowlistTransitive, new(dynamodb.AttributeValue).SetN("6"), false}, - {"MISSPELLED", Policy(0), new(dynamodb.AttributeValue), true}, + {"ALLOWLIST", RulePolicyAllowlist, &awstypes.AttributeValueMemberN{Value: "1"}, false}, + {"BLOCKLIST", RulePolicyBlocklist, &awstypes.AttributeValueMemberN{Value: "2"}, false}, + {"SILENT_BLOCKLIST", RulePolicySilentBlocklist, &awstypes.AttributeValueMemberN{Value: "3"}, false}, + {"REMOVE", RulePolicyRemove, &awstypes.AttributeValueMemberN{Value: "4"}, false}, + {"ALLOWLIST_COMPILER", RulePolicyAllowlistCompiler, &awstypes.AttributeValueMemberN{Value: "5"}, false}, + {"ALLOWLIST_TRANSITIVE", RulePolicyAllowlistTransitive, &awstypes.AttributeValueMemberN{Value: "6"}, false}, + {"MISSPELLED", Policy(0), nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - av := &dynamodb.AttributeValue{} - err := tt.policy.MarshalDynamoDBAttributeValue(av) + av, err := tt.policy.MarshalDynamoDBAttributeValue() if (err != nil) != tt.wantErr { t.Errorf("Policy.MarshalDynamoDBAttributeValue() error = %v, wantErr %v", err, tt.wantErr) return @@ -96,17 +95,17 @@ func TestPolicyTypes_MarshalDynamoDBAttributeValue(t *testing.T) { func TestPolicyType_UnmarshalDynamoDBAttributeValue(t *testing.T) { tests := []struct { name string - av *dynamodb.AttributeValue + av awstypes.AttributeValue want Policy wantErr bool }{ - {"ALLOWLIST", new(dynamodb.AttributeValue).SetN("1"), RulePolicyAllowlist, false}, - {"BLOCKLIST", new(dynamodb.AttributeValue).SetN("2"), RulePolicyBlocklist, false}, - {"SILENT_BLOCKLIST", new(dynamodb.AttributeValue).SetN("3"), RulePolicySilentBlocklist, false}, - {"REMOVE", new(dynamodb.AttributeValue).SetN("4"), RulePolicyRemove, false}, - {"ALLOWLIST_COMPILER", new(dynamodb.AttributeValue).SetN("5"), RulePolicyAllowlistCompiler, false}, - {"ALLOWLIST_TRANSITIVE", new(dynamodb.AttributeValue).SetN("6"), RulePolicyAllowlistTransitive, false}, - {"MISSPELLED", new(dynamodb.AttributeValue), Policy(0), true}, + {"ALLOWLIST", &awstypes.AttributeValueMemberN{Value: "1"}, RulePolicyAllowlist, false}, + {"BLOCKLIST", &awstypes.AttributeValueMemberN{Value: "2"}, RulePolicyBlocklist, false}, + {"SILENT_BLOCKLIST", &awstypes.AttributeValueMemberN{Value: "3"}, RulePolicySilentBlocklist, false}, + {"REMOVE", &awstypes.AttributeValueMemberN{Value: "4"}, RulePolicyRemove, false}, + {"ALLOWLIST_COMPILER", &awstypes.AttributeValueMemberN{Value: "5"}, RulePolicyAllowlistCompiler, false}, + {"ALLOWLIST_TRANSITIVE", &awstypes.AttributeValueMemberN{Value: "6"}, RulePolicyAllowlistTransitive, false}, + {"MISSPELLED", nil, Policy(0), true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -119,5 +118,4 @@ func TestPolicyType_UnmarshalDynamoDBAttributeValue(t *testing.T) { assert.Equal(t, tt.want, got) }) } - } diff --git a/pkg/types/rule_type.go b/pkg/types/rule_type.go index bbce172..a4a514e 100644 --- a/pkg/types/rule_type.go +++ b/pkg/types/rule_type.go @@ -3,8 +3,7 @@ package types import ( "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + awstypes "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) // RuleType represents a Santa rule type. @@ -77,31 +76,32 @@ func (r RuleType) MarshalText() ([]byte, error) { } // MarshalDynamoDBAttributeValue for ddb -func (r RuleType) MarshalDynamoDBAttributeValue(av *dynamodb.AttributeValue) error { +func (r RuleType) MarshalDynamoDBAttributeValue() (awstypes.AttributeValue, error) { var s string switch r { case RuleTypeBinary: - // s = "BINARY" s = "1" case RuleTypeCertificate: - // s = "CERTIFICATE" s = "2" case RuleTypeSigningID: s = "3" case RuleTypeTeamID: s = "4" default: - return fmt.Errorf("unknown rule_type value %q", r) + return nil, fmt.Errorf("unknown rule_type value %q", r) } - // av.S = &s - av.N = &s - return nil + return &awstypes.AttributeValueMemberN{Value: s}, nil } // UnmarshalDynamoDBAttributeValue implements the Unmarshaler interface -func (r *RuleType) UnmarshalDynamoDBAttributeValue(av *dynamodb.AttributeValue) error { - // switch t := aws.StringValue(av.S); t { - switch t := aws.StringValue(av.N); t { +func (r *RuleType) UnmarshalDynamoDBAttributeValue(av awstypes.AttributeValue) error { + // return attributevalue.Unmarshal(av, p) + v, ok := av.(*awstypes.AttributeValueMemberN) + if !ok { + return fmt.Errorf("unexpected rule_type value type %T", av) + } + + switch t := v.Value; t { case "1": fallthrough case "BINARY": diff --git a/pkg/types/rule_type_test.go b/pkg/types/rule_type_test.go index b6d4495..260ec64 100644 --- a/pkg/types/rule_type_test.go +++ b/pkg/types/rule_type_test.go @@ -3,7 +3,7 @@ package types import ( "testing" - "github.com/aws/aws-sdk-go/service/dynamodb" + awstypes "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/stretchr/testify/assert" ) @@ -62,22 +62,21 @@ func TestRuleType_UnmarshalText(t *testing.T) { func TestRuleType_MarshalDynamoDBAttributeValue(t *testing.T) { tests := []struct { - name string - rule RuleType - want *dynamodb.AttributeValue - wantErr bool + name string + ruleType RuleType + want awstypes.AttributeValue + wantErr bool }{ - {"BINARY", RuleTypeBinary, new(dynamodb.AttributeValue).SetN("1"), false}, - {"CERTIFICATE", RuleTypeCertificate, new(dynamodb.AttributeValue).SetN("2"), false}, - {"SIGNINGID", RuleTypeSigningID, new(dynamodb.AttributeValue).SetN("3"), false}, - {"TEAMID", RuleTypeTeamID, new(dynamodb.AttributeValue).SetN("4"), false}, - {"INVALID", RuleType(0), new(dynamodb.AttributeValue), true}, + {"BINARY", RuleTypeBinary, &awstypes.AttributeValueMemberN{Value: "1"}, false}, + {"CERTIFICATE", RuleTypeCertificate, &awstypes.AttributeValueMemberN{Value: "2"}, false}, + {"SIGNINGID", RuleTypeSigningID, &awstypes.AttributeValueMemberN{Value: "3"}, false}, + {"TEAMID", RuleTypeTeamID, &awstypes.AttributeValueMemberN{Value: "4"}, false}, + {"INVALID", RuleType(0), nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - av := &dynamodb.AttributeValue{} - err := tt.rule.MarshalDynamoDBAttributeValue(av) + av, err := tt.ruleType.MarshalDynamoDBAttributeValue() if (err != nil) != tt.wantErr { t.Errorf("RuleType.MarshalDynamoDBAttributeValue() error = %v, wantErr %v", err, tt.wantErr) return @@ -90,15 +89,15 @@ func TestRuleType_MarshalDynamoDBAttributeValue(t *testing.T) { func TestRuleType_UnmarshalDynamoDBAttributeValue(t *testing.T) { tests := []struct { name string - av *dynamodb.AttributeValue + av awstypes.AttributeValue want RuleType wantErr bool }{ - {"BINARY", new(dynamodb.AttributeValue).SetN("1"), RuleTypeBinary, false}, - {"CERTIFICATE", new(dynamodb.AttributeValue).SetN("2"), RuleTypeCertificate, false}, - {"SIGNINGID", new(dynamodb.AttributeValue).SetN("3"), RuleTypeSigningID, false}, - {"TEAMID", new(dynamodb.AttributeValue).SetN("4"), RuleTypeTeamID, false}, - {"INVALID", new(dynamodb.AttributeValue).SetN("0"), RuleType(0), true}, + {"BINARY", &awstypes.AttributeValueMemberN{Value: "1"}, RuleTypeBinary, false}, + {"CERTIFICATE", &awstypes.AttributeValueMemberN{Value: "2"}, RuleTypeCertificate, false}, + {"SIGNINGID", &awstypes.AttributeValueMemberN{Value: "3"}, RuleTypeSigningID, false}, + {"TEAMID", &awstypes.AttributeValueMemberN{Value: "4"}, RuleTypeTeamID, false}, + {"INVALID", nil, RuleType(0), true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {