From caa9a3a7818b09b240a41ba31117670c505708dc Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Mon, 25 Nov 2024 23:33:08 +0100 Subject: [PATCH] Support batching in MLIR autodiff operations (2nd try) (#2173) * Reorder generated code for Attributes, Enums, Types and Operations * add width attribute to Forwarddiffop and use it in `HandleAutoDiff` * add width attribute to autodiff as well * formatting --- enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td | 4 ++-- enzyme/Enzyme/MLIR/Dialect/Ops.h | 14 ++++++++------ enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp | 4 ++-- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index ca7659143370..be139fb3d8ba 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -80,7 +80,7 @@ def PlaceholderOp : Enzyme_Op<"placeholder", def ForwardDiffOp : Enzyme_Op<"fwddiff", [DeclareOpInterfaceMethods]> { let summary = "Perform forward mode AD on a funcop"; - let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity); + let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity, DefaultValuedAttr:$width); let results = (outs Variadic:$outputs); let assemblyFormat = [{ @@ -91,7 +91,7 @@ def ForwardDiffOp : Enzyme_Op<"fwddiff", def AutoDiffOp : Enzyme_Op<"autodiff", [DeclareOpInterfaceMethods]> { let summary = "Perform reverse mode AD on a funcop"; - let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity); + let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity, DefaultValuedAttr:$width); let results = (outs Variadic:$outputs); let assemblyFormat = [{ diff --git a/enzyme/Enzyme/MLIR/Dialect/Ops.h b/enzyme/Enzyme/MLIR/Dialect/Ops.h index 69aa6496b84f..cd2eb1f70d42 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Ops.h +++ b/enzyme/Enzyme/MLIR/Dialect/Ops.h @@ -19,15 +19,17 @@ #include "mlir/Bytecode/BytecodeOpInterface.h" -#define GET_OP_CLASSES -#include "Dialect/EnzymeOps.h.inc" -#define GET_TYPEDEF_CLASSES -#include "Dialect/EnzymeOpsTypes.h.inc" -// #include "Dialect/EnzymeTypes.h.inc" - #include "Dialect/EnzymeEnums.h.inc" #define GET_ATTRDEF_CLASSES #include "Dialect/EnzymeAttributes.h.inc" +#define GET_TYPEDEF_CLASSES +#include "Dialect/EnzymeOpsTypes.h.inc" + +#define GET_OP_CLASSES +#include "Dialect/EnzymeOps.h.inc" + +// #include "Dialect/EnzymeTypes.h.inc" + #endif // ENZYMEOPS_H diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index bf1ec877f9d3..c3fe53a7c4ea 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -139,7 +139,7 @@ struct DifferentiatePass : public DifferentiatePassBase { MTypeAnalysis TA; auto type_args = TA.getAnalyzedTypeInfo(fn); bool freeMemory = true; - size_t width = 1; + size_t width = CI.getWidth(); std::vector volatile_args; for (auto &a : fn.getFunctionBody().getArguments()) { @@ -259,7 +259,7 @@ struct DifferentiatePass : public DifferentiatePassBase { MTypeAnalysis TA; auto type_args = TA.getAnalyzedTypeInfo(fn); bool freeMemory = true; - size_t width = 1; + size_t width = CI.getWidth(); std::vector volatile_args; for (auto &a : fn.getFunctionBody().getArguments()) {