Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validator #459

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions builder/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ cc_library(
"//base",
"//base:chrome_trace",
"//runtime",
"//runtime:validate",
],
visibility = ["//visibility:public"],
)
Expand Down
14 changes: 14 additions & 0 deletions builder/pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "runtime/expr.h"
#include "runtime/pipeline.h"
#include "runtime/print.h"
#include "runtime/validate.h"

namespace slinky {

Expand Down Expand Up @@ -987,6 +988,19 @@ stmt build_pipeline(node_context& ctx, const std::vector<buffer_expr_ptr>& input
std::cout << result << std::endl;
}

std::vector<var> external;
external.reserve(inputs.size() + outputs.size() + constants.size());
for (const buffer_expr_ptr& i : inputs) {
external.push_back(i->sym());
}
for (const buffer_expr_ptr& i : outputs) {
external.push_back(i->sym());
}
for (const buffer_expr_ptr& i : constants) {
external.push_back(i->sym());
}
assert(is_valid(result, external, &ctx));

set_default_print_context(old_context);

return result;
Expand Down
8 changes: 8 additions & 0 deletions runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,11 @@ cc_library(
deps = [":runtime"],
visibility = ["//visibility:public"],
)

cc_library(
name = "validate",
srcs = ["validate.cc"],
hdrs = ["validate.h"],
deps = [":runtime"],
visibility = ["//visibility:public"],
)
41 changes: 30 additions & 11 deletions runtime/print.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

namespace slinky {

std::string to_string(var sym) { return "<" + std::to_string(sym.id) + ">"; }

std::string to_string(memory_type type) {
switch (type) {
case memory_type::stack: return "stack";
Expand Down Expand Up @@ -45,9 +43,30 @@ std::string to_string(intrinsic fn) {
}
}

std::ostream& operator<<(std::ostream& os, var sym) { return os << to_string(sym); }
std::string to_string(stmt_node_type type) {
switch (type) {
case stmt_node_type::call_stmt: return "call_stmt";
case stmt_node_type::copy_stmt: return "copy_stmt";
case stmt_node_type::let_stmt: return "let_stmt";
case stmt_node_type::block: return "block";
case stmt_node_type::loop: return "loop";
case stmt_node_type::allocate: return "allocate";
case stmt_node_type::make_buffer: return "make_buffer";
case stmt_node_type::clone_buffer: return "clone_buffer";
case stmt_node_type::crop_buffer: return "crop_buffer";
case stmt_node_type::crop_dim: return "crop_dim";
case stmt_node_type::slice_buffer: return "slice_buffer";
case stmt_node_type::slice_dim: return "slice_dim";
case stmt_node_type::transpose: return "transpose";
case stmt_node_type::check: return "check";

default: return "<invalid stmt_node_type>";
}
}

std::ostream& operator<<(std::ostream& os, memory_type type) { return os << to_string(type); }
std::ostream& operator<<(std::ostream& os, intrinsic fn) { return os << to_string(fn); }
std::ostream& operator<<(std::ostream& os, stmt_node_type type) { return os << to_string(type); }

std::ostream& operator<<(std::ostream& os, const interval_expr& i) {
return os << "[" << i.min << ", " << i.max << "]";
Expand Down Expand Up @@ -84,7 +103,7 @@ class printer : public expr_visitor, public stmt_visitor {
if (context) {
os << context->name(sym);
} else {
os << sym;
os << "<" << sym.id << ">";
}
return *this;
}
Expand Down Expand Up @@ -328,6 +347,11 @@ void print(std::ostream& os, const stmt& s, const node_context* ctx) {
p << s;
}

void print(std::ostream& os, var sym, const node_context* ctx) {
printer p(os, ctx ? ctx : default_context);
p << sym;
}

std::ostream& operator<<(std::ostream& os, const expr& e) {
print(os, e);
return os;
Expand All @@ -338,13 +362,8 @@ std::ostream& operator<<(std::ostream& os, const stmt& s) {
return os;
}

std::ostream& operator<<(std::ostream& os, const std::tuple<const expr&, const node_context&>& e) {
print(os, std::get<0>(e), &std::get<1>(e));
return os;
}

std::ostream& operator<<(std::ostream& os, const std::tuple<const stmt&, const node_context&>& s) {
print(os, std::get<0>(s), &std::get<1>(s));
std::ostream& operator<<(std::ostream& os, var sym) {
print(os, sym);
return os;
}

Expand Down
12 changes: 5 additions & 7 deletions runtime/print.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,17 @@ namespace slinky {

void print(std::ostream& os, const expr& e, const node_context* ctx = nullptr);
void print(std::ostream& os, const stmt& s, const node_context* ctx = nullptr);
void print(std::ostream& os, var sym, const node_context* ctx = nullptr);

std::ostream& operator<<(std::ostream& os, var sym);
std::ostream& operator<<(std::ostream& os, const expr& e);
std::ostream& operator<<(std::ostream& os, const stmt& s);

// Enables std::cout << std::tie(expr, ctx) << ...
std::ostream& operator<<(std::ostream& os, const std::tuple<const expr&, const node_context&>& e);
std::ostream& operator<<(std::ostream& os, const std::tuple<const stmt&, const node_context&>& s);

std::ostream& operator<<(std::ostream& os, var sym);
std::ostream& operator<<(std::ostream& os, const interval_expr& i);
std::ostream& operator<<(std::ostream& os, const box_expr& i);

std::ostream& operator<<(std::ostream& os, intrinsic fn);
std::ostream& operator<<(std::ostream& os, memory_type type);
std::ostream& operator<<(std::ostream& os, stmt_node_type type);

template <typename T>
std::ostream& operator<<(std::ostream& os, const modulus_remainder<T>& i) {
Expand All @@ -36,9 +34,9 @@ std::ostream& operator<<(std::ostream& os, const dim& d);
// It's not legal to overload std::to_string(), or anything else in std;
// intended usage here is to do `using std::to_string;` followed by naked
// to_string() calls.
std::string to_string(var sym);
std::string to_string(intrinsic fn);
std::string to_string(memory_type type);
std::string to_string(stmt_node_type type);

} // namespace slinky

Expand Down
10 changes: 10 additions & 0 deletions runtime/test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@ cc_test(
size = "small",
)

cc_test(
name = "validate",
srcs = ["validate.cc"],
deps = [
"//runtime:validate",
"@googletest//:gtest_main",
],
size = "small",
)

cc_test(
name = "buffer_benchmark",
srcs = ["buffer_benchmark.cc"],
Expand Down
37 changes: 37 additions & 0 deletions runtime/test/validate.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include <gtest/gtest.h>

#include <cassert>

#include "runtime/validate.h"
#include "runtime/expr.h"

namespace slinky {

namespace {

node_context ctx;
var x(ctx, "x");
var y(ctx, "y");
var z(ctx, "z");

} // namespace

TEST(validate, var) {
std::vector<var> vars = {x, y};
ASSERT_TRUE(is_valid(x, vars, &ctx));
ASSERT_TRUE(is_valid(y, vars, &ctx));
ASSERT_FALSE(is_valid(z, vars, &ctx));
}

TEST(validate, buffer) {
std::vector<var> vars = {x, y};
ASSERT_FALSE(is_valid(let::make(z, x + y, buffer_max(z, 0)), vars, &ctx));
ASSERT_TRUE(is_valid(allocate::make(z, memory_type::heap, 1, {}, check::make(buffer_max(z, 0))), vars, &ctx));
}

TEST(validate, out_of_scope) {
std::vector<var> vars = {x, y};
ASSERT_FALSE(is_valid(block::make({let_stmt::make(z, x + y, check::make(z)), check::make(z)}), vars, &ctx));
}

} // namespace slinky
Loading
Loading