Skip to content

Exploit Symmetry Savings

When to use this page

Use this page to reduce FLOP costs when working with symmetric tensors.

Prerequisites

Why symmetry matters

Many computations in mechanistic estimation involve symmetric tensors — covariance matrices, quadratic forms, higher-order moment tensors. When a tensor is symmetric, there is redundancy in the computation that mechestim can exploit to reduce the FLOP cost.


SymmetricTensor — first-class symmetric tensors

SymmetricTensor is an ndarray subclass that carries symmetry metadata. When you pass a SymmetricTensor to any mechestim operation, the cost is automatically reduced based on the number of unique elements.

Creating a SymmetricTensor

import mechestim as me
import numpy as np

# Wrap an existing symmetric array
data = np.array([[2.0, 1.0], [1.0, 3.0]])
S = me.as_symmetric(data, symmetric_axes=(0, 1))

# symmetric_axes=(0, 1) means axes 0 and 1 are symmetric
# For partial symmetry on a 4-tensor:
# S = me.as_symmetric(data, symmetric_axes=[(0, 1), (2, 3)])

as_symmetric validates that the data is actually symmetric (within tolerance atol=1e-6, rtol=1e-5). If it isn't, SymmetryError is raised.

Checking symmetry

Use me.is_symmetric() to check whether data is symmetric without raising an exception:

# Standalone function — works on any ndarray
if me.is_symmetric(result, symmetric_axes=(0, 1)):
    result = me.as_symmetric(result, symmetric_axes=(0, 1))

# Custom tolerance
me.is_symmetric(data, (0, 1), atol=1e-3)

# Multiple groups
me.is_symmetric(data, [(0, 1), (2, 3)])

SymmetricTensor also has an .is_symmetric() method. Called without arguments, it re-checks the carried axes. You can also pass different axes to check:

S = me.as_symmetric(data, symmetric_axes=(0, 1))
S.is_symmetric()                          # True — checks (0, 1)
S.is_symmetric(symmetric_axes=(0, 1, 2))  # check a different set of axes

Automatic cost savings

Once a tensor is wrapped as SymmetricTensor, all downstream operations automatically get reduced costs:

with me.BudgetContext(flop_budget=10**8) as budget:
    S = me.as_symmetric(np.eye(10), symmetric_axes=(0, 1))

    # Pointwise: costs 55 FLOPs (10*11/2 unique elements) instead of 100
    result = me.exp(S)
    print(f"exp cost: {budget.flops_used}")  # 55

    # Solve: uses Cholesky cost (n^3/3) instead of LU cost (2n^3/3)
    x = me.linalg.solve(S, np.ones(10))

Symmetry propagation

Symmetry metadata propagates automatically through operations using conservative algebraic rules. No runtime symmetry checks are performed — only index arithmetic on the metadata.

Operation Result type Rule
me.exp(S), me.log(S), ... SymmetricTensor Unary pointwise: same groups pass through
me.add(S, T) (same groups) SymmetricTensor Binary pointwise: groups are intersected
me.add(S, T) (different groups) depends Only groups present in both inputs survive
S * scalar SymmetricTensor Scalar ops preserve all groups
S.copy() SymmetricTensor Metadata preserved
S[0] (integer index) SymmetricTensor or ndarray Indexed dim removed from group; remaining dims keep symmetry if 2+ survive
S[0:k] (slice) SymmetricTensor or ndarray Resized dim pulled from group; same-size dims keep symmetry
me.sum(S, axis=0) SymmetricTensor or ndarray Reduced dim removed from group; remaining dims renumbered
me.sum(S) (all axes) plain ndarray Scalar result — no symmetry
S @ T plain ndarray AB != (AB)^T in general

Slicing rules in detail

When you slice a SymmetricTensor, symmetry propagates based on what happens to each dimension:

import mechestim as me
import numpy as np

# 4D tensor with full ((0,1,2,3)) symmetry, shape (10,10,10,10)
A = me.as_symmetric(sym_data, symmetric_axes=(0, 1, 2, 3))

A[0]          # shape (10,10,10) — dim 0 removed → ((0,1,2)) symmetry
A[0:5]        # shape (5,10,10,10) — dim 0 resized, pulled from group → ((1,2,3))
A[0:10]       # shape (10,10,10,10) — same size, group intact → ((0,1,2,3))
A[0:5, 0:5]   # shape (5,5,10,10) — conservative: both pulled → ((2,3))
A[arr]        # advanced indexing → plain ndarray (symmetry stripped)

Reduction rules in detail

A.sum(axis=0)              # ((0,1,2,3)) → ((0,1,2)) on 3D result
A.sum(axis=(0,1))          # ((0,1,2,3)) → ((0,1)) on 2D result
A.sum(axis=0, keepdims=True)  # dim 0 now size 1, pulled → ((1,2,3)) on 4D result
A.sum()                    # scalar → no symmetry

Binary intersection rules

# Both have ((0,1)) → intersection is ((0,1))
me.add(S1, S2)

# S1 has ((0,1),(2,3)), S2 has ((0,1)) → intersection is ((0,1))
me.add(S1, S2)

# S has ((0,1)), B is plain ndarray → intersection is empty → plain ndarray
me.add(S, B)

# Broadcasting: if a dim is stretched (size 1→n), it's pulled from its group
# before intersection

Symmetry loss warnings

When an operation causes symmetry to be lost (partially or fully), mechestim emits a SymmetryLossWarning. This helps you spot places where you might want to manually re-tag with as_symmetric().

# Warnings are shown once per call site (Python default filter)
S = me.as_symmetric(data, symmetric_axes=(0, 1))
row = S[0]  # SymmetryLossWarning: Symmetry lost along dims (0, 1): ...

# Suppress warnings globally
me.configure(symmetry_warnings=False)

