Symmetric Tensors
First-class symmetric tensor support for automatic FLOP cost reductions.
SymmetricTensor is an ndarray subclass that carries symmetry metadata
through operations. When passed to any mechestim operation, the cost is
automatically reduced based on the number of unique elements.
See Exploit Symmetry Savings for usage patterns.
mechestim._symmetric
Symmetric tensor support: SymmetryInfo, SymmetricTensor, as_symmetric.
SymmetricTensor
Bases: ndarray
An ndarray that carries symmetry metadata.
Do not instantiate directly; use :func:as_symmetric.
Source code in src/mechestim/_symmetric.py
428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 | |
symmetric_axes
property
Symmetry groups carried by this tensor.
symmetry_info
property
Return a :class:SymmetryInfo for this tensor.
is_symmetric(symmetric_axes=None, *, atol=1e-06, rtol=1e-05)
Check whether the data satisfies the given (or carried) symmetry.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
symmetric_axes
|
tuple or list of tuples
|
Axes to check. If None, checks the axes already carried
by this |
None
|
atol
|
float
|
Tolerances passed to :func: |
1e-06
|
rtol
|
float
|
Tolerances passed to :func: |
1e-06
|
Source code in src/mechestim/_symmetric.py
SymmetryInfo
dataclass
Metadata about tensor symmetry groups.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
symmetric_axes
|
list of tuple of int
|
Groups of dimension indices that are symmetric under permutation. |
required |
shape
|
tuple of int
|
Full tensor shape. |
required |
Source code in src/mechestim/_symmetric.py
symmetry_factor
property
Product of factorial(len(group)) for each group.
unique_elements
property
Number of unique elements accounting for symmetry.
For each symmetric group of k dims each of size n, the count is C(n + k - 1, k). Free (non-symmetric) dims contribute their full size. The total is the product.
as_symmetric(data, symmetric_axes)
Wrap data as a :class:SymmetricTensor after validating symmetry.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
ndarray
|
The tensor data. |
required |
symmetric_axes
|
tuple of int or list of tuple of int
|
A single symmetry group |
required |
Returns:
| Type | Description |
|---|---|
SymmetricTensor
|
|
Raises:
| Type | Description |
|---|---|
SymmetryError
|
If the data does not satisfy the claimed symmetry. |
Source code in src/mechestim/_symmetric.py
intersect_symmetry(dims_a, dims_b, shape_a, shape_b, output_shape)
Intersect symmetry groups for binary ops, accounting for broadcasting.
Source code in src/mechestim/_symmetric.py
is_symmetric(data, symmetric_axes, *, atol=1e-06, rtol=1e-05)
Check whether data is symmetric along the given axes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
ndarray
|
The tensor data. |
required |
symmetric_axes
|
tuple of int or list of tuple of int
|
A single symmetry group |
required |
atol
|
float
|
Tolerances passed to :func: |
1e-06
|
rtol
|
float
|
Tolerances passed to :func: |
1e-06
|
Returns:
| Type | Description |
|---|---|
bool
|
|
Source code in src/mechestim/_symmetric.py
propagate_symmetry_reduce(symmetric_axes, ndim, axis, keepdims=False)
Compute new symmetry groups after a reduction.
Returns None if no symmetry survives.
Source code in src/mechestim/_symmetric.py
propagate_symmetry_slice(symmetric_axes, shape, key)
Compute new symmetry groups after __getitem__(key).
Returns None if no symmetry survives (caller should return plain ndarray).
Source code in src/mechestim/_symmetric.py
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 | |
validate_symmetry(data, symmetric_axes)
Validate that data has the claimed symmetry.
For each group, checks that all dims have equal sizes and that all pairwise transpositions are satisfied within tolerance.
Raises:
| Type | Description |
|---|---|
SymmetryError
|
If the data is not symmetric along the claimed axes. |
Source code in src/mechestim/_symmetric.py
PathInfo
Contraction path with per-step diagnostics. Returned by me.einsum_path().
mechestim.PathInfo
dataclass
Information about a contraction path with per-step symmetry diagnostics.
Source code in src/mechestim/_opt_einsum/_contract.py
input_subscripts = ''
class-attribute
instance-attribute
Comma-separated input subscripts, e.g. "ij,jk,kl".
largest_intermediate
instance-attribute
Number of elements in the largest intermediate tensor.
naive_cost
instance-attribute
Naive (single-step) FLOP cost (opt_einsum convention with op_factor).
opt_cost
property
Legacy: opt_einsum-style cost (using flop_count with op_factor).
optimized_cost
instance-attribute
Sum of per-step costs (opt_einsum convention with op_factor).
output_subscript = ''
class-attribute
instance-attribute
Output subscript, e.g. "il".
path
instance-attribute
The optimized contraction path (list of index-tuples).
size_dict = field(default_factory=dict)
class-attribute
instance-attribute
Mapping from index label to dimension size.
speedup
instance-attribute
naive_cost / optimized_cost.
steps
instance-attribute
Per-step diagnostics.
| Field | Type | Description |
|---|---|---|
path |
list[tuple[int, ...]] |
Sequence of contraction index groups |
steps |
list[StepInfo] |
Per-step diagnostics |
naive_cost |
int |
FLOP cost without path optimization |
optimized_cost |
int |
FLOP cost along the optimal path |
largest_intermediate |
int |
Max number of elements in any intermediate tensor |
speedup |
float |
naive_cost / optimized_cost |
StepInfo
Per-step contraction info within a PathInfo. Each step represents one
pairwise contraction along the optimal path.
mechestim.StepInfo
dataclass
Per-step diagnostics for a contraction path.
Source code in src/mechestim/_opt_einsum/_contract.py
blas_type = False
class-attribute
instance-attribute
BLAS classification for this step (e.g. 'GEMM', 'SYMM', False).
dense_flop_cost
instance-attribute
FLOP cost without symmetry (opt_einsum convention: includes op_factor).
flop_cost
instance-attribute
Symmetry-aware FLOP cost (opt_einsum convention: includes op_factor).
input_shapes
instance-attribute
Shapes of the input operands for this step.
input_symmetries
instance-attribute
IndexSymmetry for each input in this step.
output_shape
instance-attribute
Shape of the output operand for this step.
output_symmetry
instance-attribute
IndexSymmetry of the output, or None.
subscript
instance-attribute
Einsum subscript for this step, e.g. "ijk,ai->ajk".
symmetry_savings
instance-attribute
Fraction saved: 1 - (flop_cost / dense_flop_cost). Zero when no symmetry.
| Field | Type | Description |
|---|---|---|
subscript |
str |
Einsum subscript for this pairwise step (e.g., 'ijk,ai->ajk') |
flop_cost |
int |
Symmetry-aware FLOP cost of this step |
dense_flop_cost |
int |
FLOP cost without symmetry savings |
symmetry_savings |
float |
1 - (flop_cost / dense_flop_cost) — fraction of cost saved by symmetry |
input_symmetries |
list[IndexSymmetry | None] |
Symmetry of each input to this step |
output_symmetry |
IndexSymmetry | None |
Symmetry of the step's output (propagated to next step) |
input_shapes |
list[tuple[int, ...]] |
Shapes of input operands |
output_shape |
tuple[int, ...] |
Shape of the output tensor |