diff --git a/api_client/python/timesketch_api_client/scenario.py b/api_client/python/timesketch_api_client/scenario.py new file mode 100644 index 0000000000..c2e12a62b7 --- /dev/null +++ b/api_client/python/timesketch_api_client/scenario.py @@ -0,0 +1,229 @@ +# Copyright 2024 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Timesketch API client library for working with scenarios.""" + + +import logging + +from . import error +from . import resource + + +logger = logging.getLogger("timesketch_api.scenario") + + +class Scenario(resource.BaseResource): + """Timesketch scenario object. + + Attributes: + id: The ID of the scenario. + api: An instance of TimesketchApi object. + """ + + def __init__(self, sketch_id, scenario_id, api): + """Initializes the Scenario object. + + Args: + scenario_id: Primary key ID of the scenario. + api: An instance of TiscmesketchApi object. + sketch_id: ID of a sketch. + """ + self.id = scenario_id + self.api = api + self.sketch_id = sketch_id + super().__init__( + api=api, resource_uri=f"sketches/{self.sketch_id}/scenarios/{self.id}/" + ) + + @property + def name(self): + """Property that returns the scenario name. + + Returns: + Scenario name as string. + """ + scenario = self.lazyload_data() + return scenario["objects"][0]["name"] + + @property + def scenario_id(self): + """Property that returns the scenario id. + + Returns: + Scenario id as integer. + """ + scenario = self.lazyload_data() + return scenario["objects"][0]["id"] + + @property + def dfiq_id(self): + """Property that returns the dfiq id. + + Returns: + dfiq id as string. + """ + scenario = self.lazyload_data() + return scenario["objects"][0]["dfiq_identifier"] + + @property + def description(self): + """Property that returns the scenario description. + + Returns: + Description as string. + """ + scenario = self.lazyload_data() + return scenario["objects"][0]["description"] + + def to_dict(self): + """Returns a dict representation of the scenario.""" + return self.lazyload_data() + + +class ScenarioTemplateList(resource.BaseResource): + """Timesketch scenario template list. + + Attributes: + api: An instance of TimesketchApi object. + """ + + def __init__(self, api): + """Initializes the ScenarioList object. + + Args: + api: An instance of TimesketchApi object. + """ + self.api = api + super().__init__(api=api, resource_uri="scenarios/") + + def get(self): + """ + Retrieves a list of scenario templates. + + Returns: + list: A list of Scenario tempaltes. + """ + resource_url = f"{self.api.api_root}/scenarios/" + response = self.api.session.get(resource_url) + response_json = error.get_response_json(response, logger) + scenario_objects = response_json.get("objects", []) + return scenario_objects + + +class Question(resource.BaseResource): + """Timesketch question object. + + Attributes: + id: The ID of the question. + api: An instance of TimesketchApi object. + """ + + def __init__(self, sketch_id, question_id, api): + """Initializes the question object. + + Args: + question_id: Primary key ID of the scenario. + api: An instance of TiscmesketchApi object. + sketch_id: ID of a sketch. + """ + self.id = question_id + self.api = api + self.sketch_id = sketch_id + super().__init__( + api=api, resource_uri=f"sketches/{self.sketch_id}/questions/{self.id}/" + ) + + @property + def name(self): + """Property that returns the question name. + + Returns: + Question name as string. + """ + question = self.lazyload_data() + return question["objects"][0]["name"] + + @property + def question_id(self): + """Property that returns the question id. + + Returns: + Question id as integer. + """ + question = self.lazyload_data() + return question["objects"][0]["id"] + + @property + def dfiq_id(self): + """Property that returns the question template id. + + Returns: + Question ID as string. + """ + question = self.lazyload_data() + return question["objects"][0]["dfiq_identifier"] + + @property + def description(self): + """Property that returns the question description. + + Returns: + Question description as string. + """ + question = self.lazyload_data() + return question["objects"][0]["description"] + + @property + def approaches(self): + """Property that returns the question approaches. + + Returns: + Question approaches as list of dict. + """ + question = self.lazyload_data() + return question["objects"][0]["approaches"] + + def to_dict(self): + """Returns a dict representation of the question.""" + return self.lazyload_data() + + +class QuestionTemplateList(resource.BaseResource): + """Timesketch question template list. + + Attributes: + api: An instance of TimesketchApi object. + """ + + def __init__(self, api): + """Initializes the QuestionList object. + + Args: + api: An instance of TimesketchApi object. + """ + self.api = api + super().__init__(api=api, resource_uri="questions/") + + def get(self): + """ + Retrieves a list of question templates. + + Returns: + list: A list of question tempaltes. + """ + resource_url = f"{self.api.api_root}/questions/" + response = self.api.session.get(resource_url) + response_json = error.get_response_json(response, logger) + scenario_objects = response_json.get("objects", []) + return scenario_objects diff --git a/api_client/python/timesketch_api_client/scenario_test.py b/api_client/python/timesketch_api_client/scenario_test.py new file mode 100644 index 0000000000..0be15e7da0 --- /dev/null +++ b/api_client/python/timesketch_api_client/scenario_test.py @@ -0,0 +1,85 @@ +# Copyright 2024 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the Timesketch API client""" + +import unittest +import mock + +from . import client +from . import test_lib +from . import scenario as scenario_lib + + +class ScenarioTest(unittest.TestCase): + """Test Scenario object.""" + + @mock.patch("requests.Session", test_lib.mock_session) + def setUp(self): + """Setup test case.""" + self.api_client = client.TimesketchApi("http://127.0.0.1", "test", "test") + self.sketch = self.api_client.get_sketch(1) + + def test_scenario_to_dict(self): + """Test Scenario object to dict.""" + scenario = self.sketch.list_scenarios()[0] + self.assertIsInstance(scenario.to_dict(), dict) + + +class ScenarioListTest(unittest.TestCase): + """Test ScenarioList object.""" + + @mock.patch("requests.Session", test_lib.mock_session) + def setUp(self): + """Setup test case.""" + self.api_client = client.TimesketchApi("http://127.0.0.1", "test", "test") + + def test_scenario_list(self): + """Test ScenarioList object.""" + scenario_list = scenario_lib.ScenarioTemplateList(self.api_client).get() + self.assertIsInstance(scenario_list, list) + self.assertEqual(len(scenario_list), 2) + self.assertEqual(scenario_list[0]["name"], "Test Scenario") + self.assertEqual(scenario_list[1]["id"], "S0002") + + +class QuestionTest(unittest.TestCase): + """Test Question object.""" + + @mock.patch("requests.Session", test_lib.mock_session) + def setUp(self): + """Setup test case.""" + self.api_client = client.TimesketchApi("http://127.0.0.1", "test", "test") + self.sketch = self.api_client.get_sketch(1) + + def test_question_to_dict(self): + """Test Question object to dict.""" + scenario = self.sketch.list_questions()[0] + self.assertIsInstance(scenario.to_dict(), dict) + + +class QuestionListTest(unittest.TestCase): + """Test QuestionList object.""" + + @mock.patch("requests.Session", test_lib.mock_session) + def setUp(self): + """Setup test case.""" + self.api_client = client.TimesketchApi("http://127.0.0.1", "test", "test") + + def test_question_list(self): + """Test QuestionList object.""" + question_list = scenario_lib.QuestionTemplateList(self.api_client).get() + self.assertIsInstance(question_list, list) + self.assertEqual(len(question_list), 2) + self.assertEqual(question_list[0]["name"], "Test question?") + self.assertEqual(question_list[1]["id"], "Q0002") diff --git a/api_client/python/timesketch_api_client/sketch.py b/api_client/python/timesketch_api_client/sketch.py index 5eac0b50ad..12c1a11e41 100644 --- a/api_client/python/timesketch_api_client/sketch.py +++ b/api_client/python/timesketch_api_client/sketch.py @@ -32,6 +32,7 @@ from . import searchtemplate from . import story from . import timeline +from . import scenario as scenario_lib logger = logging.getLogger("timesketch_api.sketch") @@ -1610,28 +1611,159 @@ def search_by_label( as_pandas=as_pandas, ) - def add_scenario(self, scenario_name): - """Adds a investigative scenario to the sketch. + def add_scenario(self, dfiq_id, display_name=None): + """Adds an investigative scenario to the sketch. Args: - scenario_name (str): Name of the scenario to add. + dfiq_id (str): ID of the DFIQ scenario template to add. + display_name (str): [Optional] Name of the scenario to add. Raises: RuntimeError: If sketch is archived. Returns: - Dictionary with scenario. + Scenario object. """ if self.is_archived(): raise RuntimeError("Unable to add a scenario to an archived sketch") - form_data = {"scenario_name": scenario_name} + if not display_name: + form_data = {"dfiq_id": dfiq_id} + else: + form_data = {"dfiq_id": dfiq_id, "display_name": display_name} resource_url = "{0:s}/sketches/{1:d}/scenarios/".format( self.api.api_root, self.id ) response = self.api.session.post(resource_url, json=form_data) - return error.get_response_json(response, logger) + response_json = error.get_response_json(response, logger) + scenario_objects = response_json.get("objects") + if not scenario_objects: + raise RuntimeError( + f"Failed to add the scenario {dfiq_id} to the sketch {self.id}." + ) + + if not len(scenario_objects) == 1: + raise RuntimeError( + f"Failed to add the scenario {dfiq_id} to the sketch {self.id}." + ) + scenario_data = scenario_objects[0] + scenario = scenario_lib.Scenario( + scenario_id=scenario_data.get("id", -1), + sketch_id=self.id, + api=self.api, + ) + + return scenario + + def list_scenarios(self): + """Get a list of all scenarios that are attached to the sketch. + + Returns: + List of scenarios (instances of Scenario objects) + """ + if self.is_archived(): + raise RuntimeError("Unable to list scenarios on an archived sketch.") + + scenario_list = [] + resource_url = "{0:s}/sketches/{1:d}/scenarios/".format( + self.api.api_root, self.id + ) + response = self.api.session.get(resource_url) + response_json = error.get_response_json(response, logger) + scenario_objects = response_json.get("objects") + if not scenario_objects: + return scenario_list + + if not len(scenario_objects) == 1: + return scenario_list + for scenario_dict in scenario_objects[0]: + scenario_list.append( + scenario_lib.Scenario( + scenario_id=scenario_dict.get("id", -1), + sketch_id=self.id, + api=self.api, + ) + ) + return scenario_list + + def add_question(self, dfiq_id=None, question_text=None): + """Adds an investigative question to the sketch. + + Args: + dfiq_id (str): [Optional] ID of the DFIQ question template to add. + question_text (str): [Optional] Question text to add. + + Raises: + RuntimeError: If sketch is archived or input is missing. + + Returns: + Question object. + """ + if self.is_archived(): + raise RuntimeError("Unable to add a question to an archived sketch!") + + if dfiq_id: + form_data = {"template_id": dfiq_id} + elif question_text: + form_data = {"question_text": question_text} + else: + raise RuntimeError("Either dfiq_id or question_text are required!") + + resource_url = "{0:s}/sketches/{1:d}/questions/".format( + self.api.api_root, self.id + ) + response = self.api.session.post(resource_url, json=form_data) + response_json = error.get_response_json(response, logger) + question_objects = response_json.get("objects") + if not question_objects: + raise RuntimeError( + f"Failed to add the scenario {dfiq_id} to the sketch {self.id}." + ) + + if not len(question_objects) == 1: + raise RuntimeError( + f"Failed to add the scenario {dfiq_id} to the sketch {self.id}." + ) + question_data = question_objects[0] + question = scenario_lib.Question( + question_id=question_data.get("id", -1), + sketch_id=self.id, + api=self.api, + ) + + return question + + def list_questions(self): + """Get a list of all questions that are attached to the sketch. + + Returns: + List of questions (instances of Question objects) + """ + if self.is_archived(): + raise RuntimeError("Unable to list questions on an archived sketch.") + + question_list = [] + resource_url = "{0:s}/sketches/{1:d}/questions/".format( + self.api.api_root, self.id + ) + response = self.api.session.get(resource_url) + response_json = error.get_response_json(response, logger) + question_objects = response_json.get("objects") + if not question_objects: + return question_list + + if not len(question_objects) == 1: + return question_list + for question_dict in question_objects[0]: + question_list.append( + scenario_lib.Question( + question_id=question_dict.get("id", -1), + sketch_id=self.id, + api=self.api, + ) + ) + return question_list def add_event(self, message, date, timestamp_desc, attributes=None, tags=None): """Adds an event to the sketch specific timeline. diff --git a/api_client/python/timesketch_api_client/sketch_test.py b/api_client/python/timesketch_api_client/sketch_test.py index ef41f1491b..2dfd4cf366 100644 --- a/api_client/python/timesketch_api_client/sketch_test.py +++ b/api_client/python/timesketch_api_client/sketch_test.py @@ -21,6 +21,7 @@ from . import search from . import test_lib from . import timeline as timeline_lib +from . import scenario as scenario_lib class SketchTest(unittest.TestCase): @@ -95,3 +96,49 @@ def test_list_aggregations(self): self.assertEqual( aggregations[1].description, "Aggregating values of a particular field" ) + + def test_list_scenarios(self): + """Test the Sketch list_scenarios method.""" + scenarios = self.sketch.list_scenarios() + self.assertIsInstance(scenarios, list) + self.assertEqual(len(scenarios), 1) + scenario = scenarios[0] + self.assertIsInstance(scenario, scenario_lib.Scenario) + self.assertEqual(scenario.id, 1) + self.assertEqual(scenario.name, "Test Scenario") + self.assertEqual(scenario.scenario_id, 1) + self.assertEqual(scenario.dfiq_id, "S0001") + self.assertEqual(scenario.description, "Scenario description!") + + def test_add_scenario(self): + """Test the Sketch add_scenario method.""" + scenario = self.sketch.add_scenario(dfiq_id="S0001") + self.assertIsInstance(scenario, scenario_lib.Scenario) + self.assertEqual(scenario.id, 1) + self.assertEqual(scenario.name, "Test Scenario") + self.assertEqual(scenario.scenario_id, 1) + self.assertEqual(scenario.dfiq_id, "S0001") + self.assertEqual(scenario.description, "Scenario description!") + + def test_list_questions(self): + """Test the Sketch list_questions method.""" + questions = self.sketch.list_questions() + self.assertIsInstance(questions, list) + self.assertEqual(len(questions), 1) + question = questions[0] + self.assertIsInstance(question, scenario_lib.Question) + self.assertEqual(question.id, 1) + self.assertEqual(question.name, "Test Question?") + self.assertEqual(question.question_id, 1) + self.assertEqual(question.dfiq_id, "Q0001") + self.assertEqual(question.description, "Test Question Description") + + def test_add_question(self): + """Test the Sketch add_question method.""" + question = self.sketch.add_question(dfiq_id="Q0001") + self.assertIsInstance(question, scenario_lib.Question) + self.assertEqual(question.id, 1) + self.assertEqual(question.name, "Test Question?") + self.assertEqual(question.question_id, 1) + self.assertEqual(question.dfiq_id, "Q0001") + self.assertEqual(question.description, "Test Question Description") diff --git a/api_client/python/timesketch_api_client/test_lib.py b/api_client/python/timesketch_api_client/test_lib.py index 6cbd2c827d..7219c3b4ce 100644 --- a/api_client/python/timesketch_api_client/test_lib.py +++ b/api_client/python/timesketch_api_client/test_lib.py @@ -47,6 +47,7 @@ def get(*args, **kwargs): # pylint: disable=unused-argument def post(self, *args, **kwargs): """Mock POST request handler.""" + kwargs["method"] = "POST" if self._post_done: return mock_response(*args, empty=True) return mock_response(*args, **kwargs) @@ -640,7 +641,125 @@ def json(self): aggregation_group = {"meta": {"command": "list_groups"}, "objects": []} - # Register API endpoints to the correct mock response data. + mock_sketch_scenario_response = { + "meta": {}, + "objects": [ + [ + { + "description": "Scenario description!", + "dfiq_identifier": "S0001", + "display_name": "Test Scenario", + "id": 1, + "name": "Test Scenario", + } + ] + ], + } + + mock_scenario_response = { + "meta": {}, + "objects": [ + { + "description": "Scenario description!", + "dfiq_identifier": "S0001", + "display_name": "Test Scenario", + "id": 1, + "name": "Test Scenario", + } + ], + } + + mock_scenario_templates_response = { + "objects": [ + { + "child_ids": ["F0001", "F0002"], + "description": "Scenario description!", + "id": "S0001", + "name": "Test Scenario", + "parent_ids": [], + "tags": ["test"], + }, + { + "child_ids": ["F1007"], + "description": "Scenario description 2!", + "id": "S0002", + "name": "Test Scenario 2", + "parent_ids": [], + "tags": [], + }, + ] + } + + mock_sketch_questions_response = { + "meta": {}, + "objects": [ + [ + { + "approaches": [ + { + "description": "Test Approach Description", + "display_name": "Test Approach", + "id": 26, + "name": "Test Approach", + "search_templates": [], + } + ], + "conclusions": [], + "description": "Test Question Description", + "dfiq_identifier": "Q0001", + "display_name": "Test Question?", + "id": 1, + "name": "Test Question?", + } + ] + ], + } + + mock_question_response = { + "meta": {}, + "objects": [ + { + "approaches": [ + { + "description": "Test Approach Description", + "display_name": "Test Approach", + "id": 26, + "name": "Test Approach", + "search_templates": [], + } + ], + "conclusions": [], + "description": "Test Question Description", + "dfiq_identifier": "Q0001", + "display_name": "Test Question?", + "id": 1, + "name": "Test Question?", + } + ], + } + + mock_question_templates_response = { + "objects": [ + { + "child_ids": ["Q0001.01"], + "description": "Test Question Description", + "id": "Q0001", + "name": "Test question?", + "parent_ids": ["F0001"], + "tags": ["test"], + }, + { + "child_ids": ["Q0002.01"], + "description": "Second Test Question Description", + "id": "Q0002", + "name": "Second question?", + "parent_ids": ["F0001"], + "tags": ["test"], + }, + ] + } + + # Register API endpoints to the correct mock response data for GET requests. url_router = { "http://127.0.0.1": MockResponse(text_data=auth_text_data), "http://127.0.0.1/api/v1/sketches/": MockResponse(json_data=sketch_list_data), @@ -710,9 +829,49 @@ def json(self): "http://127.0.0.1/api/v1/sketches/1/aggregation/explore/": MockResponse( json_data=aggregation_chart_data ), + "http://127.0.0.1/api/v1/sketches/1/scenarios/": MockResponse( + json_data=mock_sketch_scenario_response + ), + "http://127.0.0.1/api/v1/sketches/1/scenarios/1/": MockResponse( + json_data=mock_scenario_response + ), + "http://127.0.0.1/api/v1/scenarios/": MockResponse( + json_data=mock_scenario_templates_response + ), + "http://127.0.0.1/api/v1/sketches/1/questions/": MockResponse( + json_data=mock_sketch_questions_response + ), + "http://127.0.0.1/api/v1/questions/": MockResponse( + json_data=mock_question_templates_response + ), + "http://127.0.0.1/api/v1/sketches/1/questions/1/": MockResponse( + json_data=mock_question_response + ), + } + + # Register API endpoints to the correct mock response data for POST requests. + post_url_router = { + "http://127.0.0.1/api/v1/sketches/1/event/attributes/": MockResponse( + json_data=add_event_attribute_data + ), + "http://127.0.0.1/api/v1/sketches/1/aggregation/explore/": MockResponse( + json_data=aggregation_chart_data + ), + "http://127.0.0.1/api/v1/sketches/1/scenarios/": MockResponse( + json_data=mock_scenario_response + ), + "http://127.0.0.1/api/v1/sketches/1/questions/": MockResponse( + json_data=mock_question_response + ), + "http://127.0.0.1/api/v1/sketches/1/explore/": MockResponse( + json_data=timeline_data + ), } if kwargs.get("empty", False): return MockResponse(text_data=empty_data) + if kwargs.get("method", "").upper() == "POST": + return post_url_router.get(args[0], MockResponse(None, 404)) + return url_router.get(args[0], MockResponse(None, 404)) diff --git a/timesketch/api/v1/resources/scenarios.py b/timesketch/api/v1/resources/scenarios.py index b6a06a8a08..a5e6b7584f 100644 --- a/timesketch/api/v1/resources/scenarios.py +++ b/timesketch/api/v1/resources/scenarios.py @@ -133,17 +133,15 @@ def post(self, sketch_id): if not form: form = request.data - scenario_id = form.get("scenario_id") + dfiq_id = form.get("dfiq_id") display_name = form.get("display_name") scenario = next( - (scenario for scenario in dfiq.scenarios if scenario.id == scenario_id), + (scenario for scenario in dfiq.scenarios if scenario.id == dfiq_id), None, ) if not scenario: - abort( - HTTP_STATUS_CODE_NOT_FOUND, f"No such scenario template: {scenario_id}" - ) + abort(HTTP_STATUS_CODE_NOT_FOUND, f"No such scenario template: {dfiq_id}") if not display_name: display_name = scenario.name @@ -499,19 +497,6 @@ def post(self, sketch_id): scenario = Scenario.get_by_id(scenario_id) if scenario_id else None facet = Facet.get_by_id(facet_id) if facet_id else None - if not question_text: - abort(HTTP_STATUS_CODE_BAD_REQUEST, "Question is missing") - - if scenario: - if scenario.sketch.id != sketch.id: - abort( - HTTP_STATUS_CODE_FORBIDDEN, "Scenario is not part of this sketch." - ) - - if facet: - if facet.sketch.id != sketch.id: - abort(HTTP_STATUS_CODE_FORBIDDEN, "Facet is not part of this sketch.") - if template_id: dfiq = load_dfiq_from_config() if not dfiq: @@ -560,6 +545,22 @@ def post(self, sketch_id): new_question.approaches.append(approach_sql) else: + if not question_text: + abort(HTTP_STATUS_CODE_BAD_REQUEST, "Question is missing") + + if scenario: + if scenario.sketch.id != sketch.id: + abort( + HTTP_STATUS_CODE_FORBIDDEN, + "Scenario is not part of this sketch.", + ) + + if facet: + if facet.sketch.id != sketch.id: + abort( + HTTP_STATUS_CODE_FORBIDDEN, "Facet is not part of this sketch." + ) + new_question = InvestigativeQuestion.get_or_create( name=question_text, display_name=question_text,