Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[COST-4915] Add unattributed distribution to cost model form #5072

Merged
merged 10 commits into from
May 10, 2024
14 changes: 14 additions & 0 deletions koku/api/metrics/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,17 @@
"default_cost_type": "Infrastructure",
},
]

PLATFORM_COST = "platform_cost"
WORKER_UNALLOCATED = "worker_cost"
NETWORK_UNATTRIBUTED = "network_unattributed"
STORAGE_UNATTRIBUTED = "storage_unattributed"
DISTRIBUTION_TYPE = "distribution_type"

DEFAULT_DISTRIBUTION_INFO = {
DISTRIBUTION_TYPE: CPU_DISTRIBUTION,
PLATFORM_COST: True,
WORKER_UNALLOCATED: True,
NETWORK_UNATTRIBUTED: False,
STORAGE_UNATTRIBUTED: False,
}
26 changes: 10 additions & 16 deletions koku/cost_models/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,21 @@ class MarkupSerializer(serializers.Serializer):
class DistributionSerializer(BaseSerializer):
"""Serializer for distribution options"""

DISTRIBUTION_OPTIONS = {"distribution_type", "worker_cost", "platform_cost"}

distribution_type = serializers.ChoiceField(choices=metric_constants.DISTRIBUTION_CHOICES, required=False)
platform_cost = serializers.BooleanField(required=False)
worker_cost = serializers.BooleanField(required=False)
network_unattributed = serializers.BooleanField(required=False)
storage_unattributed = serializers.BooleanField(required=False)

def validate(self, data):
"""Run validation for distribution options."""

diff = self.DISTRIBUTION_OPTIONS.difference(data)
if diff == self.DISTRIBUTION_OPTIONS:
return {"distribution_type": metric_constants.CPU_DISTRIBUTION, "platform_cost": True, "worker_cost": True}
if diff:
distribution_info_str = ", ".join(diff)
error_msg = f"Missing distribution information: one of {distribution_info_str}"
raise serializers.ValidationError(error_msg)
default_to_true = [metric_constants.PLATFORM_COST, metric_constants.WORKER_UNALLOCATED]
distribution_keys = metric_constants.DEFAULT_DISTRIBUTION_INFO.keys()
diff = set(distribution_keys).difference(data)
if diff == distribution_keys:
return metric_constants.DEFAULT_DISTRIBUTION_INFO
for element in diff:
data[element] = element in default_to_true
return data


Expand Down Expand Up @@ -477,12 +476,7 @@ def validate(self, data):
data["currency"] = get_currency(self.context.get("request"))

if not data.get("distribution_info"):
data["distribution_info"] = {
"distribution_type": data.get("distribution", metric_constants.CPU_DISTRIBUTION),
"platform_cost": True,
"worker_cost": True,
}

