Skip to content

Instantly share code, notes, and snippets.

ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKeKVvH67TN+aMN0jjau8SCHQo5XcniG73fxKc32aF6I ezyang@ezyang-mac
https://github.com/pytorch/pytorch/issues/163449
https://github.com/pytorch/pytorch/issues/163457
https://github.com/pytorch/pytorch/issues/163420
https://github.com/pytorch/pytorch/issues/163300
https://github.com/pytorch/pytorch/issues/162723
import torch
import unittest
from torch import Tensor
from torch.distributed.tensor import (
DTensor,
DeviceMesh,
distribute_tensor,
init_device_mesh,
Partial,
Replicate,
x = DTensor.from_local(arange_nd(15), mesh["m", "n", "k"], [R, R, R])
# Eliminate M
x = DTensor.from_local(x.redistribute(placements=[R, R, S(0)]).to_local(), mesh["m", "n"]) # shard K
x = DTensor.from_local(x.redistribute(placements=[R, S(0)]).to_local(), mesh["m"]) # shard N
x = x.redistribute(placements=[S(0)]).to_local() # shard M
x = DTensor.from_local(x, mesh["n"], [S(0)]).redistribute(placements=[R]) # unshard N
x = DTensor.from_local(x.to_local(), mesh["n", "k"], [R, S(0)]).redistribute(placements=[R, R]) # unshard K
# Eliminate N
x = DTensor.from_local(x.redistribute(placements=[R, S(0)]).to_local(), mesh["n"]) # shard K
x = x.redistribute(placements=[S(0)]).to_local() # shard N
@ezyang
ezyang / gist:15791ae363900f42c704c09ca34346e3
Created October 29, 2025 19:02
Matrix-of-matrices tensor render
def render(tensor, cell_width=None):
"""
Print a tensor following the matrix-of-matrices algorithm.
Args:
tensor: A tensor-like object with .shape attribute and indexing
cell_width: Width for each cell (calculated globally if None)
Returns:
import torch
from torch import nn
from torch.distributed.tensor.placement_types import Replicate, Shard
from torch.testing._internal.distributed.fake_pg import FakeStore
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DTensor, Replicate
world_size = 4
(verl) [ezyang@devgpu086.cco2 ~/local/verl/verl/examples/ppo_trainer (main)]$ pp bash run_deepseek7b_llm.sh
+ python3 -m verl.trainer.main_ppo algorithm.adv_estimator=gae data.train_files=/home/ezyang/local/data/gsm8k/train.parquet data.val_files=/home/ezyang/local/data/gsm8k/test.parquet data.train_batch_size=1024 data.max_prompt_length=512 data.max_response_length=512 data.filter_overlong_prompts=True data.truncation=error actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat actor_rollout_ref.actor.optim.lr=1e-6 actor_rollout_ref.model.use_remove_padding=True actor_rollout_ref.actor.ppo_mini_batch_size=256 actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 actor_rollout_ref.actor.fsdp_config.param_offload=False actor_rollout_ref.actor.fsdp_config.optimizer_offload=False actor_rollout_ref.actor.use_kl_loss=False actor_rollout_ref.model.enable_gradient_checkpointing=True actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 actor_rollout_ref.rollout.tensor_model_parallel_size=4 actor_ro

I really like Scuba (Meta's internal real-time database system). The distributed, real-time database part of Scuba is quite difficult (and expensive) to replicate, but I also really like Scuba's UI for doing queries, and I have found myself wishing that I have access to it even for "small" databases, e.g., I have a sqlite dataset I want to explore.

Here's a screenshot of this UI from https://research.facebook.com/publications/scuba-diving-into-data-at-facebook/:

image

Pivotal ideas:

  • Time series by default. In the dedicated "time series" view, there are many features specifically oriented towards working towards tables that represent events that occurred over time: the start, end, compare, aggregate and granularity fields all specially privilege the timestamp field. In fact, you can't log events to Scuba's backing data store without a timestamp, they always come with one. (Scuba a
I need to add a progress bar to the application that shows progress reading the compressed dump files. The bar should display: 1. A visual progress bar 2. Bytes/sec processing rate 3. Bytes processed and percent completion 4. Update every 100ms

Let's analyze the code to understand where to integrate the progress bar:

  1. The application processes Reddit data (comments and submissions) from compressed JSON files
<think>
I need to add a progress bar to the application that shows progress reading the compressed dump files. The bar should display:
1. A visual progress bar
2. Bytes/sec processing rate
3. Bytes processed and percent completion
4. Update every 100ms
Let's analyze the code to understand where to integrate the progress bar:
1. The application processes Reddit data (comments and submissions) from compressed JSON files