Skip to content

BladeDISC 0.3.0: Announce PyTorch 2.0 Compilation Support

Compare
Choose a tag to compare
@tanyokwok tanyokwok released this 08 Dec 02:15
· 288 commits to main since this release
e762299

We released GPU AStitch optimization mainly last time in v0.2.0. Now we are proud to announce the release of BladeDISC v0.3.0.

Highlights

We have done the following things in the latest 6 months:

  • Initial support of PyTorch 2.0 compilation;
  • Contribute TorchToMHLO to Torch-MLIR, together with ByteDance AML Team;
  • Add quantization compilation;
  • Add compilation to Alibaba ARM-Based Yitian 710;
  • Improve memory-intensive kernel code generation on GPGPU;
  • Add shape constraints IR and optimizations.

PyTorch 2.0 and Dynamic Compilation

In the past half year, to support better PyTorch dynamic compilation:

  • We have kept focusing on features supported for PyTorch 2.0;
  • Collaborated with the Torch-MLIR community;
  • Refined the architecture of TorchBlade compilation.

TorchDynamo Compilation

One can now run BladeDISC compilation with just two lines modified in PyTorch 2.0:

import torch_blade # one more extra line
model = ...
compiled_model = torch.compile(model, backend='disc') 

TorchBenchmark

We added PyTorch Benchmark as the compass of the optimization and robustness of BladeDISC on various models. See the following summary reports(with BlaDNN):

TorchMLIR(MHLO) and Dynamic Shapes

BladeDISC is in a close relationship with mlir-hlo project. Part of the building blocks, including the MHLO Op definitions, TF to MHLO conversions, and some general purpose passes, have been upstreamed to mlir-hlo repository. In this release, the BladeDISC Dev Team cooperates with the community to add Torch-To-Mhlo conversion to Torch-MLIR, especially fully dynamic shape features. See RFC: llvm/torch-mlir#999. We appeal to the community developers interested in joining.

TorchBlade will now convert PyTorch workloads to MHLO based on Torch-MLIR. Then compile the MHLO modules via the BladeDISC compiler.

PyTorch Compiled Training

Release PyTorch compiled training based on PyTorch 2.0; You can find examples under BladeDISC/examples/PyTorch/Train. The features are not stable currently and in heavy developing. Please keep watching this if you are interested.

EasyCV/NLP Model Compilation

  • BEVFormer: is a pure vision model for self-driving cars. BladeDISC speeds up the model by 1.42x in EasyCV.
  • BladeDISC supports diffusion models as well in this release. It accelerates the end-to-end inference time by 3x in PAI-Diffusion(EasyNLP).

Quantization (Experimental)

We have completed a series of preliminary explorations on compilation and quantization. Completed early solutions and performance verification on multiple hardware, including X86, and ARM. The following table shows the results summary.

Model Shape Device BeforePyTorch/FP32 AfterInt8+Compilation
bert-mini 8*64 g6r / Ampere Altra / 1core 135.9 ms 39.6 ms
bert-mini 8*64 g8m / YiTian /1core 127.8 ms 31.1 ms
bert-mini 8*64 hfg7 / Cooper Lake 8369 /1core 37.5 ms 21.5 ms

We will support more hardware (e.g. CUDA) and provide concrete examples about how to quantize PyTorch/TensorFlow models in short future. And we will continue to improve the inference performance of quantized models.

Improvement in Compilation

Alibaba ARM-Based Yitian 710

We have further improved the support for ARM-Based CPUs (especially Alibaba's Yitian) and made a series of improvements:

  • Added support for BF16/int8 GEMM/Conv, making full use of the capabilities of Yitian hardware;
  • A series of unique enhancements to ARM Compute Library to solve the usability issues in dynamic shape and high concurrency scenarios;
  • Improved the quality of CodeGen for memory-intensive operators, including Stitch-CPU’s support for operators' reshaping and op duplication strategies.

Improvement on Mem-Intensive CodeGen

A series of in-depth optimizations are provided for code generation of memory-intensive computing on GPUs. It can bring up to 2x performance gain in inference scenarios on a single LayerNorm layer. The above feature can be enabled by setting the variable export DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL=true.

Shape Constraint IR

We have completed the design and development of Shape Constraint IR. By introducing shape constraints into IR, it is convenient for us to explore the structural shape constraints contained in the calculation graph fully. It will the optimization of dynamic shape compilations. You can reference the design document if you are interested.

Custom Pattern Matching

We reconstructed the process of connecting a custom library call in BladeDISC based on PDL, which greatly simplified the related development workload.

In the new method, one only needs to provide a PDL pattern description file and a kernel conforming to the BladeDISC runtime interface. Then the pattern replacement and the corresponding kernel call can be realized without recompiling BladeDISC.

We have used the mechanism in quantization compilation. You can refer to the examples here and here. In the future, we will further expand with the help of PDL and transform dialect so that the CodeGen strategy of a specific pattern can be injected.

Runtime Abstraction Layer

  • Support for large model weight
  • Concurrency performance improvement

Ongoing Work

  • High-performant GEMM kernel CodeGen based on CUTLASS
  • MLIR transform dialect-based CodeGen
  • Accerlating sparse recommendation models in TensorFlow