Source code for showerdata.cluster_module

"""Clustering hits into readout cells using a regular grid."""

import argparse
import multiprocessing
import time

import numpy as np
import numpy.typing as npt

from . import core as showerdata
from . import detector
from .core import Showers

__all__ = ["cluster"]


def _version_ge(v1: str, v2: str) -> bool:
    """Check if version string v1 is greater than or equal to v2."""
    v1 = v1.strip().lstrip("v")
    v2 = v2.strip().lstrip("v")
    parts1 = tuple(int(p) for p in v1.split(".") if p.isdigit())
    parts2 = tuple(int(p) for p in v2.split(".") if p.isdigit())

    # Pad the shorter version with zeros
    length = max(len(parts1), len(parts2))
    parts1 += (0,) * (length - len(parts1))
    parts2 += (0,) * (length - len(parts2))

    return parts1 >= parts2


def _cluster_shower_part(
    pos: npt.NDArray[np.float32],
    e: npt.NDArray[np.float32],
    cell_size: float,
    shift: npt.NDArray[np.float32],
    t: npt.NDArray[np.float32] | None = None,
) -> tuple[
    npt.NDArray[np.float32], npt.NDArray[np.float32], npt.NDArray[np.float32] | None
]:
    """Cluster a part of a shower (ECAL or HCAL)."""
    if len(pos) == 0:
        return (
            np.empty((0, 3), dtype=np.float32),
            np.empty((0,), dtype=np.float32),
            None if t is None else np.empty((0,), dtype=np.float32),
        )
    pos = pos.copy()
    pos[:, :2] += shift
    pos[:, :2] /= cell_size
    pos_idx = np.empty_like(pos, dtype=np.int32)
    pos_idx = np.floor(pos, out=pos_idx, casting="unsafe")
    pos_idx = pos_idx[:, [2, 0, 1]]  # z, x, y

    # Prior to numpy 2.3 np.unique did not have a 'sorted' argument
    if _version_ge(np.__version__, "2.3"):
        unique_idx, inverse_idx = np.unique(
            pos_idx, axis=0, return_inverse=True, sorted=True
        )
    else:
        unique_idx, inverse_idx = np.unique(pos_idx, axis=0, return_inverse=True)
    unique_idx = unique_idx[:, [1, 2, 0]]  # x, y, z
    e_clustered = np.zeros((len(unique_idx),), dtype=np.float32)
    np.add.at(e_clustered, inverse_idx, e)
    if t is not None:
        t_clustered = np.full((len(unique_idx),), np.inf, dtype=np.float32)
        np.minimum.at(t_clustered, inverse_idx, t)
    else:
        t_clustered = None
    pos_clustered = unique_idx.astype(np.float32)
    pos_clustered[:, :2] += 1 / 2
    pos_clustered[:, :2] *= cell_size
    pos_clustered[:, :2] -= shift

    return pos_clustered, e_clustered, t_clustered


def _calc_random_shift(cell_size: float, random_shift: bool) -> npt.NDArray[np.float32]:
    """Calculate a random shift for the grid."""
    if random_shift:
        return (np.random.rand(2).astype(np.float32) * cell_size) - cell_size / 2
    else:
        return np.array([0.0, 0.0], dtype=np.float32)