data["distribution_info"] = metric_constants.DEFAULT_DISTRIBUTION_INFO
if (
data.get("markup")
and not data.get("rates")
Expand Down
35 changes: 14 additions & 21 deletions koku/cost_models/test/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@

from api.iam.test.iam_test_case import IamTestCase
from api.metrics import constants as metric_constants
from api.metrics.constants import DEFAULT_DISTRIBUTION_INFO
from api.metrics.constants import SOURCE_TYPE_MAP
from api.provider.models import Provider
from api.utils import get_currency
from cost_models.models import CostModel
from cost_models.models import CostModelMap
from cost_models.serializers import CostModelSerializer
from cost_models.serializers import DistributionSerializer
from cost_models.serializers import RateSerializer
from cost_models.serializers import UUIDKeyRelatedField

Expand Down Expand Up @@ -844,47 +846,38 @@ def test_valid_distribution_info_keys(self):
if serializer.is_valid(raise_exception=True):
instance = serializer.save()
self.assertIsNotNone(instance)
# Add in default options
valid_distrib_obj[metric_constants.NETWORK_UNATTRIBUTED] = False
valid_distrib_obj[metric_constants.STORAGE_UNATTRIBUTED] = False
self.assertEqual(instance.distribution_info, valid_distrib_obj)

def test_invalid_distribution_info_keys(self):
"""Test that source distribution_info object has invalid keys."""

invalid_distrib_info_keys = {"bad_key": "", "badder_key": True, "worker_cost": False}
self.ocp_data["distribution_info"] = invalid_distrib_info_keys
self.assertEqual(self.ocp_data["distribution_info"], invalid_distrib_info_keys)
bad_key1 = "bad_key"
bad_key2 = "worst_key"
invalid_distrib_info_keys = {bad_key1: "", bad_key2: True, "worker_cost": False}
with tenant_context(self.tenant):
serializer = CostModelSerializer(data=self.ocp_data, context=self.request_context)
with self.assertRaises(serializers.ValidationError):
serializer.is_valid(raise_exception=True)
serializer = DistributionSerializer(data=invalid_distrib_info_keys)
self.assertTrue(serializer.is_valid(raise_exception=True))
self.assertNotIn(bad_key1, serializer.data)
self.assertNotIn(bad_key2, serializer.data)

def test_none_distribution_info_returns_defaults(self):
"""Test that a none distribution_info object uses default options."""

default_distrib_info_obj = {
"distribution_type": metric_constants.CPU_DISTRIBUTION,
"platform_cost": True,
"worker_cost": True,
}
with tenant_context(self.tenant):
instance = None
serializer = CostModelSerializer(data=self.ocp_data, context=self.request_context)
if serializer.is_valid(raise_exception=True):
instance = serializer.save()
self.assertIsNotNone(instance)
self.assertEqual(instance.distribution_info, default_distrib_info_obj)
self.assertEqual(instance.distribution_info, DEFAULT_DISTRIBUTION_INFO)

def test_empty_distribution_info_returns_defaults(self):
"""Test that an empty distribution_info object returns default options."""

default_distrib_info_obj = {
"distribution_type": metric_constants.CPU_DISTRIBUTION,
"platform_cost": True,
"worker_cost": True,
}
self.ocp_data["distribution_info"] = {}
with tenant_context(self.tenant):
instance = None
serializer = CostModelSerializer(data=self.ocp_data, context=self.request_context)
if serializer.is_valid(raise_exception=True):
instance = serializer.save()
self.assertEqual(instance.distribution_info, default_distrib_info_obj)
self.assertEqual(instance.distribution_info, DEFAULT_DISTRIBUTION_INFO)
44 changes: 19 additions & 25 deletions koku/masu/database/ocp_report_db_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from trino.exceptions import TrinoExternalError

from api.common import log_json
from api.metrics import constants as metric_constants
from api.metrics.constants import DEFAULT_DISTRIBUTION_TYPE
from api.provider.models import Provider
from koku.database import SQLScriptAtomicExecutorMixin
Expand Down Expand Up @@ -392,19 +393,24 @@ def populate_markup_cost(self, markup, start_date, end_date, cluster_id):
),
)

def populate_platform_and_worker_distributed_cost_sql(
self, start_date, end_date, provider_uuid, distribution_info
):
def populate_distributed_cost_sql(self, start_date, end_date, provider_uuid, distribution_info):
"""
Populate the platform cost distribution of a customer.
Populate the distribution cost model options.

args:
start_date (datetime, str): The start_date to calculate monthly_cost.
end_date (datetime, str): The end_date to calculate monthly_cost.
distribution: Choice of monthly distribution ex. memory
provider_uuid (str): The str of the provider UUID
"""
distribute_mapping = {}

key_to_file_mapping = {
metric_constants.PLATFORM_COST: "distribute_platform_cost.sql",
metric_constants.WORKER_UNALLOCATED: "distribute_worker_cost.sql",
# metric_constants.STORAGE_UNATTRIBUTED: "distribute_unattributed_storage_cost.sql",
# metric_constants.NETWORK_UNATTRIBUTED: "distribute_unattributed_network_cost.sql",
samdoran marked this conversation as resolved.
Show resolved Hide resolved
}

