Skip to content

Budget

Budget management for FLOP counting. BudgetContext is the core context manager that tracks operation costs and enforces limits.

Quick example

import mechestim as me

# Explicit budget with namespace
with me.BudgetContext(flop_budget=10**7, namespace="forward") as budget:
    W = me.ones((256, 256))
    x = me.ones((256,))
    h = me.einsum('ij,j->i', W, x)

    print(f"Used: {budget.flops_used:,} / {budget.flop_budget:,}")

# Session-wide summary across all namespaces
me.budget_summary()

# Programmatic access
data = me.budget_summary_dict()

API Reference

mechestim._budget.BudgetContext

Context manager for FLOP budget enforcement.

Parameters:

Name Type Description Default
flop_budget int

Maximum number of FLOPs allowed. Must be > 0.

required
flop_multiplier float

Multiplier applied to all FLOP costs. Default 1.

1.0
Source code in src/mechestim/_budget.py
class BudgetContext:
    """Context manager for FLOP budget enforcement.

    Parameters
    ----------
    flop_budget : int
        Maximum number of FLOPs allowed. Must be > 0.
    flop_multiplier : float, optional
        Multiplier applied to all FLOP costs. Default 1.
    """

    def __init__(
        self,
        flop_budget: int,
        flop_multiplier: float = 1.0,
        quiet: bool = False,
        namespace: str | None = None,
    ):
        if flop_budget <= 0:
            raise ValueError(f"flop_budget must be > 0, got {flop_budget}")
        self._flop_budget = flop_budget
        self._flop_multiplier = flop_multiplier
        self._flops_used = 0
        self._op_log: list[OpRecord] = []
        self._quiet = quiet
        self._namespace = namespace
        self._previous_budget: BudgetContext | None = None

    @property
    def flop_budget(self) -> int:
        return self._flop_budget

    @property
    def flops_used(self) -> int:
        return self._flops_used

    @property
    def flops_remaining(self) -> int:
        return self._flop_budget - self._flops_used

    @property
    def flop_multiplier(self) -> float:
        return self._flop_multiplier

    @property
    def op_log(self) -> list[OpRecord]:
        return self._op_log

    @property
    def namespace(self) -> str | None:
        return self._namespace

    def deduct(
        self, op_name: str, *, flop_cost: int, subscripts: str | None, shapes: tuple
    ) -> None:
        """Deduct FLOPs from the budget."""
        adjusted_cost = int(flop_cost * self._flop_multiplier)
        if adjusted_cost > self.flops_remaining:
            raise BudgetExhaustedError(
                op_name, flop_cost=adjusted_cost, flops_remaining=self.flops_remaining
            )
        self._flops_used += adjusted_cost
        self._op_log.append(
            OpRecord(
                op_name=op_name,
                subscripts=subscripts,
                shapes=shapes,
                flop_cost=adjusted_cost,
                cumulative=self._flops_used,
                namespace=self._namespace,
            )
        )

    def summary(self) -> str:
        """Return a pretty-printed FLOP budget summary."""
        header = "mechestim FLOP Budget Summary"
        if self._namespace:
            header += f" [{self._namespace}]"
        lines = [
            header,
            "=" * len(header),
            f"  Total budget:  {self._flop_budget:>14,}",
            f"  Used:          {self._flops_used:>14,}  ({100 * self._flops_used / self._flop_budget:.1f}%)",
            f"  Remaining:     {self.flops_remaining:>14,}  ({100 * self.flops_remaining / self._flop_budget:.1f}%)",
            "",
            "  By operation:",
        ]
        from collections import Counter

        cost_by_op: dict[str, int] = {}
        count_by_op: Counter[str] = Counter()
        for rec in self._op_log:
            cost_by_op[rec.op_name] = cost_by_op.get(rec.op_name, 0) + rec.flop_cost
            count_by_op[rec.op_name] += 1
        for op_name, cost in sorted(cost_by_op.items(), key=lambda x: -x[1]):
            pct = 100 * cost / self._flops_used if self._flops_used > 0 else 0
            lines.append(
                f"    {op_name:<16} {cost:>12,}  ({pct:5.1f}%)  [{count_by_op[op_name]} call{'s' if count_by_op[op_name] != 1 else ''}]"
            )
        return "\n".join(lines)

    def __enter__(self) -> BudgetContext:
        current = get_active_budget()
        if current is not None and current is not _global_default:
            raise RuntimeError("Cannot nest BudgetContexts")
        self._previous_budget = current  # save (may be global default or None)
        _thread_local.active_budget = self
        if not self._quiet:
            import sys

            import mechestim

            print(
                f"mechestim {mechestim.__version__} "
                f"(numpy {mechestim.__numpy_version__} backend) | "
                f"budget: {self._flop_budget:.2e} FLOPs",
                file=sys.stderr,
            )
        return self

    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
        _accumulator.record(self)
        _thread_local.active_budget = self._previous_budget  # restore previous
        return None

    def __call__(self, func):
        """Use BudgetContext as a decorator."""

        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            with self:
                return func(*args, **kwargs)

        return wrapper

__call__(func)

Use BudgetContext as a decorator.

Source code in src/mechestim/_budget.py
def __call__(self, func):
    """Use BudgetContext as a decorator."""

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        with self:
            return func(*args, **kwargs)

    return wrapper

deduct(op_name, *, flop_cost, subscripts, shapes)

Deduct FLOPs from the budget.

