From df98f122ad4d37dbb6c27a6732843e63e91cf5a2 Mon Sep 17 00:00:00 2001 From: Jianhua Zheng Date: Thu, 7 Nov 2024 08:56:45 +0000 Subject: [PATCH 1/2] layer_norm_grad npu --- .../autograd/gradient_funcs/layer_norm.cpp | 15 ++---- oneflow/core/functional/functional_api.yaml | 4 +- .../core/functional/impl/nn_grad_functor.cpp | 24 +++++---- oneflow/ir/include/OneFlow/OneFlowUserOps.td | 8 ++- oneflow/user/ops/layer_norm_op.cpp | 50 +++++++++++++++++++ 5 files changed, 75 insertions(+), 26 deletions(-) diff --git a/oneflow/core/autograd/gradient_funcs/layer_norm.cpp b/oneflow/core/autograd/gradient_funcs/layer_norm.cpp index 4a0835247ee..d15d3264359 100644 --- a/oneflow/core/autograd/gradient_funcs/layer_norm.cpp +++ b/oneflow/core/autograd/gradient_funcs/layer_norm.cpp @@ -107,22 +107,13 @@ Maybe LayerNorm::Apply(const LayerNormCaptureState* ctx, const TensorTuple std::shared_ptr mean = saved_tensors.at(ctx->mean_index); std::shared_ptr inv_variance = saved_tensors.at(ctx->inv_variance_index); - if (ctx->has_affine) { - // Use LayerNormParamGrad(Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, - // Int64 begin_params_axis) - const auto& results = - JUST(functional::LayerNormParamGrad(dy, x, mean, inv_variance, begin_params_axis)); - in_grads->at(1) = results->at(0); // For gamma. - in_grads->at(2) = results->at(1); // For beta. - } + CHECK(ctx->has_affine) << "LayerNorm::Apply must has_affine for NPU GPT2 test"; if (ctx->x_requires_grad) { if (ctx->scale) { std::shared_ptr gamma = saved_tensors.at(ctx->gamma_index); - in_grads->at(0) = JUST(functional::LayerNormAffineGrad(dy, x, mean, inv_variance, gamma, - begin_norm_axis, ctx->epsilon)); + *in_grads = *JUST(functional::LayerNormAffineGrad(dy, x, mean, inv_variance, gamma, begin_norm_axis, begin_params_axis, ctx->epsilon)); } else { - in_grads->at(0) = - JUST(functional::LayerNormGrad(dy, x, mean, inv_variance, begin_norm_axis, ctx->epsilon)); + *in_grads = *JUST(functional::LayerNormGrad(dy, x, mean, inv_variance, begin_norm_axis, begin_params_axis, ctx->epsilon)); } } return Maybe::Ok(); diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 8b05bf73a44..ec944aaf0dd 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -1551,11 +1551,11 @@ bind_python: True - name: "layer_norm_grad" - signature: "Tensor (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Int64 begin_norm_axis, Double epsilon) => LayerNormGrad" + signature: "TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Int64 begin_norm_axis, Int64 begin_params_axis, Double epsilon) => LayerNormGrad" bind_python: False - name: "layer_norm_affine_grad" - signature: "Tensor (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Tensor gamma, Int64 begin_norm_axis, Double epsilon) => LayerNormAffineGrad" + signature: "TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Tensor gamma, Int64 begin_norm_axis, Int64 begin_params_axis, Double epsilon) => LayerNormAffineGrad" bind_python: False - name: "layer_norm_param_grad" diff --git a/oneflow/core/functional/impl/nn_grad_functor.cpp b/oneflow/core/functional/impl/nn_grad_functor.cpp index e0cf9e2ff34..c91e93c007f 100644 --- a/oneflow/core/functional/impl/nn_grad_functor.cpp +++ b/oneflow/core/functional/impl/nn_grad_functor.cpp @@ -940,16 +940,18 @@ class LayerNormGradFunctor { .Input("mean") .Input("inv_variance") .Output("dx") + .Output("gamma_diff") + .Output("beta_diff") .Build()); } - Maybe operator()(const std::shared_ptr& dy, + Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x, const std::shared_ptr& mean, const std::shared_ptr& inv_variance, - const int64_t& begin_norm_axis, const double& epsilon) const { - auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("begin_norm_axis", "epsilon"); - attrs.SetAllAttrs(begin_norm_axis, epsilon); - return OpInterpUtil::Dispatch(*op_, {dy, x, mean, inv_variance}, attrs); + const int64_t& begin_norm_axis, const int64_t& begin_params_axis, const double& epsilon) const { + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("begin_norm_axis", "begin_params_axis", "epsilon"); + attrs.SetAllAttrs(begin_norm_axis, begin_params_axis, epsilon); + return OpInterpUtil::Dispatch(*op_, {dy, x, mean, inv_variance}, attrs); } private: @@ -966,17 +968,19 @@ class LayerNormAffineGradFunctor { .Input("inv_variance") .Input("gamma") .Output("dx") + .Output("gamma_diff") + .Output("beta_diff") .Build()); } - Maybe operator()(const std::shared_ptr& dy, + Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x, const std::shared_ptr& mean, const std::shared_ptr& inv_variance, const std::shared_ptr& gamma, - const int64_t& begin_norm_axis, const double& epsilon) const { - auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("begin_norm_axis", "epsilon"); - attrs.SetAllAttrs(begin_norm_axis, epsilon); - return OpInterpUtil::Dispatch(*op_, {dy, x, mean, inv_variance, gamma}, attrs); + const int64_t& begin_norm_axis, const int64_t& begin_params_axis, const double& epsilon) const { + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("begin_norm_axis", "begin_params_axis", "epsilon"); + attrs.SetAllAttrs(begin_norm_axis, begin_params_axis, epsilon); + return OpInterpUtil::Dispatch(*op_, {dy, x, mean, inv_variance, gamma}, attrs); } private: diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index b50a1ceceab..cbdba4fba66 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -7056,14 +7056,18 @@ def OneFlow_LayerNormGradOp : OneFlow_BaseOp<"layer_norm_grad", [NoMemoryEffect, Optional:$_add_to_output ); let output = (outs - OneFlow_Tensor:$dx + OneFlow_Tensor:$dx, + OneFlow_Tensor:$gamma_diff, + OneFlow_Tensor:$beta_diff ); let attrs = (ins DefaultValuedAttr:$begin_norm_axis, + DefaultValuedAttr:$begin_params_axis, DefaultValuedAttr:$epsilon ); let trait_attrs = (ins - DenseI32ArrayAttr:$operand_segment_sizes + DenseI32ArrayAttr:$operand_segment_sizes, + DenseI32ArrayAttr:$result_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; diff --git a/oneflow/user/ops/layer_norm_op.cpp b/oneflow/user/ops/layer_norm_op.cpp index 55eec5bb2a9..2015c35c2b6 100644 --- a/oneflow/user/ops/layer_norm_op.cpp +++ b/oneflow/user/ops/layer_norm_op.cpp @@ -141,6 +141,35 @@ oneflow::DataType InferBnParamDataType(const DataType x_data_type) { const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); CHECK_EQ_OR_RETURN(add_to_output.shape(), dx->shape()); } + + auto has_tensor = [ctx](const std::string& bn) -> bool { + bool ret = false; + for (const auto& t : ctx->inputs()) { + if (bn == t.first) { return true; } + } + for (const auto& t : ctx->outputs()) { + if (bn == t.first) { return true; } + } + return ret; + }; + const int64_t begin_params_axis = ctx->Attr("begin_params_axis"); + const bool has_beta_diff = has_tensor("beta_diff"); + const bool has_gamma_diff = has_tensor("gamma_diff"); + CHECK_GE_OR_RETURN(begin_params_axis, 1); + CHECK_LT_OR_RETURN(begin_params_axis, dy.shape().NumAxes()); + DimVector param_shape_dim_vec; + param_shape_dim_vec.insert(param_shape_dim_vec.end(), + dy.shape().dim_vec().cbegin() + begin_params_axis, + dy.shape().dim_vec().cend()); + const Shape param_shape(param_shape_dim_vec); + if (has_beta_diff) { + user_op::TensorDesc* beta_diff = ctx->MutOutputTensorDesc("beta_diff", 0); + beta_diff->set_shape(param_shape); + } + if (has_gamma_diff) { + user_op::TensorDesc* gamma_diff = ctx->MutOutputTensorDesc("gamma_diff", 0); + gamma_diff->set_shape(param_shape); + } return Maybe::Ok(); } @@ -187,6 +216,27 @@ oneflow::DataType InferBnParamDataType(const DataType x_data_type) { << "InferDataType Failed. Expected " << DataType_Name(dx->data_type()) << ", but got " << DataType_Name(add_to_output.data_type()); } + + auto has_tensor = [ctx](const std::string& bn) -> bool { + bool ret = false; + for (auto& t : ctx->inputs()) { + if (bn == t.first) { return true; } + } + for (auto& t : ctx->outputs()) { + if (bn == t.first) { return true; } + } + return ret; + }; + const bool has_beta_diff = has_tensor("beta_diff"); + const bool has_gamma_diff = has_tensor("gamma_diff"); + if (has_beta_diff) { + user_op::TensorDesc* beta_diff = ctx->MutOutputTensorDesc("beta_diff", 0); + beta_diff->set_data_type(dy.data_type()); + } + if (has_gamma_diff) { + user_op::TensorDesc* gamma_diff = ctx->MutOutputTensorDesc("gamma_diff", 0); + gamma_diff->set_data_type(dy.data_type()); + } return Maybe::Ok(); } From dfe78fb4f91690d5111bb18d7616a54a858e566d Mon Sep 17 00:00:00 2001 From: XIE Xuan Date: Mon, 11 Nov 2024 10:22:35 +0800 Subject: [PATCH 2/2] fix layernorm grad sbp (#10561) Fix SBP settings for LayerNormGradOp to ensure correct gradient aggregation for gamma_diff and beta_diff Changes - Updated SBP strategy in LayerNormGradOp: Set gamma_diff and beta_diff to use PartialSum instead of Split to avoid dimension mismatches during distributed training. - Added consistency check for begin_norm_axis and begin_params_axis: Enforce equality to ensure proper alignment of normalization and parameter dimensions. --- oneflow/user/ops/layer_norm_op.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/oneflow/user/ops/layer_norm_op.cpp b/oneflow/user/ops/layer_norm_op.cpp index 2015c35c2b6..a61001a3d40 100644 --- a/oneflow/user/ops/layer_norm_op.cpp +++ b/oneflow/user/ops/layer_norm_op.cpp @@ -183,10 +183,16 @@ oneflow::DataType InferBnParamDataType(const DataType x_data_type) { broadcast_args.emplace_back(user_op::OpArg("gamma", 0)); } int64_t begin_norm_axis = ctx->Attr("begin_norm_axis"); + int64_t begin_params_axis = ctx->Attr("begin_params_axis"); + CHECK_EQ(begin_norm_axis, begin_params_axis) + << "begin_norm_axis and begin_params_axis must be equal, but got " + << begin_norm_axis << " and " << begin_params_axis; for (int i = 0; i < begin_norm_axis; ++i) { ctx->NewBuilder() .Split(ctx->inputs(), i) - .Split(ctx->outputs(), i) + .Split(user_op::OpArg("dx", 0), i) + .PartialSum(user_op::OpArg("gamma_diff", 0)) + .PartialSum(user_op::OpArg("beta_diff", 0)) .Broadcast(broadcast_args) .Build(); }