# Or use standard Python warning filters
import warnings
warnings.filterwarnings("ignore", category=me.SymmetryLossWarning)

Symmetry-aware linalg

Several linalg operations use cheaper algorithms when given a SymmetricTensor:

Operation Cost with symmetric input Cost without
me.linalg.solve(S, b) n^3/3 + n*nrhs (Cholesky) 2n^3/3 + n^2*nrhs (LU)
me.linalg.det(S) n^3/3 (Cholesky) n^3 (LU)
me.linalg.inv(S) n^3/3 + n^3/2 n^3

me.linalg.inv(S) returns a SymmetricTensor (the inverse of a symmetric matrix is symmetric).

End-to-end example

import mechestim as me
import numpy as np

n, d = 10, 100

with me.BudgetContext(flop_budget=10**8) as budget:
    X = np.random.randn(d, n)

    # Build covariance — einsum returns SymmetricTensor
    C = me.einsum('ki,kj->ij', X, X, symmetric_axes=[(0, 1)])

    # Chain of unary ops — symmetry preserved, each costs n*(n+1)/2
    C_exp = me.exp(C)
    C_log = me.log(C_exp)

    # Solve — uses Cholesky cost automatically
    C_pd = C + me.multiply(
        me.as_symmetric(np.eye(n), symmetric_axes=(0, 1)),
        np.asarray(float(n))
    )
    x = me.linalg.solve(C_pd, np.ones(n))

    print(budget.summary())

Symmetry in einsum

When you pass a SymmetricTensor to me.einsum, the path optimizer handles everything automatically. It uses symmetry to choose the best contraction order and charges reduced costs based on unique elements.

S = me.as_symmetric(data, symmetric_axes=(0, 1))  # 10x10, 55 unique elements
v = np.ones(10)
result = me.einsum('ij,j->i', S, v)  # costs based on unique elements

Use symmetric_axes when the output of your einsum is symmetric. The result is returned as a SymmetricTensor for downstream savings:

# C[i,j] == C[j,i] — declare output axes (0, 1) as symmetric
C = me.einsum('ki,kj->ij', X, X, symmetric_axes=[(0, 1)])
# C is now a SymmetricTensor — downstream ops get automatic savings

For higher-order tensors where only some axes are symmetric, declare multiple groups:

# 4-tensor where axes (0,1) and (2,3) are each separately symmetric
result = me.einsum('...', *operands, symmetric_axes=[(0, 1), (2, 3)])

Symmetry propagation through contraction paths

When contracting multiple tensors with me.einsum, the path optimizer is fully symmetry-aware. Symmetry influences two things:

  1. Contraction order. The optimizer uses symmetric costs to choose which pair of tensors to contract at each step. A contraction that looks sub-optimal under dense cost estimates may become the best choice once symmetry savings are factored in.

  2. Symmetry propagation. Intermediates' symmetry is tracked and influences future ordering decisions. After each contraction, the result's symmetry is computed (by restricting each symmetric group to the indices that survive) and fed back to the optimizer.

The cost formula correctly distinguishes between summed and surviving indices. For example, given S_3 on {i,j,k} where i is summed out, only the S_2 subgroup on {j,k} contributes a symmetry reduction -- the summed index i does not.

Symmetry degrades along the contraction chain as free indices are consumed:

ijk (S₃) + ai → ajk (S₂ on j,k) → + bj → abk (none) → + ck → abc (none)

The early steps benefit from symmetry savings; later steps operate on dense intermediates.

Example

import mechestim as me
import numpy as np

n = 100
T_data = np.random.randn(n, n, n)
T_data = (T_data + T_data.transpose(1, 0, 2) + T_data.transpose(2, 1, 0)
          + T_data.transpose(0, 2, 1) + T_data.transpose(1, 2, 0)
          + T_data.transpose(2, 0, 1)) / 6

T = me.as_symmetric(T_data, symmetric_axes=(0, 1, 2))  # S₃ symmetric
A = me.ones((n, n))
B = me.ones((n, n))
C = me.ones((n, n))

path, info = me.einsum_path('ijk,ai,bj,ck->abc', T, A, B, C)
print(info)

Output:

Step  Subscript         FLOPs  Dense FLOPs  Symmetry Savings
────  ────────────────  ─────  ───────────  ────────────────
0     ijk,ai->ajk       ...    ...          ~3x (S₃ input)
1     ajk,bj->abk       ...    ...          ~2x (S₂ input)
2     abk,ck->abc       ...    ...          1x  (dense)
────  ────────────────  ─────  ───────────  ────────────────
Total                   ...    ...          ...x speedup

Each step's input_symmetries and output_symmetry fields in StepInfo describe exactly which symmetry was present and how it reduced cost. See PathInfo / StepInfo API for the full dataclass reference.


Common pitfalls

Symptom: SymmetryError: Tensor not symmetric along axes (0, 1): max deviation = 0.5 (tolerance: atol=1e-06, rtol=1e-05)

Fix: The data must be symmetric within tolerance. Check your computation or use a tighter numerical method.

Symptom: Expected SymmetricTensor output but got plain ndarray

Fix: Symmetry propagation is conservative — it never claims symmetry that isn't guaranteed. Operations like matmul (S @ T) always strip symmetry. Slicing and reductions strip symmetry only when the affected dims belong to a symmetric group and the operation leaves fewer than 2 dims in that group. If you know the result is symmetric, re-wrap with me.as_symmetric().

Symptom: SymmetryLossWarning appearing unexpectedly

Fix: This warning tells you symmetry metadata was lost. Check whether the result is still symmetric — if so, re-tag with as_symmetric(). To suppress: me.configure(symmetry_warnings=False) or use warnings.filterwarnings("ignore", category=me.SymmetryLossWarning).