Skip to content

Commit

Permalink
Replace decisionListsMutex with RWMutex
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremy5189 committed Sep 24, 2024
1 parent 4115911 commit 1ab030d
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 36 deletions.
2 changes: 1 addition & 1 deletion banjax.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func main() {
var passwordProtectedPaths internal.PasswordProtectedPaths

// XXX protects decisionLists
var decisionListsMutex sync.Mutex
var decisionListsMutex sync.RWMutex
var decisionLists internal.DecisionLists

standaloneTestingPtr := flag.Bool("standalone-testing", false, "makes it easy to test standalone")
Expand Down
14 changes: 7 additions & 7 deletions internal/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ func checkExpiringDecisionListsByDomain(domain string, decisionLists *DecisionLi

// XXX mmm could hold the lock for a while?
func RemoveExpiredDecisions(
decisionListsMutex *sync.Mutex,
decisionListsMutex *sync.RWMutex,
decisionLists *DecisionLists,
) {
decisionListsMutex.Lock()
Expand All @@ -466,7 +466,7 @@ func RemoveExpiredDecisions(
}

func removeExpiredDecisionsByIp(
decisionListsMutex *sync.Mutex,
decisionListsMutex *sync.RWMutex,
decisionLists *DecisionLists,
ip string,
) {
Expand All @@ -480,7 +480,7 @@ func removeExpiredDecisionsByIp(
func updateExpiringDecisionLists(
config *Config,
ip string,
decisionListsMutex *sync.Mutex,
decisionListsMutex *sync.RWMutex,
decisionLists *DecisionLists,
expires time.Time,
newDecision Decision,
Expand Down Expand Up @@ -514,7 +514,7 @@ func updateExpiringDecisionListsSessionId(
config *Config,
ip string,
sessionId string,
decisionListsMutex *sync.Mutex,
decisionListsMutex *sync.RWMutex,
decisionLists *DecisionLists,
expires time.Time,
newDecision Decision,
Expand Down Expand Up @@ -550,14 +550,14 @@ type MetricsLogLine struct {

func WriteMetricsToEncoder(
metricsLogEncoder *json.Encoder,
decisionListsMutex *sync.Mutex,
decisionListsMutex *sync.RWMutex,
decisionLists *DecisionLists,
rateLimitMutex *sync.Mutex,
ipToRegexStates *IpToRegexStates,
failedChallengeStates *FailedChallengeStates,
) {
decisionListsMutex.Lock()
defer decisionListsMutex.Unlock()
decisionListsMutex.RLock()
defer decisionListsMutex.RUnlock()

lenExpiringChallenges := 0
lenExpiringBlocks := 0
Expand Down
30 changes: 15 additions & 15 deletions internal/http_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ const (

func RunHttpServer(
config *Config,
decisionListsMutex *sync.Mutex,
decisionListsMutex *sync.RWMutex,
decisionLists *DecisionLists,
passwordProtectedPaths *PasswordProtectedPaths,
rateLimitMutex *sync.Mutex,
Expand Down Expand Up @@ -487,7 +487,7 @@ func tooManyFailedChallenges(
rateLimitMutex *sync.Mutex,
failedChallengeStates *FailedChallengeStates,
method string,
decisionListsMutex *sync.Mutex,
decisionListsMutex *sync.RWMutex,
decisionLists *DecisionLists,
) (tooManyFailedChallengesResult TooManyFailedChallengesResult) {
rateLimitMutex.Lock()
Expand Down Expand Up @@ -591,7 +591,7 @@ func sendOrValidateShaChallenge(
rateLimitMutex *sync.Mutex,
failedChallengeStates *FailedChallengeStates,
failAction FailAction,
decisionListsMutex *sync.Mutex,
decisionListsMutex *sync.RWMutex,
decisionLists *DecisionLists,
) (sendOrValidateShaChallengeResult SendOrValidateShaChallengeResult) {
clientIp := c.Request.Header.Get("X-Client-IP")
Expand Down Expand Up @@ -693,7 +693,7 @@ func sendOrValidatePassword(
banner BannerInterface,
rateLimitMutex *sync.Mutex,
failedChallengeStates *FailedChallengeStates,
decisionListsMutex *sync.Mutex,
decisionListsMutex *sync.RWMutex,
decisionLists *DecisionLists,
) (sendOrValidatePasswordResult SendOrValidatePasswordResult) {
clientIp := c.Request.Header.Get("X-Client-IP")
Expand Down Expand Up @@ -833,7 +833,7 @@ type DecisionForNginxResult struct {

func decisionForNginx(
config *Config,
decisionListsMutex *sync.Mutex,
decisionListsMutex *sync.RWMutex,
decisionLists *DecisionLists,
passwordProtectedPaths *PasswordProtectedPaths,
rateLimitMutex *sync.Mutex,
Expand Down Expand Up @@ -868,16 +868,16 @@ func decisionForNginx(

func checkPerSiteDecisionLists(
config *Config,
decisionListsMutex *sync.Mutex,
decisionListsMutex *sync.RWMutex,
decisionLists *DecisionLists,
requestedHost string,
clientIp string,
) (bool, Decision) {
// XXX ugh this locking is awful
// i got bit by just checking against the zero value here, which is a valid iota enum
decisionListsMutex.Lock()
decisionListsMutex.RLock()
decision, ok := (*decisionLists).PerSiteDecisionLists[requestedHost][clientIp]
decisionListsMutex.Unlock()
decisionListsMutex.RUnlock()

// found as plain IP form, no need to check IPFilter
if ok {
Expand Down Expand Up @@ -907,7 +907,7 @@ func checkPerSiteDecisionLists(
func decisionForNginx2(
c *gin.Context,
config *Config,
decisionListsMutex *sync.Mutex,
decisionListsMutex *sync.RWMutex,
decisionLists *DecisionLists,
passwordProtectedPaths *PasswordProtectedPaths,
rateLimitMutex *sync.Mutex,
Expand Down Expand Up @@ -1020,9 +1020,9 @@ func decisionForNginx2(
}
}

decisionListsMutex.Lock()
decisionListsMutex.RLock()
decision, ok = (*decisionLists).GlobalDecisionLists[clientIp]
decisionListsMutex.Unlock()
decisionListsMutex.RUnlock()
foundInIpFilter := false
if !ok {
for _, iterateDecision := range []Decision{Allow, Challenge, NginxBlock, IptablesBlock} {
Expand Down Expand Up @@ -1074,9 +1074,9 @@ func decisionForNginx2(
// when we insert something into the list, really we might just be extending the expiry time and/or
// changing the decision.
// XXX i forget if that comment is stale^
decisionListsMutex.Lock()
decisionListsMutex.RLock()
expiringDecision, ok := checkExpiringDecisionLists(c, clientIp, decisionLists)
decisionListsMutex.Unlock()
decisionListsMutex.RUnlock()
if !ok {
// log.Println("no mention in expiring lists")
} else {
Expand Down Expand Up @@ -1118,9 +1118,9 @@ func decisionForNginx2(

// the legacy banjax_sha_inv and user_banjax_sha_inv
// difference is one blocks after many failures and the other doesn't
decisionListsMutex.Lock()
decisionListsMutex.RLock()
failAction, ok := (*decisionLists).SitewideShaInvList[requestedHost]
decisionListsMutex.Unlock()
decisionListsMutex.RUnlock()
if !ok {
// log.Println("no mention in sitewide list")
} else {
Expand Down
2 changes: 1 addition & 1 deletion internal/iptables.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ type BannerInterface interface {
}

type Banner struct {
DecisionListsMutex *sync.Mutex
DecisionListsMutex *sync.RWMutex
DecisionLists *DecisionLists
Logger *log.Logger
LoggerTemp *log.Logger
Expand Down
8 changes: 4 additions & 4 deletions internal/kafka.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func getDialer(config *Config) *kafka.Dialer {

func RunKafkaReader(
config *Config,
decisionListsMutex *sync.Mutex,
decisionListsMutex *sync.RWMutex,
decisionLists *DecisionLists,
wg *sync.WaitGroup,
) {
Expand Down Expand Up @@ -158,7 +158,7 @@ func getBlockSessionTtl(config *Config, host string) (blockSessionTtl int) {
func handleCommand(
config *Config,
command commandMessage,
decisionListsMutex *sync.Mutex,
decisionListsMutex *sync.RWMutex,
decisionLists *DecisionLists,
) {
// exempt a site from baskerville according to config
Expand Down Expand Up @@ -191,7 +191,7 @@ func handleCommand(
func handleIPCommand(
config *Config,
command commandMessage,
decisionListsMutex *sync.Mutex,
decisionListsMutex *sync.RWMutex,
decisionLists *DecisionLists,
decision Decision,
expireDuration int,
Expand Down Expand Up @@ -219,7 +219,7 @@ func handleIPCommand(
func handleSessionCommand(
config *Config,
command commandMessage,
decisionListsMutex *sync.Mutex,
decisionListsMutex *sync.RWMutex,
decisionLists *DecisionLists,
decision Decision,
expireDuration int,
Expand Down
16 changes: 8 additions & 8 deletions internal/regex_rate_limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func RunLogTailer(
banner BannerInterface,
rateLimitMutex *sync.Mutex,
ipToRegexStates *IpToRegexStates,
decisionListsMutex *sync.Mutex,
decisionListsMutex *sync.RWMutex,
decisionLists *DecisionLists,
wg *sync.WaitGroup,
) {
Expand Down Expand Up @@ -120,12 +120,12 @@ func parseTimestamp(timeIpRest []string) (timestamp time.Time, err error) {

func checkIpInGlobalDecisionList(
ipString string,
decisionListsMutex *sync.Mutex,
decisionListsMutex *sync.RWMutex,
decisionLists *DecisionLists,
)(bool) {
// Check if IP is in the global allow list that should be skipped
decisionListsMutex.Lock()
defer decisionListsMutex.Unlock()
decisionListsMutex.RLock()
defer decisionListsMutex.RUnlock()

decision, ok := (*decisionLists).GlobalDecisionLists[ipString]
if (ok && decision == Allow) {
Expand All @@ -146,11 +146,11 @@ func checkIpInGlobalDecisionList(
func checkIpInPerSiteDecisionList(
urlString string,
ipString string,
decisionListsMutex *sync.Mutex,
decisionListsMutex *sync.RWMutex,
decisionLists *DecisionLists,
) (bool) {
decisionListsMutex.Lock()
defer decisionListsMutex.Unlock()
decisionListsMutex.RLock()
defer decisionListsMutex.RUnlock()

decision, ok := (*decisionLists).PerSiteDecisionLists[urlString][ipString]
if (ok && decision == Allow) {
Expand Down Expand Up @@ -181,7 +181,7 @@ func consumeLine(
ipToRegexStates *IpToRegexStates,
banner BannerInterface,
config *Config,
decisionListsMutex *sync.Mutex,
decisionListsMutex *sync.RWMutex,
decisionLists *DecisionLists,
) (consumeLineResult ConsumeLineResult) {

Expand Down

0 comments on commit 1ab030d

Please sign in to comment.