Deterministic Split
The deterministic_balanced_split function provides a way to split your data into balanced groups based on a specified column. This is particularly useful for A/B testing scenarios where you need to create control and treatment groups while ensuring an even distribution of data across groups.
Basic Usage
Here is a simple example using a DataFrame containing outlet information for an A/B test:
# Sample data
data = [
(10005718, "Beer Place Names"),
(3022100630, "The Beer Trader"),
(3021600840, "Brew Buyback"),
(10004683, "The Hop Exchange"),
(10005299, "Craft Collectors"),
(3022201956, "The Ale Market"),
(10005929, "Suds & Supply"),
(10117676, "Keg & Cash"),
(3021701802, "The Beer Bank"),
(3021602324, "Liquid Gold Buyers"),
(10005335, "The Beer Broker"),
(3021602854, "Hop Hustle"),
(10006770, "Liquid Assets"),
(10004748, "The Draft Vault"),
(3021702073, "Can & Cash"),
(3021800406, "Barrel & Bills"),
(3022200976, "The Beer Spot"),
(3022101082, "Craft Cashout"),
(3022100407, "Brew Bucks"),
(3022100578, "The Growler Market")
]
df = spark.createDataFrame(data, ["OutletID", "OutletName"])
Once you have the DataFrame, you can assign groups:
from heiwhy.data_split import deterministic_balanced_split
# Basic split into control and treatment groups
result_df = deterministic_balanced_split(
dataframe=df,
id_column="OutletID",
number_of_splits=2
)
# View the results
result_df.groupBy("group").count().orderBy("group").show()
When running this code, you will see the following logs indicating that the function is automatically selecting an optimal hash method:
2025-03-03 11:00:22.529 | INFO | heiwhy.data_split.deterministic_split:deterministic_balanced_split:232 - Splitting data into 3 groups based on OutletID
2025-03-03 11:00:22.530 | WARNING | heiwhy.data_split.deterministic_split:deterministic_balanced_split:243 -
No hash method specified. The optimal method will be determined automatically.
Note that this may take longer to run, if you would like to reduce runtime, specify a hash method.
2025-03-03 11:00:22.921 | INFO | heiwhy.data_split.deterministic_split:deterministic_balanced_split:252 - Automatically selected hash method: md5
This will produce output similar to:
| group | count |
|---|---|
| group_1 | 10 |
| group_2 | 10 |
As shown in the logs, when no hash method is specified, the function will automatically determine the optimal method to use. However, this automatic selection process may increase the runtime. To optimise performance, you can explicitly specify a hash method as shown in the Advanced Options section.
Customising Group Names
You can customise the group names using the group_names parameter to make the groups more meaningful:
# Split with custom group names for A/B test
result_df = deterministic_balanced_split(
dataframe=df,
id_column="OutletID",
number_of_splits=2,
group_names=["control", "treatment"]
)
# View the results
result_df.groupBy("group").count().orderBy("group").show()
This will produce:
| group | count |
|---|---|
| control | 10 |
| treatment | 10 |
Multiple Treatment Groups
You can also split your data into multiple groups for testing different treatments:
# Split into control and two treatment groups
result_df = deterministic_balanced_split(
dataframe=df,
id_column="OutletID",
number_of_splits=3,
group_names=["control", "treatment_a", "treatment_b"]
)
# View the results
result_df.groupBy("group").count().orderBy("group").show()
When running this code, you will see the following logs:
2025-03-03 12:13:27.448 | INFO | heiwhy.data_split.deterministic_split:deterministic_balanced_split:232 - Splitting data into 3 groups based on OutletID
2025-03-03 12:13:27.448 | INFO | heiwhy.data_split.deterministic_split:deterministic_balanced_split:262 - Group names: ['control', 'treatment_a', 'treatment_b']
This will produce:
| group | count |
|---|---|
| control | 7 |
| treatment_a | 7 |
| treatment_b | 6 |
Customising Output Column Name
You can also specify a custom name for the group column using the output_column parameter:
# Split with custom group column name
result_df = deterministic_balanced_split(
dataframe=df,
id_column="OutletID",
number_of_splits=2,
output_column="test_group",
group_names=["control", "treatment"]
)
# View the results
result_df.groupBy("test_group").count().orderBy("test_group").show()
This will produce:
| test_group | count |
|---|---|
| control | 10 |
| treatment | 10 |
Advanced Options
The function also supports additional parameters:
hash_method: Specify the hashing algorithm to use (e.g., "md5", "sha256")
For example:
result_df = deterministic_balanced_split(
dataframe=df,
id_column="OutletID",
number_of_splits=2,
group_names=["control", "treatment"],
hash_method="md5"
)
When specifying the hash method, you will see logs confirming your choice:
2025-03-03 12:13:27.448 | INFO | heiwhy.data_split.deterministic_split:deterministic_balanced_split:232 - Splitting data into 2 groups based on OutletID
2025-03-03 12:13:27.448 | INFO | heiwhy.data_split.deterministic_split:deterministic_balanced_split:254 - Using specified hash method: md5
2025-03-03 12:13:27.448 | INFO | heiwhy.data_split.deterministic_split:deterministic_balanced_split:262 - Group names: ['control', 'treatment']
Notes
- The function ensures that the split is deterministic, meaning the same input will always produce the same output
- The distribution across groups is balanced as much as possible
- The function works with any column type, not just numeric values
- The default group names are "group_1", "group_2", etc.
- The default group column name is "group"