Skip to content

Commit

Permalink
tblgen: Implement SelectIfComplex (#2183)
Browse files Browse the repository at this point in the history
* selectifcomplex

* remove ConjIfComplex

this is unused and can now be implemented using SelectIfComplex

* used primal
  • Loading branch information
Pangoraw authored Nov 29, 2024
1 parent 095ee7e commit 7b97a9b
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 25 deletions.
9 changes: 4 additions & 5 deletions enzyme/Enzyme/MLIR/Implementations/Common.td
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,17 @@ def SelectIfActive : Operation</*primal*/0, /*shadow*/0, /*custom*/1> {

}

def SelectIfComplex : Operation</*primal*/1, /*shadow*/0, /*custom*/0> {

}

class ConstantFP<string val, string dialect_, string op_, string type_=""> : Operation</*primal*/0, /*shadow*/0> {
string value = val;
string dialect = dialect_;
string opName = op_;
string type = type_;
}

class ConjIfComplex<string dialect_, string op_> : Operation</*primal*/1, /*shadow*/0> {
string dialect = dialect_;
string opName = op_;
}

def ResultTypes : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, "op->getResultTypes()">;

def TypeOf : Operation</*primal*/0, /*shadow*/0> {
Expand Down
65 changes: 45 additions & 20 deletions enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,45 +488,70 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os,
os << curIndent << INDENT << "imVal;\n";
os << curIndent << "})";
return true;
} else if (opName == "ConjIfComplex" ||
Def->isSubClassOf("ConjIfComplex")) {
if (resultRoot->getNumArgs() != 1)
} else if (opName == "SelectIfComplex" ||
Def->isSubClassOf("SelectIfComplex")) {
if (resultRoot->getNumArgs() != 3)
PrintFatalError(pattern->getLoc(),
"only three op ConjIfComplex supported");
"only three op SelectIfComplex supported");

os << "({\n";
os << curIndent << INDENT << "// Computing ConjIfComplex\n";
os << curIndent << INDENT << "// Computing SelectIfComplex\n";
if (intrinsic == MLIRDerivatives)
os << curIndent << INDENT << "mlir::Value imVal";
os << curIndent << INDENT << "mlir::Value imVal = ";
else
os << curIndent << INDENT << "llvm::Value *imVal";

os << curIndent << INDENT << "if (!gutils->isConstantValue(";
os << curIndent << INDENT << "llvm::Value *imVal = ";

if (isa<UnsetInit>(resultRoot->getArg(0)) && resultRoot->getArgName(0)) {
auto name = resultRoot->getArgName(0)->getAsUnquotedString();
auto [ord, isVec, ext] =
nameToOrdinal.lookup(name, pattern, resultRoot);
os << ord;
assert(!ext.size());
os << ord;
os << ";\n";
os << ord << ";\n";
} else {
handle(curIndent + INDENT + INDENT, argPattern + "_cic", os, pattern,
resultRoot->getArg(0), builder, nameToOrdinal, lookup, retidx,
origName, newFromOriginal, intrinsic);
os << ";\n";
}

os << " (isa<ComplexType>(imVal.getType()) || "
os << curIndent << INDENT
<< "if (isa<ComplexType>(imVal.getType()) || "
"(isa<TensorType>(imVal.getType()) && "
"isa<ComplexType>(cast<TensorType>(imVal.getType()).getElementType("
")))) ? ";
os << builder << ".create<"
<< cast<StringInit>(Def->getValueInit("dialect"))->getValue()
<< "::" << cast<StringInit>(Def->getValueInit("opName"))->getValue()
<< ">(op.getLoc(), imVal.getType(), imVal) : imVal;\n";
os << curIndent << "})";
")))) {\n";

os << curIndent << INDENT << INDENT << "imVal = ";
if (isa<UnsetInit>(resultRoot->getArg(1)) && resultRoot->getArgName(1)) {
auto name = resultRoot->getArgName(1)->getAsUnquotedString();
auto [ord, isVec, ext] =
nameToOrdinal.lookup(name, pattern, resultRoot);
assert(!ext.size());
os << ord << ";\n";
} else {
handle(curIndent + INDENT + INDENT, argPattern + "_cic", os, pattern,
resultRoot->getArg(1), builder, nameToOrdinal, lookup, retidx,
origName, newFromOriginal, intrinsic);
os << ";\n";
}

os << curIndent << INDENT << "} else {\n";

os << curIndent << INDENT << INDENT << "imVal = ";
if (isa<UnsetInit>(resultRoot->getArg(2)) && resultRoot->getArgName(2)) {
auto name = resultRoot->getArgName(2)->getAsUnquotedString();
auto [ord, isVec, ext] =
nameToOrdinal.lookup(name, pattern, resultRoot);
assert(!ext.size());
os << ord << ";\n";
} else {
handle(curIndent + INDENT + INDENT, argPattern + "_cic", os, pattern,
resultRoot->getArg(2), builder, nameToOrdinal, lookup, retidx,
origName, newFromOriginal, intrinsic);
os << ";\n";
}

os << curIndent << INDENT << "}\n";
os << curIndent << INDENT << "imVal;";
os << curIndent << INDENT << "})\n";
return true;
} else if (opName == "ConstantFP" || Def->isSubClassOf("ConstantFP")) {
auto value = dyn_cast<StringInit>(Def->getValueInit("value"));
Expand Down

0 comments on commit 7b97a9b

Please sign in to comment.