https://github.com/meta-pytorch/spmd_types/blob/main/DESIGN....
I'd also recommend the series of posts by Edward:
- https://blog.ezyang.com/2026/01/global-vs-local-spmd/
- https://blog.ezyang.com/2026/01/jax-sharding-type-system/
- https://blog.ezyang.com/2026/02/dtensor-erasure/
- https://blog.ezyang.com/2026/02/replicate-forwards-partial-b...
Also interesting:
AutoParallel: a PyTorch library that automatically shards and parallelizes models for distributed training. Given a model and a device mesh, it uses linear programming to find an optimal sharding strategy (FSDP, tensor parallelism, or a mix) and applies it — no manual parallelism code required.
https://github.com/meta-pytorch/autoparallel
> AutoParallel is a PyTorch library that automatically shards and parallelizes models for distributed training. Given a model and a device mesh, it uses linear programming to find an optimal sharding strategy (FSDP, tensor parallelism, or a mix) and applies it — no manual parallelism code required.