Skip to content

Deterministic Split

Functions for deterministic data splitting with configurable hash methods.

HashMethod

Bases: Enum

Available hash methods for deterministic splitting.

Methods:

Name Description
DEFAULT : str

PySpark's default hash function

XXHASH64 : str

xxHash algorithm, generally faster than cryptographic hashes

MD5 : str

MD5 cryptographic hash function

SHA2 : str

SHA-256 cryptographic hash function

Source code in heiwhy/data_split/deterministic_split.py
class HashMethod(Enum):
    """Available hash methods for deterministic splitting.

    Methods
    -------
    DEFAULT : str
        PySpark's default hash function
    XXHASH64 : str
        xxHash algorithm, generally faster than cryptographic hashes
    MD5 : str
        MD5 cryptographic hash function
    SHA2 : str
        SHA-256 cryptographic hash function
    """

    DEFAULT = "default"
    XXHASH64 = "xxhash64"
    MD5 = "md5"
    SHA2 = "sha2"

    @classmethod
    def from_string(cls, method_name: str) -> "HashMethod":
        """Convert string to HashMethod enum.

        Parameters
        ----------
        method_name : str
            Name of the hash method to use. Case-insensitive.

        Returns
        -------
        HashMethod
            The corresponding HashMethod enum value

        Raises
        ------
        ValueError
            If the method name is not recognized
        """
        try:
            return cls[method_name.upper()]
        except KeyError as err:
            valid_methods = [method.value for method in cls]
            raise ValueError(
                f"Unknown hash method: {method_name}. Valid methods are: {', '.join(valid_methods)}"
            ) from err

from_string(method_name) classmethod

Convert string to HashMethod enum.

Parameters:

Name Type Description Default
method_name str

Name of the hash method to use. Case-insensitive.

required

Returns:

Type Description
HashMethod

The corresponding HashMethod enum value

Raises:

Type Description
ValueError

If the method name is not recognized

Source code in heiwhy/data_split/deterministic_split.py
@classmethod
def from_string(cls, method_name: str) -> "HashMethod":
    """Convert string to HashMethod enum.

    Parameters
    ----------
    method_name : str
        Name of the hash method to use. Case-insensitive.

    Returns
    -------
    HashMethod
        The corresponding HashMethod enum value

    Raises
    ------
    ValueError
        If the method name is not recognized
    """
    try:
        return cls[method_name.upper()]
    except KeyError as err:
        valid_methods = [method.value for method in cls]
        raise ValueError(
            f"Unknown hash method: {method_name}. Valid methods are: {', '.join(valid_methods)}"
        ) from err

deterministic_balanced_split(dataframe, id_column, number_of_splits, output_column='group', group_names=None, hash_method=None)

Assign records to groups based on a hash of the ID column for A/B/n testing.

This function performs deterministic group assignment for A/B/n testing by hashing ID values. The assignment process guarantees several key properties:

  • Deterministic: The same ID will always be assigned to the same group
  • Consistent: The assignment remains stable regardless of dataset size
  • Balanced: Groups are as evenly sized as possible given the hash distribution

Parameters:

Name Type Description Default
dataframe DataFrame

Input DataFrame containing the data to be split

required
id_column str

Name of the column containing unique identifiers

required
number_of_splits int

Number of groups to split the data into

required
output_column str

Name of the output column containing group assignments, by default "group"

'group'
group_names list[str] | None

Custom names for the groups. Must match number_of_splits if provided. If None, groups will be named "group_1", "group_2", etc.

None
hash_method str | HashMethod | None

Hash method to use. Can be specified as a string or HashMethod enum. Valid string values are: "default", "xxhash64", "md5", "sha2" If None, will automatically find the most balanced method.

None

Returns:

Type Description
DataFrame

DataFrame with an additional column containing group assignments

Examples:

