Skip to content

Commit

Permalink
[cont.bind] squashed
Browse files Browse the repository at this point in the history
  • Loading branch information
frank-emrich committed Feb 29, 2024
1 parent 1ef95ae commit d427b84
Show file tree
Hide file tree
Showing 30 changed files with 499 additions and 132 deletions.
1 change: 1 addition & 0 deletions scripts/fuzz_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def is_git_repo():
'typed_continuations.wast',
'typed_continuations_resume.wast',
'typed_continuations_contnew.wast',
'typed_continuations_contbind.wast',
# New EH implementation is in progress
'exception-handling.wast',
'translate-eh-old-to-new.wast',
Expand Down
1 change: 1 addition & 0 deletions scripts/gen-s-parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,7 @@
("return_call_ref", "makeCallRef(s, /*isReturn=*/true)"),
# Typed continuations instructions
("cont.new", "makeContNew(s)"),
("cont.bind", "makeContBind(s)"),
("resume", "makeResume(s)"),
# GC
("i31.new", "makeRefI31(s)"), # deprecated
Expand Down
35 changes: 27 additions & 8 deletions src/gen-s-parser.inc
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,17 @@ switch (buf[0]) {
default: goto parse_error;
}
}
case 'o':
if (op == "cont.new"sv) { return makeContNew(s); }
goto parse_error;
case 'o': {
switch (buf[5]) {
case 'b':
if (op == "cont.bind"sv) { return makeContBind(s); }
goto parse_error;
case 'n':
if (op == "cont.new"sv) { return makeContNew(s); }
goto parse_error;
default: goto parse_error;
}
}
default: goto parse_error;
}
}
Expand Down Expand Up @@ -3853,12 +3861,23 @@ switch (buf[0]) {
default: goto parse_error;
}
}
case 'o':
if (op == "cont.new"sv) {
CHECK_ERR(makeContNew(ctx, pos, annotations));
return Ok{};
case 'o': {
switch (buf[5]) {
case 'b':
if (op == "cont.bind"sv) {
CHECK_ERR(makeContBind(ctx, pos, annotations));
return Ok{};
}
goto parse_error;
case 'n':
if (op == "cont.new"sv) {
CHECK_ERR(makeContNew(ctx, pos, annotations));
return Ok{};
}
goto parse_error;
default: goto parse_error;
}
goto parse_error;
}
default: goto parse_error;
}
}
Expand Down
1 change: 1 addition & 0 deletions src/ir/ReFinalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ void ReFinalize::visitStringSliceIter(StringSliceIter* curr) {
}

void ReFinalize::visitContNew(ContNew* curr) { curr->finalize(); }
void ReFinalize::visitContBind(ContBind* curr) { curr->finalize(); }
void ReFinalize::visitResume(Resume* curr) { curr->finalize(); }

void ReFinalize::visitExport(Export* curr) { WASM_UNREACHABLE("unimp"); }
Expand Down
10 changes: 10 additions & 0 deletions src/ir/cost.h
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,16 @@ struct CostAnalyzer : public OverriddenVisitor<CostAnalyzer, CostType> {
return 8 + visit(curr->ref) + visit(curr->num);
}

