Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
sdaza committed Dec 8, 2024
1 parent 5218af0 commit 2342a1f
Showing 1 changed file with 94 additions and 48 deletions.
142 changes: 94 additions & 48 deletions tests/test_experiment_analyzer.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,110 @@
import pytest
from pyspark.sql import SparkSession
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from experiment_utils.experiment_analyzer import ExperimentAnalyzer
from experiment_utils.spark_instance import *

# Initialize a Spark session
spark = SparkSession.builder \
.appName("Spark Test") \
.master("local[*]") \
.getOrCreate()
import numpy as np
import pandas as pd
from scipy.stats import truncnorm

# Create a simple DataFrame
data = [("Alice", 1), ("Bob", 2), ("Cathy", 3)]
columns = ["Name", "Id"]
df = spark.createDataFrame(data, columns)

# Show the DataFrame
df.show()

@pytest.fixture
def sample_data(spark):
"""Fixture for creating a sample Spark DataFrame."""
schema = StructType([
StructField("experiment_id", StringType(), True),
StructField("treatment", IntegerType(), True),
StructField("outcome1", IntegerType(), True),
StructField("covariate1", IntegerType(), True),
])
data = [
("exp1", 1, 10, 5),
("exp1", 0, 12, 6),
("exp2", 1, 15, 7),
("exp2", 0, 11, 8),
]
return spark.createDataFrame(data, schema)

def test_check_input(sample_data):
"""Test the __check_input method of ExperimentAnalyzer."""
outcomes = ["outcome1"]
def sample_data(
n_model=1000,
n_random=500,
base_model_conversion_mean=0.3,
base_model_conversion_variance=0.01,
base_random_conversion_mean=0.10,
base_random_conversion_variance=0.01,
model_treatment_effect=0.05,
random_treatment_effect=0.05,
random_seed=42,
):

np.random.seed(random_seed)

# Function to get a truncated normal distribution
def get_truncated_normal(mean, variance, size):
std_dev = np.sqrt(variance)
lower, upper = 0, 1
a, b = (lower - mean) / std_dev, (upper - mean) / std_dev
return truncnorm.rvs(a, b, loc=mean, scale=std_dev, size=size)

# Generate baseline conversions with a truncated normal distribution
base_model_conversion = get_truncated_normal(
base_model_conversion_mean, base_model_conversion_variance, n_model
)
base_random_conversion = get_truncated_normal(
base_random_conversion_mean, base_random_conversion_variance, n_random
)

# model group data
model_treatment = np.random.binomial(1, 0.8, n_model)
model_conversion = (
base_model_conversion + model_treatment_effect * model_treatment
) > np.random.rand(n_model)

model_data = pd.DataFrame(
{
"experiment": 123,
"group": "model",
"treatment": model_treatment,
"conversion": model_conversion.astype(int),
"baseline_conversion": base_model_conversion,
}
)

# random group data
random_treatment = np.random.binomial(1, 0.5, n_random)
random_conversion = (
base_random_conversion + random_treatment_effect * random_treatment
) > np.random.rand(n_random)
random_data = pd.DataFrame(
{
"experiment": 123,
"group": "random",
"treatment": random_treatment,
"conversion": random_conversion.astype(int),
"baseline_conversion": base_random_conversion,
}
)

# Combine data
data = pd.concat([model_data, random_data])
df = spark.createDataFrame(data)

return df


def test_no_covariates(sample_data):
"""Test get_effects no covariates"""
outcomes = "conversion"
treatment_col = "treatment"
experiment_identifier = ["experiment_id"]
covariates = ["covariate1"]
experiment_identifier = "experiment"

analyzer = ExperimentAnalyzer(
data=sample_data,
outcomes=outcomes,
treatment_col=treatment_col,
experiment_identifier=experiment_identifier,
covariates=covariates
)
experiment_identifier=experiment_identifier)


# This should not raise an error since all columns are present
try:
analyzer._ExperimentAnalyzer__check_input()
analyzer.get_effects()
analyzer.results
assert True
except Exception as e:
pytest.fail(f"__check_input raised an exception: {e}")
pytest.fail(f" raised an exception: {e}")

def test_missing_columns(sample_data):
"""Test the __check_input method with missing columns."""
outcomes = ["outcome1"]

def test_no_adjustment(sample_data):
"""Test get_effects no adjustments"""
outcomes = "conversion"
treatment_col = "treatment"
experiment_identifier = ["experiment_id"]
covariates = ["missing_covariate"]
experiment_identifier = "experiment"
covariates = "baseline_conversion"

analyzer = ExperimentAnalyzer(
data=sample_data,
Expand All @@ -72,6 +114,10 @@ def test_missing_columns(sample_data):
covariates=covariates
)

# Expecting an error due to missing covariate
with pytest.raises(Exception, match="The following required columns are missing from the dataframe"):
analyzer._ExperimentAnalyzer__check_input()
# This should not raise an error since all columns are present
try:
analyzer.get_effects()
analyzer.results
assert True
except Exception as e:
pytest.fail(f" raised an exception: {e}")

0 comments on commit 2342a1f

Please sign in to comment.