>>> # Example 1: Automatic hash method selection
>>> df = spark.createDataFrame(
...     data=[(1,), (2,), (3,)],
...     schema=["user_id"]
... )
>>> result = deterministic_balanced_split(
...     dataframe=df,
...     id_column="user_id",
...     number_of_splits=2
... )
>>> # Example 2: Specify hash method as string with custom group names
>>> result = deterministic_balanced_split(
...     dataframe=df,
...     id_column="user_id",
...     number_of_splits=2,
...     group_names=["control", "treatment"],
...     hash_method="xxhash64",
...     output_column="experiment_group"
... )
Source code in heiwhy/data_split/deterministic_split.py
def deterministic_balanced_split(
    dataframe: DataFrame,
    id_column: str,
    number_of_splits: int,
    output_column: Optional[str] = "group",
    group_names: Optional[list[str]] = None,
    hash_method: Union[str, HashMethod, None] = None,
) -> DataFrame:
    """Assign records to groups based on a hash of the ID column for A/B/n testing.

    This function performs deterministic group assignment for A/B/n testing by hashing
    ID values. The assignment process guarantees several key properties:

    - Deterministic: The same ID will always be assigned to the same group
    - Consistent: The assignment remains stable regardless of dataset size
    - Balanced: Groups are as evenly sized as possible given the hash distribution

    Parameters
    ----------
    dataframe : pyspark.sql.DataFrame
        Input DataFrame containing the data to be split
    id_column : str
        Name of the column containing unique identifiers
    number_of_splits : int
        Number of groups to split the data into
    output_column : str, optional
        Name of the output column containing group assignments, by default "group"
    group_names : list[str] | None, optional
        Custom names for the groups. Must match number_of_splits if provided.
        If None, groups will be named "group_1", "group_2", etc.
    hash_method : str | HashMethod | None, optional
        Hash method to use. Can be specified as a string or HashMethod enum.
        Valid string values are: "default", "xxhash64", "md5", "sha2"
        If None, will automatically find the most balanced method.

    Returns
    -------
    pyspark.sql.DataFrame
        DataFrame with an additional column containing group assignments

    Examples
    --------
    >>> # Example 1: Automatic hash method selection
    >>> df = spark.createDataFrame(
    ...     data=[(1,), (2,), (3,)],
    ...     schema=["user_id"]
    ... )
    >>> result = deterministic_balanced_split(
    ...     dataframe=df,
    ...     id_column="user_id",
    ...     number_of_splits=2
    ... )

    >>> # Example 2: Specify hash method as string with custom group names
    >>> result = deterministic_balanced_split(
    ...     dataframe=df,
    ...     id_column="user_id",
    ...     number_of_splits=2,
    ...     group_names=["control", "treatment"],
    ...     hash_method="xxhash64",
    ...     output_column="experiment_group"
    ... )
    """

    if number_of_splits < 2:
        raise ValueError(f"Number of splits must be at least 2, got {number_of_splits}")

    logger.info(f"Splitting data into {number_of_splits} groups based on {id_column}")

    if group_names is not None:
        if len(group_names) != number_of_splits:
            raise ValueError(
                f"Number of group names ({len(group_names)}) must match number_of_splits ({number_of_splits})"
            )

    hash_method = _get_hash_method(hash_method)

    if hash_method is None:
        logger.warning(
            dedent(
                """
                No hash method specified. The optimal method will be determined automatically.
                Note that this may take longer to run, if you would like to reduce runtime, specify a hash method.
                """
            )
        )
        hash_method = _find_best_hash_method(dataframe, id_column, number_of_splits)
        logger.info(f"Automatically selected hash method: {hash_method.value}")
    else:
        logger.info(f"Using specified hash method: {hash_method.value}")

    hash_func = _get_hash_function(hash_method, id_column)
    split_number = (F.abs(hash_func) % F.lit(number_of_splits)) + 1

    if group_names is not None:
        mapping_expr = F.create_map(*[F.lit(x) for x in sum([(i + 1, name) for i, name in enumerate(group_names)], ())])
        dataframe_grouped = dataframe.withColumn(output_column, mapping_expr[split_number.cast("string")])
        logger.info(f"Group names: {group_names}")
    else:
        dataframe_grouped = dataframe.withColumn(output_column, F.concat(F.lit("group_"), split_number.cast("string")))

    return dataframe_grouped