Source code in src/mechestim/_budget.py
def deduct(
    self, op_name: str, *, flop_cost: int, subscripts: str | None, shapes: tuple
) -> None:
    """Deduct FLOPs from the budget."""
    adjusted_cost = int(flop_cost * self._flop_multiplier)
    if adjusted_cost > self.flops_remaining:
        raise BudgetExhaustedError(
            op_name, flop_cost=adjusted_cost, flops_remaining=self.flops_remaining
        )
    self._flops_used += adjusted_cost
    self._op_log.append(
        OpRecord(
            op_name=op_name,
            subscripts=subscripts,
            shapes=shapes,
            flop_cost=adjusted_cost,
            cumulative=self._flops_used,
            namespace=self._namespace,
        )
    )

summary()

Return a pretty-printed FLOP budget summary.

Source code in src/mechestim/_budget.py
def summary(self) -> str:
    """Return a pretty-printed FLOP budget summary."""
    header = "mechestim FLOP Budget Summary"
    if self._namespace:
        header += f" [{self._namespace}]"
    lines = [
        header,
        "=" * len(header),
        f"  Total budget:  {self._flop_budget:>14,}",
        f"  Used:          {self._flops_used:>14,}  ({100 * self._flops_used / self._flop_budget:.1f}%)",
        f"  Remaining:     {self.flops_remaining:>14,}  ({100 * self.flops_remaining / self._flop_budget:.1f}%)",
        "",
        "  By operation:",
    ]
    from collections import Counter

    cost_by_op: dict[str, int] = {}
    count_by_op: Counter[str] = Counter()
    for rec in self._op_log:
        cost_by_op[rec.op_name] = cost_by_op.get(rec.op_name, 0) + rec.flop_cost
        count_by_op[rec.op_name] += 1
    for op_name, cost in sorted(cost_by_op.items(), key=lambda x: -x[1]):
        pct = 100 * cost / self._flops_used if self._flops_used > 0 else 0
        lines.append(
            f"    {op_name:<16} {cost:>12,}  ({pct:5.1f}%)  [{count_by_op[op_name]} call{'s' if count_by_op[op_name] != 1 else ''}]"
        )
    return "\n".join(lines)

mechestim._budget.OpRecord

Bases: NamedTuple

Record of a single counted operation.

Source code in src/mechestim/_budget.py
class OpRecord(NamedTuple):
    """Record of a single counted operation."""

    op_name: str
    subscripts: str | None
    shapes: tuple
    flop_cost: int
    cumulative: int
    namespace: str | None = None

mechestim._budget.budget(flop_budget, flop_multiplier=1.0, quiet=False, namespace=None)

Create a BudgetContext usable as both a context manager and decorator.

Source code in src/mechestim/_budget.py
def budget(
    flop_budget: int,
    flop_multiplier: float = 1.0,
    quiet: bool = False,
    namespace: str | None = None,
) -> BudgetContext:
    """Create a BudgetContext usable as both a context manager and decorator."""
    return BudgetContext(
        flop_budget=flop_budget,
        flop_multiplier=flop_multiplier,
        quiet=quiet,
        namespace=namespace,
    )

mechestim._budget.budget_summary_dict(by_namespace=False)

Return aggregated budget data across all recorded contexts.

Parameters:

Name Type Description Default
by_namespace bool

If True, include a "by_namespace" key with per-namespace breakdowns. Default False.

False

Returns:

Type Description
dict

Dictionary with keys "flop_budget", "flops_used", "flops_remaining", "operations", and optionally "by_namespace".

Source code in src/mechestim/_budget.py
def budget_summary_dict(by_namespace: bool = False) -> dict:
    """Return aggregated budget data across all recorded contexts.

    Parameters
    ----------
    by_namespace : bool, optional
        If ``True``, include a ``"by_namespace"`` key with per-namespace
        breakdowns. Default ``False``.

    Returns
    -------
    dict
        Dictionary with keys ``"flop_budget"``, ``"flops_used"``,
        ``"flops_remaining"``, ``"operations"``, and optionally
        ``"by_namespace"``.
    """
    # Include the global default if it has been used
    if _global_default is not None and _global_default.flops_used > 0:
        acc_copy = BudgetAccumulator()
        acc_copy._records = list(_accumulator._records)
        acc_copy.record(_global_default)
        return acc_copy.get_data(by_namespace=by_namespace)
    return _accumulator.get_data(by_namespace=by_namespace)

mechestim._budget.budget_reset()

Clear all accumulated budget data. Core library only.

Source code in src/mechestim/_budget.py
def budget_reset() -> None:
    """Clear all accumulated budget data. Core library only."""
    _accumulator.reset()

mechestim._display.render_budget_summary()

Return a Rich renderable if Rich is installed, otherwise plain text.

Source code in src/mechestim/_display.py
def render_budget_summary():
    """Return a Rich renderable if Rich is installed, otherwise plain text."""
    try:
        import rich  # noqa: F401

        return _rich_summary()
    except ImportError:
        return _plain_text_summary()

mechestim._display.budget_summary()

Print or return the session-wide budget summary.

Source code in src/mechestim/_display.py
def budget_summary():
    """Print or return the session-wide budget summary."""
    result = render_budget_summary()
    try:
        get_ipython  # noqa: F821
        return result
    except NameError:
        if isinstance(result, str):
            print(result)
        else:
            from rich.console import Console

            Console().print(result)
        return None

mechestim._display.budget_live()

Return a live-updating budget display context manager.

Source code in src/mechestim/_display.py
def budget_live():
    """Return a live-updating budget display context manager."""
    try:
        from rich.live import Live

        class _RichBudgetLive:
            def __init__(self):
                self._live = None

            def __enter__(self):
                self._live = Live(_rich_summary(), refresh_per_second=2)
                self._live.__enter__()
                return self

            def __exit__(self, *args):
                if self._live is not None:
                    self._live.update(_rich_summary())
                    self._live.__exit__(*args)
                return None

        return _RichBudgetLive()
    except ImportError:
        return _PlainTextLive()