Skip to content

Plan Your Budget

When to use this page

Use this page to learn how to query operation costs before running them.

Prerequisites

Cost query functions

These functions work outside a BudgetContext — they compute costs from shapes without executing anything.

import mechestim as me

# Einsum cost
cost = me.flops.einsum_cost('ij,jk->ik', shapes=[(256, 256), (256, 256)])
print(f"Matmul cost: {cost:,}")         # 33,554,432 (2 × 256³)

# SVD cost
cost = me.flops.svd_cost(m=256, n=256, k=10)
print(f"SVD cost: {cost:,}")            # 655,360

# Pointwise cost (unary/binary ops)
cost = me.flops.pointwise_cost(shape=(256, 256))
print(f"Pointwise cost: {cost:,}")      # 65,536

# Reduction cost
cost = me.flops.reduction_cost(input_shape=(256, 256))
print(f"Reduction cost: {cost:,}")      # 65,536

Budget breakdown example

Plan a multi-step computation before executing:

import mechestim as me

# Plan
steps = [
    ("einsum ij,j->i", me.flops.einsum_cost('ij,j->i', shapes=[(256, 256), (256,)])),
    ("ReLU (maximum)", me.flops.pointwise_cost(shape=(256,))),
    ("sum reduction", me.flops.reduction_cost(input_shape=(256,))),
]

total = sum(cost for _, cost in steps)
print(f"{'Operation':<20} {'FLOPs':>12}")
print("-" * 34)
for name, cost in steps:
    print(f"{name:<20} {cost:>12,}")
print("-" * 34)
print(f"{'Total':<20} {total:>12,}")

Output:

Operation                   FLOPs
----------------------------------
einsum ij,j->i            131,072
ReLU (maximum)                256
sum reduction                 256
----------------------------------
Total                     131,584

Multi-operand einsum planning with einsum_path

For multi-operand einsums (3+ operands), me.einsum_path() is more informative than me.flops.einsum_cost() because it shows the step-by-step contraction breakdown with per-step symmetry savings:

import mechestim as me
import numpy as np

T = me.as_symmetric(np.random.randn(50, 50, 50), symmetric_axes=(0, 1, 2))
A = me.ones((50, 50))
B = me.ones((50, 50))
C = me.ones((50, 50))

path, info = me.einsum_path('ijk,ai,bj,ck->abc', T, A, B, C)

print(f"Optimized cost: {info.optimized_cost:,}")
print(f"Naive cost:     {info.naive_cost:,}")
print(f"Speedup:        {info.speedup:.1f}x")
print(f"Largest intermediate: {info.largest_intermediate:,} elements")
print(info)  # full per-step table

me.einsum_path() has zero budget cost — it plans the contraction path without executing anything. Use it alongside me.flops.einsum_cost() for comprehensive planning.

Using namespaces to track phases

Use the namespace parameter to label different computation phases:

with me.BudgetContext(flop_budget=total, namespace="forward") as budget:
    # forward pass here
    ...

with me.BudgetContext(flop_budget=total, namespace="backward") as budget:
    # backward pass here
    ...

# Session-wide summary across all phases
me.budget_summary()

me.budget_summary_dict(by_namespace=True) returns a dict with per-namespace breakdowns for programmatic analysis.