From b3519e1464440cb9ef7201d796a30ab03d168d06 Mon Sep 17 00:00:00 2001 From: Dillon Date: Tue, 9 Apr 2024 18:19:41 -0700 Subject: [PATCH 1/2] WIP validator --- runtime/evaluate.cc | 311 ++++++++++++++++++++++++++++++++++++++++++++ runtime/evaluate.h | 3 + 2 files changed, 314 insertions(+) diff --git a/runtime/evaluate.cc b/runtime/evaluate.cc index 599c0fee..c55941c4 100644 --- a/runtime/evaluate.cc +++ b/runtime/evaluate.cc @@ -856,4 +856,315 @@ class constant_evaluator : public expr_visitor { std::optional evaluate_constant(const expr& e) { return constant_evaluator().eval(e); } +namespace { + +class validator : public expr_visitor, public stmt_visitor { + enum variable_state { + unknown, + pointer, + arithmetic, + }; + + const node_context* symbols; + + struct symbol_info { + variable_state state; + stmt decl_stmt; + expr decl_expr; + + symbol_info(variable_state state) : state(state) {} + symbol_info(variable_state state, stmt s) : state(state), decl_stmt(std::move(s)) {} + symbol_info(variable_state state, expr e) : state(state), decl_expr(std::move(e)) {} + }; + symbol_map ctx; + + variable_state state = unknown; + + void print_symbol_info(const symbol_info& s) { + if (s.decl_stmt.defined()) { + std::cerr << "Declared by:" << std::endl; + print(std::cerr, s.decl_stmt, symbols); + } else if (s.decl_expr.defined()) { + std::cerr << "Declared by:" << std::endl; + print(std::cerr, s.decl_expr, symbols); + } else { + std::cerr << "Externally defined symbol" << std::endl; + } + } + +public: + bool error = false; + + validator(span inputs, const node_context* symbols) : symbols(symbols) { + for (var i : inputs) { + ctx[i] = unknown; + } + } + + void visit(const variable* x) override { + if (!ctx.contains(x->name)) { + std::cerr << "Undefined variable "; + print(std::cerr, x, symbols); + std::cerr << " in context" << std::endl; + error = true; + } + state = unknown; + } + void visit(const constant*) override { state = unknown; } + + template + void visit_let(const T* x) { + std::vector> lets; + lets.reserve(op->lets.size()); + + for (size_t i = 0; i < op->lets.size(); ++i) { + if (!op->lets[i].defined()) { + std::cerr << "Undefined variable "; + print(std::cerr, x, symbols); + std::cerr << " in context" << std::endl; + error = true; + return; + } + } + if (!x->body.defined()) { + std::cerr << "Undefined let body in context" << std::endl; + error = true; + return; + } + x->body.accept(this); + } + + void visit(const let* x) override { visit_let(x); } + + void check_arithmetic(const expr& x, bool required = true) { + if (error) return; + + if (x.defined()) { + x.accept(this); + if (state == pointer) { + std::cerr << "Arithmetic on pointer value: "; + print(std::cerr, x, symbols); + std::cerr << std::endl << "In context:" << std::endl; + error = true; + } + } else if (required) { + std::cerr << "Undefined expression in context:" << std::endl; + } + } + + void check(const expr& x) { + if (error) return; + x.accept(this); + } + + template + void visit_binary_arithmetic(const T* x) { + check_arithmetic(x->a); + check_arithmetic(x->b); + } + + template + void visit_binary(const T* x) { + check(x->a); + check(x->b); + } + + void visit(const add* x) override { visit_binary_arithmetic(x); } + void visit(const sub* x) override { visit_binary_arithmetic(x); } + void visit(const mul* x) override { visit_binary_arithmetic(x); } + void visit(const div* x) override { visit_binary_arithmetic(x); } + void visit(const mod* x) override { visit_binary_arithmetic(x); } + void visit(const class min* x) override { visit_binary_arithmetic(x); } + void visit(const class max* x) override { visit_binary_arithmetic(x); } + void visit(const equal* x) override { visit_binary(x); } + void visit(const not_equal* x) override { visit_binary(x); } + void visit(const less* x) override { visit_binary_arithmetic(x); } + void visit(const less_equal* x) override { visit_binary_arithmetic(x); } + void visit(const logical_and* x) override { visit_binary(x); } + void visit(const logical_or* x) override { visit_binary(x); } + void visit(const logical_not* x) override { check(x->x); } + void visit(const class select* x) override { + check(x->condition); + check(x->true_value); + check(x->false_value); + } + + void check_pointer(const expr& x) { + if (error) return; + + if (x.defined()) { + x.accept(this); + if (state == arithmetic) { + std::cerr << "Expression " << x << " is arithmetic, expected pointer" << std::endl; + std::cerr << std::endl << "In context:" << std::endl; + error = true; + } + } else { + std::cerr << "Undefined expression in context:" << std::endl; + } + } + + void check_pointer(var name) { + std::optional state = ctx.lookup(name); + if (state && state->state == arithmetic) { + std::cerr << "Arithmetic symbol "; + print(std::cerr, var(name), symbols); + std::cerr << " used as a pointer" << std::endl; + print_symbol_info(*state); + error = true; + } + } + + void visit(const call* x) override { + if (error) return; + + switch (x->intrinsic) { + case intrinsic::negative_infinity: + case intrinsic::positive_infinity: + case intrinsic::indeterminate: + std::cerr << "Cannot evaluate " << x->intrinsic << std::endl; + error = true; + return; + + case intrinsic::abs: check_arithmetic(x->args[0]); return; + + case intrinsic::buffer_rank: + case intrinsic::buffer_elem_size: + case intrinsic::buffer_size_bytes: + case intrinsic::buffer_base: // We treat pointers to data as arithmetic + if (x->args.size() != 1) { + std::cerr << "Wrong number of arguments for buffer intrinsic " << x->intrinsic << std::endl; + error = true; + return; + } + check_pointer(x->args[0]); + state = arithmetic; + return; + case intrinsic::buffer_min: + case intrinsic::buffer_max: + case intrinsic::buffer_stride: + case intrinsic::buffer_fold_factor: + case intrinsic::buffer_extent: + if (x->args.size() != 2) { + std::cerr << "Wrong number of arguments for buffer intrinsic " << x->intrinsic << std::endl; + error = true; + return; + } + check_pointer(x->args[0]); + check_arithmetic(x->args[1]); + state = arithmetic; + return; + + case intrinsic::buffer_at: + check_pointer(x->args[0]); + for (std::size_t i = 1; i < x->args.size(); ++i) { + check_arithmetic(x->args[i]); + } + state = arithmetic; + return; + } + } + + void visit(const let_stmt* x) override { visit_let(x); } + + void check(const stmt& s, bool required = true) { + if (error) return; + + if (s.defined()) { + s.accept(this); + } else if (required) { + std::cerr << "Undefined statement " << std::endl; + error = true; + } + } + + void visit(const block* x) override { + check(x->a, /*required=*/false); + check(x->b, /*required=*/false); + } + void visit(const loop* x) override { + check_arithmetic(x->bounds.min); + check_arithmetic(x->bounds.max); + auto s = set_value_in_scope(ctx, x->name, {arithmetic, x}); + check(x->body); + } + void visit(const if_then_else* x) override { + check_arithmetic(x->condition); + check(x->true_body, /*required=*/false); + check(x->false_body, /*required=*/false); + } + void visit(const call_func* x) override { + for (var b : x->buffer_args) { + check_pointer(b); + } + } + + void check_arithmetic(const interval_expr& b, bool required = true) { + check_arithmetic(b.min, required); + check_arithmetic(b.max, required); + } + void check_arithmetic(const dim_expr& d, bool required = true) { + check_arithmetic(d.bounds, required); + check_arithmetic(d.stride, required); + check_arithmetic(d.fold_factor, required); + } + + void visit(const allocate* x) override { + for (const dim_expr& i : x->dims) { + check_arithmetic(i); + } + auto s = set_value_in_scope(ctx, x->name, {pointer, x}); + check(x->body); + } + void visit(const make_buffer* x) override { + check_arithmetic(x->base); // We treat pointers to data as arithmetic + check_arithmetic(x->elem_size); + for (const dim_expr& i : x->dims) { + check_arithmetic(i); + } + auto s = set_value_in_scope(ctx, x->name, {pointer, x}); + check(x->body); + } + void visit(const crop_buffer* x) override { + check_pointer(x->name); + for (const interval_expr& i : x->bounds) { + check_arithmetic(i, /*required=*/false); + } + check(x->body); + } + void visit(const crop_dim* x) override { + check_pointer(x->name); + check_arithmetic(x->bounds); + check(x->body); + } + void visit(const slice_buffer* x) override { + check_pointer(x->name); + for (const expr& i : x->at) { + check_arithmetic(i, /*required=*/false); + } + check(x->body); + } + void visit(const slice_dim* x) override { + check_pointer(x->name); + check_arithmetic(x->at); + check(x->body); + } + void visit(const truncate_rank* x) override { check_pointer(x->name); } + void visit(const class check* x) override { check(x->condition); } +}; + +} // namespace + +bool is_valid(const expr& e, span inputs, const node_context* symbols) { + validator v(inputs, symbols); + e.accept(&v); + return !v.error; +} + +bool is_valid(const stmt& s, span inputs, const node_context* symbols) { + validator v(inputs, symbols); + s.accept(&v); + return !v.error; +} + } // namespace slinky diff --git a/runtime/evaluate.h b/runtime/evaluate.h index e467bbe4..2c3fd653 100644 --- a/runtime/evaluate.h +++ b/runtime/evaluate.h @@ -55,6 +55,9 @@ std::optional evaluate_constant(const expr& e); // Returns true if `fn` can be evaluated. bool can_evaluate(intrinsic fn); +bool is_valid(const expr& e, span inputs, const node_context* symbols); +bool is_valid(const stmt& s, span inputs, const node_context* symbols); + } // namespace slinky #endif // SLINKY_RUNTIME_EVALUATE_H From 2857f7ad246162181b43a9a62dbd3003995fe044 Mon Sep 17 00:00:00 2001 From: Dillon Date: Wed, 2 Oct 2024 00:04:26 -0700 Subject: [PATCH 2/2] Fix up and move validator --- builder/BUILD | 1 + builder/pipeline.cc | 14 ++ runtime/BUILD | 8 + runtime/evaluate.cc | 311 ----------------------------- runtime/evaluate.h | 3 - runtime/print.cc | 41 ++-- runtime/print.h | 12 +- runtime/test/BUILD | 10 + runtime/test/validate.cc | 37 ++++ runtime/validate.cc | 418 +++++++++++++++++++++++++++++++++++++++ runtime/validate.h | 15 ++ 11 files changed, 538 insertions(+), 332 deletions(-) create mode 100644 runtime/test/validate.cc create mode 100644 runtime/validate.cc create mode 100644 runtime/validate.h diff --git a/builder/BUILD b/builder/BUILD index a8178b8a..042e4d4e 100644 --- a/builder/BUILD +++ b/builder/BUILD @@ -31,6 +31,7 @@ cc_library( "//base", "//base:chrome_trace", "//runtime", + "//runtime:validate", ], visibility = ["//visibility:public"], ) diff --git a/builder/pipeline.cc b/builder/pipeline.cc index f0baa657..1bb32eb2 100644 --- a/builder/pipeline.cc +++ b/builder/pipeline.cc @@ -23,6 +23,7 @@ #include "runtime/expr.h" #include "runtime/pipeline.h" #include "runtime/print.h" +#include "runtime/validate.h" namespace slinky { @@ -987,6 +988,19 @@ stmt build_pipeline(node_context& ctx, const std::vector& input std::cout << result << std::endl; } + std::vector external; + external.reserve(inputs.size() + outputs.size() + constants.size()); + for (const buffer_expr_ptr& i : inputs) { + external.push_back(i->sym()); + } + for (const buffer_expr_ptr& i : outputs) { + external.push_back(i->sym()); + } + for (const buffer_expr_ptr& i : constants) { + external.push_back(i->sym()); + } + assert(is_valid(result, external, &ctx)); + set_default_print_context(old_context); return result; diff --git a/runtime/BUILD b/runtime/BUILD index 3d2e8bea..b224326f 100644 --- a/runtime/BUILD +++ b/runtime/BUILD @@ -38,3 +38,11 @@ cc_library( deps = [":runtime"], visibility = ["//visibility:public"], ) + +cc_library( + name = "validate", + srcs = ["validate.cc"], + hdrs = ["validate.h"], + deps = [":runtime"], + visibility = ["//visibility:public"], +) diff --git a/runtime/evaluate.cc b/runtime/evaluate.cc index c55941c4..599c0fee 100644 --- a/runtime/evaluate.cc +++ b/runtime/evaluate.cc @@ -856,315 +856,4 @@ class constant_evaluator : public expr_visitor { std::optional evaluate_constant(const expr& e) { return constant_evaluator().eval(e); } -namespace { - -class validator : public expr_visitor, public stmt_visitor { - enum variable_state { - unknown, - pointer, - arithmetic, - }; - - const node_context* symbols; - - struct symbol_info { - variable_state state; - stmt decl_stmt; - expr decl_expr; - - symbol_info(variable_state state) : state(state) {} - symbol_info(variable_state state, stmt s) : state(state), decl_stmt(std::move(s)) {} - symbol_info(variable_state state, expr e) : state(state), decl_expr(std::move(e)) {} - }; - symbol_map ctx; - - variable_state state = unknown; - - void print_symbol_info(const symbol_info& s) { - if (s.decl_stmt.defined()) { - std::cerr << "Declared by:" << std::endl; - print(std::cerr, s.decl_stmt, symbols); - } else if (s.decl_expr.defined()) { - std::cerr << "Declared by:" << std::endl; - print(std::cerr, s.decl_expr, symbols); - } else { - std::cerr << "Externally defined symbol" << std::endl; - } - } - -public: - bool error = false; - - validator(span inputs, const node_context* symbols) : symbols(symbols) { - for (var i : inputs) { - ctx[i] = unknown; - } - } - - void visit(const variable* x) override { - if (!ctx.contains(x->name)) { - std::cerr << "Undefined variable "; - print(std::cerr, x, symbols); - std::cerr << " in context" << std::endl; - error = true; - } - state = unknown; - } - void visit(const constant*) override { state = unknown; } - - template - void visit_let(const T* x) { - std::vector> lets; - lets.reserve(op->lets.size()); - - for (size_t i = 0; i < op->lets.size(); ++i) { - if (!op->lets[i].defined()) { - std::cerr << "Undefined variable "; - print(std::cerr, x, symbols); - std::cerr << " in context" << std::endl; - error = true; - return; - } - } - if (!x->body.defined()) { - std::cerr << "Undefined let body in context" << std::endl; - error = true; - return; - } - x->body.accept(this); - } - - void visit(const let* x) override { visit_let(x); } - - void check_arithmetic(const expr& x, bool required = true) { - if (error) return; - - if (x.defined()) { - x.accept(this); - if (state == pointer) { - std::cerr << "Arithmetic on pointer value: "; - print(std::cerr, x, symbols); - std::cerr << std::endl << "In context:" << std::endl; - error = true; - } - } else if (required) { - std::cerr << "Undefined expression in context:" << std::endl; - } - } - - void check(const expr& x) { - if (error) return; - x.accept(this); - } - - template - void visit_binary_arithmetic(const T* x) { - check_arithmetic(x->a); - check_arithmetic(x->b); - } - - template - void visit_binary(const T* x) { - check(x->a); - check(x->b); - } - - void visit(const add* x) override { visit_binary_arithmetic(x); } - void visit(const sub* x) override { visit_binary_arithmetic(x); } - void visit(const mul* x) override { visit_binary_arithmetic(x); } - void visit(const div* x) override { visit_binary_arithmetic(x); } - void visit(const mod* x) override { visit_binary_arithmetic(x); } - void visit(const class min* x) override { visit_binary_arithmetic(x); } - void visit(const class max* x) override { visit_binary_arithmetic(x); } - void visit(const equal* x) override { visit_binary(x); } - void visit(const not_equal* x) override { visit_binary(x); } - void visit(const less* x) override { visit_binary_arithmetic(x); } - void visit(const less_equal* x) override { visit_binary_arithmetic(x); } - void visit(const logical_and* x) override { visit_binary(x); } - void visit(const logical_or* x) override { visit_binary(x); } - void visit(const logical_not* x) override { check(x->x); } - void visit(const class select* x) override { - check(x->condition); - check(x->true_value); - check(x->false_value); - } - - void check_pointer(const expr& x) { - if (error) return; - - if (x.defined()) { - x.accept(this); - if (state == arithmetic) { - std::cerr << "Expression " << x << " is arithmetic, expected pointer" << std::endl; - std::cerr << std::endl << "In context:" << std::endl; - error = true; - } - } else { - std::cerr << "Undefined expression in context:" << std::endl; - } - } - - void check_pointer(var name) { - std::optional state = ctx.lookup(name); - if (state && state->state == arithmetic) { - std::cerr << "Arithmetic symbol "; - print(std::cerr, var(name), symbols); - std::cerr << " used as a pointer" << std::endl; - print_symbol_info(*state); - error = true; - } - } - - void visit(const call* x) override { - if (error) return; - - switch (x->intrinsic) { - case intrinsic::negative_infinity: - case intrinsic::positive_infinity: - case intrinsic::indeterminate: - std::cerr << "Cannot evaluate " << x->intrinsic << std::endl; - error = true; - return; - - case intrinsic::abs: check_arithmetic(x->args[0]); return; - - case intrinsic::buffer_rank: - case intrinsic::buffer_elem_size: - case intrinsic::buffer_size_bytes: - case intrinsic::buffer_base: // We treat pointers to data as arithmetic - if (x->args.size() != 1) { - std::cerr << "Wrong number of arguments for buffer intrinsic " << x->intrinsic << std::endl; - error = true; - return; - } - check_pointer(x->args[0]); - state = arithmetic; - return; - case intrinsic::buffer_min: - case intrinsic::buffer_max: - case intrinsic::buffer_stride: - case intrinsic::buffer_fold_factor: - case intrinsic::buffer_extent: - if (x->args.size() != 2) { - std::cerr << "Wrong number of arguments for buffer intrinsic " << x->intrinsic << std::endl; - error = true; - return; - } - check_pointer(x->args[0]); - check_arithmetic(x->args[1]); - state = arithmetic; - return; - - case intrinsic::buffer_at: - check_pointer(x->args[0]); - for (std::size_t i = 1; i < x->args.size(); ++i) { - check_arithmetic(x->args[i]); - } - state = arithmetic; - return; - } - } - - void visit(const let_stmt* x) override { visit_let(x); } - - void check(const stmt& s, bool required = true) { - if (error) return; - - if (s.defined()) { - s.accept(this); - } else if (required) { - std::cerr << "Undefined statement " << std::endl; - error = true; - } - } - - void visit(const block* x) override { - check(x->a, /*required=*/false); - check(x->b, /*required=*/false); - } - void visit(const loop* x) override { - check_arithmetic(x->bounds.min); - check_arithmetic(x->bounds.max); - auto s = set_value_in_scope(ctx, x->name, {arithmetic, x}); - check(x->body); - } - void visit(const if_then_else* x) override { - check_arithmetic(x->condition); - check(x->true_body, /*required=*/false); - check(x->false_body, /*required=*/false); - } - void visit(const call_func* x) override { - for (var b : x->buffer_args) { - check_pointer(b); - } - } - - void check_arithmetic(const interval_expr& b, bool required = true) { - check_arithmetic(b.min, required); - check_arithmetic(b.max, required); - } - void check_arithmetic(const dim_expr& d, bool required = true) { - check_arithmetic(d.bounds, required); - check_arithmetic(d.stride, required); - check_arithmetic(d.fold_factor, required); - } - - void visit(const allocate* x) override { - for (const dim_expr& i : x->dims) { - check_arithmetic(i); - } - auto s = set_value_in_scope(ctx, x->name, {pointer, x}); - check(x->body); - } - void visit(const make_buffer* x) override { - check_arithmetic(x->base); // We treat pointers to data as arithmetic - check_arithmetic(x->elem_size); - for (const dim_expr& i : x->dims) { - check_arithmetic(i); - } - auto s = set_value_in_scope(ctx, x->name, {pointer, x}); - check(x->body); - } - void visit(const crop_buffer* x) override { - check_pointer(x->name); - for (const interval_expr& i : x->bounds) { - check_arithmetic(i, /*required=*/false); - } - check(x->body); - } - void visit(const crop_dim* x) override { - check_pointer(x->name); - check_arithmetic(x->bounds); - check(x->body); - } - void visit(const slice_buffer* x) override { - check_pointer(x->name); - for (const expr& i : x->at) { - check_arithmetic(i, /*required=*/false); - } - check(x->body); - } - void visit(const slice_dim* x) override { - check_pointer(x->name); - check_arithmetic(x->at); - check(x->body); - } - void visit(const truncate_rank* x) override { check_pointer(x->name); } - void visit(const class check* x) override { check(x->condition); } -}; - -} // namespace - -bool is_valid(const expr& e, span inputs, const node_context* symbols) { - validator v(inputs, symbols); - e.accept(&v); - return !v.error; -} - -bool is_valid(const stmt& s, span inputs, const node_context* symbols) { - validator v(inputs, symbols); - s.accept(&v); - return !v.error; -} - } // namespace slinky diff --git a/runtime/evaluate.h b/runtime/evaluate.h index 2c3fd653..e467bbe4 100644 --- a/runtime/evaluate.h +++ b/runtime/evaluate.h @@ -55,9 +55,6 @@ std::optional evaluate_constant(const expr& e); // Returns true if `fn` can be evaluated. bool can_evaluate(intrinsic fn); -bool is_valid(const expr& e, span inputs, const node_context* symbols); -bool is_valid(const stmt& s, span inputs, const node_context* symbols); - } // namespace slinky #endif // SLINKY_RUNTIME_EVALUATE_H diff --git a/runtime/print.cc b/runtime/print.cc index 4693b115..8d1413bb 100644 --- a/runtime/print.cc +++ b/runtime/print.cc @@ -8,8 +8,6 @@ namespace slinky { -std::string to_string(var sym) { return "<" + std::to_string(sym.id) + ">"; } - std::string to_string(memory_type type) { switch (type) { case memory_type::stack: return "stack"; @@ -45,9 +43,30 @@ std::string to_string(intrinsic fn) { } } -std::ostream& operator<<(std::ostream& os, var sym) { return os << to_string(sym); } +std::string to_string(stmt_node_type type) { + switch (type) { + case stmt_node_type::call_stmt: return "call_stmt"; + case stmt_node_type::copy_stmt: return "copy_stmt"; + case stmt_node_type::let_stmt: return "let_stmt"; + case stmt_node_type::block: return "block"; + case stmt_node_type::loop: return "loop"; + case stmt_node_type::allocate: return "allocate"; + case stmt_node_type::make_buffer: return "make_buffer"; + case stmt_node_type::clone_buffer: return "clone_buffer"; + case stmt_node_type::crop_buffer: return "crop_buffer"; + case stmt_node_type::crop_dim: return "crop_dim"; + case stmt_node_type::slice_buffer: return "slice_buffer"; + case stmt_node_type::slice_dim: return "slice_dim"; + case stmt_node_type::transpose: return "transpose"; + case stmt_node_type::check: return "check"; + + default: return ""; + } +} + std::ostream& operator<<(std::ostream& os, memory_type type) { return os << to_string(type); } std::ostream& operator<<(std::ostream& os, intrinsic fn) { return os << to_string(fn); } +std::ostream& operator<<(std::ostream& os, stmt_node_type type) { return os << to_string(type); } std::ostream& operator<<(std::ostream& os, const interval_expr& i) { return os << "[" << i.min << ", " << i.max << "]"; @@ -84,7 +103,7 @@ class printer : public expr_visitor, public stmt_visitor { if (context) { os << context->name(sym); } else { - os << sym; + os << "<" << sym.id << ">"; } return *this; } @@ -328,6 +347,11 @@ void print(std::ostream& os, const stmt& s, const node_context* ctx) { p << s; } +void print(std::ostream& os, var sym, const node_context* ctx) { + printer p(os, ctx ? ctx : default_context); + p << sym; +} + std::ostream& operator<<(std::ostream& os, const expr& e) { print(os, e); return os; @@ -338,13 +362,8 @@ std::ostream& operator<<(std::ostream& os, const stmt& s) { return os; } -std::ostream& operator<<(std::ostream& os, const std::tuple& e) { - print(os, std::get<0>(e), &std::get<1>(e)); - return os; -} - -std::ostream& operator<<(std::ostream& os, const std::tuple& s) { - print(os, std::get<0>(s), &std::get<1>(s)); +std::ostream& operator<<(std::ostream& os, var sym) { + print(os, sym); return os; } diff --git a/runtime/print.h b/runtime/print.h index c68b9b9e..6da58cc4 100644 --- a/runtime/print.h +++ b/runtime/print.h @@ -11,19 +11,17 @@ namespace slinky { void print(std::ostream& os, const expr& e, const node_context* ctx = nullptr); void print(std::ostream& os, const stmt& s, const node_context* ctx = nullptr); +void print(std::ostream& os, var sym, const node_context* ctx = nullptr); +std::ostream& operator<<(std::ostream& os, var sym); std::ostream& operator<<(std::ostream& os, const expr& e); std::ostream& operator<<(std::ostream& os, const stmt& s); - -// Enables std::cout << std::tie(expr, ctx) << ... -std::ostream& operator<<(std::ostream& os, const std::tuple& e); -std::ostream& operator<<(std::ostream& os, const std::tuple& s); - -std::ostream& operator<<(std::ostream& os, var sym); std::ostream& operator<<(std::ostream& os, const interval_expr& i); std::ostream& operator<<(std::ostream& os, const box_expr& i); + std::ostream& operator<<(std::ostream& os, intrinsic fn); std::ostream& operator<<(std::ostream& os, memory_type type); +std::ostream& operator<<(std::ostream& os, stmt_node_type type); template std::ostream& operator<<(std::ostream& os, const modulus_remainder& i) { @@ -36,9 +34,9 @@ std::ostream& operator<<(std::ostream& os, const dim& d); // It's not legal to overload std::to_string(), or anything else in std; // intended usage here is to do `using std::to_string;` followed by naked // to_string() calls. -std::string to_string(var sym); std::string to_string(intrinsic fn); std::string to_string(memory_type type); +std::string to_string(stmt_node_type type); } // namespace slinky diff --git a/runtime/test/BUILD b/runtime/test/BUILD index 48aa57a7..e08c5523 100644 --- a/runtime/test/BUILD +++ b/runtime/test/BUILD @@ -35,6 +35,16 @@ cc_test( size = "small", ) +cc_test( + name = "validate", + srcs = ["validate.cc"], + deps = [ + "//runtime:validate", + "@googletest//:gtest_main", + ], + size = "small", +) + cc_test( name = "buffer_benchmark", srcs = ["buffer_benchmark.cc"], diff --git a/runtime/test/validate.cc b/runtime/test/validate.cc new file mode 100644 index 00000000..42b7291e --- /dev/null +++ b/runtime/test/validate.cc @@ -0,0 +1,37 @@ +#include + +#include + +#include "runtime/validate.h" +#include "runtime/expr.h" + +namespace slinky { + +namespace { + +node_context ctx; +var x(ctx, "x"); +var y(ctx, "y"); +var z(ctx, "z"); + +} // namespace + +TEST(validate, var) { + std::vector vars = {x, y}; + ASSERT_TRUE(is_valid(x, vars, &ctx)); + ASSERT_TRUE(is_valid(y, vars, &ctx)); + ASSERT_FALSE(is_valid(z, vars, &ctx)); +} + +TEST(validate, buffer) { + std::vector vars = {x, y}; + ASSERT_FALSE(is_valid(let::make(z, x + y, buffer_max(z, 0)), vars, &ctx)); + ASSERT_TRUE(is_valid(allocate::make(z, memory_type::heap, 1, {}, check::make(buffer_max(z, 0))), vars, &ctx)); +} + +TEST(validate, out_of_scope) { + std::vector vars = {x, y}; + ASSERT_FALSE(is_valid(block::make({let_stmt::make(z, x + y, check::make(z)), check::make(z)}), vars, &ctx)); +} + +} // namespace slinky diff --git a/runtime/validate.cc b/runtime/validate.cc new file mode 100644 index 00000000..367a096f --- /dev/null +++ b/runtime/validate.cc @@ -0,0 +1,418 @@ +#include "runtime/validate.h" + +#include "base/span.h" +#include "runtime/depends_on.h" +#include "runtime/expr.h" +#include "runtime/print.h" +#include "runtime/stmt.h" + +namespace slinky { + +namespace { + +class validator : public expr_visitor, public stmt_visitor { + enum variable_state { + unknown, + pointer, + arithmetic, + }; + + const node_context* symbols; + + struct symbol_info { + variable_state state; + stmt decl_stmt; + expr decl_expr; + + symbol_info(variable_state state) : state(state) {} + symbol_info(variable_state state, stmt s) : state(state), decl_stmt(std::move(s)) {} + symbol_info(variable_state state, expr e) : state(state), decl_expr(std::move(e)) {} + }; + symbol_map ctx; + + variable_state state = unknown; + + void print_symbol_info(const symbol_info& s) { + if (s.decl_stmt.defined()) { + std::cerr << "Declared by:" << std::endl; + print(std::cerr, s.decl_stmt, symbols); + } else if (s.decl_expr.defined()) { + std::cerr << "Declared by:" << std::endl; + print(std::cerr, s.decl_expr, symbols); + } else { + std::cerr << "Externally defined symbol" << std::endl; + } + } + +public: + bool error = false; + + validator(span external, const node_context* symbols) : symbols(symbols) { + for (var i : external) { + ctx[i] = unknown; + } + } + + void visit(const variable* x) override { + std::optional x_state = ctx.lookup(x->sym); + if (!x_state) { + std::cerr << "Undefined variable "; + print(std::cerr, x, symbols); + std::cerr << " in context" << std::endl; + error = true; + } + state = x_state->state; + } + void visit(const constant*) override { state = unknown; } + + template + void visit_let(const T* x) { + std::vector> lets; + lets.reserve(x->lets.size()); + + for (size_t i = 0; i < x->lets.size(); ++i) { + check(x->lets[i].second); + lets.push_back(set_value_in_scope(ctx, x->lets[i].first, symbol_info(state, x))); + } + if (!x->body.defined()) { + std::cerr << "Undefined let body in context" << std::endl; + error = true; + return; + } + x->body.accept(this); + } + + void visit(const let* x) override { visit_let(x); } + + void check_arithmetic(const expr& x, bool required = true) { + if (error) return; + + if (x.defined()) { + x.accept(this); + if (state == pointer) { + std::cerr << "Arithmetic on pointer value: "; + print(std::cerr, x, symbols); + std::cerr << std::endl << "In context:" << std::endl; + error = true; + return; + } + } else if (required) { + std::cerr << "Undefined arithmetic expression in context:" << std::endl; + error = true; + return; + } + state = arithmetic; + } + + void check(const expr& x, bool required = true) { + if (error) return; + if (x.defined()) { + x.accept(this); + } else if (required) { + std::cerr << "Undefined expression in context:" << std::endl; + error = true; + } + } + + template + void visit_binary_arithmetic(const T* x) { + check_arithmetic(x->a, false); + check_arithmetic(x->b, false); + } + + template + void visit_binary(const T* x) { + check(x->a, false); + check(x->b, false); + } + + void visit(const add* x) override { visit_binary_arithmetic(x); } + void visit(const sub* x) override { visit_binary_arithmetic(x); } + void visit(const mul* x) override { visit_binary_arithmetic(x); } + void visit(const div* x) override { visit_binary_arithmetic(x); } + void visit(const mod* x) override { visit_binary_arithmetic(x); } + void visit(const class min* x) override { visit_binary_arithmetic(x); } + void visit(const class max* x) override { visit_binary_arithmetic(x); } + void visit(const equal* x) override { visit_binary(x); } + void visit(const not_equal* x) override { visit_binary(x); } + void visit(const less* x) override { visit_binary_arithmetic(x); } + void visit(const less_equal* x) override { visit_binary_arithmetic(x); } + void visit(const logical_and* x) override { visit_binary(x); } + void visit(const logical_or* x) override { visit_binary(x); } + void visit(const logical_not* x) override { check(x->a); } + void visit(const class select* x) override { + check(x->condition); + check(x->true_value, false); + check(x->false_value, false); + } + + void check_pointer(const expr& x) { + if (error) return; + + if (x.defined()) { + x.accept(this); + if (state == arithmetic) { + std::cerr << "Expression " << x << " is arithmetic, expected pointer" << std::endl; + std::cerr << std::endl << "In context:" << std::endl; + error = true; + } + } else { + std::cerr << "Undefined pointer expression in context:" << std::endl; + } + } + + void check_pointer(var sym) { + std::optional state = ctx.lookup(sym); + if (state && state->state == arithmetic) { + std::cerr << "Arithmetic symbol "; + print(std::cerr, var(sym), symbols); + std::cerr << " used as a pointer" << std::endl; + print_symbol_info(*state); + error = true; + } + state = pointer; + } + + void visit(const call* x) override { + if (error) return; + + switch (x->intrinsic) { + case intrinsic::negative_infinity: + case intrinsic::positive_infinity: + case intrinsic::indeterminate: + std::cerr << "Cannot evaluate " << x->intrinsic << std::endl; + error = true; + return; + + case intrinsic::abs: check_arithmetic(x->args[0]); return; + + case intrinsic::buffer_rank: + case intrinsic::buffer_elem_size: + case intrinsic::buffer_size_bytes: + if (x->args.size() != 1) { + std::cerr << "Wrong number of arguments for buffer intrinsic " << x->intrinsic << std::endl; + error = true; + return; + } + check_pointer(x->args[0]); + state = arithmetic; + return; + case intrinsic::buffer_min: + case intrinsic::buffer_max: + case intrinsic::buffer_stride: + case intrinsic::buffer_fold_factor: + if (x->args.size() != 2) { + std::cerr << "Wrong number of arguments for buffer intrinsic " << x->intrinsic << std::endl; + error = true; + return; + } + check_pointer(x->args[0]); + check_arithmetic(x->args[1]); + state = arithmetic; + return; + + case intrinsic::buffer_at: + check_pointer(x->args[0]); + for (size_t i = 1; i < x->args.size(); ++i) { + check_arithmetic(x->args[i], false); + } + state = arithmetic; + return; + case intrinsic::free: + if (x->args.size() != 1) { + std::cerr << "Wrong number of arguments for buffer intrinsic " << x->intrinsic << std::endl; + error = true; + return; + } + check_pointer(x->args[0]); + state = arithmetic; + return; + case intrinsic::and_then: + case intrinsic::or_else: + for (const expr& i : x->args) { + check_arithmetic(i); + } + state = arithmetic; + return; + case intrinsic::define_undef: + if (x->args.size() != 2) { + std::cerr << "Wrong number of arguments for buffer intrinsic " << x->intrinsic << std::endl; + error = true; + return; + } + check(x->args[0]); + check(x->args[1]); + return; + case intrinsic::semaphore_init: + if (x->args.size() != 2) { + std::cerr << "Wrong number of arguments for buffer intrinsic " << x->intrinsic << std::endl; + error = true; + return; + } + check(x->args[0]); + check_arithmetic(x->args[1], /*required=*/false); + state = arithmetic; + return; + case intrinsic::semaphore_wait: + case intrinsic::semaphore_signal: + if (x->args.size() % 2 != 0) { + std::cerr << "Wrong number of arguments for buffer intrinsic " << x->intrinsic << std::endl; + error = true; + return; + } + for (std::size_t i = 0; i < x->args.size(); i += 2) { + check(x->args[i]); + check_arithmetic(x->args[i + 1], /*required=*/false); + } + return; + case intrinsic::trace_begin: + if (x->args.size() != 1) { + std::cerr << "Wrong number of arguments for buffer intrinsic " << x->intrinsic << std::endl; + error = true; + return; + } + return; + case intrinsic::trace_end: + if (x->args.size() != 1) { + std::cerr << "Wrong number of arguments for buffer intrinsic " << x->intrinsic << std::endl; + error = true; + return; + } + state = arithmetic; + return; + } + } + + void visit(const let_stmt* x) override { visit_let(x); } + + void check(const stmt& s, bool required = true) { + if (s.defined()) { + s.accept(this); + } else if (required) { + std::cerr << "Undefined statement " << std::endl; + error = true; + } + if (error) { + std::cerr << " " << s.type() << std::endl; + } + } + + void check(const stmt& s, var sym, bool required = true) { + if (s.defined()) { + s.accept(this); + } else if (required) { + std::cerr << "Undefined statement " << std::endl; + error = true; + } + if (error) { + std::cerr << " " << s.type() << " " << sym << std::endl; + } + } + + void visit(const block* x) override { + for (const stmt& i : x->stmts) { + check(i); + } + } + void visit(const loop* x) override { + check_arithmetic(x->bounds.min); + check_arithmetic(x->bounds.max); + check_arithmetic(x->step); + auto s = set_value_in_scope(ctx, x->sym, {arithmetic, x}); + check(x->body); + } + void visit(const call_stmt* x) override { + for (var b : x->inputs) { + check_pointer(b); + } + for (var b : x->outputs) { + check_pointer(b); + } + } + void visit(const copy_stmt* x) override { + check_pointer(x->src); + check_pointer(x->dst); + } + + void check_arithmetic(const interval_expr& b, bool required = true) { + check_arithmetic(b.min, required); + check_arithmetic(b.max, required); + } + + void visit(const allocate* x) override { + for (const dim_expr& d : x->dims) { + check_arithmetic(d.bounds, /*required=*/true); + check_arithmetic(d.stride, /*required=*/false); + check_arithmetic(d.fold_factor, /*required=*/false); + } + auto s = set_value_in_scope(ctx, x->sym, {pointer, x}); + check(x->body, x->sym); + } + void visit(const make_buffer* x) override { + check_arithmetic(x->base); // We treat pointers to data as arithmetic + check_arithmetic(x->elem_size); + for (const dim_expr& d : x->dims) { + check_arithmetic(d.bounds, /*required=*/true); + check_arithmetic(d.stride, /*required=*/true); + check_arithmetic(d.fold_factor, /*required=*/false); + } + auto s = set_value_in_scope(ctx, x->sym, {pointer, x}); + check(x->body, x->sym); + } + + void visit(const crop_buffer* x) override { + check_pointer(x->src); + for (const interval_expr& i : x->bounds) { + check_arithmetic(i, /*required=*/false); + } + auto s = set_value_in_scope(ctx, x->sym, {pointer, x}); + check(x->body, x->sym); + } + void visit(const crop_dim* x) override { + check_pointer(x->src); + check_arithmetic(x->bounds); + auto s = set_value_in_scope(ctx, x->sym, {pointer, x}); + check(x->body, x->sym); + } + void visit(const slice_buffer* x) override { + check_pointer(x->src); + for (const expr& i : x->at) { + check_arithmetic(i, /*required=*/false); + } + auto s = set_value_in_scope(ctx, x->sym, {pointer, x}); + check(x->body, x->sym); + } + void visit(const slice_dim* x) override { + check_pointer(x->src); + check_arithmetic(x->at); + auto s = set_value_in_scope(ctx, x->sym, {pointer, x}); + check(x->body, x->sym); + } + void visit(const transpose* x) override { + check_pointer(x->src); + auto s = set_value_in_scope(ctx, x->sym, {pointer, x}); + check(x->body, x->sym); + } + void visit(const clone_buffer* x) override { + check_pointer(x->src); + auto s = set_value_in_scope(ctx, x->sym, {pointer, x}); + check(x->body, x->sym); + } + void visit(const class check* x) override { check(x->condition); } +}; + +} // namespace + +bool is_valid(const expr& e, span external, const node_context* symbols) { + validator v(external, symbols); + e.accept(&v); + return !v.error; +} + +bool is_valid(const stmt& s, span external, const node_context* symbols) { + validator v(external, symbols); + s.accept(&v); + return !v.error; +} + +} // namespace slinky diff --git a/runtime/validate.h b/runtime/validate.h new file mode 100644 index 00000000..1336619e --- /dev/null +++ b/runtime/validate.h @@ -0,0 +1,15 @@ +#ifndef SLINKY_RUNTIME_VALIDATE_H +#define SLINKY_RUNTIME_VALIDATE_H + +#include "base/span.h" +#include "runtime/expr.h" +#include "runtime/stmt.h" + +namespace slinky { + +bool is_valid(const expr& e, span external, const node_context* symbols); +bool is_valid(const stmt& s, span external, const node_context* symbols); + +} // namespace slinky + +#endif // SLINKY_RUNTIME_VALIDATE_H