Skip to content

Commit

Permalink
ConstantFieldPropagation: Add a variation that picks between 2 values…
Browse files Browse the repository at this point in the history
… using RefTest (WebAssembly#6692)

CFP focuses on finding when a field always contains a constant, and then replaces
a struct.get with that constant. If we find there are two constant values, then in some
cases we can still optimize, if we have a way to pick between them. All we have is the
struct.get and its reference, so we must use a ref.test:

   (struct.get $T x (..ref..))
     =>
   (select
     (..constant1..)
     (..constant2..)
     (ref.test $U (..ref..))
   )

This is valid if, of all the subtypes of $T, those that pass the test have
constant1 in that field, and those that fail the test have constant2. For
example, a simple case is where $T has two subtypes, $T is never created
itself, and each of the two subtypes has a different constant value.

This is a somewhat risky operation, as ref.test is not necessarily cheap.
To mitigate that, this is a new pass, --cfp-reftest that is not run by
default, and also we only optimize when we can use a ref.test on what
we think will be a final type (because ref.test on a final type can be
faster in VMs).
  • Loading branch information
kripken authored Jun 27, 2024
1 parent 53712b6 commit cdf8139
Show file tree
Hide file tree
Showing 7 changed files with 1,723 additions and 16 deletions.
267 changes: 251 additions & 16 deletions src/passes/ConstantFieldPropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,30 @@
// write to that field of a different value (even using a subtype of T), then
// anywhere we see a get of that field we can place a ref.func of F.
//
// A variation of this pass also uses ref.test to optimize. This is riskier, as
// adding a ref.test means we are adding a non-trivial amount of work, and
// whether it helps overall depends on subsequent optimizations, so we do not do
// it by default. In this variation, if we inferred a field has exactly two
// possible values, and we can differentiate between them using a ref.test, then
// we do
//
// (struct.get $T x (..ref..))
// =>
// (select
// (..constant1..)
// (..constant2..)
// (ref.test $U (..ref..))
// )
//
// This is valid if, of all the subtypes of $T, those that pass the test have
// constant1 in that field, and those that fail the test have constant2. For
// example, a simple case is where $T has two subtypes, $T is never created
// itself, and each of the two subtypes has a different constant value. (Note
// that we do similar things in e.g. GlobalStructInference, where we turn a
// struct.get into a select, but the risk there is much lower since the
// condition for the select is something like a ref.eq - very cheap - while here
// we emit a ref.test which in general is as expensive as a cast.)
//
// FIXME: This pass assumes a closed world. When we start to allow multi-module
// wasm GC programs we need to check for type escaping.
//
Expand All @@ -34,6 +58,7 @@
#include "ir/struct-utils.h"
#include "ir/utils.h"
#include "pass.h"
#include "support/small_vector.h"
#include "wasm-builder.h"
#include "wasm-traversal.h"
#include "wasm.h"
Expand Down Expand Up @@ -73,17 +98,30 @@ struct FunctionOptimizer : public WalkerPass<PostWalker<FunctionOptimizer>> {
// Only modifies struct.get operations.
bool requiresNonNullableLocalFixups() override { return false; }

// We receive the propagated infos, that is, info about field types in a form
// that takes into account subtypes for quick computation, and also the raw
// subtyping and new infos (information about struct.news).
std::unique_ptr<Pass> create() override {
return std::make_unique<FunctionOptimizer>(infos);
return std::make_unique<FunctionOptimizer>(
propagatedInfos, subTypes, rawNewInfos, refTest);
}

FunctionOptimizer(PCVStructValuesMap& infos) : infos(infos) {}
FunctionOptimizer(const PCVStructValuesMap& propagatedInfos,
const SubTypes& subTypes,
const PCVStructValuesMap& rawNewInfos,
bool refTest)
: propagatedInfos(propagatedInfos), subTypes(subTypes),
rawNewInfos(rawNewInfos), refTest(refTest) {}

void visitStructGet(StructGet* curr) {
auto type = curr->ref->type;
if (type == Type::unreachable) {
return;
}
auto heapType = type.getHeapType();
if (!heapType.isStruct()) {
return;
}

Builder builder(*getModule());

Expand All @@ -92,8 +130,8 @@ struct FunctionOptimizer : public WalkerPass<PostWalker<FunctionOptimizer>> {
// as if nothing was ever noted for that field.
PossibleConstantValues info;
assert(!info.hasNoted());
auto iter = infos.find(type.getHeapType());
if (iter != infos.end()) {
auto iter = propagatedInfos.find(heapType);
if (iter != propagatedInfos.end()) {
// There is information on this type, fetch it.
info = iter->second[curr->index];
}
Expand All @@ -113,25 +151,204 @@ struct FunctionOptimizer : public WalkerPass<PostWalker<FunctionOptimizer>> {
return;
}

// If the value is not a constant, then it is unknown and we must give up.
// If the value is not a constant, then it is unknown and we must give up
// on simply applying a constant. However, we can try to use a ref.test, if
// that is allowed.
if (!info.isConstant()) {
if (refTest) {
optimizeUsingRefTest(curr);
}
return;
}

// We can do this! Replace the get with a trap on a null reference using a
// ref.as_non_null (we need to trap as the get would have done so), plus the
// constant value. (Leave it to further optimizations to get rid of the
// ref.)
Expression* value = info.makeExpression(*getModule());
auto field = GCTypeUtils::getField(type, curr->index);
assert(field);
value =
Bits::makePackedFieldGet(value, *field, curr->signed_, *getModule());
auto* value = makeExpression(info, heapType, curr);
replaceCurrent(builder.makeSequence(
builder.makeDrop(builder.makeRefAs(RefAsNonNull, curr->ref)), value));
changed = true;
}

// Given information about a constant value, and the struct type and StructGet
// that reads it, create an expression for that value.
Expression* makeExpression(const PossibleConstantValues& info,
HeapType type,
StructGet* curr) {
auto* value = info.makeExpression(*getModule());
auto field = GCTypeUtils::getField(type, curr->index);
assert(field);
return Bits::makePackedFieldGet(value, *field, curr->signed_, *getModule());
}

void optimizeUsingRefTest(StructGet* curr) {
auto refType = curr->ref->type;
auto refHeapType = refType.getHeapType();

// We only handle immutable fields in this function, as we will be looking
// at |rawNewInfos|. That is, we are trying to see when a type and its
// subtypes have different values (so that we can differentiate between them
// using a ref.test), and those differences are lost in |propagatedInfos|,
// which has propagated to relevant types so that we can do a single check
// to see what value could be there. So we need to use something more
// precise, |rawNewInfos|, which tracks the values written to struct.news,
// where we know the type exactly (unlike with a struct.set). But for that
// reason the field must be immutable, so that it is valid to only look at
// the struct.news. (A more complex flow analysis could do better here, but
// would be far beyond the scope of this pass.)
if (GCTypeUtils::getField(refType, curr->index)->mutable_ == Mutable) {
return;
}

// We seek two possible constant values. For each we track the constant and
// the types that have that constant. For example, if we have types A, B, C
// and A and B have 42 in their field, and C has 1337, then we'd have this:
//
// values = [ { 42, [A, B] }, { 1337, [C] } ];
struct Value {
PossibleConstantValues constant;
// Use a SmallVector as we'll only have 2 Values, and so the stack usage
// here is fixed.
SmallVector<HeapType, 10> types;

// Whether this slot is used. If so, |constant| has a value, and |types|
// is not empty.
bool used() const {
if (constant.hasNoted()) {
assert(!types.empty());
return true;
}
assert(types.empty());
return false;
}
} values[2];

// Handle one of the subtypes of the relevant type. We check what value it
// has for the field, and update |values|. If we hit a problem, we mark us
// as having failed.
auto fail = false;
auto handleType = [&](HeapType type, Index depth) {
if (fail) {
// TODO: Add a mechanism to halt |iterSubTypes| in the middle, as once
// we fail there is no point to further iterating.
return;
}

auto iter = rawNewInfos.find(type);
if (iter == rawNewInfos.end()) {
// This type has no struct.news, so we can ignore it: it is abstract.
return;
}

auto value = iter->second[curr->index];
if (!value.isConstant()) {
// The value here is not constant, so give up entirely.
fail = true;
return;
}

// Consider the constant value compared to previous ones.
for (Index i = 0; i < 2; i++) {
if (!values[i].used()) {
// There is nothing in this slot: place this value there.
values[i].constant = value;
values[i].types.push_back(type);
break;
}

// There is something in this slot. If we have the same value, append.
if (values[i].constant == value) {
values[i].types.push_back(type);
break;
}

// Otherwise, this value is different than values[i], which is fine:
// we can add it as the second value in the next loop iteration - at
// least, we can do that if there is another iteration: If it's already
// the last, we've failed to find only two values.
if (i == 1) {
fail = true;
return;
}
}
};
subTypes.iterSubTypes(refHeapType, handleType);

if (fail) {
return;
}

// We either filled slot 0, or we did not, and if we did not then cannot
// have filled slot 1 after it.
assert(values[0].used() || !values[1].used());

if (!values[1].used()) {
// We did not see two constant values (we might have seen just one, or
// even no constant values at all).
return;
}

// We have exactly two values to pick between. We can pick between those
// values using a single ref.test if the two sets of types are actually
// disjoint. In general we could compute the LUB of each set and see if it
// overlaps with the other, but for efficiency we only want to do this
// optimization if the type we test on is closed/final, since ref.test on a
// final type can be fairly fast (perhaps constant time). We therefore look
// if one of the sets of types contains a single type and it is final, and
// if so then we'll test on it. (However, see a few lines below on how we
// test for finality.)
// TODO: Consider adding a variation on this pass that uses non-final types.
auto isProperTestType = [&](const Value& value) -> std::optional<HeapType> {
auto& types = value.types;
if (types.size() != 1) {
// Too many types.
return {};
}

auto type = types[0];
// Do not test finality using isOpen(), as that may only be applied late
// in the optimization pipeline. We are in closed-world here, so just
// see if there are subtypes in practice (if not, this can be marked as
// final later, and we assume optimistically that it will).
if (!subTypes.getImmediateSubTypes(type).empty()) {
// There are subtypes.
return {};
}

// Success, we can test on this.
return type;
};

// Look for the index in |values| to test on.
Index testIndex;
if (auto test = isProperTestType(values[0])) {
testIndex = 0;
} else if (auto test = isProperTestType(values[1])) {
testIndex = 1;
} else {
// We failed to find a simple way to separate the types.
return;
}

// Success! We can replace the struct.get with a select over the two values
// (and a trap on null) with the proper ref.test.
Builder builder(*getModule());

auto& testIndexTypes = values[testIndex].types;
assert(testIndexTypes.size() == 1);
auto testType = testIndexTypes[0];

auto* nnRef = builder.makeRefAs(RefAsNonNull, curr->ref);

replaceCurrent(builder.makeSelect(
builder.makeRefTest(nnRef, Type(testType, NonNullable)),
makeExpression(values[testIndex].constant, refHeapType, curr),
makeExpression(values[1 - testIndex].constant, refHeapType, curr)));

changed = true;
}

void doWalkFunction(Function* func) {
WalkerPass<PostWalker<FunctionOptimizer>>::doWalkFunction(func);

Expand All @@ -143,7 +360,10 @@ struct FunctionOptimizer : public WalkerPass<PostWalker<FunctionOptimizer>> {
}

private:
PCVStructValuesMap& infos;
const PCVStructValuesMap& propagatedInfos;
const SubTypes& subTypes;
const PCVStructValuesMap& rawNewInfos;
const bool refTest;

bool changed = false;
};
Expand Down Expand Up @@ -193,6 +413,11 @@ struct ConstantFieldPropagation : public Pass {
// Only modifies struct.get operations.
bool requiresNonNullableLocalFixups() override { return false; }

// Whether we are optimizing using ref.test, see above.
const bool refTest;

ConstantFieldPropagation(bool refTest) : refTest(refTest) {}

void run(Module* module) override {
if (!module->features.hasGC()) {
return;
Expand All @@ -214,8 +439,16 @@ struct ConstantFieldPropagation : public Pass {
BoolStructValuesMap combinedCopyInfos;
functionCopyInfos.combineInto(combinedCopyInfos);

// Prepare data we will need later.
SubTypes subTypes(*module);

PCVStructValuesMap rawNewInfos;
if (refTest) {
// The refTest optimizations require the raw new infos (see above), but we
// can skip copying here if we'll never read this.
rawNewInfos = combinedNewInfos;
}

// Handle subtyping. |combinedInfo| so far contains data that represents
// each struct.new and struct.set's operation on the struct type used in
// that instruction. That is, if we do a struct.set to type T, the value was
Expand Down Expand Up @@ -288,17 +521,19 @@ struct ConstantFieldPropagation : public Pass {

// Optimize.
// TODO: Skip this if we cannot optimize anything
FunctionOptimizer(combinedInfos).run(runner, module);

// TODO: Actually remove the field from the type, where possible? That might
// be best in another pass.
FunctionOptimizer(combinedInfos, subTypes, rawNewInfos, refTest)
.run(runner, module);
}
};

} // anonymous namespace

Pass* createConstantFieldPropagationPass() {
return new ConstantFieldPropagation();
return new ConstantFieldPropagation(false);
}

Pass* createConstantFieldPropagationRefTestPass() {
return new ConstantFieldPropagation(true);
}

} // namespace wasm
3 changes: 3 additions & 0 deletions src/passes/pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ void PassRegistry::registerPasses() {
registerPass("cfp",
"propagate constant struct field values",
createConstantFieldPropagationPass);
registerPass("cfp-reftest",
"propagate constant struct field values, using ref.test",
createConstantFieldPropagationRefTestPass);
registerPass(
"dce", "removes unreachable code", createDeadCodeEliminationPass);
registerPass("dealign",
Expand Down
1 change: 1 addition & 0 deletions src/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Pass* createCodeFoldingPass();
Pass* createCodePushingPass();
Pass* createConstHoistingPass();
Pass* createConstantFieldPropagationPass();
Pass* createConstantFieldPropagationRefTestPass();
Pass* createDAEPass();
Pass* createDAEOptimizingPass();
Pass* createDataFlowOptsPass();
Expand Down
3 changes: 3 additions & 0 deletions test/lit/help/wasm-opt.test
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@
;; CHECK-NEXT: --cfp propagate constant struct field
;; CHECK-NEXT: values
;; CHECK-NEXT:
;; CHECK-NEXT: --cfp-reftest propagate constant struct field
;; CHECK-NEXT: values, using ref.test
;; CHECK-NEXT:
;; CHECK-NEXT: --coalesce-locals reduce # of locals by coalescing
;; CHECK-NEXT:
;; CHECK-NEXT: --coalesce-locals-learning reduce # of locals by coalescing
Expand Down
3 changes: 3 additions & 0 deletions test/lit/help/wasm2js.test
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@
;; CHECK-NEXT: --cfp propagate constant struct field
;; CHECK-NEXT: values
;; CHECK-NEXT:
;; CHECK-NEXT: --cfp-reftest propagate constant struct field
;; CHECK-NEXT: values, using ref.test
;; CHECK-NEXT:
;; CHECK-NEXT: --coalesce-locals reduce # of locals by coalescing
;; CHECK-NEXT:
;; CHECK-NEXT: --coalesce-locals-learning reduce # of locals by coalescing
Expand Down
Loading

0 comments on commit cdf8139

Please sign in to comment.