Skip to content

Deterministic Split

The deterministic_balanced_split function provides a way to split your data into balanced groups based on a specified column. This is particularly useful for A/B testing scenarios where you need to create control and treatment groups while ensuring an even distribution of data across groups.

Basic Usage

Here is a simple example using a DataFrame containing outlet information for an A/B test:

Create example DataFrame
# Sample data
data = [
    (10005718, "Beer Place Names"),
    (3022100630, "The Beer Trader"),
    (3021600840, "Brew Buyback"),
    (10004683, "The Hop Exchange"),
    (10005299, "Craft Collectors"),
    (3022201956, "The Ale Market"),
    (10005929, "Suds & Supply"),
    (10117676, "Keg & Cash"),
    (3021701802, "The Beer Bank"),
    (3021602324, "Liquid Gold Buyers"),
    (10005335, "The Beer Broker"),
    (3021602854, "Hop Hustle"),
    (10006770, "Liquid Assets"),
    (10004748, "The Draft Vault"),
    (3021702073, "Can & Cash"),
    (3021800406, "Barrel & Bills"),
    (3022200976, "The Beer Spot"),
    (3022101082, "Craft Cashout"),
    (3022100407, "Brew Bucks"),
    (3022100578, "The Growler Market")
]

df = spark.createDataFrame(data, ["OutletID", "OutletName"])

Once you have the DataFrame, you can assign groups:

Basic group assignment
from heiwhy.data_split import deterministic_balanced_split

# Basic split into control and treatment groups
result_df = deterministic_balanced_split(
    dataframe=df,
    id_column="OutletID",
    number_of_splits=2
)

# View the results
result_df.groupBy("group").count().orderBy("group").show()

When running this code, you will see the following logs indicating that the function is automatically selecting an optimal hash method:

Automatic hash method selection logs
2025-03-03 11:00:22.529 | INFO     | heiwhy.data_split.deterministic_split:deterministic_balanced_split:232 - Splitting data into 3 groups based on OutletID
2025-03-03 11:00:22.530 | WARNING  | heiwhy.data_split.deterministic_split:deterministic_balanced_split:243 -
No hash method specified. The optimal method will be determined automatically.
Note that this may take longer to run, if you would like to reduce runtime, specify a hash method.

2025-03-03 11:00:22.921 | INFO     | heiwhy.data_split.deterministic_split:deterministic_balanced_split:252 - Automatically selected hash method: md5

This will produce output similar to:

group count
group_1 10
group_2 10

As shown in the logs, when no hash method is specified, the function will automatically determine the optimal method to use. However, this automatic selection process may increase the runtime. To optimise performance, you can explicitly specify a hash method as shown in the Advanced Options section.

Customising Group Names

You can customise the group names using the group_names parameter to make the groups more meaningful:

Custom group names
# Split with custom group names for A/B test
result_df = deterministic_balanced_split(
    dataframe=df,
    id_column="OutletID",
    number_of_splits=2,
    group_names=["control", "treatment"]
)

# View the results
result_df.groupBy("group").count().orderBy("group").show()

This will produce:

group count
control 10
treatment 10

Multiple Treatment Groups

You can also split your data into multiple groups for testing different treatments:

Multiple treatment groups
# Split into control and two treatment groups
result_df = deterministic_balanced_split(
    dataframe=df,
    id_column="OutletID",
    number_of_splits=3,
    group_names=["control", "treatment_a", "treatment_b"]
)

# View the results
result_df.groupBy("group").count().orderBy("group").show()

When running this code, you will see the following logs:

Multiple treatment groups logs
2025-03-03 12:13:27.448 | INFO     | heiwhy.data_split.deterministic_split:deterministic_balanced_split:232 - Splitting data into 3 groups based on OutletID
2025-03-03 12:13:27.448 | INFO     | heiwhy.data_split.deterministic_split:deterministic_balanced_split:262 - Group names: ['control', 'treatment_a', 'treatment_b']

This will produce:

group count
control 7
treatment_a 7
treatment_b 6

Customising Output Column Name

You can also specify a custom name for the group column using the output_column parameter:

Custom column name
# Split with custom group column name
result_df = deterministic_balanced_split(
    dataframe=df,
    id_column="OutletID",
    number_of_splits=2,
    output_column="test_group",
    group_names=["control", "treatment"]
)

# View the results
result_df.groupBy("test_group").count().orderBy("test_group").show()

This will produce:

test_group count
control 10
treatment 10

Advanced Options

The function also supports additional parameters:

  • hash_method: Specify the hashing algorithm to use (e.g., "md5", "sha256")

For example:

Specific hash method example
result_df = deterministic_balanced_split(
    dataframe=df,
    id_column="OutletID",
    number_of_splits=2,
    group_names=["control", "treatment"],
    hash_method="md5"
)

When specifying the hash method, you will see logs confirming your choice:

Specified hash method logs
2025-03-03 12:13:27.448 | INFO     | heiwhy.data_split.deterministic_split:deterministic_balanced_split:232 - Splitting data into 2 groups based on OutletID
2025-03-03 12:13:27.448 | INFO     | heiwhy.data_split.deterministic_split:deterministic_balanced_split:254 - Using specified hash method: md5
2025-03-03 12:13:27.448 | INFO     | heiwhy.data_split.deterministic_split:deterministic_balanced_split:262 - Group names: ['control', 'treatment']

Notes

  • The function ensures that the split is deterministic, meaning the same input will always produce the same output
  • The distribution across groups is balanced as much as possible
  • The function works with any column type, not just numeric values
  • The default group names are "group_1", "group_2", etc.
  • The default group column name is "group"