Skip to content

Commit

Permalink
Make stuff private, move things around
Browse files Browse the repository at this point in the history
  • Loading branch information
pmcollins committed Nov 16, 2024
1 parent d501e92 commit a3657b6
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 88 deletions.
136 changes: 57 additions & 79 deletions src/splunk_otel/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@

import opentelemetry.context
import wrapt
from opentelemetry._logs import Logger, SeverityNumber, get_logger, set_logger_provider
from opentelemetry._logs import get_logger, Logger, SeverityNumber
from opentelemetry.context import Context
from opentelemetry.exporter.otlp.proto.grpc._log_exporter import OTLPLogExporter
from opentelemetry.instrumentation.version import __version__ as version
from opentelemetry.sdk._logs import LoggerProvider, LogRecord
from opentelemetry.sdk._logs._internal.export import BatchLogRecordProcessor
from opentelemetry.sdk._logs import LogRecord
from opentelemetry.sdk.resources import Resource
from opentelemetry.trace import TraceFlags
from opentelemetry.trace.propagation import _SPAN_KEY
Expand All @@ -24,37 +22,51 @@
from splunk_otel.env import Env

_SERVICE_NAME_ATTR = "service.name"

_SPLUNK_DISTRO_VERSION_ATTR = "splunk.distro.version"
_DEFAULT_OTEL_SERVICE_NAME = "unknown_service"
_NO_SERVICE_NAME_WARNING = """service.name attribute is not set, your service is unnamed and will be difficult to identify.
set your service name using the OTEL_SERVICE_NAME environment variable.
E.g. `OTEL_SERVICE_NAME="<YOUR_SERVICE_NAME_HERE>"`"""
_DEFAULT_SERVICE_NAME = "unnamed-python-service"

_profiling_timer = None

_profile_timer = None
_pylogger = logging.getLogger(__name__)


def start_profiling():
tcm = ThreadContextMapping()
tcm = _ThreadContextMapping()
tcm.wrap_context_methods()

period_millis = 100
resource = mk_resource(Env().getval("OTEL_SERVICE_NAME"))
resource = _mk_resource(Env().getval("OTEL_SERVICE_NAME"))
logger = get_logger("splunk-profiler")
scraper = ProfilingScraper(resource, tcm.get_thread_states(), period_millis, logger)
scraper = _ProfileScraper(resource, tcm.get_thread_states(), period_millis, logger)

global _profiling_timer # noqa PLW0603
_profiling_timer = PeriodicTimer(period_millis, scraper.tick)
_profiling_timer.start()
global _profile_timer # noqa PLW0603
_profile_timer = _PeriodicTimer(period_millis, scraper.tick)
_profile_timer.start()


def stop_profiling():
_profiling_timer.stop()
_profile_timer.stop()


def _mk_resource(service_name) -> Resource:
if service_name:
resolved_name = service_name
else:
_pylogger.warning(_NO_SERVICE_NAME_WARNING)
resolved_name = _DEFAULT_SERVICE_NAME
return Resource.create(
{
_SPLUNK_DISTRO_VERSION_ATTR: version,
_SERVICE_NAME_ATTR: resolved_name,
}
)


class ThreadContextMapping:
class _ThreadContextMapping:
def __init__(self):
self.thread_states = {}

Expand Down Expand Up @@ -111,12 +123,12 @@ def wrapper(wrapped, _instance, args, kwargs):
return wrapper


def collect_stacktraces():
def _collect_stacktraces():
out = []
frames = sys._current_frames() # noqa SLF001