CostType visitContBind(ContBind* curr) {
// Inspired by struct.new: The only cost of cont.bind is that it may need to
// allocate a buffer to hold the arguments.
CostType ret = 4;
ret += visit(curr->cont);
for (auto* arg : curr->operands) {
ret += visit(arg);
}
return ret;
}
CostType visitContNew(ContNew* curr) {
// Some arbitrary "high" value, reflecting that this may allocate a stack
return 14 + visit(curr->func);
Expand Down
4 changes: 4 additions & 0 deletions src/ir/effects.h
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,10 @@ class EffectAnalyzer {
parent.implicitTrap = true;
}

void visitContBind(ContBind* curr) {
// traps when curr->cont is null ref.
parent.implicitTrap = true;
}
void visitContNew(ContNew* curr) {
// traps when curr->func is null ref.
parent.implicitTrap = true;
Expand Down
3 changes: 3 additions & 0 deletions src/ir/module-utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ struct CodeScanner
counts.include(get->type);
} else if (auto* set = curr->dynCast<ArraySet>()) {
counts.note(set->ref->type);
} else if (auto* contBind = curr->dynCast<ContBind>()) {
counts.note(contBind->contTypeBefore);
counts.note(contBind->contTypeAfter);
} else if (auto* contNew = curr->dynCast<ContNew>()) {
counts.note(contNew->contType);
} else if (auto* resume = curr->dynCast<Resume>()) {
Expand Down
4 changes: 4 additions & 0 deletions src/ir/possible-contents.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1200,6 +1200,10 @@ struct InfoCollector

void visitReturn(Return* curr) { addResult(curr->value); }

void visitContBind(ContBind* curr) {
// TODO: optimize when possible
addRoot(curr);
}
void visitContNew(ContNew* curr) {
// TODO: optimize when possible
addRoot(curr);
Expand Down
1 change: 1 addition & 0 deletions src/ir/subtype-exprs.h
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ struct SubtypingDiscoverer : public OverriddenVisitor<SubType> {
void visitStringSliceWTF(StringSliceWTF* curr) {}
void visitStringSliceIter(StringSliceIter* curr) {}

void visitContBind(ContBind* curr) { WASM_UNREACHABLE("not implemented"); }
void visitContNew(ContNew* curr) { WASM_UNREACHABLE("not implemented"); }
void visitResume(Resume* curr) { WASM_UNREACHABLE("not implemented"); }
};
Expand Down
12 changes: 12 additions & 0 deletions src/parser/contexts.h
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,11 @@ struct NullInstrParserCtx {
return Ok{};
}
template<typename HeapTypeT>
Result<>
makeContBind(Index, const std::vector<Annotation>&, HeapTypeT, HeapTypeT) {
return Ok{};
}
template<typename HeapTypeT>
Result<> makeContNew(Index, const std::vector<Annotation>&, HeapTypeT) {
return Ok{};
}
Expand Down Expand Up @@ -2523,6 +2528,13 @@ struct ParseDefsCtx : TypeParserCtx<ParseDefsCtx> {
return withLoc(pos, irBuilder.makeStringSliceIter());
}

Result<> makeContBind(Index pos,
const std::vector<Annotation>& annotations,
HeapType contTypeBefore,
HeapType contTypeAfter) {
return withLoc(pos, irBuilder.makeContBind(contTypeBefore, contTypeAfter));
}

Result<> makeContNew(Index pos,
const std::vector<Annotation>& annotations,
HeapType type) {
Expand Down
15 changes: 15 additions & 0 deletions src/parser/parsers.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ Result<> makeStringSliceWTF(Ctx&,
template<typename Ctx>
Result<> makeStringSliceIter(Ctx&, Index, const std::vector<Annotation>&);
template<typename Ctx>
Result<> makeContBind(Ctx&, Index, const std::vector<Annotation>&);
template<typename Ctx>
Result<> makeContNew(Ctx*, Index, const std::vector<Annotation>&);
template<typename Ctx>
Result<> makeResume(Ctx&, Index, const std::vector<Annotation>&);
Expand Down Expand Up @@ -2431,6 +2433,19 @@ Result<> makeStringSliceIter(Ctx& ctx,
return ctx.makeStringSliceIter(pos, annotations);
}

// contbind ::= 'cont.bind' typeidx typeidx
template<typename Ctx>
Result<>
makeContBind(Ctx& ctx, Index pos, const std::vector<Annotation>& annotations) {
auto typeBefore = typeidx(ctx);
CHECK_ERR(typeBefore);

auto typeAfter = typeidx(ctx);
CHECK_ERR(typeAfter);

return ctx.makeContBind(pos, annotations, *typeBefore, *typeAfter);
}

template<typename Ctx>
Result<>
makeContNew(Ctx& ctx, Index pos, const std::vector<Annotation>& annotations) {
Expand Down
7 changes: 6 additions & 1 deletion src/passes/Print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2378,7 +2378,12 @@ struct PrintExpressionContents
void visitStringSliceIter(StringSliceIter* curr) {
printMedium(o, "stringview_iter.slice");
}

void visitContBind(ContBind* curr) {
printMedium(o, "cont.bind ");
printHeapType(curr->contTypeBefore);
o << ' ';
printHeapType(curr->contTypeAfter);
}
void visitContNew(ContNew* curr) {
printMedium(o, "cont.new ");
printHeapType(curr->contType);
Expand Down
1 change: 1 addition & 0 deletions src/passes/TypeGeneralizing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,7 @@ struct TransferFn : OverriddenVisitor<TransferFn> {
void visitStringSliceWTF(StringSliceWTF* curr) { WASM_UNREACHABLE("TODO"); }
void visitStringSliceIter(StringSliceIter* curr) { WASM_UNREACHABLE("TODO"); }

void visitContBind(ContBind* curr) { WASM_UNREACHABLE("TODO"); }
void visitContNew(ContNew* curr) { WASM_UNREACHABLE("TODO"); }
void visitResume(Resume* curr) { WASM_UNREACHABLE("TODO"); }
};
Expand Down
2 changes: 2 additions & 0 deletions src/wasm-binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -1299,6 +1299,7 @@ enum ASTNodes {

// typed continuation opcodes
ContNew = 0xe0,
ContBind = 0xe1,
Resume = 0xe3,

};
Expand Down Expand Up @@ -1928,6 +1929,7 @@ class WasmBinaryReader {
void visitRefAsCast(RefCast* curr, uint32_t code);
void visitRefAs(RefAs* curr, uint8_t code);
void visitContNew(ContNew* curr);
void visitContBind(ContBind* curr);
void visitResume(Resume* curr);

[[noreturn]] void throwError(std::string text);
Expand Down
12 changes: 12 additions & 0 deletions src/wasm-builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,18 @@ class Builder {
return ret;
}

ContBind* makeContBind(HeapType contTypeBefore,
HeapType contTypeAfter,
const std::vector<Expression*>& operands,
Expression* cont) {
auto* ret = wasm.allocator.alloc<ContBind>();
ret->contTypeBefore = contTypeBefore;
ret->contTypeAfter = contTypeAfter;
ret->operands.set(operands);
ret->cont = cont;
ret->finalize();
return ret;
}
ContNew* makeContNew(HeapType contType, Expression* func) {
auto* ret = wasm.allocator.alloc<ContNew>();
ret->contType = contType;
Expand Down
9 changes: 9 additions & 0 deletions src/wasm-delegations-fields.def
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,15 @@ switch (DELEGATE_ID) {
break;
}

case Expression::Id::ContBindId: {
DELEGATE_START(ContBind);
DELEGATE_FIELD_CHILD(ContBind, cont);
DELEGATE_FIELD_CHILD_VECTOR(ContBind, operands);
DELEGATE_FIELD_HEAPTYPE(ContBind, contTypeAfter);
DELEGATE_FIELD_HEAPTYPE(ContBind, contTypeBefore);
DELEGATE_END(ContBind);
break;
}
case Expression::Id::ContNewId: {
DELEGATE_START(ContNew);
DELEGATE_FIELD_CHILD(ContNew, func);
Expand Down
1 change: 1 addition & 0 deletions src/wasm-delegations.def
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ DELEGATE(StringIterNext);
DELEGATE(StringIterMove);
DELEGATE(StringSliceWTF);
DELEGATE(StringSliceIter);
DELEGATE(ContBind);
DELEGATE(ContNew);
DELEGATE(Resume);

Expand Down
2 changes: 2 additions & 0 deletions src/wasm-interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -2401,6 +2401,7 @@ class ConstantExpressionRunner : public ExpressionRunner<SubType> {
}
return ExpressionRunner<SubType>::visitRefAs(curr);
}
Flow visitContBind(ContBind* curr) { WASM_UNREACHABLE("unimplemented"); }
Flow visitContNew(ContNew* curr) { WASM_UNREACHABLE("unimplemented"); }
Flow visitResume(Resume* curr) { WASM_UNREACHABLE("unimplemented"); }

Expand Down Expand Up @@ -3976,6 +3977,7 @@ class ModuleRunnerBase : public ExpressionRunner<SubType> {
multiValues.pop_back();
return ret;
}
Flow visitContBind(ContBind* curr) { return Flow(NONCONSTANT_FLOW); }
Flow visitContNew(ContNew* curr) { return Flow(NONCONSTANT_FLOW); }
Flow visitResume(Resume* curr) { return Flow(NONCONSTANT_FLOW); }

Expand Down
3 changes: 3 additions & 0 deletions src/wasm-ir-builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ class IRBuilder : public UnifiedExpressionVisitor<IRBuilder, Result<>> {
[[nodiscard]] Result<> makeStringIterMove(StringIterMoveOp op);
[[nodiscard]] Result<> makeStringSliceWTF(StringSliceWTFOp op);
[[nodiscard]] Result<> makeStringSliceIter();
[[nodiscard]] Result<> makeContBind(HeapType contTypeBefore,
HeapType contTypeAfter);
[[nodiscard]] Result<> makeContNew(HeapType ct);
[[nodiscard]] Result<> makeResume(HeapType ct,
const std::vector<Name>& tags,
Expand Down Expand Up @@ -252,6 +254,7 @@ class IRBuilder : public UnifiedExpressionVisitor<IRBuilder, Result<>> {
[[nodiscard]] Result<> visitThrow(Throw*);
[[nodiscard]] Result<> visitStringNew(StringNew*);
[[nodiscard]] Result<> visitStringEncode(StringEncode*);
[[nodiscard]] Result<> visitContBind(ContBind*);
[[nodiscard]] Result<> visitResume(Resume*);
[[nodiscard]] Result<> visitTupleMake(TupleMake*);
[[nodiscard]] Result<>
Expand Down
1 change: 1 addition & 0 deletions src/wasm-s-parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ class SExpressionWasmBuilder {
Expression* makeStringIterMove(Element& s, StringIterMoveOp op);
Expression* makeStringSliceWTF(Element& s, StringSliceWTFOp op);
Expression* makeStringSliceIter(Element& s);
Expression* makeContBind(Element& s);
Expression* makeContNew(Element& s);
Expression* makeResume(Element& s);

Expand Down
13 changes: 13 additions & 0 deletions src/wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,7 @@ class Expression {
StringIterMoveId,
StringSliceWTFId,
StringSliceIterId,
ContBindId,
ContNewId,
ResumeId,
NumExpressionIds
Expand Down Expand Up @@ -1998,6 +1999,18 @@ class StringSliceIter
void finalize();
};

class ContBind : public SpecificExpression<Expression::ContBindId> {
public:
ContBind(MixedArena& allocator) : operands(allocator) {}

HeapType contTypeBefore;
HeapType contTypeAfter;
ExpressionList operands;
Expression* cont;

void finalize();
};

class ContNew : public SpecificExpression<Expression::ContNewId> {
public:
ContNew() = default;
Expand Down
35 changes: 35 additions & 0 deletions src/wasm/wasm-binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4046,6 +4046,10 @@ BinaryConsts::ASTNodes WasmBinaryReader::readExpression(Expression*& curr) {
visitCallRef(call);
break;
}
case BinaryConsts::ContBind: {
visitContBind((curr = allocator.alloc<ContBind>())->cast<ContBind>());
break;
}
case BinaryConsts::ContNew: {
auto contNew = allocator.alloc<ContNew>();
curr = contNew;
Expand Down Expand Up @@ -7768,6 +7772,37 @@ void WasmBinaryReader::visitRefAs(RefAs* curr, uint8_t code) {
curr->finalize();
}

void WasmBinaryReader::visitContBind(ContBind* curr) {
BYN_TRACE("zz node: ContBind\n");

auto contTypeBeforeIndex = getU32LEB();
curr->contTypeBefore = getTypeByIndex(contTypeBeforeIndex);

auto contTypeAfterIndex = getU32LEB();
curr->contTypeAfter = getTypeByIndex(contTypeAfterIndex);

for (auto& ct : {curr->contTypeBefore, curr->contTypeAfter}) {
if (!ct.isContinuation()) {
throwError("non-continuation type in cont.bind instruction " +
ct.toString());
}
}

curr->cont = popNonVoidExpression();

size_t paramsBefore =
curr->contTypeBefore.getContinuation().type.getSignature().params.size();
size_t paramsAfter =
curr->contTypeAfter.getContinuation().type.getSignature().params.size();
size_t numArgs = paramsBefore - paramsAfter;
curr->operands.resize(numArgs);
for (size_t i = 0; i < numArgs; i++) {
curr->operands[numArgs - i - 1] = popNonVoidExpression();
}

curr->finalize();
}

void WasmBinaryReader::visitContNew(ContNew* curr) {
BYN_TRACE("zz node: ContNew\n");

Expand Down
Loading

0 comments on commit d427b84

Please sign in to comment.