Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creating Causal Identification module #1166

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open

Conversation

cetagostini
Copy link
Contributor

@cetagostini cetagostini commented Nov 4, 2024

Description

Short description: Integration of CausalGraphModel in BaseMMM Class

This update integrates a CausalGraphModel into the BaseMMM class, allowing for automated causal identification based on backdoor criteria, assuming a given Directed Acyclic Graph (DAG).

Summary of Changes

  1. Added Causal Graph Option:

    • The BaseMMM class now accepts an optional dag parameter, which can be provided either as a string (DOT format) or a networkx.DiGraph.
    • If dag is provided, a CausalGraphModel is instantiated to analyze causal relationships and determine necessary adjustment sets.
  2. Automatic Minimal Adjustment Set Handling:

    • The BaseMMM initialization now includes logic to calculate the minimal adjustment set required to estimate the causal effect of the treatment variables (assume to be media channels) on the outcome.
    • control_columns are automatically updated to include variables from the minimal adjustment set only.
    • If the variable yearly_seasonality is not in the minimal adjustment set, the yearly_seasonality parameter is set to None, effectively disabling it in the model.
  3. Warnings for Missing Adjustment Sets:

    • If a minimal adjustment set cannot be identified, a warning is issued, and not modifications are made during the initialization.

Code Example

Here's how to initialize BaseMMM with a DAG for causal inference:

dag_str = """
digraph {
    x1 -> y;
    x2 -> y;
    yearly_seasonality -> y;
    event_1 -> y;
    event_2 -> y;
}
"""

mmm = MMM(
    model_config=my_model_config,
    sampler_config=my_sampler_config,
    date_column="date_week",
    adstock=GeometricAdstock(l_max=8),
    saturation=LogisticSaturation(),
    channel_columns=["x1", "x2"],
    control_columns=["event_1", "event_2"],
    yearly_seasonality=2,  # Disabled if 'yearly_seasonality' is not in minimal adjustment set
    dag=dag_str,
    outcome_column="y",
)

Related Issue

  • Closes #
  • Related to #

Checklist

Modules affected

  • MMM
  • CLV

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc-marketing--1166.org.readthedocs.build/en/1166/

@github-actions github-actions bot added the MMM label Nov 4, 2024
@cetagostini cetagostini requested review from wd60622 and juanitorduz and removed request for wd60622 November 4, 2024 23:35
@wd60622
Copy link
Contributor

wd60622 commented Nov 4, 2024

What is z in the 2nd body example? Would that be in the model?

@wd60622 wd60622 added causal inference enhancement New feature or request labels Nov 4, 2024
@cetagostini
Copy link
Contributor Author

What is z in the 2nd body example? Would that be in the model?

Old example, I did the correction!

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@github-actions github-actions bot added the docs Improvements or additions to documentation label Nov 13, 2024
Copy link

codecov bot commented Nov 13, 2024

Codecov Report

Attention: Patch coverage is 29.41176% with 36 lines in your changes missing coverage. Please review.

Project coverage is 36.62%. Comparing base (00f84b9) to head (3c4f5c7).
Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
pymc_marketing/mmm/causal.py 31.42% 24 Missing ⚠️
pymc_marketing/mmm/mmm.py 7.69% 12 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (00f84b9) and HEAD (3c4f5c7). Click for more details.

HEAD has 2 uploads less than BASE
Flag BASE (00f84b9) HEAD (3c4f5c7)
3 1
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1166       +/-   ##
===========================================
- Coverage   95.59%   36.62%   -58.97%     
===========================================
  Files          39       40        +1     
  Lines        4066     4117       +51     
===========================================
- Hits         3887     1508     -2379     
- Misses        179     2609     +2430     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.


🚨 Try these New Features:

@cetagostini cetagostini marked this pull request as draft November 15, 2024 16:44
@github-actions github-actions bot added the tests label Nov 16, 2024
@cetagostini cetagostini marked this pull request as ready for review November 16, 2024 22:22
Copy link
Collaborator

@juanitorduz juanitorduz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool stuff @cetagostini ! I gave a quick look into the code and I think it brings a lot of value. Nevertheless, I think we need to improve the right level of abstraction and do not include this feature in the BaseMMM class (see comment below)

