From c400befab1ba3467dc24ad19859f2ad3e05f39ce Mon Sep 17 00:00:00 2001 From: Aaron Peterson Date: Sat, 16 Oct 2021 11:04:59 -0700 Subject: [PATCH] Refactor worker dependencies on server (#919) * Initial refactor pass * Merge CronAnalysisTask changes from HEAD * rename deserialize -> task_deserialize * remove task imports from client * import psq as needed * remove --run_local and refactor TASK_MAP * Move monitoring setup to own method * Fix bad merge * Do not overwrite output_manager * clean-up * Manually merge update and address review comments --- README.md | 5 - turbinia/client.py | 119 +---------------- turbinia/lib/recipe_helpers.py | 8 +- turbinia/task_manager.py | 38 +----- turbinia/task_utils.py | 211 +++++++++++++++++++++++++++++++ turbinia/turbiniactl.py | 42 +----- turbinia/worker.py | 123 +++++++++--------- turbinia/workers/__init__.py | 50 ++------ turbinia/workers/workers_test.py | 6 +- 9 files changed, 306 insertions(+), 296 deletions(-) create mode 100644 turbinia/task_utils.py diff --git a/README.md b/README.md index dcdae101e..df1ebf30a 100644 --- a/README.md +++ b/README.md @@ -96,9 +96,6 @@ optional arguments: Log file -r REQUEST_ID, --request_id REQUEST_ID Create new requests with this Request ID - -R, --run_local Run completely locally without any server or other - infrastructure. This can be used to run one-off Tasks to - process data locally. -S, --server Run Turbinia Server indefinitely -V, --version Show the version -D, --dump_json Dump JSON output of Turbinia Request instead of sending @@ -125,8 +122,6 @@ optional arguments: -p POLL_INTERVAL, --poll_interval POLL_INTERVAL Number of seconds to wait between polling for task state info - -t TASK, --task TASK The name of a single Task to run locally (must be used - with --run_local. -T, --debug_tasks Show debug output for all supported tasks -w, --wait Wait to exit until all tasks for the given request have completed diff --git a/turbinia/client.py b/turbinia/client.py index 03d12b6f0..e96d5aff0 100644 --- a/turbinia/client.py +++ b/turbinia/client.py @@ -37,80 +37,16 @@ from turbinia.config import logger from turbinia.config import DATETIME_FORMAT from turbinia import task_manager +from turbinia import task_utils from turbinia import TurbiniaException from turbinia.lib import text_formatter as fmt from turbinia.lib import docker_manager from turbinia.jobs import manager as job_manager from turbinia.workers import Priority -from turbinia.workers.artifact import FileArtifactExtractionTask -from turbinia.workers.analysis.wordpress_access import WordpressAccessLogAnalysisTask -from turbinia.workers.analysis.wordpress_creds import WordpressCredsAnalysisTask -from turbinia.workers.analysis.jenkins import JenkinsAnalysisTask -from turbinia.workers.analysis.jupyter import JupyterAnalysisTask -from turbinia.workers.analysis.linux_acct import LinuxAccountAnalysisTask -from turbinia.workers.analysis.loki import LokiAnalysisTask -from turbinia.workers.analysis.windows_acct import WindowsAccountAnalysisTask -from turbinia.workers.finalize_request import FinalizeRequestTask -from turbinia.workers.cron import CronAnalysisTask -from turbinia.workers.dfdewey import DfdeweyTask -from turbinia.workers.docker import DockerContainersEnumerationTask -from turbinia.workers.grep import GrepTask -from turbinia.workers.fsstat import FsstatTask -from turbinia.workers.hadoop import HadoopAnalysisTask -from turbinia.workers.hindsight import HindsightTask -from turbinia.workers.partitions import PartitionEnumerationTask -from turbinia.workers.plaso import PlasoTask -from turbinia.workers.psort import PsortTask -from turbinia.workers.redis import RedisAnalysisTask -from turbinia.workers.sshd import SSHDAnalysisTask -from turbinia.workers.strings import StringsAsciiTask -from turbinia.workers.strings import StringsUnicodeTask -from turbinia.workers.tomcat import TomcatAnalysisTask -from turbinia.workers.volatility import VolatilityTask -from turbinia.workers.worker_stat import StatTask -from turbinia.workers.binary_extractor import BinaryExtractorTask -from turbinia.workers.bulk_extractor import BulkExtractorTask -from turbinia.workers.photorec import PhotorecTask -from turbinia.workers.abort import AbortTask MAX_RETRIES = 10 RETRY_SLEEP = 60 -# TODO(aarontp): Remove this map after -# https://github.com/google/turbinia/issues/278 is fixed. -TASK_MAP = { - 'fileartifactextractiontask': FileArtifactExtractionTask, - 'wordpressaccessloganalysistask': WordpressAccessLogAnalysisTask, - 'wordpresscredsanalysistask': WordpressCredsAnalysisTask, - 'finalizerequesttask': FinalizeRequestTask, - 'jenkinsanalysistask': JenkinsAnalysisTask, - 'jupyteranalysistask': JupyterAnalysisTask, - 'greptask': GrepTask, - 'fsstattask': FsstatTask, - 'hadoopanalysistask': HadoopAnalysisTask, - 'hindsighttask': HindsightTask, - 'linuxaccountanalysistask': LinuxAccountAnalysisTask, - 'windowsaccountanalysistask': WindowsAccountAnalysisTask, - 'lokianalysistask': LokiAnalysisTask, - 'partitionenumerationtask': PartitionEnumerationTask, - 'plasotask': PlasoTask, - 'psorttask': PsortTask, - 'redisanalysistask': RedisAnalysisTask, - 'sshdanalysistask': SSHDAnalysisTask, - 'stringsasciitask': StringsAsciiTask, - 'stringsunicodetask': StringsUnicodeTask, - 'tomcatanalysistask': TomcatAnalysisTask, - 'volatilitytask': VolatilityTask, - 'stattask': StatTask, - 'binaryextractortask': BinaryExtractorTask, - 'bulkextractortask': BulkExtractorTask, - 'dockercontainersenumerationtask': DockerContainersEnumerationTask, - 'photorectask': PhotorecTask, - 'aborttask': AbortTask, - 'cronanalysistask': CronAnalysisTask, - 'dfdeweytask': DfdeweyTask -} - config.LoadConfig() if config.TASK_MANAGER.lower() == 'psq': from libcloudforensics.providers.gcp.internal import function as gcp_function @@ -128,7 +64,7 @@ def setup(is_client=False): logger.setup() -def get_turbinia_client(run_local=False): +def get_turbinia_client(): """Return Turbinia client based on config. Returns: @@ -137,9 +73,9 @@ def get_turbinia_client(run_local=False): # pylint: disable=no-else-return setup(is_client=True) if config.TASK_MANAGER.lower() == 'psq': - return BaseTurbiniaClient(run_local=run_local) + return BaseTurbiniaClient() elif config.TASK_MANAGER.lower() == 'celery': - return TurbiniaCeleryClient(run_local=run_local) + return TurbiniaCeleryClient() else: msg = 'Task Manager type "{0:s}" not implemented'.format( config.TASK_MANAGER) @@ -225,31 +161,10 @@ class BaseTurbiniaClient: task_manager (TaskManager): Turbinia task manager """ - def __init__(self, run_local=False): + def __init__(self): config.LoadConfig() - if run_local: - self.task_manager = None - else: - self.task_manager = task_manager.get_task_manager() - self.task_manager.setup(server=False) - - def create_task(self, task_name): - """Creates a Turbinia Task by name. - - Args: - task_name(string): Name of the Task we are going to run. - - Returns: - TurbiniaTask: An instantiated Task object. - - Raises: - TurbiniaException: When no Task object matching task_name is found. - """ - task_obj = TASK_MAP.get(task_name.lower()) - log.debug('Looking up Task {0:s} by name'.format(task_name)) - if not task_obj: - raise TurbiniaException('No Task named {0:s} found'.format(task_name)) - return task_obj() + self.task_manager = task_manager.get_task_manager() + self.task_manager.setup(server=False) def list_jobs(self): """List the available jobs.""" @@ -958,26 +873,6 @@ def format_task_status( return '\n'.join(report) - def run_local_task(self, task_name, request): - """Runs a Turbinia Task locally. - - Args: - task_name(string): Name of the Task we are going to run. - request (TurbiniaRequest): Object containing request and evidence info. - - Returns: - TurbiniaTaskResult: The result returned by the Task Execution. - """ - task = self.create_task(task_name) - task.request_id = request.request_id - task.base_output_dir = config.OUTPUT_DIR - task.run_local = True - if not request.evidence: - raise TurbiniaException('TurbiniaRequest does not contain evidence.') - log.info('Running Task {0:s} locally'.format(task_name)) - result = task.run_wrapper(request.evidence[0].serialize()) - return result - def send_request(self, request): """Sends a TurbiniaRequest message. diff --git a/turbinia/lib/recipe_helpers.py b/turbinia/lib/recipe_helpers.py index cdfac3239..1db52758d 100644 --- a/turbinia/lib/recipe_helpers.py +++ b/turbinia/lib/recipe_helpers.py @@ -21,6 +21,7 @@ from yaml import load from turbinia.lib.file_helpers import file_to_str from turbinia.lib.file_helpers import file_to_list +from turbinia.task_utils import TaskLoader log = logging.getLogger('turbinia') @@ -156,13 +157,12 @@ def validate_recipe(recipe_dict): return (False, message) proposed_task = recipe_item_contents['task'] - # Doing a delayed import to avoid circular dependencies. - from turbinia.client import TASK_MAP - if proposed_task.lower() not in TASK_MAP: + task_loader = TaskLoader() + if not task_loader.check_task_name(proposed_task): log.error( 'Task {0:s} defined for task recipe {1:s} does not exist.'.format( proposed_task, recipe_item)) return (False, message) tasks_with_recipe.append(recipe_item) - return (True, '') + return (True, '') \ No newline at end of file diff --git a/turbinia/task_manager.py b/turbinia/task_manager.py index 59f2b101b..8d89e12bb 100644 --- a/turbinia/task_manager.py +++ b/turbinia/task_manager.py @@ -28,6 +28,7 @@ from turbinia import evidence from turbinia import config from turbinia import state_manager +from turbinia import task_utils from turbinia import TurbiniaException from turbinia.jobs import manager as jobs_manager from turbinia.lib import recipe_helpers @@ -86,38 +87,6 @@ def get_task_manager(): raise turbinia.TurbiniaException(msg) -def task_runner(obj, *args, **kwargs): - """Wrapper function to run specified TurbiniaTask object. - - Args: - obj: An instantiated TurbiniaTask object. - *args: Any Args to pass to obj. - **kwargs: Any keyword args to pass to obj. - - Returns: - Output from TurbiniaTask (should be TurbiniaTaskResult). - """ - - # GKE Specific - do not queue more work if pod places this file - if os.path.exists(config.SCALEDOWN_WORKER_FILE): - raise psq.Retry() - - # try to acquire lock, timeout and requeue task if the worker - # is already processing a task. - try: - lock = filelock.FileLock(config.LOCK_FILE) - with lock.acquire(timeout=0.001): - obj = workers.TurbiniaTask.deserialize(obj) - run = obj.run_wrapper(*args, **kwargs) - except filelock.Timeout: - raise psq.Retry() - finally: - # *always* make sure we release the lock - lock.release() - - return run - - class BaseTaskManager: """Class to manage Turbinia Tasks. @@ -600,7 +569,8 @@ def _backend_setup(self, *args, **kwargs): self.celery.setup() self.kombu = turbinia_celery.TurbiniaKombu(config.KOMBU_CHANNEL) self.kombu.setup() - self.celery_runner = self.celery.app.task(task_runner, name="task_runner") + self.celery_runner = self.celery.app.task( + task_utils.task_runner, name="task_runner") def process_tasks(self): """Determine the current state of our tasks. @@ -761,5 +731,5 @@ def enqueue_task(self, task, evidence_): 'Adding PSQ task {0:s} with evidence {1:s} to queue'.format( task.name, evidence_.name)) task.stub = self.psq.enqueue( - task_runner, task.serialize(), evidence_.serialize()) + task_utils.task_runner, task.serialize(), evidence_.serialize()) time.sleep(PSQ_QUEUE_WAIT_SECONDS) diff --git a/turbinia/task_utils.py b/turbinia/task_utils.py new file mode 100644 index 000000000..5910afadf --- /dev/null +++ b/turbinia/task_utils.py @@ -0,0 +1,211 @@ +#-*- coding: utf-8 -*- +# Copyright 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Task runner for Turbinia.""" + +from datetime import datetime +import logging +import os +import sys + +import filelock + +import turbinia +from turbinia import config +from turbinia.config import DATETIME_FORMAT +from turbinia import TurbiniaException + +log = logging.getLogger('turbinia') + +config.LoadConfig() + + +class TaskLoader(): + """Utility class for handling Task loading/checking/deserialization. + + Attributes: + TASK_LIST(list): A list of all valid Tasks. + """ + + TASK_LIST = [ + 'FileArtifactExtractionTask', + 'WordpressAccessLogAnalysisTask', + 'WordpressCredsAnalysisTask', + 'FinalizeRequestTask', + 'JenkinsAnalysisTask', + 'JupyterAnalysisTask', + 'GrepTask', + 'FsstatTask', + 'HadoopAnalysisTask', + 'HindsightTask', + 'LinuxAccountAnalysisTask', + 'WindowsAccountAnalysisTask', + 'LokiAnalysisTask', + 'PartitionEnumerationTask', + 'PlasoTask', + 'PsortTask', + 'RedisAnalysisTask', + 'SSHDAnalysisTask', + 'StringsAsciiTask', + 'StringsUnicodeTask', + 'TomcatAnalysisTask', + 'VolatilityTask', + 'StatTask', + 'BinaryExtractorTask', + 'BulkExtractorTask', + 'DockerContainersEnumerationTask', + 'PhotorecTask', + 'AbortTask', + 'CronAnalysisTask', + ] + + def check_task_name(self, task_name): + """Checks whether a given task name is a valid task + + Args: + task_name(str): Name of the Task to check. + + Returns: + bool: True if task with the given name exists, else False + """ + for task in TASK_LIST: + if task.lower() == task_name.lower(): + return True + return False + + def get_task(self, task_name): + """Gets an instantiated Task object for the given name. + + Args: + task_name(str): Name of the Task to return. + + Returns: + TurbiniaTask: An instantiated Task object. + """ + # TODO(aarontp): Remove this list after + # https://github.com/google/turbinia/issues/278 is fixed. + # + # Late imports to minimize what loads all Tasks + from turbinia.workers.artifact import FileArtifactExtractionTask + from turbinia.workers.analysis.wordpress_access import WordpressAccessLogAnalysisTask + from turbinia.workers.analysis.wordpress_creds import WordpressCredsAnalysisTask + from turbinia.workers.analysis.jenkins import JenkinsAnalysisTask + from turbinia.workers.analysis.jupyter import JupyterAnalysisTask + from turbinia.workers.analysis.linux_acct import LinuxAccountAnalysisTask + from turbinia.workers.analysis.loki import LokiAnalysisTask + from turbinia.workers.analysis.windows_acct import WindowsAccountAnalysisTask + from turbinia.workers.finalize_request import FinalizeRequestTask + from turbinia.workers.cron import CronAnalysisTask + from turbinia.workers.docker import DockerContainersEnumerationTask + from turbinia.workers.grep import GrepTask + from turbinia.workers.fsstat import FsstatTask + from turbinia.workers.hadoop import HadoopAnalysisTask + from turbinia.workers.hindsight import HindsightTask + from turbinia.workers.partitions import PartitionEnumerationTask + from turbinia.workers.plaso import PlasoTask + from turbinia.workers.psort import PsortTask + from turbinia.workers.redis import RedisAnalysisTask + from turbinia.workers.sshd import SSHDAnalysisTask + from turbinia.workers.strings import StringsAsciiTask + from turbinia.workers.strings import StringsUnicodeTask + from turbinia.workers.tomcat import TomcatAnalysisTask + from turbinia.workers.volatility import VolatilityTask + from turbinia.workers.worker_stat import StatTask + from turbinia.workers.binary_extractor import BinaryExtractorTask + from turbinia.workers.bulk_extractor import BulkExtractorTask + from turbinia.workers.photorec import PhotorecTask + from turbinia.workers.abort import AbortTask + + for task in self.TASK_LIST: + if task.lower() == task_name.lower(): + try: + task_obj = locals()[task] + return task_obj() + except (AttributeError, KeyError): + message = ( + "Could not import {0:s} object! Make sure it is imported where " + "this method is defined.".format(task_name)) + log.error(message) + raise TurbiniaException(message) + + return + + def get_task_names(self): + """Returns a list of Task names. + + Returns: + (list) All Task names. + """ + return self.TASK_LIST + + +def task_deserialize(input_dict): + """Converts an input dictionary back into a TurbiniaTask object. + + Args: + input_dict (dict): TurbiniaTask object dictionary. + + Returns: + TurbiniaTask: Deserialized object. + """ + + type_ = input_dict['name'] + task_loader = TaskLoader() + task = task_loader.get_task(type_) + if not task: + raise TurbiniaException('Could not load Task module {0:s}'.format(type_)) + # Remove serialized output manager because this gets reinstantiated when the + # empty Task is instantiated and we don't want to overwrite it. + input_dict.pop('output_manager') + task.__dict__.update(input_dict) + task.last_update = datetime.strptime( + input_dict['last_update'], DATETIME_FORMAT) + return task + + +def task_runner(obj, *args, **kwargs): + """Wrapper function to run specified TurbiniaTask object. + + Args: + obj: An instantiated TurbiniaTask object. + *args: Any Args to pass to obj. + **kwargs: Any keyword args to pass to obj. + + Returns: + Output from TurbiniaTask (should be TurbiniaTaskResult). + """ + + # GKE Specific - do not queue more work if pod places this file + if config.TASK_MANAGER.lower() == 'psq': + if os.path.exists(config.SCALEDOWN_WORKER_FILE): + # Late import because this is only needed for PSQ + import psq + raise psq.Retry() + + # Try to acquire lock, timeout and requeue task if the worker + # is already processing a task. + try: + lock = filelock.FileLock(config.LOCK_FILE) + with lock.acquire(timeout=0.001): + obj = task_deserialize(obj) + run = obj.run_wrapper(*args, **kwargs) + except filelock.Timeout: + # Late import because this is only needed for PSQ + import psq + raise psq.Retry() + # *Always* make sure we release the lock + finally: + lock.release() + + return run diff --git a/turbinia/turbiniactl.py b/turbinia/turbiniactl.py index 23722ac8b..311d970e4 100644 --- a/turbinia/turbiniactl.py +++ b/turbinia/turbiniactl.py @@ -30,10 +30,11 @@ from turbinia import config from turbinia import TurbiniaException from turbinia.config import logger +from turbinia.lib import recipe_helpers from turbinia import __version__ -from turbinia.processors import archive from turbinia.output_manager import OutputManager from turbinia.output_manager import GCSOutputWriter +from turbinia.processors import archive log = logging.getLogger('turbinia') # We set up the logger first without the file handler, and we will set up the @@ -98,10 +99,6 @@ def main(): parser.add_argument( '-r', '--request_id', help='Create new requests with this Request ID', required=False) - parser.add_argument( - '-R', '--run_local', action='store_true', - help='Run completely locally without any server or other infrastructure. ' - 'This can be used to run one-off Tasks to process data locally.') parser.add_argument( '-S', '--server', action='store_true', help='Run Turbinia Server indefinitely') @@ -133,10 +130,6 @@ def main(): parser.add_argument( '-p', '--poll_interval', default=60, type=int, help='Number of seconds to wait between polling for task state info') - parser.add_argument( - '-t', '--task', - help='The name of a single Task to run locally (must be used with ' - '--run_local).') parser.add_argument( '-T', '--debug_tasks', action='store_true', help='Show debug output for all supported tasks', default=False) @@ -387,11 +380,6 @@ def main(): args = parser.parse_args() - # (jorlamd): Importing recipe_helpers late to avoid a bug where - # client.TASK_MAP is imported early rendering the check for worker - # status not possible. - from turbinia.lib import recipe_helpers - # Load the config before final logger setup so we can the find the path to the # log file. try: @@ -512,16 +500,7 @@ def main(): # Create Client object client = None if args.command not in ('psqworker', 'server'): - client = TurbiniaClientProvider.get_turbinia_client(args.run_local) - - # Make sure run_local flags aren't conflicting with other server/client flags - if args.run_local and (server_flags_set or worker_flags_set): - log.error('--run_local flag is not compatible with server/worker flags') - sys.exit(1) - - if args.run_local and not args.task: - log.error('--run_local flag requires --task flag') - sys.exit(1) + client = TurbiniaClientProvider.get_turbinia_client() # Set zone/project to defaults if flags are not set, and also copy remote # disk if needed. @@ -866,10 +845,7 @@ def main(): log.info( 'Run command "turbiniactl status -r {0:s}" to see the status of' ' this request and associated tasks'.format(request.request_id)) - if not args.run_local: - client.send_request(request) - else: - log.debug('--run_local specified so not sending request to server') + client.send_request(request) if args.wait: log.info( @@ -885,16 +861,6 @@ def main(): region=region, request_id=request.request_id, all_fields=args.all_fields)) - if args.run_local and not evidence_: - log.error('Evidence must be specified if using --run_local') - sys.exit(1) - if args.run_local and evidence_.cloud_only: - log.error('--run_local cannot be used with Cloud only Evidence types') - sys.exit(1) - if args.run_local and evidence_: - result = client.run_local_task(args.task, request) - log.info('Task execution result: {0:s}'.format(result)) - log.info('Done.') sys.exit(0) diff --git a/turbinia/worker.py b/turbinia/worker.py index 86ebda7ed..4134927dc 100644 --- a/turbinia/worker.py +++ b/turbinia/worker.py @@ -24,19 +24,27 @@ from prometheus_client import start_http_server from turbinia import config from turbinia.config import logger -from turbinia.client import BaseTurbiniaClient -from turbinia import task_manager +from turbinia import task_utils from turbinia import TurbiniaException from turbinia.lib import docker_manager from turbinia.jobs import manager as job_manager +from turbinia.tcelery import TurbiniaCelery config.LoadConfig() -if config.TASK_MANAGER.lower() == 'psq': +task_manager_type = config.TASK_MANAGER.lower() +if task_manager_type == 'psq': import psq from google.cloud import exceptions from google.cloud import datastore from google.cloud import pubsub +elif task_manager_type == 'celery': + from celery import states as celery_states + from turbinia import tcelery as turbinia_celery +else: + raise TurbiniaException( + 'Unknown task manager {0:s} found, please update config to use "psq" or ' + '"celery"'.format(task_manager_type)) log = logging.getLogger('turbinia') @@ -180,21 +188,16 @@ def register_job_timeouts(dependencies): job_manager.JobsManager.RegisterTimeout(job, timeout) -class TurbiniaCeleryWorker(BaseTurbiniaClient): - """Turbinia Celery Worker class. - - Attributes: - worker (celery.app): Celery worker app - """ +class TurbiniaWorkerBase: + """Base class for Turibinia Workers.""" def __init__(self, jobs_denylist=None, jobs_allowlist=None): - """Initialization for celery worker. + """Initialization for Turbinia Worker. Args: jobs_denylist (Optional[list[str]]): Jobs we will exclude from running jobs_allowlist (Optional[list[str]]): The only Jobs we will include to run """ - super(TurbiniaCeleryWorker, self).__init__() setup() # Deregister jobs from denylist/allowlist. job_manager.JobsManager.DeregisterJobs(jobs_denylist, jobs_allowlist) @@ -222,15 +225,15 @@ def __init__(self, jobs_denylist=None, jobs_allowlist=None): check_directory(config.MOUNT_DIR_PREFIX) check_directory(config.OUTPUT_DIR) check_directory(config.TMP_DIR) + register_job_timeouts(dependencies) jobs = job_manager.JobsManager.GetJobNames() log.info( - 'Dependency check complete. The following jobs will be enabled ' + 'Dependency check complete. The following jobs are enabled ' 'for this worker: {0:s}'.format(','.join(jobs))) - self.worker = self.task_manager.celery.app - def start(self): - """Start Turbinia Celery Worker.""" + def _monitoring_setup(self): + """Sets up monitoring server.""" if config.PROMETHEUS_ENABLED: if config.PROMETHEUS_PORT and config.PROMETHEUS_ADDR: log.info('Starting Prometheus endpoint.') @@ -238,13 +241,45 @@ def start(self): port=config.PROMETHEUS_PORT, addr=config.PROMETHEUS_ADDR) else: log.info('Prometheus enabled but port or address not set!') + + def _backend_setup(self): + """Sets up the required backend dependencies for the worker""" + raise NotImplementedError + + def start(self): + """Start Turbinia Worker.""" + raise NotImplementedError + + +class TurbiniaCeleryWorker(TurbiniaWorkerBase): + """Turbinia Celery Worker class. + + Attributes: + worker (celery.app): Celery worker app + celery (TurbiniaCelery): Turbinia Celery object + """ + + def __init__(self, *args, **kwargs): + super(TurbiniaCeleryWorker, self).__init__(*args, **kwargs) + self.worker = None + self.celery = None + + def _backend_setup(self): + self.celery = turbinia_celery.TurbiniaCelery() + self.celery.setup() + self.worker = self.celery.app + + def start(self): + """Start Turbinia Celery Worker.""" log.info('Running Turbinia Celery Worker.') - self.worker.task(task_manager.task_runner, name='task_runner') + self._monitoring_setup() + self._backend_setup() + self.worker.task(task_utils.task_runner, name='task_runner') argv = ['celery', 'worker', '--loglevel=info', '--pool=solo'] self.worker.start(argv) -class TurbiniaPsqWorker: +class TurbiniaPsqWorker(TurbiniaWorkerBase): """Turbinia PSQ Worker class. Attributes: @@ -255,14 +290,12 @@ class TurbiniaPsqWorker: TurbiniaException: When errors occur """ - def __init__(self, jobs_denylist=None, jobs_allowlist=None): - """Initialization for PSQ Worker. + def __init__(self, *args, **kwargs): + super(TurbiniaPsqWorker, self).__init__(*args, **kwargs) + self.worker = None + self.psq = None - Args: - jobs_denylist (Optional[list[str]]): Jobs we will exclude from running - jobs_allowlist (Optional[list[str]]): The only Jobs we will include to run - """ - setup() + def _backend_setup(self): psq_publisher = pubsub.PublisherClient() psq_subscriber = pubsub.SubscriberClient() datastore_client = datastore.Client(project=config.TURBINIA_PROJECT) @@ -274,50 +307,12 @@ def __init__(self, jobs_denylist=None, jobs_allowlist=None): msg = 'Error creating PSQ Queue: {0:s}'.format(str(e)) log.error(msg) raise TurbiniaException(msg) - - # Deregister jobs from denylist/allowlist. - job_manager.JobsManager.DeregisterJobs(jobs_denylist, jobs_allowlist) - disabled_jobs = list(config.DISABLED_JOBS) if config.DISABLED_JOBS else [] - disabled_jobs = [j.lower() for j in disabled_jobs] - # Only actually disable jobs that have not been allowlisted. - if jobs_allowlist: - disabled_jobs = list(set(disabled_jobs) - set(jobs_allowlist)) - if disabled_jobs: - log.info( - 'Disabling non-allowlisted jobs configured to be disabled in the ' - 'config file: {0:s}'.format(', '.join(disabled_jobs))) - job_manager.JobsManager.DeregisterJobs(jobs_denylist=disabled_jobs) - - # Check for valid dependencies/directories. - dependencies = config.ParseDependencies() - if config.DOCKER_ENABLED: - try: - check_docker_dependencies(dependencies) - except TurbiniaException as e: - log.warning( - "DOCKER_ENABLED=True is set in the config, but there is an error checking for the docker daemon: {0:s}" - ).format(str(e)) - check_system_dependencies(dependencies) - check_directory(config.MOUNT_DIR_PREFIX) - check_directory(config.OUTPUT_DIR) - check_directory(config.TMP_DIR) - register_job_timeouts(dependencies) - - jobs = job_manager.JobsManager.GetJobNames() - log.info( - 'Dependency check complete. The following jobs are enabled ' - 'for this worker: {0:s}'.format(','.join(jobs))) log.info('Starting PSQ listener on queue {0:s}'.format(self.psq.name)) self.worker = psq.Worker(queue=self.psq) def start(self): """Start Turbinia PSQ Worker.""" - if config.PROMETHEUS_ENABLED: - if config.PROMETHEUS_PORT and config.PROMETHEUS_ADDR: - log.info('Starting Prometheus endpoint.') - start_http_server( - port=config.PROMETHEUS_PORT, addr=config.PROMETHEUS_ADDR) - else: - log.info('Prometheus enabled but port or address not set!') log.info('Running Turbinia PSQ Worker.') + self._monitoring_setup() + self._backend_setup() self.worker.listen() \ No newline at end of file diff --git a/turbinia/workers/__init__.py b/turbinia/workers/__init__.py index 94ea7f384..7eaa227d4 100644 --- a/turbinia/workers/__init__.py +++ b/turbinia/workers/__init__.py @@ -41,6 +41,7 @@ from turbinia.processors import resource_manager from turbinia import output_manager from turbinia import state_manager +from turbinia import task_utils from turbinia import TurbiniaException from turbinia import log_and_report from turbinia.lib import docker_manager @@ -206,7 +207,7 @@ def close(self, task, success, status=None): if evidence.source_path: if os.path.exists(evidence.source_path): self.saved_paths.append(evidence.source_path) - if not task.run_local and evidence.copyable: + if evidence.copyable: task.output_manager.save_evidence(evidence, self) else: self.log( @@ -256,8 +257,7 @@ def close(self, task, success, status=None): with open(logfile, 'w') as f: f.write('\n'.join(self._log)) f.write('\n') - if not task.run_local: - task.output_manager.save_local_file(logfile, self) + task.output_manager.save_local_file(logfile, self) self.closed = True log.debug('Result close successful. Status is [{0:s}]'.format(self.status)) @@ -396,7 +396,6 @@ class TurbiniaTask: output_manager (OutputManager): The object that manages saving output. result (TurbiniaTaskResult): A TurbiniaTaskResult object. request_id (str): The id of the initial request to process this evidence. - run_local (bool): Whether we are running locally without a Worker or not. state_key (str): A key used to manage task state stub (psq.task.TaskResult|celery.app.Task): The task manager implementation specific task stub that exists server side to keep a @@ -442,7 +441,6 @@ def __init__( self.output_manager = output_manager.OutputManager() self.result = None self.request_id = request_id - self.run_local = False self.state_key = None self.stub = None self.tmp_dir = None @@ -473,23 +471,7 @@ def deserialize(cls, input_dict): Returns: TurbiniaTask: Deserialized object. """ - from turbinia import client # Avoid circular imports - - type_ = input_dict['name'] - try: - task = getattr(sys.modules['turbinia.client'], type_)() - except AttributeError: - message = ( - "Could not import {0:s} object! Make sure it is imported where " - "this method is defined.".format(type_)) - log.error(message) - raise TurbiniaException(message) - task.__dict__.update(input_dict) - task.output_manager = output_manager.OutputManager() - task.output_manager.__dict__.update(input_dict['output_manager']) - task.last_update = datetime.strptime( - input_dict['last_update'], DATETIME_FORMAT) - return task + return task_utils.task_deserialize(input_dict) @classmethod def check_worker_role(cls): @@ -694,8 +676,7 @@ def execute( level=logging.DEBUG) continue result.log('Output log file found at {0:s}'.format(file_)) - if not self.run_local: - self.output_manager.save_local_file(file_, result) + self.output_manager.save_local_file(file_, result) if ret not in success_codes: message = 'Execution of [{0!s}] failed with status {1:d}'.format(cmd, ret) @@ -711,8 +692,7 @@ def execute( level=logging.DEBUG) continue result.log('Output save file at {0:s}'.format(file_)) - if not self.run_local: - self.output_manager.save_local_file(file_, result) + self.output_manager.save_local_file(file_, result) for evidence in new_evidence: # If the local path is set in the Evidence, we check to make sure that @@ -758,9 +738,8 @@ def setup(self, evidence): if not self.result: self.result = self.create_result(input_evidence=evidence) - if not self.run_local: - if evidence.copyable and not config.SHARED_FILESYSTEM: - self.output_manager.retrieve_evidence(evidence) + if evidence.copyable and not config.SHARED_FILESYSTEM: + self.output_manager.retrieve_evidence(evidence) if evidence.source_path and not os.path.exists(evidence.source_path): raise TurbiniaException( @@ -768,13 +747,13 @@ def setup(self, evidence): evidence.source_path)) return self.result - def setup_metrics(self, task_map=None): + def setup_metrics(self, task_list=None): """Sets up the application metrics. Returns early with metrics if they are already setup. Arguments: - task_map(dict): Map of task names to task objects + task_list(list): List of Task names Returns: Dict: Mapping of task names to metrics objects. @@ -784,12 +763,11 @@ def setup_metrics(self, task_map=None): if METRICS: return METRICS - if not task_map: - # Late import to avoid circular dependencies - from turbinia.client import TASK_MAP - task_map = TASK_MAP + if not task_list: + task_loader = task_utils.TaskLoader() + task_list = task_loader.get_task_names() - for task_name in task_map: + for task_name in task_list: task_name = task_name.lower() if task_name in METRICS: continue diff --git a/turbinia/workers/workers_test.py b/turbinia/workers/workers_test.py index aa177f208..fe1d5b2e4 100644 --- a/turbinia/workers/workers_test.py +++ b/turbinia/workers/workers_test.py @@ -343,10 +343,10 @@ def testTurbiniaTaskExecuteEvidenceExistsButEmpty(self, popen_mock): @mock.patch('turbinia.workers.Histogram') def testTurbiniaSetupMetrics(self, mock_histogram): """Tests that metrics are set up correctly.""" - mock_task_map = {'TestTask1': None, 'TestTask2': None} + mock_task_list = {'TestTask1', 'TestTask2'} mock_histogram.return_value = "test_metrics" - metrics = self.task.setup_metrics(task_map=mock_task_map) - self.assertEqual(len(metrics), len(mock_task_map)) + metrics = self.task.setup_metrics(task_list=mock_task_list) + self.assertEqual(len(metrics), len(mock_task_list)) self.assertEqual(metrics['testtask1'], 'test_metrics') self.assertIn('testtask1', metrics)