Skip to content

Commit

Permalink
Rawlog fix. (#107)
Browse files Browse the repository at this point in the history
* Dynamic threshold. We use a more aggressive threshold if an incident is being detected. (#100)

* Fix the labeler: save all columng of  request_sets (#101)

* JAVA sdk downgraded to fix s3 issue. Incident detector null fix. Optional low rate attack detection. (#102)

* Rawlog fix. Attack detection fix(F.lit). (#106)

Co-authored-by: Maria Karanasou <[email protected]>
  • Loading branch information
mazhurin and mkaranasou authored Jan 21, 2022
1 parent e48390c commit 9659f4b
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 40 deletions.
2 changes: 1 addition & 1 deletion src/baskerville/models/base_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def get_window(self):
# self.spark_conf.storage_level
# )
if not window_df.rdd.isEmpty():
print(f'# Request sets = {window_df.count()}')
self.logger.info(f'# Request sets = {window_df.count()}')
yield window_df
else:
self.logger.info(f'Empty window df for {str(filter_._jc)}')
Expand Down
61 changes: 25 additions & 36 deletions src/baskerville/models/pipeline_tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def process_data(self):
self.logger.info('No data in to process.')
else:
for window_df in get_window(
self.df, self.time_bucket, self.config.spark.storage_level
df_original, self.time_bucket, self.config.spark.storage_level, self.logger
):
self.df = window_df.repartition(
*self.group_by_cols
Expand Down Expand Up @@ -1245,7 +1245,6 @@ def upsert_feedback_context(self):
def prepare_to_save(self):
try:
success = self.upsert_feedback_context()
self.df.show()
success = True
if success:
# explode submitted feedback first
Expand All @@ -1257,7 +1256,6 @@ def prepare_to_save(self):
F.col('id_fc').alias('sumbitted_context_id'),
F.explode('feedback').alias('feedback')
).cache()
self.df.show()
self.df = self.df.select(
F.col('uuid_organization').alias('top_uuid_organization'),
F.col('id_context').alias('client_id_context'),
Expand Down Expand Up @@ -1771,22 +1769,34 @@ def label_with_id_and_set(metric, self, return_value):

def classify_anomalies(self):
self.logger.info('Anomaly thresholding...')
hosts = self.incident_detector.get_hosts_with_incidents() if self.incident_detector else []
if self.incident_detector:
self.logger.info('Getting hosts with incidents...')
hosts = self.incident_detector.get_hosts_with_incidents()
else:
hosts = []

self.df = self.df.withColumn('threshold',
self.logger.info(f'Number of hosts under attack {len(hosts)}.')

self.df = self.df.withColumn('attack_prediction',
F.when(F.col('target').isin(hosts),
self.config.engine.anomaly_threshold_during_incident).otherwise(
self.config.engine.anomaly_threshold))
F.lit(1)).otherwise(F.lit(0)))

self.logger.info(f'Dynamic thresholds calculation...')
self.df = self.df.withColumn('threshold',
F.when(F.col('target').isin(hosts),
F.lit(self.config.engine.anomaly_threshold_during_incident)).otherwise(
F.lit(self.config.engine.anomaly_threshold)))
self.logger.info(f'Dynamic thresholding...')
self.df = self.df.withColumn(
'prediction',
F.when(F.col('score') > F.col('threshold'), F.lit(1.0)).otherwise(F.lit(0.)))
F.when(F.col('score') > F.col('threshold'), F.lit(1)).otherwise(F.lit(0)))

self.df = self.df.drop('threshold')

def detect_low_rate_attack(self):
if not self.config.engine.low_rate_attack_enabled:
self.df = self.df.withColumn('low_rate_attack', 0.0)
self.logger.info('Skipping low rate attack detection.')
self.df = self.df.withColumn('low_rate_attack', F.lit(0))
return

self.logger.info('Low rate attack detecting...')
Expand All @@ -1795,50 +1805,29 @@ def detect_low_rate_attack(self):
'features',
F.from_json('features', self.low_rate_attack_schema)
)
self.df.select('features').show(1, False)
self.df = self.df.withColumn(
'features.request_total',
F.col('features.request_total').cast(
T.DoubleType()
).alias('features.request_total')
).persist(self.config.spark.storage_level)
)
self.df = self.df.withColumn(
'low_rate_attack',
F.when(self.lra_condition, 1.0).otherwise(0.0)
)

def detect_attack(self):
self.logger.info('Attack detecting...')
self.detect_low_rate_attack()
# return df_attack
return self.df

def updated_df_with_attacks(self, df_attack):
self.df = self.df.join(
df_attack,
on=[df_attack.uuid_request_set == self.df.uuid_request_set],
how='left'
F.when(self.lra_condition, F.lit(1)).otherwise(F.lit(0))
)

def run(self):
if get_dtype_for_col(self.df, 'features') == 'string':
self.logger.info('Unwrapping features from json...')

# this can be true when running the raw log pipeline
self.df = self.df.withColumn(
"features",
F.from_json("features", self.features_schema)
)
self.df = self.df.repartition('target').persist(
self.config.spark.storage_level
)
self.classify_anomalies()
df_attack = self.detect_attack()

# 'attack_prediction' column is not set anymore in this task
self.df = self.df.withColumn('attack_prediction', F.lit(0))

if not df_has_rows(df_attack):
self.updated_df_with_attacks(df_attack)
self.logger.info('No attacks detected...')
self.classify_anomalies()
self.detect_low_rate_attack()

self.df = super().run()
return self.df
Expand Down
2 changes: 2 additions & 0 deletions src/baskerville/spark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ def get_or_create_spark_session(spark_conf):
conf.set('spark.kubernetes.driver.pod.name', os.environ['MY_POD_NAME'])
conf.set('spark.driver.host', os.environ['MY_POD_IP'])
conf.set('spark.driver.port', 20020)
else:
conf.set('spark.sql.codegen.wholeStage', 'false')

spark = SparkSession.builder \
.config(conf=conf) \
Expand Down
6 changes: 3 additions & 3 deletions src/baskerville/spark/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def columns_to_dict(df, col_name, columns_to_gather):
)


def get_window(df, time_bucket: TimeBucket, storage_level: str):
def get_window(df, time_bucket: TimeBucket, storage_level: str, logger):
df = df.withColumn(
'timestamp', F.col('@timestamp').cast('timestamp')
)
Expand All @@ -282,10 +282,10 @@ def get_window(df, time_bucket: TimeBucket, storage_level: str):
)
window_df = df.where(filter_) #.persist(storage_level)
if not window_df.rdd.isEmpty():
print(f'# Request sets = {window_df.count()}')
logger.info(f'# Request sets = {window_df.count()}')
yield window_df
else:
print(f'Empty window df for {str(filter_._jc)}')
logger.info(f'Empty window df for {str(filter_._jc)}')
current_window_start = current_window_start + time_bucket.td
current_end = current_window_start + time_bucket.td
if current_window_start >= stop:
Expand Down

0 comments on commit 9659f4b

Please sign in to comment.