for thread_id, frame in frames.items():
stack_summary = extract_stack_summary(frame)
stack_summary = _extract_stack_summary(frame)
frames = [(sf.filename, sf.name, sf.lineno) for sf in stack_summary]
out.append(
{
Expand All @@ -127,14 +139,14 @@ def collect_stacktraces():
return out


class ProfilingScraper:
class _ProfileScraper:
def __init__(
self,
resource,
thread_states,
period_millis,
logger: Logger,
collect_stacktraces_func=collect_stacktraces,
collect_stacktraces_func=_collect_stacktraces,
time_func=time.time,
):
self.resource = resource
Expand All @@ -155,8 +167,8 @@ def mk_log_record(self, stacktraces):

time_seconds = self.time()

pb_profile = stacktraces_to_cpu_profile(stacktraces, self.thread_states, self.period_millis, time_seconds)
pb_profile_str = pb_profile_to_str(pb_profile)
pb_profile = _stacktraces_to_cpu_profile(stacktraces, self.thread_states, self.period_millis, time_seconds)
pb_profile_str = _pb_profile_to_str(pb_profile)

return LogRecord(
timestamp=int(time_seconds * 1e9),
Expand All @@ -175,7 +187,15 @@ def mk_log_record(self, stacktraces):
)


class PeriodicTimer:
def _pb_profile_to_str(pb_profile) -> str:
serialized = pb_profile.SerializeToString()
compressed = gzip.compress(serialized)
b64encoded = base64.b64encode(compressed)
return b64encoded.decode()


class _PeriodicTimer:

def __init__(self, period_millis, target):
self.period_seconds = period_millis / 1e3
self.target = target
Expand All @@ -197,21 +217,7 @@ def stop(self):
self.thread.join()


def mk_resource(service_name) -> Resource:
if service_name:
resolved_name = service_name
else:
_pylogger.warning(_NO_SERVICE_NAME_WARNING)
resolved_name = _DEFAULT_SERVICE_NAME
return Resource.create(
{
_SPLUNK_DISTRO_VERSION_ATTR: version,
_SERVICE_NAME_ATTR: resolved_name,
}
)


class StringTable:
class _StringTable:
def __init__(self):
self.strings = OrderedDict()

Expand All @@ -229,29 +235,29 @@ def keys(self):
return list(self.strings.keys())


def get_location(functions_table, str_table, locations_table, frame):
def _get_location(functions_table, str_table, locations_table, frame):
(file_name, function_name, line_no) = frame
key = f"{file_name}:{function_name}:{line_no}"
location = locations_table.get(key)

if location is None:
location = profile_pb2.Location()
location.id = len(locations_table) + 1
line = get_line(functions_table, str_table, file_name, function_name, line_no)
line = _get_line(functions_table, str_table, file_name, function_name, line_no)
location.line.append(line)
locations_table[key] = location

return location


def get_line(functions_table, str_table, file_name, function_name, line_no):
def _get_line(functions_table, str_table, file_name, function_name, line_no):
line = profile_pb2.Line()
line.function_id = get_function(functions_table, str_table, file_name, function_name).id
line.function_id = _get_function(functions_table, str_table, file_name, function_name).id
line.line = line_no if line_no != 0 else -1
return line


def get_function(functions_table, str_table, file_name, function_name):
def _get_function(functions_table, str_table, file_name, function_name):
key = f"{file_name}:{function_name}"
func = functions_table.get(key)

Expand All @@ -267,8 +273,15 @@ def get_function(functions_table, str_table, file_name, function_name):
return func


def stacktraces_to_cpu_profile(stacktraces, thread_states, period_millis, time_seconds):
str_table = StringTable()
def _extract_stack_summary(frame):
stack_iterator = traceback.walk_stack(frame)
out = StackSummary.extract(stack_iterator, limit=None, lookup_lines=False)
out.reverse()
return out


def _stacktraces_to_cpu_profile(stacktraces, thread_states, period_millis, time_seconds):
str_table = _StringTable()
locations_table = OrderedDict()
functions_table = OrderedDict()

Expand Down Expand Up @@ -319,7 +332,7 @@ def stacktraces_to_cpu_profile(stacktraces, thread_states, period_millis, time_s
location_ids = []

for frame in reversed(stacktrace["frames"]):
location = get_location(functions_table, str_table, locations_table, frame)
location = _get_location(functions_table, str_table, locations_table, frame)
location_ids.append(location.id)

sample.location_id.extend(location_ids)
Expand All @@ -333,38 +346,3 @@ def stacktraces_to_cpu_profile(stacktraces, thread_states, period_millis, time_s
pb_profile.location.extend(list(locations_table.values()))

return pb_profile


def pb_profile_to_str(pb_profile) -> str:
serialized = pb_profile.SerializeToString()
compressed = gzip.compress(serialized)
b64encoded = base64.b64encode(compressed)
return b64encoded.decode()


def pb_profile_from_str(stringified: str) -> profile_pb2.Profile:
byte_array = base64.b64decode(stringified)
decompressed = gzip.decompress(byte_array)
out = profile_pb2.Profile()
out.ParseFromString(decompressed)
return out


def extract_stack_summary(frame):
stack_iterator = traceback.walk_stack(frame)
out = StackSummary.extract(stack_iterator, limit=None, lookup_lines=False)
out.reverse()
return out


def configure_otel():
logger_provider = LoggerProvider()
logger_provider.add_log_record_processor(BatchLogRecordProcessor(OTLPLogExporter()))
set_logger_provider(logger_provider)


if __name__ == "__main__":
configure_otel()
start_profiling()
time.sleep(12)
stop_profiling()
30 changes: 21 additions & 9 deletions tests/test_profile.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64
import gzip
import json
import random
import time
Expand All @@ -7,8 +9,9 @@
from google.protobuf.json_format import MessageToDict
from opentelemetry._logs import Logger
from opentelemetry.sdk.resources import Resource

from splunk_otel import profile_pb2
from splunk_otel.profile import ProfilingScraper, pb_profile_from_str, pb_profile_to_str, stacktraces_to_cpu_profile
from splunk_otel.profile import _pb_profile_to_str, _ProfileScraper, _stacktraces_to_cpu_profile


@pytest.fixture
Expand Down Expand Up @@ -39,22 +42,22 @@ def load_json(fname):
def test_basic_proto_serialization():
# noinspection PyUnresolvedReferences
profile = profile_pb2.Profile()
serialized = pb_profile_to_str(profile)
decoded_profile = pb_profile_from_str(serialized)
serialized = _pb_profile_to_str(profile)
decoded_profile = _pb_profile_from_str(serialized)
assert profile == decoded_profile


def test_stacktraces_to_cpu_profile(stacktraces_fixture, pb_profile_fixture, thread_states_fixture):
time_seconds = 1726760000 # corresponds to the timestamp in the fixture
interval_millis = 100
profile = stacktraces_to_cpu_profile(stacktraces_fixture, thread_states_fixture, interval_millis, time_seconds)
profile = _stacktraces_to_cpu_profile(stacktraces_fixture, thread_states_fixture, interval_millis, time_seconds)
assert pb_profile_fixture == MessageToDict(profile)


def test_profile_scraper(stacktraces_fixture):
time_seconds = 1726760000
logger = FakeLogger()
ps = ProfilingScraper(
logger = _FakeLogger()
ps = _ProfileScraper(
Resource({}),
{},
100,
Expand All @@ -67,11 +70,19 @@ def test_profile_scraper(stacktraces_fixture):
log_record = logger.log_records[0]

assert log_record.timestamp == int(time_seconds * 1e9)
assert len(MessageToDict(pb_profile_from_str(log_record.body))) == 4 # sanity check
assert len(MessageToDict(_pb_profile_from_str(log_record.body))) == 4 # sanity check
assert log_record.attributes["profiling.data.total.frame.count"] == 30


def do_work(time_ms):
def _pb_profile_from_str(stringified: str) -> profile_pb2.Profile:
byte_array = base64.b64decode(stringified)
decompressed = gzip.decompress(byte_array)
out = profile_pb2.Profile()
out.ParseFromString(decompressed)
return out


def _do_work(time_ms):
now = time.time()
target = now + time_ms / 1000.0

Expand All @@ -89,7 +100,8 @@ def do_work(time_ms):
return total


class FakeLogger(Logger):
class _FakeLogger(Logger):

def __init__(self):
super().__init__("fake-logger")
self.log_records = []
Expand Down

0 comments on commit a3657b6

Please sign in to comment.