Skip to content

FLOP Cost Query API

Pre-execution cost estimation functions. These are pure functions that compute FLOP costs from shapes without executing anything or consuming budget.

Quick example

import mechestim as me

# Einsum cost
cost = me.flops.einsum_cost('ij,jk->ik', shapes=[(256, 256), (256, 256)])
print(f"Matmul: {cost:,} FLOPs")  # 33,554,432

# SVD cost
cost = me.flops.svd_cost(m=256, n=256, k=10)
print(f"SVD: {cost:,} FLOPs")  # 655,360

# Pointwise (unary/binary) cost
cost = me.flops.pointwise_cost(shape=(256, 256))
print(f"Pointwise: {cost:,} FLOPs")  # 65,536

# Reduction cost
cost = me.flops.reduction_cost(input_shape=(1000, 100))
print(f"Reduction: {cost:,} FLOPs")  # 100,000

API Reference

mechestim._flops

FLOP cost calculators for mechestim operations.

einsum_cost(subscripts, shapes, operand_symmetries=None)

FLOP cost of an einsum operation.

Delegates to contract_path from opt_einsum, which uses flop_count with op_factor (multiply-add = 2 FLOPs for inner products).

Parameters:

Name Type Description Default
subscripts str

Einsum subscript string.

required
shapes list of tuple of int

Shapes of the input operands.

required
operand_symmetries list of SymmetryInfo or None

Symmetry information for each input operand.

None

Returns:

Type Description
int

Estimated FLOP count.

Source code in src/mechestim/_flops.py
def einsum_cost(
    subscripts: str,
    shapes: list[tuple[int, ...]],
    operand_symmetries: "list[SymmetryInfo | None] | None" = None,
) -> int:
    """FLOP cost of an einsum operation.

    Delegates to ``contract_path`` from opt_einsum, which uses ``flop_count``
    with ``op_factor`` (multiply-add = 2 FLOPs for inner products).

    Parameters
    ----------
    subscripts : str
        Einsum subscript string.
    shapes : list of tuple of int
        Shapes of the input operands.
    operand_symmetries : list of SymmetryInfo or None, optional
        Symmetry information for each input operand.

    Returns
    -------
    int
        Estimated FLOP count.
    """
    from mechestim._opt_einsum import contract_path

    # Convert SymmetryInfo -> IndexSymmetry if needed
    index_syms = None
    if operand_symmetries and any(s is not None for s in operand_symmetries):
        input_parts = subscripts.replace(" ", "").split("->")[0].split(",")
        index_syms = []
        for sym, chars in zip(operand_symmetries, input_parts):
            if sym is None:
                index_syms.append(None)
            else:
                groups = [
                    frozenset(chars[d] for d in g)
                    for g in sym.symmetric_axes
                    if len(g) >= 2
                ]
                index_syms.append(groups if groups else None)

    _, path_info = contract_path(
        subscripts, *shapes, shapes=True, input_symmetries=index_syms
    )
    return path_info.optimized_cost

parse_einsum_subscripts(subscripts)

Parse an einsum subscript string into input and output index lists.

Parameters:

Name Type Description Default
subscripts str

Einsum subscript string (e.g., 'ij,jk->ik').

required

Returns:

Name Type Description
inputs list of list of str

Index labels for each input operand.

output list of str

Index labels for the output.

Source code in src/mechestim/_flops.py
def parse_einsum_subscripts(subscripts: str) -> tuple[list[list[str]], list[str]]:
    """Parse an einsum subscript string into input and output index lists.

    Parameters
    ----------
    subscripts : str
        Einsum subscript string (e.g., ``'ij,jk->ik'``).

    Returns
    -------
    inputs : list of list of str
        Index labels for each input operand.
    output : list of str
        Index labels for the output.
    """
    subscripts = subscripts.replace(" ", "")
    if "->" in subscripts:
        input_part, output_part = subscripts.split("->")
        output = list(output_part)
    else:
        input_part = subscripts
        all_labels: list[str] = []
        for part in input_part.split(","):
            all_labels.extend(list(part))
        counts = Counter(all_labels)
        output = sorted(lbl for lbl, c in counts.items() if c == 1)
    inputs = [list(part) for part in input_part.split(",")]
    return inputs, output

pointwise_cost(shape, symmetry_info=None)

FLOP cost of a pointwise (element-wise) operation.

Parameters:

Name Type Description Default
shape tuple of int

Shape of the array.

required
symmetry_info SymmetryInfo or None

If provided, only unique elements are counted.

None

Returns:

Type Description
int

Estimated FLOP count (one per element, or one per unique element).

