You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is a brain dump of what is missing from torchao.float8 to support training with rowwise scaling, to help if someone wants to jump in to build this.
already done
torch._scaled_mm supports rowwise scaling
inductor supports rowwise scaled gemms, in max-autotune mode (I haven't personally tested this yet)
we need Float8Linear to be configurable with rowwise scales for each argument, and for the scaling to respect the config, validated by tests + benchmarks, would require changes to torchao.float8.config.py and torchao.float8.float8_linear.py.
after (1) and (2), we could make each gemm configurable to enable leaving some of them in high precision
performance fixes throughout torchao.float8 and inductor, if needed based on how well inductor generates the scaling code
This is a brain dump of what is missing from
torchao.float8to support training with rowwise scaling, to help if someone wants to jump in to build this.already done
torch._scaled_mmsupports rowwise scalingmax-autotunemode (I haven't personally tested this yet)needed
Float8Tensorto work with rowwise scales. We had an unlanded PR onfloat8_experimentaldoing that here ([wip] add axiswise granularity to Float8Tensor meta-pytorch/float8_experimental#352), just never got the time to land it. You can reuse that PR or do something similar. Note that [Float8Quant] Add rowwise scaling option to float8 dyanmic quant #819 landed recently adding float8 rowwise scaling to inference, so being consistent with that where applicable would be nice.Float8Linearto be configurable with rowwise scales for each argument, and for the scaling to respect the config, validated by tests + benchmarks, would require changes totorchao.float8.config.pyandtorchao.float8.float8_linear.py.torchao.float8and inductor, if needed based on how well inductor generates the scaling code