Skip to content

Commit

Permalink
Support batching in MLIR autodiff operations (2nd try) (#2173)
Browse files Browse the repository at this point in the history
* Reorder generated code for Attributes, Enums, Types and Operations

* add width attribute to Forwarddiffop and use it in `HandleAutoDiff`

* add width attribute to autodiff as well

* formatting
  • Loading branch information
jumerckx authored Nov 25, 2024
1 parent 3a64b16 commit caa9a3a
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
4 changes: 2 additions & 2 deletions enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def PlaceholderOp : Enzyme_Op<"placeholder",
def ForwardDiffOp : Enzyme_Op<"fwddiff",
[DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Perform forward mode AD on a funcop";
let arguments = (ins FlatSymbolRefAttr:$fn, Variadic<AnyType>:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity);
let arguments = (ins FlatSymbolRefAttr:$fn, Variadic<AnyType>:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity, DefaultValuedAttr<I64Attr, "1">:$width);
let results = (outs Variadic<AnyType>:$outputs);

let assemblyFormat = [{
Expand All @@ -91,7 +91,7 @@ def ForwardDiffOp : Enzyme_Op<"fwddiff",
def AutoDiffOp : Enzyme_Op<"autodiff",
[DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Perform reverse mode AD on a funcop";
let arguments = (ins FlatSymbolRefAttr:$fn, Variadic<AnyType>:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity);
let arguments = (ins FlatSymbolRefAttr:$fn, Variadic<AnyType>:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity, DefaultValuedAttr<I64Attr, "1">:$width);
let results = (outs Variadic<AnyType>:$outputs);

let assemblyFormat = [{
Expand Down
14 changes: 8 additions & 6 deletions enzyme/Enzyme/MLIR/Dialect/Ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,17 @@

#include "mlir/Bytecode/BytecodeOpInterface.h"

#define GET_OP_CLASSES
#include "Dialect/EnzymeOps.h.inc"
#define GET_TYPEDEF_CLASSES
#include "Dialect/EnzymeOpsTypes.h.inc"
// #include "Dialect/EnzymeTypes.h.inc"

#include "Dialect/EnzymeEnums.h.inc"

#define GET_ATTRDEF_CLASSES
#include "Dialect/EnzymeAttributes.h.inc"

#define GET_TYPEDEF_CLASSES
#include "Dialect/EnzymeOpsTypes.h.inc"

#define GET_OP_CLASSES
#include "Dialect/EnzymeOps.h.inc"

// #include "Dialect/EnzymeTypes.h.inc"

#endif // ENZYMEOPS_H
4 changes: 2 additions & 2 deletions enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {
MTypeAnalysis TA;
auto type_args = TA.getAnalyzedTypeInfo(fn);
bool freeMemory = true;
size_t width = 1;
size_t width = CI.getWidth();

std::vector<bool> volatile_args;
for (auto &a : fn.getFunctionBody().getArguments()) {
Expand Down Expand Up @@ -259,7 +259,7 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {
MTypeAnalysis TA;
auto type_args = TA.getAnalyzedTypeInfo(fn);
bool freeMemory = true;
size_t width = 1;
size_t width = CI.getWidth();

std::vector<bool> volatile_args;
for (auto &a : fn.getFunctionBody().getArguments()) {
Expand Down

0 comments on commit caa9a3a

Please sign in to comment.