Source code in src/mechestim/_flops.py
def pointwise_cost(
    shape: tuple[int, ...], symmetry_info: "SymmetryInfo | None" = None
) -> int:
    """FLOP cost of a pointwise (element-wise) operation.

    Parameters
    ----------
    shape : tuple of int
        Shape of the array.
    symmetry_info : SymmetryInfo or None, optional
        If provided, only unique elements are counted.

    Returns
    -------
    int
        Estimated FLOP count (one per element, or one per unique element).
    """
    if symmetry_info is not None:
        return max(symmetry_info.unique_elements, 1)
    result = 1
    for dim in shape:
        result *= dim
    return max(result, 1)

reduction_cost(input_shape, axis=None, symmetry_info=None)

FLOP cost of a reduction operation.

Parameters:

Name Type Description Default
input_shape tuple of int

Shape of the input array.

required
axis int or None

Axis along which to reduce. If None, reduce over all elements.

None
symmetry_info SymmetryInfo or None

If provided, only unique elements are counted.

None

Returns:

Type Description
int

Estimated FLOP count (one per element).

Notes

The axis parameter is accepted for API consistency but does not affect the result: a reduction always touches every element regardless of which axis is reduced, so the cost is always prod(input_shape).

Source code in src/mechestim/_flops.py
def reduction_cost(
    input_shape: tuple[int, ...],
    axis: int | None = None,
    symmetry_info: "SymmetryInfo | None" = None,
) -> int:
    """FLOP cost of a reduction operation.

    Parameters
    ----------
    input_shape : tuple of int
        Shape of the input array.
    axis : int or None, optional
        Axis along which to reduce. If None, reduce over all elements.
    symmetry_info : SymmetryInfo or None, optional
        If provided, only unique elements are counted.

    Returns
    -------
    int
        Estimated FLOP count (one per element).

    Notes
    -----
    The ``axis`` parameter is accepted for API consistency but does not
    affect the result: a reduction always touches every element regardless
    of which axis is reduced, so the cost is always ``prod(input_shape)``.
    """
    if symmetry_info is not None:
        return max(symmetry_info.unique_elements, 1)
    result = 1
    for dim in input_shape:
        result *= dim
    return max(result, 1)

search_cost(queries, sorted_size)

FLOP cost of binary search.

Parameters:

Name Type Description Default
queries int

Number of search queries.

required
sorted_size int

Size of the sorted array being searched.

required

Returns:

Type Description
int

Estimated FLOP count: queries * ceil(log2(sorted_size)).

Source code in src/mechestim/_flops.py
def search_cost(queries: int, sorted_size: int) -> int:
    """FLOP cost of binary search.

    Parameters
    ----------
    queries : int
        Number of search queries.
    sorted_size : int
        Size of the sorted array being searched.

    Returns
    -------
    int
        Estimated FLOP count: queries * ceil(log2(sorted_size)).
    """
    if queries <= 0:
        return 1
    return max(queries * _ceil_log2(sorted_size), 1)

sort_cost(n)

FLOP cost of comparison-based sort.

Parameters:

Name Type Description Default
n int

Number of elements to sort.

required

Returns:

Type Description
int

Estimated FLOP count: n * ceil(log2(n)).

Source code in src/mechestim/_flops.py
def sort_cost(n: int) -> int:
    """FLOP cost of comparison-based sort.

    Parameters
    ----------
    n : int
        Number of elements to sort.

    Returns
    -------
    int
        Estimated FLOP count: n * ceil(log2(n)).
    """
    if n <= 0:
        return 1
    return max(n * _ceil_log2(n), 1)

svd_cost(m, n, k=None)

FLOP cost of a (truncated) SVD.

Parameters:

Name Type Description Default
m int

Number of rows.

required
n int

Number of columns.

required
k int or None

Number of singular values/vectors to compute. Defaults to min(m, n).

None

Returns:

Type Description
int

Estimated FLOP count: m * n * k.

Notes

Based on Golub-Reinsch bidiagonalization.

Source code in src/mechestim/_flops.py
def svd_cost(m: int, n: int, k: int | None = None) -> int:
    """FLOP cost of a (truncated) SVD.

    Parameters
    ----------
    m : int
        Number of rows.
    n : int
        Number of columns.
    k : int or None, optional
        Number of singular values/vectors to compute. Defaults to min(m, n).

    Returns
    -------
    int
        Estimated FLOP count: m * n * k.

    Notes
    -----
    Based on Golub-Reinsch bidiagonalization.
    """
    if k is None:
        k = min(m, n)
    return m * n * k