diff --git a/src/splunk_otel/profile.py b/src/splunk_otel/profile.py index 53d660d..b68b095 100644 --- a/src/splunk_otel/profile.py +++ b/src/splunk_otel/profile.py @@ -10,12 +10,10 @@ import opentelemetry.context import wrapt -from opentelemetry._logs import Logger, SeverityNumber, get_logger, set_logger_provider +from opentelemetry._logs import get_logger, Logger, SeverityNumber from opentelemetry.context import Context -from opentelemetry.exporter.otlp.proto.grpc._log_exporter import OTLPLogExporter from opentelemetry.instrumentation.version import __version__ as version -from opentelemetry.sdk._logs import LoggerProvider, LogRecord -from opentelemetry.sdk._logs._internal.export import BatchLogRecordProcessor +from opentelemetry.sdk._logs import LogRecord from opentelemetry.sdk.resources import Resource from opentelemetry.trace import TraceFlags from opentelemetry.trace.propagation import _SPAN_KEY @@ -24,6 +22,7 @@ from splunk_otel.env import Env _SERVICE_NAME_ATTR = "service.name" + _SPLUNK_DISTRO_VERSION_ATTR = "splunk.distro.version" _DEFAULT_OTEL_SERVICE_NAME = "unknown_service" _NO_SERVICE_NAME_WARNING = """service.name attribute is not set, your service is unnamed and will be difficult to identify. @@ -31,30 +30,43 @@ E.g. `OTEL_SERVICE_NAME=""`""" _DEFAULT_SERVICE_NAME = "unnamed-python-service" -_profiling_timer = None - +_profile_timer = None _pylogger = logging.getLogger(__name__) def start_profiling(): - tcm = ThreadContextMapping() + tcm = _ThreadContextMapping() tcm.wrap_context_methods() period_millis = 100 - resource = mk_resource(Env().getval("OTEL_SERVICE_NAME")) + resource = _mk_resource(Env().getval("OTEL_SERVICE_NAME")) logger = get_logger("splunk-profiler") - scraper = ProfilingScraper(resource, tcm.get_thread_states(), period_millis, logger) + scraper = _ProfileScraper(resource, tcm.get_thread_states(), period_millis, logger) - global _profiling_timer # noqa PLW0603 - _profiling_timer = PeriodicTimer(period_millis, scraper.tick) - _profiling_timer.start() + global _profile_timer # noqa PLW0603 + _profile_timer = _PeriodicTimer(period_millis, scraper.tick) + _profile_timer.start() def stop_profiling(): - _profiling_timer.stop() + _profile_timer.stop() + + +def _mk_resource(service_name) -> Resource: + if service_name: + resolved_name = service_name + else: + _pylogger.warning(_NO_SERVICE_NAME_WARNING) + resolved_name = _DEFAULT_SERVICE_NAME + return Resource.create( + { + _SPLUNK_DISTRO_VERSION_ATTR: version, + _SERVICE_NAME_ATTR: resolved_name, + } + ) -class ThreadContextMapping: +class _ThreadContextMapping: def __init__(self): self.thread_states = {} @@ -111,12 +123,12 @@ def wrapper(wrapped, _instance, args, kwargs): return wrapper -def collect_stacktraces(): +def _collect_stacktraces(): out = [] frames = sys._current_frames() # noqa SLF001 for thread_id, frame in frames.items(): - stack_summary = extract_stack_summary(frame) + stack_summary = _extract_stack_summary(frame) frames = [(sf.filename, sf.name, sf.lineno) for sf in stack_summary] out.append( { @@ -127,14 +139,14 @@ def collect_stacktraces(): return out -class ProfilingScraper: +class _ProfileScraper: def __init__( self, resource, thread_states, period_millis, logger: Logger, - collect_stacktraces_func=collect_stacktraces, + collect_stacktraces_func=_collect_stacktraces, time_func=time.time, ): self.resource = resource @@ -155,8 +167,8 @@ def mk_log_record(self, stacktraces): time_seconds = self.time() - pb_profile = stacktraces_to_cpu_profile(stacktraces, self.thread_states, self.period_millis, time_seconds) - pb_profile_str = pb_profile_to_str(pb_profile) + pb_profile = _stacktraces_to_cpu_profile(stacktraces, self.thread_states, self.period_millis, time_seconds) + pb_profile_str = _pb_profile_to_str(pb_profile) return LogRecord( timestamp=int(time_seconds * 1e9), @@ -175,7 +187,15 @@ def mk_log_record(self, stacktraces): ) -class PeriodicTimer: +def _pb_profile_to_str(pb_profile) -> str: + serialized = pb_profile.SerializeToString() + compressed = gzip.compress(serialized) + b64encoded = base64.b64encode(compressed) + return b64encoded.decode() + + +class _PeriodicTimer: + def __init__(self, period_millis, target): self.period_seconds = period_millis / 1e3 self.target = target @@ -197,21 +217,7 @@ def stop(self): self.thread.join() -def mk_resource(service_name) -> Resource: - if service_name: - resolved_name = service_name - else: - _pylogger.warning(_NO_SERVICE_NAME_WARNING) - resolved_name = _DEFAULT_SERVICE_NAME - return Resource.create( - { - _SPLUNK_DISTRO_VERSION_ATTR: version, - _SERVICE_NAME_ATTR: resolved_name, - } - ) - - -class StringTable: +class _StringTable: def __init__(self): self.strings = OrderedDict() @@ -229,7 +235,7 @@ def keys(self): return list(self.strings.keys()) -def get_location(functions_table, str_table, locations_table, frame): +def _get_location(functions_table, str_table, locations_table, frame): (file_name, function_name, line_no) = frame key = f"{file_name}:{function_name}:{line_no}" location = locations_table.get(key) @@ -237,21 +243,21 @@ def get_location(functions_table, str_table, locations_table, frame): if location is None: location = profile_pb2.Location() location.id = len(locations_table) + 1 - line = get_line(functions_table, str_table, file_name, function_name, line_no) + line = _get_line(functions_table, str_table, file_name, function_name, line_no) location.line.append(line) locations_table[key] = location return location -def get_line(functions_table, str_table, file_name, function_name, line_no): +def _get_line(functions_table, str_table, file_name, function_name, line_no): line = profile_pb2.Line() - line.function_id = get_function(functions_table, str_table, file_name, function_name).id + line.function_id = _get_function(functions_table, str_table, file_name, function_name).id line.line = line_no if line_no != 0 else -1 return line -def get_function(functions_table, str_table, file_name, function_name): +def _get_function(functions_table, str_table, file_name, function_name): key = f"{file_name}:{function_name}" func = functions_table.get(key) @@ -267,8 +273,15 @@ def get_function(functions_table, str_table, file_name, function_name): return func -def stacktraces_to_cpu_profile(stacktraces, thread_states, period_millis, time_seconds): - str_table = StringTable() +def _extract_stack_summary(frame): + stack_iterator = traceback.walk_stack(frame) + out = StackSummary.extract(stack_iterator, limit=None, lookup_lines=False) + out.reverse() + return out + + +def _stacktraces_to_cpu_profile(stacktraces, thread_states, period_millis, time_seconds): + str_table = _StringTable() locations_table = OrderedDict() functions_table = OrderedDict() @@ -319,7 +332,7 @@ def stacktraces_to_cpu_profile(stacktraces, thread_states, period_millis, time_s location_ids = [] for frame in reversed(stacktrace["frames"]): - location = get_location(functions_table, str_table, locations_table, frame) + location = _get_location(functions_table, str_table, locations_table, frame) location_ids.append(location.id) sample.location_id.extend(location_ids) @@ -333,38 +346,3 @@ def stacktraces_to_cpu_profile(stacktraces, thread_states, period_millis, time_s pb_profile.location.extend(list(locations_table.values())) return pb_profile - - -def pb_profile_to_str(pb_profile) -> str: - serialized = pb_profile.SerializeToString() - compressed = gzip.compress(serialized) - b64encoded = base64.b64encode(compressed) - return b64encoded.decode() - - -def pb_profile_from_str(stringified: str) -> profile_pb2.Profile: - byte_array = base64.b64decode(stringified) - decompressed = gzip.decompress(byte_array) - out = profile_pb2.Profile() - out.ParseFromString(decompressed) - return out - - -def extract_stack_summary(frame): - stack_iterator = traceback.walk_stack(frame) - out = StackSummary.extract(stack_iterator, limit=None, lookup_lines=False) - out.reverse() - return out - - -def configure_otel(): - logger_provider = LoggerProvider() - logger_provider.add_log_record_processor(BatchLogRecordProcessor(OTLPLogExporter())) - set_logger_provider(logger_provider) - - -if __name__ == "__main__": - configure_otel() - start_profiling() - time.sleep(12) - stop_profiling() diff --git a/tests/test_profile.py b/tests/test_profile.py index 3f6352d..0ab9b7d 100644 --- a/tests/test_profile.py +++ b/tests/test_profile.py @@ -1,3 +1,5 @@ +import base64 +import gzip import json import random import time @@ -7,8 +9,9 @@ from google.protobuf.json_format import MessageToDict from opentelemetry._logs import Logger from opentelemetry.sdk.resources import Resource + from splunk_otel import profile_pb2 -from splunk_otel.profile import ProfilingScraper, pb_profile_from_str, pb_profile_to_str, stacktraces_to_cpu_profile +from splunk_otel.profile import _pb_profile_to_str, _ProfileScraper, _stacktraces_to_cpu_profile @pytest.fixture @@ -39,22 +42,22 @@ def load_json(fname): def test_basic_proto_serialization(): # noinspection PyUnresolvedReferences profile = profile_pb2.Profile() - serialized = pb_profile_to_str(profile) - decoded_profile = pb_profile_from_str(serialized) + serialized = _pb_profile_to_str(profile) + decoded_profile = _pb_profile_from_str(serialized) assert profile == decoded_profile def test_stacktraces_to_cpu_profile(stacktraces_fixture, pb_profile_fixture, thread_states_fixture): time_seconds = 1726760000 # corresponds to the timestamp in the fixture interval_millis = 100 - profile = stacktraces_to_cpu_profile(stacktraces_fixture, thread_states_fixture, interval_millis, time_seconds) + profile = _stacktraces_to_cpu_profile(stacktraces_fixture, thread_states_fixture, interval_millis, time_seconds) assert pb_profile_fixture == MessageToDict(profile) def test_profile_scraper(stacktraces_fixture): time_seconds = 1726760000 - logger = FakeLogger() - ps = ProfilingScraper( + logger = _FakeLogger() + ps = _ProfileScraper( Resource({}), {}, 100, @@ -67,11 +70,19 @@ def test_profile_scraper(stacktraces_fixture): log_record = logger.log_records[0] assert log_record.timestamp == int(time_seconds * 1e9) - assert len(MessageToDict(pb_profile_from_str(log_record.body))) == 4 # sanity check + assert len(MessageToDict(_pb_profile_from_str(log_record.body))) == 4 # sanity check assert log_record.attributes["profiling.data.total.frame.count"] == 30 -def do_work(time_ms): +def _pb_profile_from_str(stringified: str) -> profile_pb2.Profile: + byte_array = base64.b64decode(stringified) + decompressed = gzip.decompress(byte_array) + out = profile_pb2.Profile() + out.ParseFromString(decompressed) + return out + + +def _do_work(time_ms): now = time.time() target = now + time_ms / 1000.0 @@ -89,7 +100,8 @@ def do_work(time_ms): return total -class FakeLogger(Logger): +class _FakeLogger(Logger): + def __init__(self): super().__init__("fake-logger") self.log_records = []