Skip to content

Commit

Permalink
[stablehlo] fix: enhance torch's index-like op lowering to stablehlo'…
Browse files Browse the repository at this point in the history
…s gather/scatter (#3829)

In torch.index_put like ops, `values` is only required to be
broadcastable to `input[indices]`, rather than exact dimension match.
This patch fixes the problem by add additional
stablehlo.dynamic_broadcast_in_dim before creating stablehlo.scatter op.
BTW, this patch also enhance the `getBroadcastResultShape` utility in
hlo namespace.
  • Loading branch information
Vremold authored Nov 5, 2024
1 parent 4c1518d commit b75d0e3
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
Type outElementType);

FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
Operation *op, ArrayRef<Value> tensors,
size_t dimSizeIndexBits);
FailureOr<std::pair<Value, SmallVector<int64_t>>>
getBroadcastResultShape(PatternRewriter &rewriter, Operation *op,
ArrayRef<Value> tensors, size_t dimSizeIndexBits);

Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
TensorType outType,
Expand Down
90 changes: 61 additions & 29 deletions lib/Conversion/TorchToStablehlo/GatherScatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,16 +220,10 @@ namespace {
FailureOr<Value> broadcastAndConcatIndices(Operation *op,
ConversionPatternRewriter &rewriter,
SmallVector<Value> indexTensors,
llvm::ArrayRef<int64_t> inputShape,
size_t dimSizeIndexBits,
int &maxIndexRank) {
// Step 1: broadcast indices tensors
SmallVector<int64_t> indicesShape;
SmallVector<int64_t> expandShape;
SmallVector<int64_t> concatShape;

bool allIndexStaticShape = true;
Value bcastSizeTensor;

// concat index tensor into to indices tensor for concat
for (size_t i = 0; i < indexTensors.size(); i++) {
Expand All @@ -242,20 +236,15 @@ FailureOr<Value> broadcastAndConcatIndices(Operation *op,
maxIndexRank = std::max(maxIndexRank, (int)indexTensorType.getRank());
}

if (!allIndexStaticShape) {
auto bcastSizeTensorInfo = hlo::getBroadcastResultShape(
rewriter, op, indexTensors, dimSizeIndexBits);
if (failed(bcastSizeTensorInfo)) {
return failure();
}
bcastSizeTensor = *bcastSizeTensorInfo;
}

for (int i = 0; i < maxIndexRank; i++) {
indicesShape.push_back(inputShape[i]);
expandShape.push_back(inputShape[i]);
concatShape.push_back(inputShape[i]);
auto bcastSizeInfo = hlo::getBroadcastResultShape(rewriter, op, indexTensors,
dimSizeIndexBits);
if (failed(bcastSizeInfo)) {
return failure();
}
Value bcastSizeTensor = (*bcastSizeInfo).first;
auto indicesShape = (*bcastSizeInfo).second;
SmallVector<int64_t> expandShape(indicesShape.begin(), indicesShape.end());
SmallVector<int64_t> concatShape(indicesShape.begin(), indicesShape.end());
expandShape.push_back(1);
concatShape.push_back(indexTensors.size());

Expand Down Expand Up @@ -879,7 +868,6 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
auto inputTensorType = cast<RankedTensorType>(input.getType());
auto outType =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
auto outShape = outType.getShape();
Value indexList = op.getIndices();
SmallVector<Value> indicesTorchType;
if (!getListConstructElements(indexList, indicesTorchType))
Expand All @@ -890,9 +878,8 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
indicesTorchType);

int maxIndexRank = -1;
auto gatherIndicesInfo =
broadcastAndConcatIndices(op, rewriter, indexTensors, outShape,
options.dimSizeIndexBits, maxIndexRank);
auto gatherIndicesInfo = broadcastAndConcatIndices(
op, rewriter, indexTensors, options.dimSizeIndexBits, maxIndexRank);
if (failed(gatherIndicesInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to generate broadcasted indices");
Expand Down Expand Up @@ -949,6 +936,8 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
auto outType =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
auto inputType = cast<RankedTensorType>(input.getType());
auto inputShape = inputType.getShape();
auto inputRank = inputType.getRank();
auto valuesType = cast<RankedTensorType>(values.getType());
int64_t valueRank = valuesType.getRank();
auto valuesShape = valuesType.getShape();
Expand All @@ -968,15 +957,58 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
indicesTorchType);

int maxIndexRank = -1;
auto scatterIndicesInfo =
broadcastAndConcatIndices(op, rewriter, indexTensors, valuesShape,
options.dimSizeIndexBits, maxIndexRank);
auto scatterIndicesInfo = broadcastAndConcatIndices(
op, rewriter, indexTensors, options.dimSizeIndexBits, maxIndexRank);
if (failed(scatterIndicesInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to generate broadcasted indices");
}
auto scatterIndices = *scatterIndicesInfo;

// broadcast `values` tensor to match expectedValuesShape.
SmallVector<int64_t> scatterIndicesDims;
for (int64_t i = 0; i < maxIndexRank; ++i) {
scatterIndicesDims.push_back(i);
}
auto expectedValuesShapeTensorInfo =
hlo::getDimSizesOfTensor(rewriter, op, scatterIndices, scatterIndicesDims,
options.dimSizeIndexBits);
if (failed(expectedValuesShapeTensorInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get shape of broadcasted indices");
}
auto expectedValuesShapeTensors = *expectedValuesShapeTensorInfo;
SmallVector<int64_t> trailingInputDims;
for (int64_t i = indexCnt; i < inputRank; ++i) {
trailingInputDims.push_back(i);
}
auto trailingInputShapeTensorInfo = hlo::getDimSizesOfTensor(
rewriter, op, input, trailingInputDims, options.dimSizeIndexBits);
if (failed(trailingInputShapeTensorInfo)) {
return rewriter.notifyMatchFailure(op, "failed to get shape of input");
}
expectedValuesShapeTensors.append((*trailingInputShapeTensorInfo).begin(),
(*trailingInputShapeTensorInfo).end());

llvm::ArrayRef<int64_t> scatterIndicesShape =
(cast<RankedTensorType>(scatterIndices.getType())).getShape();
SmallVector<int64_t> expectedValuesShape(
scatterIndicesShape.begin(), scatterIndicesShape.begin() + maxIndexRank);
for (int64_t i = indexCnt; i < inputRank; i++) {
expectedValuesShape.push_back(inputShape[i]);
}

valuesType =
RankedTensorType::get(expectedValuesShape, valuesType.getElementType());
values =
hlo::promoteAndBroadcast(rewriter, values, valuesType,
rewriter
.create<tensor::FromElementsOp>(
op->getLoc(), expectedValuesShapeTensors)
.getResult());
valueRank = valuesType.getRank();
valuesShape = valuesType.getShape();

// create stablehlo::ScatterOp
int64_t indexVecDim = maxIndexRank;
SmallVector<int64_t> scatterDimOperandDimMap;
Expand Down Expand Up @@ -1216,9 +1248,9 @@ Value getSummand(ConversionPatternRewriter &rewriter, Operation *op,
SmallVector<Value> indexTensors{Nidx, CIdx, idxY, idxX};

int maxIndexRank = -1;
auto gatherIndicesInfo = broadcastAndConcatIndices(
input.getDefiningOp(), rewriter, indexTensors, outType.getShape(),
dimSizeIndexBits, maxIndexRank);
auto gatherIndicesInfo =
broadcastAndConcatIndices(input.getDefiningOp(), rewriter, indexTensors,
dimSizeIndexBits, maxIndexRank);
auto gatherIndices = *gatherIndicesInfo;
int64_t numIndicesDim = indexTensors.size();
int64_t indexVecDim = maxIndexRank;
Expand Down
28 changes: 18 additions & 10 deletions lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,9 @@ getDimIndexOfTensor(PatternRewriter &rewriter, Operation *op, Value value) {
return getDimIndexOfTensor(rewriter, op, value, dims);
}

FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
Operation *op, ArrayRef<Value> tensors,
size_t dimSizeIndexBits) {
FailureOr<std::pair<Value, SmallVector<int64_t>>>
getBroadcastResultShape(PatternRewriter &rewriter, Operation *op,
ArrayRef<Value> tensors, size_t dimSizeIndexBits) {
SmallVector<ArrayRef<int64_t>> tensorSizes;

int maxRank = 0;
Expand All @@ -337,10 +337,11 @@ FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
}

SmallVector<Value> bcastSizeTensors;
SmallVector<int64_t> bcastSizes;
for (int outDim = 0; outDim < maxRank; ++outDim) { // loop dimensions.
int dynamicDimCnt = 0;
int staticDimCnt = 0;
int64_t staticDimSize;
int64_t dimSize = -1;
Value dimSizeTensor = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(),
rewriter.getIntegerAttr(rewriter.getIntegerType(dimSizeIndexBits), 1));
Expand All @@ -351,12 +352,16 @@ FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
continue;

// dim size: 1
if (tensorSizes[i][inDim] == 1)
if (tensorSizes[i][inDim] == 1) {
if (dimSize == -1)
dimSize = 1;
continue;
}
// dim size: dynamic
if (tensorSizes[i][inDim] == ShapedType::kDynamic ||
tensorSizes[i][inDim] == kUnknownSize) {
dynamicDimCnt++;
dimSize = ShapedType::kDynamic;
auto dimSizeTensorInfo = hlo::getDimSizesOfTensor(
rewriter, op, tensors[i], {inDim}, dimSizeIndexBits);
if (failed(dimSizeTensorInfo)) {
Expand All @@ -371,12 +376,12 @@ FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
return failure();
}
// we already found static dim size not equal with this, fail.
if (staticDimCnt > 0 && staticDimSize != tensorSizes[i][inDim]) {
if (staticDimCnt > 0 && dimSize != tensorSizes[i][inDim]) {
return failure();
}

staticDimCnt++;
staticDimSize = tensorSizes[i][inDim];
dimSize = tensorSizes[i][inDim];
auto dimSizeTensorInfo = hlo::getDimSizesOfTensor(
rewriter, op, tensors[i], {inDim}, dimSizeIndexBits);
if (failed(dimSizeTensorInfo)) {
Expand All @@ -389,12 +394,15 @@ FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
// if (dynamicDimCnt > 1) {
// return failure();
// }

bcastSizes.push_back(dimSize);
bcastSizeTensors.push_back(dimSizeTensor);
}
std::reverse(bcastSizes.begin(), bcastSizes.end());
std::reverse(bcastSizeTensors.begin(), bcastSizeTensors.end());
return rewriter.create<tensor::FromElementsOp>(op->getLoc(), bcastSizeTensors)
.getResult();
return std::pair<Value, SmallVector<int64_t>>(
rewriter.create<tensor::FromElementsOp>(op->getLoc(), bcastSizeTensors)
.getResult(),
bcastSizes);
}

FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,7 @@
"IndexPutImpl3DFloatAccumulateModule_basic",
"IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexPutImplIndexWithNoneModule_basic",
"IndexPutWithNoneAndBroadcastModule_basic",
"IndexSelectRank0IdxModule_basic",
"IndexTensorNegativeIndexModule_basic",
"IntFloatModule_basic",
Expand Down

0 comments on commit b75d0e3

Please sign in to comment.