diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index 4e583273698..ca71867462f 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -333,9 +333,11 @@ void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle, return unwrap( AHandle(wrap(&B), wrap(CI), Args.size(), refs.data(), gutils)); }; - shadowErasers[Name] = [=](IRBuilder<> &B, Value *ToFree) -> llvm::CallInst * { - return cast_or_null(unwrap(FHandle(wrap(&B), wrap(ToFree)))); - }; + if (FHandle) + shadowErasers[Name] = [=](IRBuilder<> &B, + Value *ToFree) -> llvm::CallInst * { + return cast_or_null(unwrap(FHandle(wrap(&B), wrap(ToFree)))); + }; } void EnzymeRegisterCallHandler(char *Name, diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 7c012506e37..8dfd7ae8104 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -9325,7 +9325,16 @@ llvm::CallInst *freeKnownAllocation(llvm::IRBuilder<> &builder, } if (allocationfn == "julia.gc_alloc_obj" || allocationfn == "jl_gc_alloc_typed" || - allocationfn == "ijl_gc_alloc_typed") + allocationfn == "ijl_gc_alloc_typed" || + allocationfn == "jl_alloc_array_1d" || + allocationfn == "ijl_alloc_array_1d" || + allocationfn == "jl_alloc_array_2d" || + allocationfn == "ijl_alloc_array_2d" || + allocationfn == "jl_alloc_array_3d" || + allocationfn == "ijl_alloc_array_3d" || allocationfn == "jl_new_array" || + allocationfn == "ijl_new_array" || + allocationfn == "jl_alloc_genericmemory" || + allocationfn == "ijl_alloc_genericmemory") return nullptr; if (allocationfn == "enzyme_allocator") { diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index c77d86efc55..16ad4b07172 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -5243,6 +5243,22 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { TypeTree(BaseType::Integer).Only(-1, &call), &call); return; } + if (funcName == "julia.except_enter" || funcName == "ijl_excstack_state" || + funcName == "jl_excstack_state") { + updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); + return; + } + if (funcName == "jl_array_copy" || funcName == "ijl_array_copy" || + funcName == "jl_inactive_inout" || + funcName == "jl_genericmemory_copy_slice" || + funcName == "ijl_genericmemory_copy_slice") { + if (directions & DOWN) + updateAnalysis(&call, getAnalysis(call.getOperand(0)), &call); + if (directions & UP) + updateAnalysis(call.getOperand(0), getAnalysis(&call), &call); + return; + } + if (isAllocationFunction(funcName, TLI)) { size_t Idx = 0; for (auto &Arg : ci->args()) {