Skip to content

Commit

Permalink
feat(Analytics):refacto
Browse files Browse the repository at this point in the history
  • Loading branch information
cheikhgwane committed Jul 9, 2024
1 parent faa113c commit fdd912e
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 67 deletions.
61 changes: 30 additions & 31 deletions hexa/core/analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,42 +26,41 @@ def track(
"""
from hexa.pipelines.authentication import PipelineRunUser

mixpanel = None
if settings.MIXPANEL_TOKEN:
mixpanel = Mixpanel(token=settings.MIXPANEL_TOKEN, consumer=mixpanel_consumer)

# First check if we have a PipelineUser and get the associated user if it exits
# user will be None in case of pipeline triggered via schedule or webhook
tracked_user = None
if request:
tracked_user = (
request.user.pipeline_run.user
if isinstance(request.user, PipelineRunUser)
else request.user
)

parsed = user_agent_parser.Parse(request.headers["User-Agent"])
properties.update(
{
"$browser": parsed["user_agent"]["family"],
"$device": parsed["device"]["family"],
"$os": parsed["os"]["family"],
"ip": request.META["REMOTE_ADDR"],
}
)
# First check if we have a PipelineUser and get the associated user if it exits
# user will be None in case of pipeline triggered via schedule or webhook
tracked_user = None
if request:
tracked_user = (
request.user.pipeline_run.user
if isinstance(request.user, PipelineRunUser)
else request.user
)

if tracked_user is None or tracked_user.analytics_enabled:
try:
mixpanel.track(
distinct_id=str(tracked_user.id) if tracked_user else None,
event_name=event,
properties=properties,
parsed = user_agent_parser.Parse(request.headers["User-Agent"])
properties.update(
{
"$browser": parsed["user_agent"]["family"],
"$device": parsed["device"]["family"],
"$os": parsed["os"]["family"],
"ip": request.META["REMOTE_ADDR"],
}
)
except Exception as e:
capture_exception(e)

if tracked_user is None or tracked_user.analytics_enabled:
try:
mixpanel.track(
distinct_id=str(tracked_user.id) if tracked_user else None,
event_name=event,
properties=properties,
)
except Exception as e:
capture_exception(e)


def track_user(user: User):
def set_user_properties(user: User):
if settings.MIXPANEL_TOKEN and user.analytics_enabled:
try:
mixpanel = Mixpanel(
Expand All @@ -72,8 +71,8 @@ def track_user(user: User):
properties={
"$email": user.email,
"$name": user.display_name,
"staff_status": user.is_staff,
"superuser_status": user.is_superuser,
"is_staff": user.is_staff,
"is_superuser": user.is_superuser,
"email_domain": user.email.split("@")[1],
"features_flag": [
f.feature.code for f in user.featureflag_set.all()
Expand Down
8 changes: 4 additions & 4 deletions hexa/core/tests/test_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from django.test import RequestFactory

from hexa.core.analytics import track, track_user
from hexa.core.analytics import set_user_properties, track
from hexa.core.test import TestCase
from hexa.files.tests.mocks.mockgcp import mock_gcp_storage
from hexa.pipelines.authentication import PipelineRunUser
Expand Down Expand Up @@ -238,15 +238,15 @@ def test_create_user_profile(
mixpanel_token = "token"
# mock the flush method of the BufferedConsumer
with self.settings(MIXPANEL_TOKEN=mixpanel_token):
track_user(self.USER)
set_user_properties(self.USER)

mock_mixpanel_instance.people_set_once.assert_called_once_with(
distinct_id=str(self.USER.id),
properties={
"$email": self.USER.email,
"$name": self.USER.display_name,
"staff_status": self.USER.is_staff,
"superuser_status": self.USER.is_superuser,
"is_staff": self.USER.is_staff,
"is_superuser": self.USER.is_superuser,
"email_domain": "bluesquarehub.com",
"features_flag": [],
},
Expand Down
5 changes: 1 addition & 4 deletions hexa/datasets/tests/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from unittest import mock

from django.conf import settings
from django.db import IntegrityError

Expand Down Expand Up @@ -330,8 +328,7 @@ class DatasetVersionTest(GraphQLTestCase, DatasetTestMixin):
def setUpTestData(cls):
get_storage().create_bucket(settings.WORKSPACE_DATASETS_BUCKET)

@mock.patch("hexa.datasets.schema.mutations.track")
def test_create_dataset_version(self, mocked_track):
def test_create_dataset_version(self):
superuser = self.create_user("[email protected]", is_superuser=True)

workspace = self.create_workspace(superuser, "Workspace", "Description")
Expand Down
21 changes: 7 additions & 14 deletions hexa/pipelines/tests/test_schema/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ def setUpTestData(cls):
role=WorkspaceMembershipRole.VIEWER,
)

@patch("hexa.pipelines.schema.mutations.track")
def test_create_pipeline(self, mocked_track):
def test_create_pipeline(self):
self.assertEqual(0, len(Pipeline.objects.all()))

self.client.force_login(self.USER_ROOT)
Expand Down Expand Up @@ -671,8 +670,7 @@ def test_delete_pipeline(self):
self.assertEqual(True, r["data"]["deletePipeline"]["success"])
self.assertEqual(0, len(Pipeline.objects.filter_for_user(user=self.USER_ROOT)))

@patch("hexa.pipelines.schema.mutations.track")
def test_pipeline_new_run(self, mocked_track):
def test_pipeline_new_run(self):
self.assertEqual(0, len(PipelineRun.objects.all()))
self.test_create_pipeline_version()
self.assertEqual(1, len(Pipeline.objects.all()))
Expand All @@ -698,8 +696,7 @@ def test_pipeline_new_run(self, mocked_track):
)
self.assertEqual(1, len(PipelineRun.objects.all()))

@patch("hexa.pipelines.schema.mutations.track")
def test_pipeline_new_run_with_version_config(self, mocked_track):
def test_pipeline_new_run_with_version_config(self):
self.assertEqual(0, len(PipelineRun.objects.all()))
pipeline_version_config = {"param1": "param1_data"}
self.test_create_pipeline_version(
Expand Down Expand Up @@ -731,8 +728,7 @@ def test_pipeline_new_run_with_version_config(self, mocked_track):
pipeline_version_config, r["data"]["runPipeline"]["run"]["config"]
)

@patch("hexa.pipelines.schema.mutations.track")
def test_pipeline_new_run_with_pipeline_run_config(self, mocked_track):
def test_pipeline_new_run_with_pipeline_run_config(self):
self.assertEqual(0, len(PipelineRun.objects.all()))
pipeline_version_config = {"param1": "param1_data"}
self.test_create_pipeline_version(
Expand Down Expand Up @@ -763,8 +759,7 @@ def test_pipeline_new_run_with_pipeline_run_config(self, mocked_track):
self.assertEqual(1, len(PipelineRun.objects.all()))
self.assertEqual(pipeline_run_config, r["data"]["runPipeline"]["run"]["config"])

@patch("hexa.pipelines.schema.mutations.track")
def test_pipeline_new_run_with_empty_pipeline_run_config(self, mocked_track):
def test_pipeline_new_run_with_empty_pipeline_run_config(self):
self.assertEqual(0, len(PipelineRun.objects.all()))
pipeline_version_config = {"param1": "param1_data", "param2": "param2_data"}
self.test_create_pipeline_version(
Expand Down Expand Up @@ -1938,8 +1933,7 @@ def test_upload_unschedulable_pipeline(self):
r["data"]["uploadPipeline"],
)

@patch("hexa.pipelines.schema.mutations.track")
def test_pipeline_new_run_with_timeout(self, mocked_track):
def test_pipeline_new_run_with_timeout(self):
self.test_create_pipeline()

code1 = Pipeline.objects.filter_for_user(user=self.USER_ROOT).first().code
Expand Down Expand Up @@ -2001,8 +1995,7 @@ def test_pipeline_new_run_with_timeout(self, mocked_track):
r["data"]["runPipeline"],
)

@patch("hexa.pipelines.schema.mutations.track")
def test_pipeline_new_run_default_timeout(self, mocked_track):
def test_pipeline_new_run_default_timeout(self):
self.assertEqual(0, len(PipelineRun.objects.all()))
self.test_create_pipeline_version()
self.assertEqual(1, len(Pipeline.objects.all()))
Expand Down
18 changes: 6 additions & 12 deletions hexa/pipelines/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def test_run_pipeline_not_enabled(self):
self.assertEqual(r.status_code, 400)
self.assertEqual(r.json(), {"error": "Pipeline has no webhook enabled"})

@patch("hexa.pipelines.views.track")
def test_run_pipeline_notebook_webhook(self, mocked_track):
def test_run_pipeline_notebook_webhook(self):
pipeline = Pipeline.objects.create(
code="new_pipeline",
name="notebook.ipynb",
Expand All @@ -120,8 +119,7 @@ def test_run_pipeline_notebook_webhook(self, mocked_track):
self.assertEqual(str(pipeline.last_run.id), response.json()["run_id"])
self.assertEqual(pipeline.last_run.trigger_mode, PipelineRunTrigger.WEBHOOK)

@patch("hexa.pipelines.views.track")
def test_run_pipeline_valid(self, mocked_track):
def test_run_pipeline_valid(self):
self.assertEqual(self.PIPELINE.last_run, None)
response = self.client.post(
reverse(
Expand All @@ -136,8 +134,7 @@ def test_run_pipeline_valid(self, mocked_track):
self.PIPELINE.last_run.trigger_mode, PipelineRunTrigger.WEBHOOK
)

@patch("hexa.pipelines.views.track")
def test_run_pipeline_old_token(self, mocked_track):
def test_run_pipeline_old_token(self):
self.assertEqual(self.PIPELINE.last_run, None)
old_token = self.PIPELINE.webhook_token

Expand All @@ -164,8 +161,7 @@ def test_run_pipeline_old_token(self, mocked_track):
self.assertEqual(response.status_code, 404)
self.assertEqual(response.json(), {"error": "Pipeline not found"})

@patch("hexa.pipelines.views.track")
def test_run_pipeline_specific_version(self, mocked_track):
def test_run_pipeline_specific_version(self):
response = self.client.post(
reverse(
"pipelines:run_with_version",
Expand All @@ -190,8 +186,7 @@ def test_run_pipeline_invalid_version(self):
self.assertEqual(response.status_code, 404)
self.assertEqual(response.json(), {"error": "Pipeline version not found"})

@patch("hexa.pipelines.views.track")
def test_run_pipeline_with_multiple_config(self, mocked_track):
def test_run_pipeline_with_multiple_config(self):
self.assert200withConfig(
[
{
Expand Down Expand Up @@ -423,8 +418,7 @@ def test_urlencoded_boolean_parameter(self):
content_type="application/x-www-form-urlencoded",
)

@patch("hexa.pipelines.views.track")
def test_send_mail_notifications(self, mocked_track):
def test_send_mail_notifications(self):
endpoint_url = reverse(
"pipelines:run",
args=[self.PIPELINE.webhook_token],
Expand Down
4 changes: 2 additions & 2 deletions hexa/user_management/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from django.dispatch import receiver
from django.http import HttpRequest

from hexa.core.analytics import track_user
from hexa.core.analytics import set_user_properties
from hexa.user_management.models import User


@receiver(user_logged_in, sender=User, dispatch_uid="user_logged_in_handler")
def user_logged_in_handler(sender: type, request: HttpRequest, user: User, **kwargs):
track_user(user)
set_user_properties(user)

0 comments on commit fdd912e

Please sign in to comment.