distribution = distribution_info.get("distribution_type", DEFAULT_DISTRIBUTION_TYPE)
table_name = self._table_map["line_item_daily_summary"]
report_period = self.report_periods_for_provider_uuid(provider_uuid, start_date)
Expand All @@ -415,26 +421,14 @@ def populate_platform_and_worker_distributed_cost_sql(
return

report_period_id = report_period.id
distribute_mapping = {
"platform_cost": {
"sql_file": "distribute_platform_cost.sql",
"log_msg": {
True: "distributing platform cost",
False: "removing platform_distributed cost model rate type",
},
},
"worker_cost": {
"sql_file": "distribute_worker_cost.sql",
"log_msg": {
True: "distributing worker unallocated cost",
False: "removing worker_distributed cost model rate type",
},
},
}

for cost_model_key, metadata in distribute_mapping.items():
for cost_model_key, sql_file in key_to_file_mapping.items():
populate = distribution_info.get(cost_model_key, False)
# if populate is false we only execute the delete sql.
if populate:
log_msg = f"distributing {cost_model_key}"
else:
# if populate is false we only execute the delete sql.
log_msg = f"removing {cost_model_key} distribution"
sql_params = {
"start_date": start_date,
"end_date": end_date,
Expand All @@ -445,9 +439,9 @@ def populate_platform_and_worker_distributed_cost_sql(
"populate": populate,
}

sql = pkgutil.get_data("masu.database", f"sql/openshift/cost_model/{metadata['sql_file']}")
sql = pkgutil.get_data("masu.database", f"sql/openshift/cost_model/distribute_cost/{sql_file}")
sql = sql.decode("utf-8")
LOG.info(log_json(msg=metadata["log_msg"][populate], context=sql_params))
LOG.info(log_json(msg=log_msg, context=sql_params))
self._prepare_and_execute_raw_sql_query(table_name, sql, sql_params, operation="INSERT")

def populate_monthly_cost_sql(self, cost_type, rate_type, rate, start_date, end_date, distribution, provider_uuid):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ user_defined_project_sum as (
AND report_period_id = {{report_period_id}}
AND lids.namespace != 'Worker unallocated'
AND lids.namespace != 'Platform unallocated'
AND lids.namespace != 'Storage unattributed'
AND lids.namespace != 'Network unattributed'
AND (cost_category_id IS NULL OR cat.name != 'Platform')
GROUP BY usage_start, cluster_id, source_uuid
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ user_defined_project_sum as (
AND report_period_id = {{report_period_id}}
AND lids.namespace != 'Worker unallocated'
AND lids.namespace != 'Platform unallocated'
AND lids.namespace != 'Storage unattributed'
AND lids.namespace != 'Network unattributed'
AND (cost_category_id IS NULL OR cat.name != 'Platform')
GROUP BY usage_start, cluster_id, source_uuid
),
Expand Down
4 changes: 1 addition & 3 deletions koku/masu/processor/ocp/ocp_cost_model_cost_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,7 @@ def update_summary_cost_model_costs(self, start_date, end_date):

with OCPReportDBAccessor(self._schema) as accessor:

accessor.populate_platform_and_worker_distributed_cost_sql(
start_date, end_date, self._provider_uuid, self._distribution_info
)
accessor.populate_distributed_cost_sql(start_date, end_date, self._provider_uuid, self._distribution_info)
accessor.populate_ui_summary_tables(start_date, end_date, self._provider.uuid)
report_period = accessor.report_periods_for_provider_uuid(self._provider_uuid, start_date)
if report_period:
Expand Down
8 changes: 4 additions & 4 deletions koku/masu/test/database/test_ocp_report_db_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,19 +930,19 @@ def test_populate_usage_costs_new_columns_no_report_period(self):
acc.populate_usage_costs("", "", start_date, end_date, self.provider_uuid)
self.assertIn("no report period for OCP provider", logger.output[0])

def test_populate_platform_and_worker_distributed_cost_sql_no_report_period(self):
def test_populate_distributed_cost_sql_no_report_period(self):
"""Test that updating monthly costs without a matching report period no longer throws an error"""
start_date = "2000-01-01"
end_date = "2000-02-01"
with self.accessor as acc:
result = acc.populate_platform_and_worker_distributed_cost_sql(
result = acc.populate_distributed_cost_sql(
start_date, end_date, self.provider_uuid, {"platform_cost": True}
)
self.assertIsNone(result)

@patch("masu.database.ocp_report_db_accessor.pkgutil.get_data")
@patch("masu.database.ocp_report_db_accessor.OCPReportDBAccessor._execute_raw_sql_query")
def test_populate_platform_and_worker_distributed_cost_sql_called(self, mock_sql_execute, mock_data_get):
def test_populate_distributed_cost_sql_called(self, mock_sql_execute, mock_data_get):
"""Test that the platform distribution is called."""

def get_pkgutil_values(file):
Expand Down Expand Up @@ -972,7 +972,7 @@ def get_pkgutil_values(file):

with self.accessor as acc:
acc.prepare_query = mock_jinja
acc.populate_platform_and_worker_distributed_cost_sql(
acc.populate_distributed_cost_sql(
start_date, end_date, self.ocp_test_provider_uuid, {"worker_cost": True, "platform_cost": True}
)
expected_calls = [
Expand Down
Loading