@DanFu09 Thanks for open-sourcing the code!
I see that in your previous fly repo(https://github.com/HazyResearch/fly), you used cast_inputs=torch.float16 for BlockdiagButterflyMultiply, but changed it to bf16 here. I wonder if there's a specific reason (e.g. fp16 training not converging due to range issues)?
Also, I wonder if there are opportunities for fusing the two bmm operations into one kernel? It seems hard to find the exact kernel torch is calling though.
@DanFu09 Thanks for open-sourcing the code!
I see that in your previous fly repo(https://github.com/HazyResearch/fly), you used cast_inputs=torch.float16 for BlockdiagButterflyMultiply, but changed it to bf16 here. I wonder if there's a specific reason (e.g. fp16 training not converging due to range issues)?
Also, I wonder if there are opportunities for fusing the two bmm operations into one kernel? It seems hard to find the exact kernel torch is calling though.