Exploit Symmetry Savings
When to use this page
Use this page to reduce FLOP costs when working with symmetric tensors.
Prerequisites
Why symmetry matters
Many computations in mechanistic estimation involve symmetric tensors — covariance matrices, quadratic forms, higher-order moment tensors. When a tensor is symmetric, there is redundancy in the computation that mechestim can exploit to reduce the FLOP cost.
SymmetricTensor — first-class symmetric tensors
SymmetricTensor is an ndarray subclass that carries symmetry metadata. When you pass a SymmetricTensor to any mechestim operation, the cost is automatically reduced based on the number of unique elements.
Creating a SymmetricTensor
import mechestim as me
import numpy as np
# Wrap an existing symmetric array
data = np.array([[2.0, 1.0], [1.0, 3.0]])
S = me.as_symmetric(data, symmetric_axes=(0, 1))
# symmetric_axes=(0, 1) means axes 0 and 1 are symmetric
# For partial symmetry on a 4-tensor:
# S = me.as_symmetric(data, symmetric_axes=[(0, 1), (2, 3)])
as_symmetric validates that the data is actually symmetric (within tolerance atol=1e-6, rtol=1e-5). If it isn't, SymmetryError is raised.
Checking symmetry
Use me.is_symmetric() to check whether data is symmetric without raising an exception:
# Standalone function — works on any ndarray
if me.is_symmetric(result, symmetric_axes=(0, 1)):
result = me.as_symmetric(result, symmetric_axes=(0, 1))
# Custom tolerance
me.is_symmetric(data, (0, 1), atol=1e-3)
# Multiple groups
me.is_symmetric(data, [(0, 1), (2, 3)])
SymmetricTensor also has an .is_symmetric() method. Called without arguments, it re-checks the carried axes. You can also pass different axes to check:
S = me.as_symmetric(data, symmetric_axes=(0, 1))
S.is_symmetric() # True — checks (0, 1)
S.is_symmetric(symmetric_axes=(0, 1, 2)) # check a different set of axes
Automatic cost savings
Once a tensor is wrapped as SymmetricTensor, all downstream operations automatically get reduced costs:
with me.BudgetContext(flop_budget=10**8) as budget:
S = me.as_symmetric(np.eye(10), symmetric_axes=(0, 1))
# Pointwise: costs 55 FLOPs (10*11/2 unique elements) instead of 100
result = me.exp(S)
print(f"exp cost: {budget.flops_used}") # 55
# Solve: uses Cholesky cost (n^3/3) instead of LU cost (2n^3/3)
x = me.linalg.solve(S, np.ones(10))
Symmetry propagation
Symmetry metadata propagates automatically through operations using conservative algebraic rules. No runtime symmetry checks are performed — only index arithmetic on the metadata.
| Operation | Result type | Rule |
|---|---|---|
me.exp(S), me.log(S), ... |
SymmetricTensor |
Unary pointwise: same groups pass through |
me.add(S, T) (same groups) |
SymmetricTensor |
Binary pointwise: groups are intersected |
me.add(S, T) (different groups) |
depends | Only groups present in both inputs survive |
S * scalar |
SymmetricTensor |
Scalar ops preserve all groups |
S.copy() |
SymmetricTensor |
Metadata preserved |
S[0] (integer index) |
SymmetricTensor or ndarray |
Indexed dim removed from group; remaining dims keep symmetry if 2+ survive |
S[0:k] (slice) |
SymmetricTensor or ndarray |
Resized dim pulled from group; same-size dims keep symmetry |
me.sum(S, axis=0) |
SymmetricTensor or ndarray |
Reduced dim removed from group; remaining dims renumbered |
me.sum(S) (all axes) |
plain ndarray | Scalar result — no symmetry |
S @ T |
plain ndarray | AB != (AB)^T in general |
Slicing rules in detail
When you slice a SymmetricTensor, symmetry propagates based on what happens to each dimension:
import mechestim as me
import numpy as np
# 4D tensor with full ((0,1,2,3)) symmetry, shape (10,10,10,10)
A = me.as_symmetric(sym_data, symmetric_axes=(0, 1, 2, 3))
A[0] # shape (10,10,10) — dim 0 removed → ((0,1,2)) symmetry
A[0:5] # shape (5,10,10,10) — dim 0 resized, pulled from group → ((1,2,3))
A[0:10] # shape (10,10,10,10) — same size, group intact → ((0,1,2,3))
A[0:5, 0:5] # shape (5,5,10,10) — conservative: both pulled → ((2,3))
A[arr] # advanced indexing → plain ndarray (symmetry stripped)
Reduction rules in detail
A.sum(axis=0) # ((0,1,2,3)) → ((0,1,2)) on 3D result
A.sum(axis=(0,1)) # ((0,1,2,3)) → ((0,1)) on 2D result
A.sum(axis=0, keepdims=True) # dim 0 now size 1, pulled → ((1,2,3)) on 4D result
A.sum() # scalar → no symmetry
Binary intersection rules
# Both have ((0,1)) → intersection is ((0,1))
me.add(S1, S2)
# S1 has ((0,1),(2,3)), S2 has ((0,1)) → intersection is ((0,1))
me.add(S1, S2)
# S has ((0,1)), B is plain ndarray → intersection is empty → plain ndarray
me.add(S, B)
# Broadcasting: if a dim is stretched (size 1→n), it's pulled from its group
# before intersection
Symmetry loss warnings
When an operation causes symmetry to be lost (partially or fully), mechestim emits a SymmetryLossWarning. This helps you spot places where you might want to manually re-tag with as_symmetric().
# Warnings are shown once per call site (Python default filter)
S = me.as_symmetric(data, symmetric_axes=(0, 1))
row = S[0] # SymmetryLossWarning: Symmetry lost along dims (0, 1): ...
# Suppress warnings globally
me.configure(symmetry_warnings=False)
# Or use standard Python warning filters
import warnings
warnings.filterwarnings("ignore", category=me.SymmetryLossWarning)
Symmetry-aware linalg
Several linalg operations use cheaper algorithms when given a SymmetricTensor:
| Operation | Cost with symmetric input | Cost without |
|---|---|---|
me.linalg.solve(S, b) |
n^3/3 + n*nrhs (Cholesky) |
2n^3/3 + n^2*nrhs (LU) |
me.linalg.det(S) |
n^3/3 (Cholesky) |
n^3 (LU) |
me.linalg.inv(S) |
n^3/3 + n^3/2 |
n^3 |
me.linalg.inv(S) returns a SymmetricTensor (the inverse of a symmetric matrix is symmetric).
End-to-end example
import mechestim as me
import numpy as np
n, d = 10, 100
with me.BudgetContext(flop_budget=10**8) as budget:
X = np.random.randn(d, n)
# Build covariance — einsum returns SymmetricTensor
C = me.einsum('ki,kj->ij', X, X, symmetric_axes=[(0, 1)])
# Chain of unary ops — symmetry preserved, each costs n*(n+1)/2
C_exp = me.exp(C)
C_log = me.log(C_exp)
# Solve — uses Cholesky cost automatically
C_pd = C + me.multiply(
me.as_symmetric(np.eye(n), symmetric_axes=(0, 1)),
np.asarray(float(n))
)
x = me.linalg.solve(C_pd, np.ones(n))
print(budget.summary())
Symmetry in einsum
When you pass a SymmetricTensor to me.einsum, the path optimizer handles everything automatically. It uses symmetry to choose the best contraction order and charges reduced costs based on unique elements.
S = me.as_symmetric(data, symmetric_axes=(0, 1)) # 10x10, 55 unique elements
v = np.ones(10)
result = me.einsum('ij,j->i', S, v) # costs based on unique elements
Use symmetric_axes when the output of your einsum is symmetric. The result is returned as a SymmetricTensor for downstream savings:
# C[i,j] == C[j,i] — declare output axes (0, 1) as symmetric
C = me.einsum('ki,kj->ij', X, X, symmetric_axes=[(0, 1)])
# C is now a SymmetricTensor — downstream ops get automatic savings
For higher-order tensors where only some axes are symmetric, declare multiple groups:
# 4-tensor where axes (0,1) and (2,3) are each separately symmetric
result = me.einsum('...', *operands, symmetric_axes=[(0, 1), (2, 3)])
Symmetry propagation through contraction paths
When contracting multiple tensors with me.einsum, the path optimizer is
fully symmetry-aware. Symmetry influences two things:
-
Contraction order. The optimizer uses symmetric costs to choose which pair of tensors to contract at each step. A contraction that looks sub-optimal under dense cost estimates may become the best choice once symmetry savings are factored in.
-
Symmetry propagation. Intermediates' symmetry is tracked and influences future ordering decisions. After each contraction, the result's symmetry is computed (by restricting each symmetric group to the indices that survive) and fed back to the optimizer.
The cost formula correctly distinguishes between summed and surviving indices. For example, given S_3 on {i,j,k} where i is summed out, only the S_2 subgroup on {j,k} contributes a symmetry reduction -- the summed index i does not.
Symmetry degrades along the contraction chain as free indices are consumed:
The early steps benefit from symmetry savings; later steps operate on dense intermediates.
Example
import mechestim as me
import numpy as np
n = 100
T_data = np.random.randn(n, n, n)
T_data = (T_data + T_data.transpose(1, 0, 2) + T_data.transpose(2, 1, 0)
+ T_data.transpose(0, 2, 1) + T_data.transpose(1, 2, 0)
+ T_data.transpose(2, 0, 1)) / 6
T = me.as_symmetric(T_data, symmetric_axes=(0, 1, 2)) # S₃ symmetric
A = me.ones((n, n))
B = me.ones((n, n))
C = me.ones((n, n))
path, info = me.einsum_path('ijk,ai,bj,ck->abc', T, A, B, C)
print(info)
Output:
Step Subscript FLOPs Dense FLOPs Symmetry Savings
──── ──────────────── ───── ─────────── ────────────────
0 ijk,ai->ajk ... ... ~3x (S₃ input)
1 ajk,bj->abk ... ... ~2x (S₂ input)
2 abk,ck->abc ... ... 1x (dense)
──── ──────────────── ───── ─────────── ────────────────
Total ... ... ...x speedup
Each step's input_symmetries and output_symmetry fields in StepInfo
describe exactly which symmetry was present and how it reduced cost. See
PathInfo / StepInfo API for the full dataclass reference.
Common pitfalls
Symptom: SymmetryError: Tensor not symmetric along axes (0, 1): max deviation = 0.5 (tolerance: atol=1e-06, rtol=1e-05)
Fix: The data must be symmetric within tolerance. Check your computation or use a tighter numerical method.
Symptom: Expected SymmetricTensor output but got plain ndarray
Fix: Symmetry propagation is conservative — it never claims symmetry that isn't guaranteed. Operations like matmul (S @ T) always strip symmetry. Slicing and reductions strip symmetry only when the affected dims belong to a symmetric group and the operation leaves fewer than 2 dims in that group. If you know the result is symmetric, re-wrap with me.as_symmetric().
Symptom: SymmetryLossWarning appearing unexpectedly
Fix: This warning tells you symmetry metadata was lost. Check whether the result is still symmetric — if so, re-tag with as_symmetric(). To suppress: me.configure(symmetry_warnings=False) or use warnings.filterwarnings("ignore", category=me.SymmetryLossWarning).
Related pages
- Use Einsum — einsum basics, multi-operand contractions, and path inspection
- Use Linear Algebra — linalg operations and costs
- FLOP Counting Model — how costs are calculated
- Symmetric Tensors API —
PathInfo,StepInfo, andSymmetricTensor