def _process_shower(
    shower: Showers,
    random_shift: bool,
    detector_config: detector.DetectorGeometry,
) -> Showers:
    """Process a single shower by clustering its hits."""
    pos = shower.points[0, :, :3]
    e = shower.points[0, :, 3]
    t = shower.points[0, :, 4] if shower.points.shape[2] > 4 else None
    mask = e > 0
    pos = pos[mask]
    e = e[mask]
    if t is not None:
        t = t[mask]

    ecal_mask = pos[:, 2] < detector_config.num_layers_ecal - 0.5
    hcal_mask = pos[:, 2] >= detector_config.num_layers_ecal - 0.5
    pos_clustered_ecal, e_clustered_ecal, t_clustered_ecal = _cluster_shower_part(
        pos=pos[ecal_mask],
        e=e[ecal_mask],
        cell_size=detector_config.ecal_cell_size,
        shift=_calc_random_shift(detector_config.ecal_cell_size, random_shift),
        t=t[ecal_mask] if t is not None else None,
    )
    pos_clustered_hcal, e_clustered_hcal, t_clustered_hcal = _cluster_shower_part(
        pos=pos[hcal_mask],
        e=e[hcal_mask],
        cell_size=detector_config.hcal_cell_size,
        shift=_calc_random_shift(detector_config.hcal_cell_size, random_shift),
        t=t[hcal_mask] if t is not None else None,
    )

    pos_clustered = np.concatenate((pos_clustered_ecal, pos_clustered_hcal), axis=0)
    e_clustered = np.concatenate((e_clustered_ecal, e_clustered_hcal), axis=0)
    points_clustered = np.zeros((1, len(mask), 4 if t is None else 5), dtype=np.float32)
    points_clustered[0, : len(pos_clustered), :3] = pos_clustered
    points_clustered[0, : len(e_clustered), 3] = e_clustered
    if t is not None:
        if t_clustered_ecal is None or t_clustered_hcal is None:
            raise RuntimeError("Time information missing in one of the shower parts.")
        t_clustered = np.concatenate((t_clustered_ecal, t_clustered_hcal), axis=0)
        points_clustered[0, : len(t_clustered), 4] = t_clustered

    return Showers(
        points=points_clustered,
        energies=shower.energies,
        pdg=shower.pdg,
        directions=shower.directions,
        shower_ids=shower.shower_ids,
    )


[docs] def cluster( showers: Showers, random_shift: bool = True, detector_config: detector.DetectorGeometry = detector.get_ILD_geometry(), processes: int = 1, ) -> Showers: """Cluster hits into readout cells using a regular grid. Args: showers: Showers to cluster. random_shift: Whether to apply a random shift to the grid (default: True). detector_config: Simplified detector description (default: ILD). processes: Number of parallel processes to use (default: 1, i.e. no parallelism). Returns: Clustered showers. """ if processes > 1: with multiprocessing.Pool(processes) as pool: processed_showers = pool.starmap( _process_shower, [(shower, random_shift, detector_config) for shower in showers], ) else: processed_showers = [ _process_shower(shower, random_shift, detector_config) for shower in showers ] return showerdata.concatenate(processed_showers)
def add_parser_arguments(parser: "argparse.ArgumentParser") -> None: """Add arguments for the clustering module to an argparse parser.""" parser.add_argument( "input", type=str, help="Input file containing showers to be clustered." ) parser.add_argument( "output", type=str, help="Output file to save the clustered showers." ) parser.add_argument( "--no-random-shift", action="store_true", help="Disable random shift of the grid.", ) parser.add_argument( "--batch-size", type=int, default=10000, help="Number of showers to process in each batch.", ) parser.add_argument( "--processes", type=int, default=1, help="Number of parallel processes to use (default: 1, i.e. no parallelism).", ) def initialize_parser() -> argparse.ArgumentParser: """Initialize an argparse parser for the clustering module.""" parser = argparse.ArgumentParser( description="Cluster hits into readout cells using a regular grid." ) add_parser_arguments(parser) return parser def main(args: argparse.Namespace) -> None: """Main function to run clustering from command line arguments.""" start_time = time.time() shape = showerdata.get_file_shape(args.input) with ( showerdata.ShowerDataFile(args.input, "r") as in_file, showerdata.ShowerDataFile(args.output, "w", shape=shape) as out_file, ): out_file.attrs["clustered_for"] = "ILD" for start in range(0, shape[0], args.batch_size): end = min(start + args.batch_size, shape[0]) showers = in_file[start:end] clustered_showers = cluster( showers, random_shift=not args.no_random_shift, detector_config=detector.get_ILD_geometry(), processes=args.processes, ) out_file[start:end] = clustered_showers print( f"[{time.time() - start_time:8.2f}s] Processed showers {start} to {end}." ) if __name__ == "__main__": main(initialize_parser().parse_args())