"""

def __init__(
self, causal_model: CausalModel, treatment: list[str] | tuple[str], outcome: str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we type treatment as an iterator https://wiki.python.org/moin/Iterator?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(applied to all below)

Comment on lines +117 to +127
dag: str | None = Field(
None,
description="Optional DAG provided as a string Dot format for causal identification.",
),
treatment_nodes: list[str] | tuple[str] | None = Field(
None,
description="Column names of the variables of interest to identify causal effects on outcome.",
),
outcome_node: str | None = Field(
None, description="Name of the outcome variable."
),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where I would like to discuss the API. Our MMM class is already a huge monolith of many components, and I would like us to start modularizing more or even making it a subclass.

For instance, we can keep BaseMMM as it is and have an additional

CausalMMM(BaseMMM), and if people want to use this class, they need to install DoWhy. I am personally against adding DoWhy as a required dependency, as in my experience, they sometimes hard-pin soma packages and can make it harder to resolve dependencies. WDYT?

Thoughts @wd60622 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this idea. I can work quickly on it, will wait for William comments 🙌🏻 Probably will have a meeting with him on Tuesday!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the dependencies are the issue then I think we can get away with having the dowhy and networkx only be required if the dag is specified. That would make models with backward compat not needing to add the new dependencies for the same model. Would only checking for these depends in the case of using this functionality solve your concerns? @juanitorduz

I think going the route of subclassing could just add more code to manage 😢

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think going the route of subclassing could just add more code to manage 😢

true ... what would be your suggestion @wd60622 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think less code to manage its better and users still import the same MMM class. The amount of code lines is only 20, I don't see it as something crazy. Whats your opinion?

Comment on lines +344 to +352
attrs["dag"] = (
json.dumps(self.dag, default=default) if hasattr(self, "dag") else "None"
)
attrs["treatment_nodes"] = (
self.treatment_nodes if hasattr(self, "treatment_nodes") else "None"
)
attrs["outcome_node"] = (
self.outcome_node if hasattr(self, "outcome_node") else "None"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think this has to be in the model builder class, which is much more general. For instance, this is not necessary for the CLV module.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100%

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then where should I added? I got issues if I don't add them here, but let me try again.

Comment on lines +68 to +69
"networkx",
"dowhy",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should not be core dependencies but rather optional (like numpyro for sampling). See the comment above on abstraction.

Copy link
Contributor Author

@cetagostini cetagostini Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm following your and @wd60622 idea and moving this to optional dependencies.

Copy link

review-notebook-app bot commented Nov 17, 2024

View / edit / reply to this conversation on ReviewNB

juanitorduz commented on 2024-11-17T18:02:02Z
----------------------------------------------------------------

Did anything chane in the MMM example notebook?


Copy link

review-notebook-app bot commented Nov 17, 2024

View / edit / reply to this conversation on ReviewNB

juanitorduz commented on 2024-11-17T18:20:27Z
----------------------------------------------------------------

Add subtitle like: business problem


Copy link

review-notebook-app bot commented Nov 17, 2024

View / edit / reply to this conversation on ReviewNB

juanitorduz commented on 2024-11-17T18:20:28Z
----------------------------------------------------------------

Shall we remove the first data points which are generated by the natural fact that we can not adstock much for the initial point ?


Copy link

review-notebook-app bot commented Nov 17, 2024

View / edit / reply to this conversation on ReviewNB

juanitorduz commented on 2024-11-17T18:20:28Z
----------------------------------------------------------------

Again, lets remove the first point because this initial jump is just artificial and looks odd.


Copy link

review-notebook-app bot commented Nov 17, 2024

View / edit / reply to this conversation on ReviewNB

juanitorduz commented on 2024-11-17T18:20:29Z
----------------------------------------------------------------

Any idea on the divergences?


Copy link

review-notebook-app bot commented Nov 17, 2024

View / edit / reply to this conversation on ReviewNB

juanitorduz commented on 2024-11-17T18:20:30Z
----------------------------------------------------------------

Can we use $x_1$ instead of $x1$?


Copy link

review-notebook-app bot commented Nov 17, 2024

View / edit / reply to this conversation on ReviewNB

juanitorduz commented on 2024-11-17T18:20:31Z
----------------------------------------------------------------

Maybe we should display the HDI instead of the heat plots as we need to fix that the legend here is miningless (I will fix this soon)


Copy link

review-notebook-app bot commented Nov 17, 2024

View / edit / reply to this conversation on ReviewNB

juanitorduz commented on 2024-11-17T18:20:31Z
----------------------------------------------------------------

Do we need to display these warnings?


@juanitorduz
Copy link
Collaborator

@carlosagostini I really liked the notebook 🚀 !

I think if we improve the level of abstraction (see comments above) and update the notebooks, we can merge this one soon!

@drbenvincent It would be great if you could review the notebook to provide feedback if you have time 🙏 :)

tests/mmm/test_causal.py Outdated Show resolved Hide resolved
outcome=self.outcome_node,
)

self.control_columns = self.causal_graphical_model.compute_adjustment_sets(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I understanding that this line will modify the control columns for the model? Then we don't need to change the build_model method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, thats why I added here.

Comment on lines +125 to +127
outcome_node: str | None = Field(
None, description="Name of the outcome variable."
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to specify this? We already have the output_var. Could that be leveraged?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me check, not necessarily.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we could simplify, that'd be great

Copy link
Contributor Author

@cetagostini cetagostini Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wd60622 we can but then I'll need to restrict that output variable in the DAG is always y (The return from output_var ). The use of sales or revenue or registrations or any other would not be possible as name. Not crazy, but then we avoid one parameter. What do you think?

Copy link
Contributor

@wd60622 wd60622 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments and questions

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
causal inference docs Improvements or additions to documentation enhancement New feature or request MMM tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants