Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

layer_norm_grad for npu #10560

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions oneflow/core/autograd/gradient_funcs/layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,22 +107,13 @@ Maybe<void> LayerNorm::Apply(const LayerNormCaptureState* ctx, const TensorTuple
std::shared_ptr<Tensor> mean = saved_tensors.at(ctx->mean_index);
std::shared_ptr<Tensor> inv_variance = saved_tensors.at(ctx->inv_variance_index);

if (ctx->has_affine) {
// Use LayerNormParamGrad(Tensor dy, Tensor x, Tensor mean, Tensor inv_variance,
// Int64 begin_params_axis)
const auto& results =
JUST(functional::LayerNormParamGrad(dy, x, mean, inv_variance, begin_params_axis));
in_grads->at(1) = results->at(0); // For gamma.
in_grads->at(2) = results->at(1); // For beta.
}
CHECK(ctx->has_affine) << "LayerNorm::Apply must has_affine for NPU GPT2 test";
if (ctx->x_requires_grad) {
if (ctx->scale) {
std::shared_ptr<Tensor> gamma = saved_tensors.at(ctx->gamma_index);
in_grads->at(0) = JUST(functional::LayerNormAffineGrad(dy, x, mean, inv_variance, gamma,
begin_norm_axis, ctx->epsilon));
*in_grads = *JUST(functional::LayerNormAffineGrad(dy, x, mean, inv_variance, gamma, begin_norm_axis, begin_params_axis, ctx->epsilon));
} else {
in_grads->at(0) =
JUST(functional::LayerNormGrad(dy, x, mean, inv_variance, begin_norm_axis, ctx->epsilon));
*in_grads = *JUST(functional::LayerNormGrad(dy, x, mean, inv_variance, begin_norm_axis, begin_params_axis, ctx->epsilon));
}
}
return Maybe<void>::Ok();
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1551,11 +1551,11 @@
bind_python: True

- name: "layer_norm_grad"
signature: "Tensor (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Int64 begin_norm_axis, Double epsilon) => LayerNormGrad"
signature: "TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Int64 begin_norm_axis, Int64 begin_params_axis, Double epsilon) => LayerNormGrad"
bind_python: False

- name: "layer_norm_affine_grad"
signature: "Tensor (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Tensor gamma, Int64 begin_norm_axis, Double epsilon) => LayerNormAffineGrad"
signature: "TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Tensor gamma, Int64 begin_norm_axis, Int64 begin_params_axis, Double epsilon) => LayerNormAffineGrad"
bind_python: False

- name: "layer_norm_param_grad"
Expand Down
24 changes: 14 additions & 10 deletions oneflow/core/functional/impl/nn_grad_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -940,16 +940,18 @@ class LayerNormGradFunctor {
.Input("mean")
.Input("inv_variance")
.Output("dx")
.Output("gamma_diff")
.Output("beta_diff")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dy,
const std::shared_ptr<one::Tensor>& x,
const std::shared_ptr<one::Tensor>& mean,
const std::shared_ptr<one::Tensor>& inv_variance,
const int64_t& begin_norm_axis, const double& epsilon) const {
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("begin_norm_axis", "epsilon");
attrs.SetAllAttrs(begin_norm_axis, epsilon);
return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x, mean, inv_variance}, attrs);
const int64_t& begin_norm_axis, const int64_t& begin_params_axis, const double& epsilon) const {
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("begin_norm_axis", "begin_params_axis", "epsilon");
attrs.SetAllAttrs(begin_norm_axis, begin_params_axis, epsilon);
return OpInterpUtil::Dispatch<TensorTuple>(*op_, {dy, x, mean, inv_variance}, attrs);
}

private:
Expand All @@ -966,17 +968,19 @@ class LayerNormAffineGradFunctor {
.Input("inv_variance")
.Input("gamma")
.Output("dx")
.Output("gamma_diff")
.Output("beta_diff")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dy,
const std::shared_ptr<one::Tensor>& x,
const std::shared_ptr<one::Tensor>& mean,
const std::shared_ptr<one::Tensor>& inv_variance,
const std::shared_ptr<one::Tensor>& gamma,
const int64_t& begin_norm_axis, const double& epsilon) const {
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("begin_norm_axis", "epsilon");
attrs.SetAllAttrs(begin_norm_axis, epsilon);
return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x, mean, inv_variance, gamma}, attrs);
const int64_t& begin_norm_axis, const int64_t& begin_params_axis, const double& epsilon) const {
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("begin_norm_axis", "begin_params_axis", "epsilon");
attrs.SetAllAttrs(begin_norm_axis, begin_params_axis, epsilon);
return OpInterpUtil::Dispatch<TensorTuple>(*op_, {dy, x, mean, inv_variance, gamma}, attrs);
}

private:
Expand Down
8 changes: 6 additions & 2 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7056,14 +7056,18 @@ def OneFlow_LayerNormGradOp : OneFlow_BaseOp<"layer_norm_grad", [NoMemoryEffect,
Optional<OneFlow_Tensor>:$_add_to_output
);
let output = (outs
OneFlow_Tensor:$dx
OneFlow_Tensor:$dx,
OneFlow_Tensor:$gamma_diff,
OneFlow_Tensor:$beta_diff
);
let attrs = (ins
DefaultValuedAttr<SI64Attr, "0">:$begin_norm_axis,
DefaultValuedAttr<SI64Attr, "0">:$begin_params_axis,
DefaultValuedAttr<F64Attr, "0.">:$epsilon
);
let trait_attrs = (ins
DenseI32ArrayAttr:$operand_segment_sizes
DenseI32ArrayAttr:$operand_segment_sizes,
DenseI32ArrayAttr:$result_segment_sizes
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
Expand Down
58 changes: 57 additions & 1 deletion oneflow/user/ops/layer_norm_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,35 @@ oneflow::DataType InferBnParamDataType(const DataType x_data_type) {
const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0);
CHECK_EQ_OR_RETURN(add_to_output.shape(), dx->shape());
}

auto has_tensor = [ctx](const std::string& bn) -> bool {
bool ret = false;
for (const auto& t : ctx->inputs()) {
if (bn == t.first) { return true; }
}
for (const auto& t : ctx->outputs()) {
if (bn == t.first) { return true; }
}
return ret;
};
const int64_t begin_params_axis = ctx->Attr<int64_t>("begin_params_axis");
const bool has_beta_diff = has_tensor("beta_diff");
const bool has_gamma_diff = has_tensor("gamma_diff");
CHECK_GE_OR_RETURN(begin_params_axis, 1);
CHECK_LT_OR_RETURN(begin_params_axis, dy.shape().NumAxes());
DimVector param_shape_dim_vec;
param_shape_dim_vec.insert(param_shape_dim_vec.end(),
dy.shape().dim_vec().cbegin() + begin_params_axis,
dy.shape().dim_vec().cend());
const Shape param_shape(param_shape_dim_vec);
if (has_beta_diff) {
user_op::TensorDesc* beta_diff = ctx->MutOutputTensorDesc("beta_diff", 0);
beta_diff->set_shape(param_shape);
}
if (has_gamma_diff) {
user_op::TensorDesc* gamma_diff = ctx->MutOutputTensorDesc("gamma_diff", 0);
gamma_diff->set_shape(param_shape);
}
return Maybe<void>::Ok();
}

Expand All @@ -154,10 +183,16 @@ oneflow::DataType InferBnParamDataType(const DataType x_data_type) {
broadcast_args.emplace_back(user_op::OpArg("gamma", 0));
}
int64_t begin_norm_axis = ctx->Attr<int64_t>("begin_norm_axis");
int64_t begin_params_axis = ctx->Attr<int64_t>("begin_params_axis");
CHECK_EQ(begin_norm_axis, begin_params_axis)
<< "begin_norm_axis and begin_params_axis must be equal, but got "
<< begin_norm_axis << " and " << begin_params_axis;
for (int i = 0; i < begin_norm_axis; ++i) {
ctx->NewBuilder()
.Split(ctx->inputs(), i)
.Split(ctx->outputs(), i)
.Split(user_op::OpArg("dx", 0), i)
.PartialSum(user_op::OpArg("gamma_diff", 0))
.PartialSum(user_op::OpArg("beta_diff", 0))
.Broadcast(broadcast_args)
.Build();
}
Expand Down Expand Up @@ -187,6 +222,27 @@ oneflow::DataType InferBnParamDataType(const DataType x_data_type) {
<< "InferDataType Failed. Expected " << DataType_Name(dx->data_type()) << ", but got "
<< DataType_Name(add_to_output.data_type());
}

auto has_tensor = [ctx](const std::string& bn) -> bool {
bool ret = false;
for (auto& t : ctx->inputs()) {
if (bn == t.first) { return true; }
}
for (auto& t : ctx->outputs()) {
if (bn == t.first) { return true; }
}
return ret;
};
const bool has_beta_diff = has_tensor("beta_diff");
const bool has_gamma_diff = has_tensor("gamma_diff");
if (has_beta_diff) {
user_op::TensorDesc* beta_diff = ctx->MutOutputTensorDesc("beta_diff", 0);
beta_diff->set_data_type(dy.data_type());
}
if (has_gamma_diff) {
user_op::TensorDesc* gamma_diff = ctx->MutOutputTensorDesc("gamma_diff", 0);
gamma_diff->set_data_type(dy.data_type());
}
return Maybe<void>::Ok();
}

Expand Down