diff --git a/builder/BUILD b/builder/BUILD index a8178b8a..042e4d4e 100644 --- a/builder/BUILD +++ b/builder/BUILD @@ -31,6 +31,7 @@ cc_library( "//base", "//base:chrome_trace", "//runtime", + "//runtime:validate", ], visibility = ["//visibility:public"], ) diff --git a/builder/pipeline.cc b/builder/pipeline.cc index f0baa657..1bb32eb2 100644 --- a/builder/pipeline.cc +++ b/builder/pipeline.cc @@ -23,6 +23,7 @@ #include "runtime/expr.h" #include "runtime/pipeline.h" #include "runtime/print.h" +#include "runtime/validate.h" namespace slinky { @@ -987,6 +988,19 @@ stmt build_pipeline(node_context& ctx, const std::vector& input std::cout << result << std::endl; } + std::vector external; + external.reserve(inputs.size() + outputs.size() + constants.size()); + for (const buffer_expr_ptr& i : inputs) { + external.push_back(i->sym()); + } + for (const buffer_expr_ptr& i : outputs) { + external.push_back(i->sym()); + } + for (const buffer_expr_ptr& i : constants) { + external.push_back(i->sym()); + } + assert(is_valid(result, external, &ctx)); + set_default_print_context(old_context); return result; diff --git a/runtime/BUILD b/runtime/BUILD index 3d2e8bea..b224326f 100644 --- a/runtime/BUILD +++ b/runtime/BUILD @@ -38,3 +38,11 @@ cc_library( deps = [":runtime"], visibility = ["//visibility:public"], ) + +cc_library( + name = "validate", + srcs = ["validate.cc"], + hdrs = ["validate.h"], + deps = [":runtime"], + visibility = ["//visibility:public"], +) diff --git a/runtime/print.cc b/runtime/print.cc index 4693b115..8d1413bb 100644 --- a/runtime/print.cc +++ b/runtime/print.cc @@ -8,8 +8,6 @@ namespace slinky { -std::string to_string(var sym) { return "<" + std::to_string(sym.id) + ">"; } - std::string to_string(memory_type type) { switch (type) { case memory_type::stack: return "stack"; @@ -45,9 +43,30 @@ std::string to_string(intrinsic fn) { } } -std::ostream& operator<<(std::ostream& os, var sym) { return os << to_string(sym); } +std::string to_string(stmt_node_type type) { + switch (type) { + case stmt_node_type::call_stmt: return "call_stmt"; + case stmt_node_type::copy_stmt: return "copy_stmt"; + case stmt_node_type::let_stmt: return "let_stmt"; + case stmt_node_type::block: return "block"; + case stmt_node_type::loop: return "loop"; + case stmt_node_type::allocate: return "allocate"; + case stmt_node_type::make_buffer: return "make_buffer"; + case stmt_node_type::clone_buffer: return "clone_buffer"; + case stmt_node_type::crop_buffer: return "crop_buffer"; + case stmt_node_type::crop_dim: return "crop_dim"; + case stmt_node_type::slice_buffer: return "slice_buffer"; + case stmt_node_type::slice_dim: return "slice_dim"; + case stmt_node_type::transpose: return "transpose"; + case stmt_node_type::check: return "check"; + + default: return ""; + } +} + std::ostream& operator<<(std::ostream& os, memory_type type) { return os << to_string(type); } std::ostream& operator<<(std::ostream& os, intrinsic fn) { return os << to_string(fn); } +std::ostream& operator<<(std::ostream& os, stmt_node_type type) { return os << to_string(type); } std::ostream& operator<<(std::ostream& os, const interval_expr& i) { return os << "[" << i.min << ", " << i.max << "]"; @@ -84,7 +103,7 @@ class printer : public expr_visitor, public stmt_visitor { if (context) { os << context->name(sym); } else { - os << sym; + os << "<" << sym.id << ">"; } return *this; } @@ -328,6 +347,11 @@ void print(std::ostream& os, const stmt& s, const node_context* ctx) { p << s; } +void print(std::ostream& os, var sym, const node_context* ctx) { + printer p(os, ctx ? ctx : default_context); + p << sym; +} + std::ostream& operator<<(std::ostream& os, const expr& e) { print(os, e); return os; @@ -338,13 +362,8 @@ std::ostream& operator<<(std::ostream& os, const stmt& s) { return os; } -std::ostream& operator<<(std::ostream& os, const std::tuple& e) { - print(os, std::get<0>(e), &std::get<1>(e)); - return os; -} - -std::ostream& operator<<(std::ostream& os, const std::tuple& s) { - print(os, std::get<0>(s), &std::get<1>(s)); +std::ostream& operator<<(std::ostream& os, var sym) { + print(os, sym); return os; } diff --git a/runtime/print.h b/runtime/print.h index c68b9b9e..6da58cc4 100644 --- a/runtime/print.h +++ b/runtime/print.h @@ -11,19 +11,17 @@ namespace slinky { void print(std::ostream& os, const expr& e, const node_context* ctx = nullptr); void print(std::ostream& os, const stmt& s, const node_context* ctx = nullptr); +void print(std::ostream& os, var sym, const node_context* ctx = nullptr); +std::ostream& operator<<(std::ostream& os, var sym); std::ostream& operator<<(std::ostream& os, const expr& e); std::ostream& operator<<(std::ostream& os, const stmt& s); - -// Enables std::cout << std::tie(expr, ctx) << ... -std::ostream& operator<<(std::ostream& os, const std::tuple& e); -std::ostream& operator<<(std::ostream& os, const std::tuple& s); - -std::ostream& operator<<(std::ostream& os, var sym); std::ostream& operator<<(std::ostream& os, const interval_expr& i); std::ostream& operator<<(std::ostream& os, const box_expr& i); + std::ostream& operator<<(std::ostream& os, intrinsic fn); std::ostream& operator<<(std::ostream& os, memory_type type); +std::ostream& operator<<(std::ostream& os, stmt_node_type type); template std::ostream& operator<<(std::ostream& os, const modulus_remainder& i) { @@ -36,9 +34,9 @@ std::ostream& operator<<(std::ostream& os, const dim& d); // It's not legal to overload std::to_string(), or anything else in std; // intended usage here is to do `using std::to_string;` followed by naked // to_string() calls. -std::string to_string(var sym); std::string to_string(intrinsic fn); std::string to_string(memory_type type); +std::string to_string(stmt_node_type type); } // namespace slinky diff --git a/runtime/test/BUILD b/runtime/test/BUILD index 48aa57a7..e08c5523 100644 --- a/runtime/test/BUILD +++ b/runtime/test/BUILD @@ -35,6 +35,16 @@ cc_test( size = "small", ) +cc_test( + name = "validate", + srcs = ["validate.cc"], + deps = [ + "//runtime:validate", + "@googletest//:gtest_main", + ], + size = "small", +) + cc_test( name = "buffer_benchmark", srcs = ["buffer_benchmark.cc"], diff --git a/runtime/test/validate.cc b/runtime/test/validate.cc new file mode 100644 index 00000000..42b7291e --- /dev/null +++ b/runtime/test/validate.cc @@ -0,0 +1,37 @@ +#include + +#include + +#include "runtime/validate.h" +#include "runtime/expr.h" + +namespace slinky { + +namespace { + +node_context ctx; +var x(ctx, "x"); +var y(ctx, "y"); +var z(ctx, "z"); + +} // namespace + +TEST(validate, var) { + std::vector vars = {x, y}; + ASSERT_TRUE(is_valid(x, vars, &ctx)); + ASSERT_TRUE(is_valid(y, vars, &ctx)); + ASSERT_FALSE(is_valid(z, vars, &ctx)); +} + +TEST(validate, buffer) { + std::vector vars = {x, y}; + ASSERT_FALSE(is_valid(let::make(z, x + y, buffer_max(z, 0)), vars, &ctx)); + ASSERT_TRUE(is_valid(allocate::make(z, memory_type::heap, 1, {}, check::make(buffer_max(z, 0))), vars, &ctx)); +} + +TEST(validate, out_of_scope) { + std::vector vars = {x, y}; + ASSERT_FALSE(is_valid(block::make({let_stmt::make(z, x + y, check::make(z)), check::make(z)}), vars, &ctx)); +} + +} // namespace slinky diff --git a/runtime/validate.cc b/runtime/validate.cc new file mode 100644 index 00000000..367a096f --- /dev/null +++ b/runtime/validate.cc @@ -0,0 +1,418 @@ +#include "runtime/validate.h" + +#include "base/span.h" +#include "runtime/depends_on.h" +#include "runtime/expr.h" +#include "runtime/print.h" +#include "runtime/stmt.h" + +namespace slinky { + +namespace { + +class validator : public expr_visitor, public stmt_visitor { + enum variable_state { + unknown, + pointer, + arithmetic, + }; + + const node_context* symbols; + + struct symbol_info { + variable_state state; + stmt decl_stmt; + expr decl_expr; + + symbol_info(variable_state state) : state(state) {} + symbol_info(variable_state state, stmt s) : state(state), decl_stmt(std::move(s)) {} + symbol_info(variable_state state, expr e) : state(state), decl_expr(std::move(e)) {} + }; + symbol_map ctx; + + variable_state state = unknown; + + void print_symbol_info(const symbol_info& s) { + if (s.decl_stmt.defined()) { + std::cerr << "Declared by:" << std::endl; + print(std::cerr, s.decl_stmt, symbols); + } else if (s.decl_expr.defined()) { + std::cerr << "Declared by:" << std::endl; + print(std::cerr, s.decl_expr, symbols); + } else { + std::cerr << "Externally defined symbol" << std::endl; + } + } + +public: + bool error = false; + + validator(span external, const node_context* symbols) : symbols(symbols) { + for (var i : external) { + ctx[i] = unknown; + } + } + + void visit(const variable* x) override { + std::optional x_state = ctx.lookup(x->sym); + if (!x_state) { + std::cerr << "Undefined variable "; + print(std::cerr, x, symbols); + std::cerr << " in context" << std::endl; + error = true; + } + state = x_state->state; + } + void visit(const constant*) override { state = unknown; } + + template + void visit_let(const T* x) { + std::vector> lets; + lets.reserve(x->lets.size()); + + for (size_t i = 0; i < x->lets.size(); ++i) { + check(x->lets[i].second); + lets.push_back(set_value_in_scope(ctx, x->lets[i].first, symbol_info(state, x))); + } + if (!x->body.defined()) { + std::cerr << "Undefined let body in context" << std::endl; + error = true; + return; + } + x->body.accept(this); + } + + void visit(const let* x) override { visit_let(x); } + + void check_arithmetic(const expr& x, bool required = true) { + if (error) return; + + if (x.defined()) { + x.accept(this); + if (state == pointer) { + std::cerr << "Arithmetic on pointer value: "; + print(std::cerr, x, symbols); + std::cerr << std::endl << "In context:" << std::endl; + error = true; + return; + } + } else if (required) { + std::cerr << "Undefined arithmetic expression in context:" << std::endl; + error = true; + return; + } + state = arithmetic; + } + + void check(const expr& x, bool required = true) { + if (error) return; + if (x.defined()) { + x.accept(this); + } else if (required) { + std::cerr << "Undefined expression in context:" << std::endl; + error = true; + } + } + + template + void visit_binary_arithmetic(const T* x) { + check_arithmetic(x->a, false); + check_arithmetic(x->b, false); + } + + template + void visit_binary(const T* x) { + check(x->a, false); + check(x->b, false); + } + + void visit(const add* x) override { visit_binary_arithmetic(x); } + void visit(const sub* x) override { visit_binary_arithmetic(x); } + void visit(const mul* x) override { visit_binary_arithmetic(x); } + void visit(const div* x) override { visit_binary_arithmetic(x); } + void visit(const mod* x) override { visit_binary_arithmetic(x); } + void visit(const class min* x) override { visit_binary_arithmetic(x); } + void visit(const class max* x) override { visit_binary_arithmetic(x); } + void visit(const equal* x) override { visit_binary(x); } + void visit(const not_equal* x) override { visit_binary(x); } + void visit(const less* x) override { visit_binary_arithmetic(x); } + void visit(const less_equal* x) override { visit_binary_arithmetic(x); } + void visit(const logical_and* x) override { visit_binary(x); } + void visit(const logical_or* x) override { visit_binary(x); } + void visit(const logical_not* x) override { check(x->a); } + void visit(const class select* x) override { + check(x->condition); + check(x->true_value, false); + check(x->false_value, false); + } + + void check_pointer(const expr& x) { + if (error) return; + + if (x.defined()) { + x.accept(this); + if (state == arithmetic) { + std::cerr << "Expression " << x << " is arithmetic, expected pointer" << std::endl; + std::cerr << std::endl << "In context:" << std::endl; + error = true; + } + } else { + std::cerr << "Undefined pointer expression in context:" << std::endl; + } + } + + void check_pointer(var sym) { + std::optional state = ctx.lookup(sym); + if (state && state->state == arithmetic) { + std::cerr << "Arithmetic symbol "; + print(std::cerr, var(sym), symbols); + std::cerr << " used as a pointer" << std::endl; + print_symbol_info(*state); + error = true; + } + state = pointer; + } + + void visit(const call* x) override { + if (error) return; + + switch (x->intrinsic) { + case intrinsic::negative_infinity: + case intrinsic::positive_infinity: + case intrinsic::indeterminate: + std::cerr << "Cannot evaluate " << x->intrinsic << std::endl; + error = true; + return; + + case intrinsic::abs: check_arithmetic(x->args[0]); return; + + case intrinsic::buffer_rank: + case intrinsic::buffer_elem_size: + case intrinsic::buffer_size_bytes: + if (x->args.size() != 1) { + std::cerr << "Wrong number of arguments for buffer intrinsic " << x->intrinsic << std::endl; + error = true; + return; + } + check_pointer(x->args[0]); + state = arithmetic; + return; + case intrinsic::buffer_min: + case intrinsic::buffer_max: + case intrinsic::buffer_stride: + case intrinsic::buffer_fold_factor: + if (x->args.size() != 2) { + std::cerr << "Wrong number of arguments for buffer intrinsic " << x->intrinsic << std::endl; + error = true; + return; + } + check_pointer(x->args[0]); + check_arithmetic(x->args[1]); + state = arithmetic; + return; + + case intrinsic::buffer_at: + check_pointer(x->args[0]); + for (size_t i = 1; i < x->args.size(); ++i) { + check_arithmetic(x->args[i], false); + } + state = arithmetic; + return; + case intrinsic::free: + if (x->args.size() != 1) { + std::cerr << "Wrong number of arguments for buffer intrinsic " << x->intrinsic << std::endl; + error = true; + return; + } + check_pointer(x->args[0]); + state = arithmetic; + return; + case intrinsic::and_then: + case intrinsic::or_else: + for (const expr& i : x->args) { + check_arithmetic(i); + } + state = arithmetic; + return; + case intrinsic::define_undef: + if (x->args.size() != 2) { + std::cerr << "Wrong number of arguments for buffer intrinsic " << x->intrinsic << std::endl; + error = true; + return; + } + check(x->args[0]); + check(x->args[1]); + return; + case intrinsic::semaphore_init: + if (x->args.size() != 2) { + std::cerr << "Wrong number of arguments for buffer intrinsic " << x->intrinsic << std::endl; + error = true; + return; + } + check(x->args[0]); + check_arithmetic(x->args[1], /*required=*/false); + state = arithmetic; + return; + case intrinsic::semaphore_wait: + case intrinsic::semaphore_signal: + if (x->args.size() % 2 != 0) { + std::cerr << "Wrong number of arguments for buffer intrinsic " << x->intrinsic << std::endl; + error = true; + return; + } + for (std::size_t i = 0; i < x->args.size(); i += 2) { + check(x->args[i]); + check_arithmetic(x->args[i + 1], /*required=*/false); + } + return; + case intrinsic::trace_begin: + if (x->args.size() != 1) { + std::cerr << "Wrong number of arguments for buffer intrinsic " << x->intrinsic << std::endl; + error = true; + return; + } + return; + case intrinsic::trace_end: + if (x->args.size() != 1) { + std::cerr << "Wrong number of arguments for buffer intrinsic " << x->intrinsic << std::endl; + error = true; + return; + } + state = arithmetic; + return; + } + } + + void visit(const let_stmt* x) override { visit_let(x); } + + void check(const stmt& s, bool required = true) { + if (s.defined()) { + s.accept(this); + } else if (required) { + std::cerr << "Undefined statement " << std::endl; + error = true; + } + if (error) { + std::cerr << " " << s.type() << std::endl; + } + } + + void check(const stmt& s, var sym, bool required = true) { + if (s.defined()) { + s.accept(this); + } else if (required) { + std::cerr << "Undefined statement " << std::endl; + error = true; + } + if (error) { + std::cerr << " " << s.type() << " " << sym << std::endl; + } + } + + void visit(const block* x) override { + for (const stmt& i : x->stmts) { + check(i); + } + } + void visit(const loop* x) override { + check_arithmetic(x->bounds.min); + check_arithmetic(x->bounds.max); + check_arithmetic(x->step); + auto s = set_value_in_scope(ctx, x->sym, {arithmetic, x}); + check(x->body); + } + void visit(const call_stmt* x) override { + for (var b : x->inputs) { + check_pointer(b); + } + for (var b : x->outputs) { + check_pointer(b); + } + } + void visit(const copy_stmt* x) override { + check_pointer(x->src); + check_pointer(x->dst); + } + + void check_arithmetic(const interval_expr& b, bool required = true) { + check_arithmetic(b.min, required); + check_arithmetic(b.max, required); + } + + void visit(const allocate* x) override { + for (const dim_expr& d : x->dims) { + check_arithmetic(d.bounds, /*required=*/true); + check_arithmetic(d.stride, /*required=*/false); + check_arithmetic(d.fold_factor, /*required=*/false); + } + auto s = set_value_in_scope(ctx, x->sym, {pointer, x}); + check(x->body, x->sym); + } + void visit(const make_buffer* x) override { + check_arithmetic(x->base); // We treat pointers to data as arithmetic + check_arithmetic(x->elem_size); + for (const dim_expr& d : x->dims) { + check_arithmetic(d.bounds, /*required=*/true); + check_arithmetic(d.stride, /*required=*/true); + check_arithmetic(d.fold_factor, /*required=*/false); + } + auto s = set_value_in_scope(ctx, x->sym, {pointer, x}); + check(x->body, x->sym); + } + + void visit(const crop_buffer* x) override { + check_pointer(x->src); + for (const interval_expr& i : x->bounds) { + check_arithmetic(i, /*required=*/false); + } + auto s = set_value_in_scope(ctx, x->sym, {pointer, x}); + check(x->body, x->sym); + } + void visit(const crop_dim* x) override { + check_pointer(x->src); + check_arithmetic(x->bounds); + auto s = set_value_in_scope(ctx, x->sym, {pointer, x}); + check(x->body, x->sym); + } + void visit(const slice_buffer* x) override { + check_pointer(x->src); + for (const expr& i : x->at) { + check_arithmetic(i, /*required=*/false); + } + auto s = set_value_in_scope(ctx, x->sym, {pointer, x}); + check(x->body, x->sym); + } + void visit(const slice_dim* x) override { + check_pointer(x->src); + check_arithmetic(x->at); + auto s = set_value_in_scope(ctx, x->sym, {pointer, x}); + check(x->body, x->sym); + } + void visit(const transpose* x) override { + check_pointer(x->src); + auto s = set_value_in_scope(ctx, x->sym, {pointer, x}); + check(x->body, x->sym); + } + void visit(const clone_buffer* x) override { + check_pointer(x->src); + auto s = set_value_in_scope(ctx, x->sym, {pointer, x}); + check(x->body, x->sym); + } + void visit(const class check* x) override { check(x->condition); } +}; + +} // namespace + +bool is_valid(const expr& e, span external, const node_context* symbols) { + validator v(external, symbols); + e.accept(&v); + return !v.error; +} + +bool is_valid(const stmt& s, span external, const node_context* symbols) { + validator v(external, symbols); + s.accept(&v); + return !v.error; +} + +} // namespace slinky diff --git a/runtime/validate.h b/runtime/validate.h new file mode 100644 index 00000000..1336619e --- /dev/null +++ b/runtime/validate.h @@ -0,0 +1,15 @@ +#ifndef SLINKY_RUNTIME_VALIDATE_H +#define SLINKY_RUNTIME_VALIDATE_H + +#include "base/span.h" +#include "runtime/expr.h" +#include "runtime/stmt.h" + +namespace slinky { + +bool is_valid(const expr& e, span external, const node_context* symbols); +bool is_valid(const stmt& s, span external, const node_context* symbols); + +} // namespace slinky + +#endif // SLINKY_RUNTIME_VALIDATE_H