diff --git a/CMakeLists.txt b/CMakeLists.txt index 86b2562ca8b..4b9d59c11ed 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -110,6 +110,7 @@ set(CMAKE_TEST_CCXX_FLAGS) # TESTS specifics string(TOUPPER "${CMAKE_BUILD_TYPE}" UPPERCASE_CMAKE_BUILD_TYPE) include("cmake/dnnl_compat.cmake") +include("cmake/mkldnn_compat.cmake") include("cmake/utils.cmake") include("cmake/options.cmake") diff --git a/cmake/TBB.cmake b/cmake/TBB.cmake index 3723d0de305..d2547cdf7a4 100644 --- a/cmake/TBB.cmake +++ b/cmake/TBB.cmake @@ -26,7 +26,10 @@ include("cmake/Threading.cmake") macro(handle_tbb_target) if(TBB_FOUND) set_property(TARGET TBB::tbb PROPERTY "MAP_IMPORTED_CONFIG_RELWITHMDD" "DEBUG") - include_directories_with_host_compiler(${_tbb_include_dirs}) + foreach(inc_dir ${_tbb_include_dirs}) + include_directories(BEFORE SYSTEM ${inc_dir}) + append_host_compiler_options(CMAKE_CXX_FLAGS "-I${inc_dir}") + endforeach() list(APPEND EXTRA_SHARED_LIBS ${TBB_IMPORTED_TARGETS}) # Print TBB location @@ -56,7 +59,7 @@ macro(handle_tbb_target) append_to_windows_path_list(CTESTCONFIG_PATH "${_tbb_redist_dir}") endmacro() -if(NOT DNNL_CPU_THREADING_RUNTIME STREQUAL "TBB") +if(NOT "${DNNL_CPU_THREADING_RUNTIME}" MATCHES "^(TBB|TBB_AUTO)$") return() endif() diff --git a/cmake/Threading.cmake b/cmake/Threading.cmake index b7d458f7f63..aa018745ea1 100644 --- a/cmake/Threading.cmake +++ b/cmake/Threading.cmake @@ -39,13 +39,12 @@ list(APPEND EXTRA_SHARED_LIBS "${CMAKE_THREAD_LIBS_INIT}") # A macro to avoid code duplication macro(find_package_tbb) - set(_cmake_proj_dir "${PROJECT_SOURCE_DIR}/cmake") if(WIN32) - find_package(TBB ${ARGN} COMPONENTS tbb HINTS ${_cmake_proj_dir}/win) + find_package(TBB ${ARGN} COMPONENTS tbb) elseif(APPLE) - find_package(TBB ${ARGN} COMPONENTS tbb HINTS ${_cmake_proj_dir}/mac) + find_package(TBB ${ARGN} COMPONENTS tbb) elseif(UNIX) - find_package(TBB ${ARGN} COMPONENTS tbb HINTS ${_cmake_proj_dir}/lnx) + find_package(TBB ${ARGN} COMPONENTS tbb) endif() if(TBB_FOUND) diff --git a/cmake/gen_mkldnn_compat_cmakes.cmake b/cmake/gen_mkldnn_compat_cmakes.cmake new file mode 100644 index 00000000000..3975f18e562 --- /dev/null +++ b/cmake/gen_mkldnn_compat_cmakes.cmake @@ -0,0 +1,46 @@ +#=============================================================================== +# Copyright 2019-2020 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== + +# Creates cmake config for MKLDNN based on oneDNN one +# (by replacing DNNL with MKLDNN) +# Parameters: +# DIR -- path to cmake install dir + +set(DNNL_DIR ${DIR}/dnnl) +set(MKLDNN_DIR ${DIR}/mkldnn) + +file(MAKE_DIRECTORY ${MKLDNN_DIR}) + +file(GLOB_RECURSE fs "${DNNL_DIR}/*") +foreach(f ${fs}) + # set the destination + file(RELATIVE_PATH frel ${DNNL_DIR} ${f}) + string(REGEX REPLACE "dnnl" "mkldnn" dest_rel "${frel}") + set(dest "${MKLDNN_DIR}/${dest_rel}") + # message(STATUS "file: ${f} --> ${frel} --> ${dest_rel} --> ${dest}") + + # read and change the content of the file + file(STRINGS ${f} contents NEWLINE_CONSUME) + string(REGEX REPLACE "DNNL" "MKLDNN" contents "${contents}") + string(REGEX REPLACE "dnnl" "mkldnn" contents "${contents}") + foreach (ext "a" "so" "dylib" "dll" "lib") + string(REGEX REPLACE "mkldnn[.]${ext}" "dnnl.${ext}" contents "${contents}") + endforeach() + string(REGEX REPLACE "lmkldnn" "ldnnl" contents "${contents}") + + # store the result + file(WRITE ${dest} ${contents}) +endforeach() diff --git a/cmake/lnx/TBBConfig.cmake b/cmake/lnx/TBBConfig.cmake deleted file mode 100644 index bedbff68e39..00000000000 --- a/cmake/lnx/TBBConfig.cmake +++ /dev/null @@ -1,183 +0,0 @@ -#=============================================================================== -# Copyright 2017-2020 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -#=============================================================================== - -# TBB_FOUND should not be set explicitly. It is defined automatically by CMake. -# Handling of TBB_VERSION is in TBBConfigVersion.cmake. - -if (NOT TBB_FIND_COMPONENTS) - set(TBB_FIND_COMPONENTS "tbb;tbbmalloc;tbbmalloc_proxy") - foreach (_tbb_component ${TBB_FIND_COMPONENTS}) - set(TBB_FIND_REQUIRED_${_tbb_component} 1) - endforeach() -endif() - -# Add components with internal dependencies: tbbmalloc_proxy -> tbbmalloc -list(FIND TBB_FIND_COMPONENTS tbbmalloc_proxy _tbbmalloc_proxy_ix) -if (NOT _tbbmalloc_proxy_ix EQUAL -1) - list(FIND TBB_FIND_COMPONENTS tbbmalloc _tbbmalloc_ix) - if (_tbbmalloc_ix EQUAL -1) - list(APPEND TBB_FIND_COMPONENTS tbbmalloc) - set(TBB_FIND_REQUIRED_tbbmalloc ${TBB_FIND_REQUIRED_tbbmalloc_proxy}) - endif() -endif() - -# oneDNN changes: use TBBROOT to locate Intel TBB -# get_filename_component(_tbb_root "${CMAKE_CURRENT_LIST_FILE}" PATH) -# get_filename_component(_tbb_root "${_tbb_root}" PATH) -if (NOT TBBROOT) - if(DEFINED ENV{TBBROOT}) - set (TBBROOT $ENV{TBBROOT}) - endif() -endif() - -set(_tbb_root ${TBBROOT}) - -set(_tbb_x32_subdir ia32) -set(_tbb_x64_subdir intel64) - -if (CMAKE_SIZEOF_VOID_P EQUAL 8) - set(_tbb_arch_subdir ${_tbb_x64_subdir}) -else() - set(_tbb_arch_subdir ${_tbb_x32_subdir}) -endif() - -if (CMAKE_CXX_COMPILER_LOADED) - set(_tbb_compiler_id ${CMAKE_CXX_COMPILER_ID}) - set(_tbb_compiler_ver ${CMAKE_CXX_COMPILER_VERSION}) -elseif (CMAKE_C_COMPILER_LOADED) - set(_tbb_compiler_id ${CMAKE_C_COMPILER_ID}) - set(_tbb_compiler_ver ${CMAKE_C_COMPILER_VERSION}) -endif() - -# For non-GCC compilers try to find version of system GCC to choose right compiler subdirectory. -if (NOT _tbb_compiler_id STREQUAL "GNU") - execute_process(COMMAND gcc --version OUTPUT_VARIABLE _tbb_gcc_ver_output ERROR_QUIET) - string(REGEX REPLACE ".*gcc.* ([0-9]+\\.[0-9]+)\\.[0-9]+.*" "\\1" _tbb_compiler_ver "${_tbb_gcc_ver_output}") - if (NOT _tbb_compiler_ver) - message(FATAL_ERROR "This Intel TBB package is intended to be used only environment with available 'gcc'") - endif() - unset(_tbb_gcc_ver_output) -endif() - -if (EXISTS "${_tbb_root}/lib/${_tbb_arch_subdir}") - set(_tbb_lib ${_tbb_root}/lib/${_tbb_arch_subdir}) - set(_tbb_inc ${_tbb_root}/include) - - file(GLOB _tbb_gcc_versions_available RELATIVE ${_tbb_lib} ${_tbb_lib}/*) - # shall we check _tbb_gcc_versions_available is not empty? - foreach (_tbb_gcc_version ${_tbb_gcc_versions_available}) - string(SUBSTRING ${_tbb_gcc_version} 3 -1 _tbb_gcc_version_number) - if (NOT _tbb_compiler_ver VERSION_LESS _tbb_gcc_version_number) - set(_tbb_compiler_subdir ${_tbb_gcc_version}) - endif() - endforeach() -else() - if (TBBROOT) - set(__tbb_hint_path "${TBBROOT}") - else() - set(__tbb_hint_path "/non/existing/path") - endif() - - # try to find TBB in the system - find_library(_tbb_lib NAMES tbb - HINTS "${__tbb_hint_path}" - PATH_SUFFIXES lib lib64) - find_path(_tbb_inc NAMES tbb.h - HINTS "${__tbb_hint_path}" - PATH_SUFFIXES include tbb include/tbb) - unset(__tbb_hint_path) - - if (NOT _tbb_lib OR NOT _tbb_inc) - message("FATAL_ERROR" "Cannot find TBB") - endif() - - get_filename_component(_tbb_lib "${_tbb_lib}" PATH) - get_filename_component(_tbb_inc "${_tbb_inc}" PATH) - - set(_tbb_arch_subdir "") - set(_tbb_compiler_subdir "") -endif() - -unset(_tbb_gcc_version_number) -unset(_tbb_compiler_id) -unset(_tbb_compiler_ver) - -# Now we check that all the needed component are present -get_filename_component(_tbb_lib_path "${_tbb_lib}/${_tbb_compiler_subdir}" ABSOLUTE) - -if (TBB_FOUND) - return() -endif() - -foreach (_tbb_soversion 2 12) -foreach (_tbb_component ${TBB_FIND_COMPONENTS}) - set(_tbb_release_lib - "${_tbb_lib_path}/lib${_tbb_component}.so.${_tbb_soversion}") - set(_tbb_debug_lib - "${_tbb_lib_path}/lib${_tbb_component}_debug.so.${_tbb_soversion}") - - # oneDNN change: check library existence (BUILD_MODE related only, not both) - string(TOUPPER "${CMAKE_BUILD_TYPE}" UPPERCASE_CMAKE_BUILD_TYPE) - if (UPPERCASE_CMAKE_BUILD_TYPE STREQUAL "DEBUG") - if (EXISTS "${_tbb_debug_lib}") - set(_lib_exists TRUE) - elseif (EXISTS "${_tbb_release_lib}") - message(FATAL_ERROR - "Intel TBB release library is found here: ${_tbb_release_lib}. " - "But the debug library - (lib${_tbb_component}_debug.so.${_tbb_soversion}) is missing.") - endif() - else() - if (EXISTS "${_tbb_release_lib}") - set(_lib_exists TRUE) - endif() - endif() - - if (_lib_exists) - if (NOT TARGET TBB::${_tbb_component}) - add_library(TBB::${_tbb_component} SHARED IMPORTED) - set_target_properties(TBB::${_tbb_component} PROPERTIES - IMPORTED_CONFIGURATIONS "RELEASE;DEBUG" - IMPORTED_LOCATION_RELEASE "${_tbb_release_lib}" - IMPORTED_LOCATION_DEBUG "${_tbb_debug_lib}" - INTERFACE_INCLUDE_DIRECTORIES "${_tbb_inc}") - - # Add internal dependencies for imported targets: TBB::tbbmalloc_proxy -> TBB::tbbmalloc - if (_tbb_component STREQUAL tbbmalloc_proxy) - set_target_properties(TBB::tbbmalloc_proxy PROPERTIES INTERFACE_LINK_LIBRARIES TBB::tbbmalloc) - endif() - - list(APPEND TBB_IMPORTED_TARGETS TBB::${_tbb_component}) - set(TBB_${_tbb_component}_FOUND 1) - endif() - break() - endif() -endforeach() -endforeach() - -if (NOT _lib_exists AND TBB_FIND_REQUIRED AND TBB_FIND_REQUIRED_${_tbb_component}) - message(FATAL_ERROR "Missed required Intel TBB component: ${_tbb_component}") -endif() - -unset(_tbb_x32_subdir) -unset(_tbb_x64_subdir) -unset(_tbb_arch_subdir) -unset(_tbb_compiler_subdir) -unset(_tbbmalloc_proxy_ix) -unset(_tbbmalloc_ix) -unset(_tbb_lib_path) -unset(_tbb_release_lib) -unset(_tbb_debug_lib) diff --git a/cmake/mac/TBBConfig.cmake b/cmake/mac/TBBConfig.cmake deleted file mode 100644 index 7bb9af865e2..00000000000 --- a/cmake/mac/TBBConfig.cmake +++ /dev/null @@ -1,127 +0,0 @@ -#=============================================================================== -# Copyright 2017-2020 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -#=============================================================================== - -# TBB_FOUND should not be set explicitly. It is defined automatically by CMake. -# Handling of TBB_VERSION is in TBBConfigVersion.cmake. - -if (NOT TBB_FIND_COMPONENTS) - set(TBB_FIND_COMPONENTS "tbb;tbbmalloc;tbbmalloc_proxy") - foreach (_tbb_component ${TBB_FIND_COMPONENTS}) - set(TBB_FIND_REQUIRED_${_tbb_component} 1) - endforeach() -endif() - -# Add components with internal dependencies: tbbmalloc_proxy -> tbbmalloc -list(FIND TBB_FIND_COMPONENTS tbbmalloc_proxy _tbbmalloc_proxy_ix) -if (NOT _tbbmalloc_proxy_ix EQUAL -1) - list(FIND TBB_FIND_COMPONENTS tbbmalloc _tbbmalloc_ix) - if (_tbbmalloc_ix EQUAL -1) - list(APPEND TBB_FIND_COMPONENTS tbbmalloc) - set(TBB_FIND_REQUIRED_tbbmalloc ${TBB_FIND_REQUIRED_tbbmalloc_proxy}) - endif() -endif() - -# oneDNN changes: use TBBROOT to locate Intel TBB -# get_filename_component(_tbb_root "${CMAKE_CURRENT_LIST_FILE}" PATH) -# get_filename_component(_tbb_root "${_tbb_root}" PATH) -if (NOT TBBROOT) - if(DEFINED ENV{TBBROOT}) - set (TBBROOT $ENV{TBBROOT}) - else() - message("FATAL_ERROR" "TBBROOT is unset") - endif() -endif() - -set(_tbb_root ${TBBROOT}) - -set(_tbb_x32_subdir .) -set(_tbb_x64_subdir .) - -if (CMAKE_SIZEOF_VOID_P EQUAL 8) - set(_tbb_arch_subdir ${_tbb_x64_subdir}) -else() - set(_tbb_arch_subdir ${_tbb_x32_subdir}) -endif() - -set(_tbb_compiler_subdir .) - -get_filename_component(_tbb_lib_path "${_tbb_root}/lib/${_tbb_arch_subdir}/${_tbb_compiler_subdir}" ABSOLUTE) - -if (TBB_FOUND) - return() -endif() - -foreach (_tbb_lib_version .12 "") -foreach (_tbb_component ${TBB_FIND_COMPONENTS}) - set(_tbb_release_lib "${_tbb_lib_path}/lib${_tbb_component}${_tbb_lib_version}.dylib") - set(_tbb_debug_lib "${_tbb_lib_path}/lib${_tbb_component}_debug${_tbb_lib_version}.dylib") - - # oneDNN change: check library existence (BUILD_MODE related only, not both) - string(TOUPPER "${CMAKE_BUILD_TYPE}" UPPERCASE_CMAKE_BUILD_TYPE) - if (UPPERCASE_CMAKE_BUILD_TYPE STREQUAL "DEBUG") - if (EXISTS "${_tbb_debug_lib}") - set(_lib_exists TRUE) - elseif (EXISTS "${_tbb_release_lib}") - message(FATAL_ERROR - "Intel TBB release library is found here: ${_tbb_release_lib}. " - "But the debug library - (lib${_tbb_component}_debug${_tbb_lib_version}.dylib) is missing.") - endif() - else() - if (EXISTS "${_tbb_release_lib}") - set(_lib_exists TRUE) - endif() - endif() - - if (_lib_exists) - if (NOT TARGET TBB::${_tbb_component}) - add_library(TBB::${_tbb_component} SHARED IMPORTED) - set_target_properties(TBB::${_tbb_component} PROPERTIES - IMPORTED_CONFIGURATIONS "RELEASE;DEBUG" - IMPORTED_LOCATION_RELEASE "${_tbb_release_lib}" - IMPORTED_LOCATION_DEBUG "${_tbb_debug_lib}" - INTERFACE_INCLUDE_DIRECTORIES "${_tbb_root}/include") - - # Add internal dependencies for imported targets: TBB::tbbmalloc_proxy -> TBB::tbbmalloc - if (_tbb_component STREQUAL tbbmalloc_proxy) - set_target_properties(TBB::tbbmalloc_proxy PROPERTIES INTERFACE_LINK_LIBRARIES TBB::tbbmalloc) - endif() - - list(APPEND TBB_IMPORTED_TARGETS TBB::${_tbb_component}) - set(TBB_${_tbb_component}_FOUND 1) - endif() - break() - endif() -endforeach() -endforeach() - -foreach (_tbb_component ${TBB_FIND_COMPONENTS}) - if (NOT TARGET TBB::${_tbb_component} AND TBB_FIND_REQUIRED AND TBB_FIND_REQUIRED_${_tbb_component}) - message(FATAL_ERROR "Missed required Intel TBB component: ${_tbb_component}") - endif() -endforeach() - -unset(_tbb_x32_subdir) -unset(_tbb_x64_subdir) -unset(_tbb_arch_subdir) -unset(_tbb_compiler_subdir) -unset(_tbbmalloc_proxy_ix) -unset(_tbbmalloc_ix) -unset(_tbb_lib_path) -unset(_tbb_release_lib) -unset(_tbb_debug_lib) -unset(_tbb_lib_version) -unset(_lib_exists) diff --git a/cmake/mkldnn_compat.cmake b/cmake/mkldnn_compat.cmake new file mode 100644 index 00000000000..7bbce23f35b --- /dev/null +++ b/cmake/mkldnn_compat.cmake @@ -0,0 +1,88 @@ +#=============================================================================== +# Copyright 2019-2020 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== + +# Provides compatibility with Intel MKL-DNN build options +#=============================================================================== + +# Sets if DNNL_* var is unset, copy the value from corresponding MKLDNN_* var +macro(mkldnn_compat_var dnnl_var mkldnn_var props) + if (DEFINED ${mkldnn_var} AND NOT DEFINED ${dnnl_var}) + if ("${props}" STREQUAL "CACHE STRING") + set(${dnnl_var} "${${mkldnn_var}}" CACHE STRING "" FORCE) + elseif ("${props}" STREQUAL "CACHE BOOL") + set(${dnnl_var} "${${mkldnn_var}}" CACHE BOOL "" FORCE) + else() + set(${dnnl_var} "${${mkldnn_var}}") + endif() + message(STATUS "Intel MKL-DNN compat: " + "set ${dnnl_var} to ${mkldnn_var} with value `${${dnnl_var}}`") + endif() +endmacro() + +set(COMPAT_CACHE_BOOL_VARS + "VERBOSE" + "ENABLE_CONCURRENT_EXEC" + "BUILD_EXAMPLES" + "BUILD_TESTS" + "BUILD_FOR_CI" + "WERROR" + "ENABLE_JIT_PROFILING" + ) + +set(COMPAT_CACHE_STRING_VARS + "LIBRARY_TYPE" + "INSTALL_MODE" + "ARCH_OPT_FLAGS" + "CPU_RUNTIME" + "GPU_RUNTIME" + "USE_CLANG_SANITIZER" + ) + +# Map MKLDNN_ to DNNL_ options + +foreach (var ${COMPAT_CACHE_BOOL_VARS}) + mkldnn_compat_var("DNNL_${var}" "MKLDNN_${var}" "CACHE BOOL") +endforeach() +mkldnn_compat_var(_DNNL_USE_MKL _MKLDNN_USE_MKL "CACHE BOOL") + +foreach (var ${COMPAT_CACHE_STRING_VARS}) + mkldnn_compat_var("DNNL_${var}" "MKLDNN_${var}" "CACHE STRING") +endforeach() + +# Handle legacy options: MKLDNN_THREADING and MKLDNN_GPU_BACKEND. + +if(MKLDNN_THREADING) + set(DNNL_CPU_RUNTIME "${DNNL_THREADING}" CACHE STRING "" FORCE) + message(STATUS "Using the obsolete way to specify the CPU runtime. " + "Use DNNL_CPU_RUNTIME=${DNNL_CPU_RUNTIME} instead.") +endif() + +if(MKLDNN_GPU_BACKEND) + if (MKLDNN_GPU_BACKEND STREQUAL "OPENCL") + set(MKLDNN_GPU_BACKEND "OCL" CACHE STRING "" FORCE) + message(STATUS "Using the obsolete way to specify the OpenCL runtime. " + "Use DNNL_GPU_RUNTIME=OCL instead.") + endif() + set(DNNL_GPU_RUNTIME "${MKLDNN_GPU_BACKEND}" CACHE STRING "" FORCE) + message(STATUS "Using the obsolete way to specify the GPU runtime. " + "Use DNNL_GPU_RUNTME=${DNNL_GPU_RUNTIME} instead.") +endif() + +if (MKLDNN_GPU_RUNTIME STREQUAL "OPENCL") + set(DNNL_GPU_RUNTIME "OCL" CACHE STRING "" FORCE) + message(STATUS "Using the obsolete way to specify the OpenCL runtime. " + "Use DNNL_GPU_RUNTIME=OCL instead.") +endif() diff --git a/cmake/options.cmake b/cmake/options.cmake index bd62e6b4142..cf965d91a4d 100644 --- a/cmake/options.cmake +++ b/cmake/options.cmake @@ -184,7 +184,7 @@ set(DNNL_CPU_RUNTIME "OMP" CACHE STRING To use Threading Building Blocks (TBB) one should also set TBBROOT (either environment variable or CMake option) to the library location.") -if(NOT "${DNNL_CPU_RUNTIME}" MATCHES "^(NONE|OMP|TBB|SEQ|THREADPOOL|DPCPP|SYCL)$") +if(NOT "${DNNL_CPU_RUNTIME}" MATCHES "^(NONE|OMP|TBB|TBB_AUTO|SEQ|THREADPOOL|DPCPP|SYCL)$") message(FATAL_ERROR "Unsupported CPU runtime: ${DNNL_CPU_RUNTIME}") endif() diff --git a/cmake/platform.cmake b/cmake/platform.cmake index 56ba84ac569..b8220cabf41 100644 --- a/cmake/platform.cmake +++ b/cmake/platform.cmake @@ -45,6 +45,11 @@ if($ENV{ONEDNN_WERROR}) set(DNNL_WERROR $ENV{ONEDNN_WERROR}) endif() +# Compatibility with Intel MKL-DNN +if($ENV{MKLDNN_WERROR}) + set(DNNL_WERROR $ENV{MKLDNN_WERROR}) +endif() + if($ENV{DNNL_WERROR}) set(DNNL_WERROR $ENV{DNNL_WERROR}) endif() diff --git a/cmake/win/TBBConfig.cmake b/cmake/win/TBBConfig.cmake deleted file mode 100644 index 623147f53ac..00000000000 --- a/cmake/win/TBBConfig.cmake +++ /dev/null @@ -1,164 +0,0 @@ -#=============================================================================== -# Copyright 2017-2021 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -#=============================================================================== - -# TBB_FOUND should not be set explicitly. It is defined automatically by CMake. -# Handling of TBB_VERSION is in TBBConfigVersion.cmake. - -if (NOT TBB_FIND_COMPONENTS) - set(TBB_FIND_COMPONENTS "tbb;tbbmalloc;tbbmalloc_proxy") - foreach (_tbb_component ${TBB_FIND_COMPONENTS}) - set(TBB_FIND_REQUIRED_${_tbb_component} 1) - endforeach() -endif() - -# Add components with internal dependencies: tbbmalloc_proxy -> tbbmalloc -list(FIND TBB_FIND_COMPONENTS tbbmalloc_proxy _tbbmalloc_proxy_ix) -if (NOT _tbbmalloc_proxy_ix EQUAL -1) - list(FIND TBB_FIND_COMPONENTS tbbmalloc _tbbmalloc_ix) - if (_tbbmalloc_ix EQUAL -1) - list(APPEND TBB_FIND_COMPONENTS tbbmalloc) - set(TBB_FIND_REQUIRED_tbbmalloc ${TBB_FIND_REQUIRED_tbbmalloc_proxy}) - endif() -endif() - -# oneDNN changes: use TBBROOT to locate Intel TBB -# get_filename_component(_tbb_root "${CMAKE_CURRENT_LIST_FILE}" PATH) -# get_filename_component(_tbb_root "${_tbb_root}" PATH) -if (NOT TBBROOT) - if(DEFINED ENV{TBBROOT}) - set (TBBROOT $ENV{TBBROOT}) - else() - message("FATAL_ERROR" "TBBROOT is unset") - endif() -endif() - -set(_tbb_root ${TBBROOT}) - -set(_tbb_x32_subdir ia32) -set(_tbb_x64_subdir intel64) - -if (CMAKE_SIZEOF_VOID_P EQUAL 8) - set(_tbb_arch_subdir ${_tbb_x64_subdir}) -else() - set(_tbb_arch_subdir ${_tbb_x32_subdir}) -endif() - -# Workaround: 3.19.0 and 3.19.1 versions don't define MSVC_VERSION. -# The workaround is to assume that vc14 is used. -set(_tbb_detect_msvc_version FALSE) -if (NOT ${CMAKE_VERSION} VERSION_EQUAL "3.19.0" AND NOT ${CMAKE_VERSION} VERSION_EQUAL "3.19.1") - set(_tbb_detect_msvc_version TRUE) -endif() - -# Detect the most relevant MSVC subdirectory -set(_tbb_msvc_1700_subdir vc11) -set(_tbb_msvc_1800_subdir vc12) -set(_tbb_msvc_1900_subdir vc14) - -# oneDNN changes: if the project is not with MSVC, try to use MSVC 1900 -set(_tbb_msvc_ver 1900) - -if (_tbb_detect_msvc_version) - if (MSVC) - set(_tbb_msvc_ver ${MSVC_VERSION}) - endif() - if (MSVC_VERSION VERSION_LESS 1700) - message(FATAL_ERROR "This Intel TBB package is intended to be used only in the project with MSVC version 1700 (vc11) or higher") - elseif (MSVC_VERSION VERSION_GREATER 1900) - set(_tbb_msvc_ver 1900) - endif() -endif() -set(_tbb_compiler_subdir ${_tbb_msvc_${_tbb_msvc_ver}_subdir}) -unset(_tbb_msvc_1700_subdir) -unset(_tbb_msvc_1800_subdir) -unset(_tbb_msvc_1900_subdir) - -if (WINDOWS_STORE) - set(_tbb_compiler_subdir ${_tbb_compiler_subdir}_ui) -endif() - -#set conveniance variable to locate TBB files (these are used for a PSXE install) -get_filename_component(_tbb_lib_path "${_tbb_root}/lib/${_tbb_arch_subdir}/${_tbb_compiler_subdir}" ABSOLUTE) -get_filename_component(_tbb_inc_path "${_tbb_root}/include/" ABSOLUTE) - -if (TBB_FOUND) - return() -endif() - -foreach (_tbb_lib_version 12 "") -foreach (_tbb_component ${TBB_FIND_COMPONENTS}) - set(_tbb_release_lib "${_tbb_lib_path}/${_tbb_component}${_tbb_lib_version}.lib") - set(_tbb_debug_lib "${_tbb_lib_path}/${_tbb_component}${_tbb_lib_version}_debug.lib") - - # oneDNN change: check library existence (BUILD_MODE related only, not both) - string(TOUPPER "${CMAKE_BUILD_TYPE}" UPPERCASE_CMAKE_BUILD_TYPE) - if (UPPERCASE_CMAKE_BUILD_TYPE STREQUAL "DEBUG") - if (EXISTS "${_tbb_debug_lib}") - set(_lib_exists TRUE) - elseif (EXISTS "${_tbb_release_lib}") - message(FATAL_ERROR - "Intel TBB release library is found here: ${_tbb_release_lib}. " - "But the debug library - (lib${_tbb_component}${tbb_lib_version}_debug.lib) is missing.") - endif() - else() - if (EXISTS "${_tbb_release_lib}") - set(_lib_exists TRUE) - endif() - endif() - - if (_lib_exists) - if (NOT TARGET TBB::${_tbb_component}) - add_library(TBB::${_tbb_component} SHARED IMPORTED) - set_target_properties(TBB::${_tbb_component} PROPERTIES - IMPORTED_CONFIGURATIONS "RELEASE;DEBUG" - IMPORTED_LOCATION_RELEASE "${_tbb_release_lib}" - IMPORTED_LOCATION_DEBUG "${_tbb_debug_lib}" - INTERFACE_INCLUDE_DIRECTORIES "${_tbb_inc_path}" - IMPORTED_IMPLIB_RELEASE "${_tbb_release_lib}" - IMPORTED_IMPLIB_DEBUG "${_tbb_debug_lib}" - INTERFACE_COMPILE_DEFINITIONS "__TBB_NO_IMPLICIT_LINKAGE=1") - - # Add internal dependencies for imported targets: TBB::tbbmalloc_proxy -> TBB::tbbmalloc - if (_tbb_component STREQUAL tbbmalloc_proxy) - set_target_properties(TBB::tbbmalloc_proxy PROPERTIES INTERFACE_LINK_LIBRARIES TBB::tbbmalloc) - endif() - - list(APPEND TBB_IMPORTED_TARGETS TBB::${_tbb_component}) - set(TBB_${_tbb_component}_FOUND 1) - endif() - break() - endif() -endforeach() -endforeach() - -foreach (_tbb_component ${TBB_FIND_COMPONENTS}) - if (NOT TARGET TBB::${_tbb_component} AND TBB_FIND_REQUIRED AND TBB_FIND_REQUIRED_${_tbb_component}) - message(FATAL_ERROR "Missed required Intel TBB component: ${_tbb_component}") - endif() -endforeach() - -unset(_tbb_x32_subdir) -unset(_tbb_x64_subdir) -unset(_tbb_arch_subdir) -unset(_tbb_compiler_subdir) -unset(_tbbmalloc_proxy_ix) -unset(_tbbmalloc_ix) -unset(_tbb_lib_path) -unset(_tbb_release_lib) -unset(_tbb_debug_lib) -unset(_tbb_lib_version) -unset(_lib_exists) diff --git a/doc/advanced/transition-to-v1.md b/doc/advanced/transition-to-v1.md new file mode 100644 index 00000000000..3637fb999ce --- /dev/null +++ b/doc/advanced/transition-to-v1.md @@ -0,0 +1,550 @@ +Transition from v0.x to v1.x {#dev_guide_transition_to_v1} +========================================================== + +> **NOTE** +> +> Starting with version 1.4 +> Intel(R) Math Kernel Library for Deep Neural Networks (Intel(R) MKL-DNN) +> is renamed to oneAPI Deep Neural Network Library (oneDNN). +> For consistency, only this guide uses Intel MKL-DNN nomenclature. + +## Introduction + +This article describes user-visible and some important internal changes to +Intel MKL-DNN that occurred between v0.20 and v1.0. + +The v0.x branch ([mnt-v0](https://github.com/oneapi-src/oneDNN/tree/mnt-v0)) is +deprecated and users are strongly encouraged to migrate to +[v1.x](https://github.com/oneapi-src/oneDNN). + +@sa +Discussion on the API changes occurred in PR #384: +[RFC: API changes for the upcoming v1.0](https://github.com/oneapi-src/oneDNN/pull/384). + +## Summary of Changes + +We tried to keep changes minimal to make migration as simple as possible. In +particular, the Intel MKL-DNN programming model stays the same. Nevertheless, +the new version brings a lot of incompatible changes requiring developers to +revisit significant portions of the integrated code. + +All changes can be split into the following groups: +1. Minor API changes +2. Improving the library robustness +3. Simplified execution model +4. Changes in memory description +5. Changes in the build system + +These groups are discussed in detail below. + +## 1. Minor API Changes + +### 1.1. Remove deprecated functionality + +| Deprecated functionality | Replacement +| :--- | :--- +| ReLU primitive | [Eltwise](@ref dnnl::eltwise_forward) with algorithm kind [ReLU](@ref dnnl::algorithm::eltwise_relu) +| ConvolutionReLU (single primitive) | Convolution with ReLU as a [post operation](@ref dev_guide_attributes_post_ops) +| Double precision scales | Single precision scales +| RNN backward pd w/o forward pd hint | RNN backward pd w/ forward pd hint +| `mkldnn_omit_stats` batch norm. flag | `mkldnn_use_global_stats` +| `mkldnn_eltwise_desc_t.negative_slope` | `mkldnn_eltwise_desc_t.alpha` +| `mkldnn_rnn_cell_flags_t` | Not available anymore -- RNN primitives are separated into RNN, LSTM, and GRU +| `mkldnn_padding_kind_t` | Not used anymore + +The complete list of the removed C functions: +~~~cpp + mkldnn_relu_forward_desc_init(...); + mkldnn_relu_backward_desc_init(...); + mkldnn_convolution_relu_desc_init(...); + mkldnn_rnn_cell_desc_init(...); + mkldnn_rnn_cell_get_gates_count(...); + mkldnn_rnn_cell_get_states_count(...); + mkldnn_rnn_forward_desc_init(...); + mkldnn_rnn_backward_desc_init(...); +~~~ + +The complete list of the removed C++ classes and functions: +~~~cpp + struct mkldnn::convolution_relu_forward {} + struct mkldnn::relu_forward {} + struct mkldnn::relu_backward {} + struct mkldnn::rnn_cell {} + struct mkldnn::rnn_forward {} + struct mkldnn::rnn_backward {} + + mkldnn::sum::primitive_desc(const memory::desc &output, std::vector scale, std::vector inputs); + mkldnn::sum::primitive_desc(std::vector scale, std::vector inputs); + mkldnn::eltwise_forward::desc(prop_kind aprop_kind, const memory::desc &src_desc, T negative_slope); + mkldnn::eltwise_backward::desc(const memory::desc &diff_data_desc, const memory::desc &data_desc, T negative_slope); +~~~ + +### 1.2. Rename `foo_v2()` to `foo()` and remove old `foo()` (C API only) + +The functions like: +~~~cpp + mkldnn_primitive_desc_create_v2(...); +~~~ +were renamed to: +~~~cpp + mkldnn_primitive_desc_create(...); +~~~ + +In v0.x, the `foo_v2()` functions typically were used to pass +[attributes](@ref dev_guide_attributes), and `foo()` assumed empty attributes. +In v1.0, the attributes parameter is mandatory. A user can still pass `NULL` to +indicate that the default (empty) attributes should be used. + +The list of functions that had the `_v2` suffix: + +~~~cpp + mkldnn_primitive_desc_iterator_create_v2(...); + mkldnn_primitive_desc_create_v2(...); + mkldnn_reorder_primitive_desc_create_v2(...); +~~~ + +### 1.3. Remove s16 (int16_t) data type support + +The experimental `s16` data type is not supported any more and has been dropped. + +### 1.4. Disallow setting the rounding mode + +Rounding mode that was a part of attributes has been dropped. All computations +respect the MXCSR register when performing rounding. Unless the rounding mode is +set explicitly, rounding to the nearest even integer (RNE) is used. + +### 1.5. Rename a few types, enumerations, and functions + +#### 1.5.1. Types + +| API | v0.x | v1.0 +| :-- | :-- | :-- +| C | mkldnn_batch_normalization_flag_t | [mkldnn_normalization_flags_t](@ref dnnl_normalization_flags_t) +| C | mkldnn_format_t | [mkldnn_format_tag_t](@ref dnnl_format_tag_t) +| C++ | mkldnn::batch_normalization_flag | [mkldnn::normalization_flags](@ref dnnl::normalization_flags::use_global_stats) +| C++ | mkldnn::memory::format | [mkldnn::memory::format_tag](@ref dnnl::memory::format_tag) + +#### 1.5.2. Enumerations + +| API | v0.x | v1.0 +| :-- | :-- | :-- +| C | mkldnn_fuse_bn_relu | [mkldnn_fuse_norm_relu](@ref dnnl_fuse_norm_relu) +| C++ | mkldnn::fuse_bn_relu | [mkldnn::normalization_flags::fuse_norm_relu](@ref dnnl::normalization_flags::fuse_norm_relu) +| C++ | mkldnn::query::eengine | [mkldnn::query::engine](@ref dnnl::query::engine) + +#### 1.5.3. Functions + +| API | v0.x | v1.0 +| :-- | :-- | :-- +| C | mkldnn_memory_desc_init() | [mkldnn_memory_desc_init_by_tag()](@ref dnnl_memory_desc_init_by_tag) + +### 1.6. Unscoped enumerations become scoped (C++ API only) + +All `enum` became `enum class`. This requires the following changes: + +| Type | Value in v0.x | Value in v1.0 +| :-- | :-- | :-- +| mkldnn::prop_kind | mkldnn::forward_inference | [mkldnn::prop_kind::forward_inference](@ref dnnl::prop_kind::forward_inference) +| mkldnn::algorithm | mkldnn::eltwise_tanh | [mkldnn::algorithm::eltwise_tanh](@ref dnnl::algorithm::eltwise_tanh) +| mkldnn::normalization_flags | mkldnn::fuse_bn_norm_relu | [mkldnn::normalization_flags::fuse_norm_relu](@ref dnnl::normalization_flags::fuse_norm_relu) +| mkldnn::query | mkldnn::eengine | [mkldnn::query::engine](@ref dnnl::query::engine) +| mkldnn::memory::data_type | mkldnn::memory::f32 | [mkldnn::memory::data_type::f32](@ref dnnl::memory::data_type::f32) +| mkldnn::memory::format_tag | mkldnn::memory::nchw | [mkldnn::memory::format_tag::nchw](@ref dnnl::memory::format_tag::nchw) + +### 1.7. Remove view primitive + +Version 0.x had an implementation of view that was simply an alias for memory. +In Intel MKL-DNN v1.0, we removed view as a type and replaced it with a +memory descriptor directly. In order to initialize sub-memory, use +[mkldnn::memory::desc::submemory_desc()](@ref dnnl::memory::desc::submemory_desc()). + +@sa +For more detail, refer to section +[4. View rework](https://github.com/oneapi-src/oneDNN/tree/rfc-api-changes-v1.0/doc/rfc/api-v1.0#4-view-rework) +of the [RFC for v1.0](https://github.com/oneapi-src/oneDNN/pull/384). + +### 1.8. RNN-specific changes + +Each type of [RNN](@ref dnnl_api_rnn) (Vanilla RNN, LSTM, and two types of GRU) +is now initialized by a separate function/operation descriptor constructor. + +For instance, instead of using mkldnn::rnn_forward with specified RNN types +a user is expected to use: +- [mkldnn::vanilla_rnn_forward](@ref dnnl::vanilla_rnn_forward) for Vanilla RNN +- [mkldnn::lstm_forward](@ref dnnl::lstm_forward) for LSTM +- [mkldnn::gru_forward](@ref dnnl::gru_forward) for GRU +- [mkldnn::lbr_gru_forward](@ref dnnl::lbr_gru_forward) for the linear-before-reset variant of GRU + +Also, the hidden and cell states in LSTM are now separated. This means that +instead of one `src_iter` tensor of shape +`(layers, directions, states, batch, channels)` a user passes +`src_iter` tensor of shape `(layers, directions, batch, channels)` for hidden +states and +`src_iter_c` tensor of shape `(layers, directions, batch, channels)` for cell +states. +The same applies to `dst_iter`; the hidden state and the cell state are split +into `dst_iter` and `dst_iter_c` respectively. + +### 1.9. GEMM API changes + +Intel MKL-DNN provides three GEMM-like functions: +- [mkldnn_sgemm()](@ref dnnl_sgemm) -- Single precision matrix-matrix multiply +- [mkldnn_gemm_u8s8s32()](@ref dnnl_gemm_u8s8s32) -- u8/s8 integer matrix-matrix multiply +- [mkldnn_gemm_s8s8s32()](@ref dnnl_gemm_s8s8s32) -- s8/s8 integer matrix-matrix multiply + +With version 1.0 we switched from a Fortran-style to a C-style API, meaning that +the parameters are passed by value rather than by address, and matrices are +assumed to be in row-major format rather than column-major format. + +Moreover, to broaden the applicability of integer matrix-matrix multiply +functions we changed the formula from: +\f[ + C_{s32} = + \alpha + \cdot + (op(A_{i8}) + o_A) \cdot (op(B_{s8}) + o_B) + + \beta \cdot C_{s32} + + o_C +\f] +to +\f[ + C_{s32} = + \alpha + \cdot + (op(A_{i8}) - o_A) \cdot (op(B_{s8}) - o_B) + + \beta \cdot C_{s32} + + o_C +\f] + +where for both [mkldnn_gemm_u8s8s32()](@ref dnnl_gemm_u8s8s32) and +[mkldnn_gemm_s8s8s32()](@ref dnnl_gemm_s8s8s32) the types of +offsets for matrices A and B correspond to the type of the matrices themselves; +that is: +- `typeof(o_A) == typeof(*A)` and +- `typeof(o_B) == typeof(*B)`. + +### 1.10. Primitive descriptor queries for memory descriptors + +In version 0.x when querying the primitive descriptor for a memory descriptor +that is not used, the C API returned NULL and the C++ API threw an exception. In +version 1.0, both the C and C++ APIs return a zero memory descriptor. + +Zero memory descriptor means that the number of dimensions equals 0 and all the +fields are set to zero. A memory object created with such a memory descriptor +does not require any buffer allocations. + +These changes enable simplifying the code that handles +[workspace](@ref dev_guide_inference_and_training_aspects_workspace) or +[scratchpad](@ref dev_guide_attributes_scratchpad): + +~~~cpp + // The code works fine even if scratchpad is not required. +    // In this case the memory would be just zero memory. + +    auto scratchpad_md = pd.scratchpad_desc(); +    auto scratchpad = memory(scratchpad_md, pd.get_engine()); + +    primitive.execute(stream, { + ..., + {MKLDNN_SCRATCHPAD, scratchpad}}; +~~~ + +### 1.11. Default constructors for C++ classes (C++ API only) + +In Intel MKL-DNN v1.0, all C++ objects (primitives, memory objects, engines, +and streams) now have default empty constructors. This enables defining the +object, and then initializing it later on. An attempt to use any methods of an +uninitialized object will result in the throwing of an exception. + +This improvement can be especially useful when Intel MKL-DNN objects are +members of the user's classes. For example: + +~~~cpp + class RELU_layer { + public: + RELU_layer() {} // no need to initialize eltwise here + + void init() { + ... + // deferred initialization + eltwise = eltwise_forward(...); + } + + private: + eltwise_forward eltwise; + }; +~~~ + +## 2. Improving the Library Robustness + +### 2.1. Memory allocation in the C API + +In Intel MKL-DNN v1.0, constructing a memory object using special value +`MKLDNN_MEMORY_ALLOCATE` for a handle results in the buffer being allocated by +the library. This makes the behavior of the C API memory object constructor +aligned with its C++ API `mkldnn::memory` counterpart. Note that the C++ API +memory object class still has an extra constructor that does not take a handle +at all, and asks the library to allocate the buffer (that is, the same behavior +as calling with the handle equal to `MKLDNN_MEMORY_ALLOCATE`). + +### 2.2. Explicit scratchpad management + +Intel MKL-DNN primitives may require temporary +[scratchpad memory](@ref dev_guide_attributes_scratchpad) for storing +intermediate computational results. For instance, convolution backward by +weights typically requires extra space to perform a reduction of the +`diff_weights` computed by different threads (the work is divided across +images). Starting with version 1.0, the library supports two modes: +1. Implicit scratchpad, managed by the library (**default**). + See [mkldnn::scratchpad_mode::library](#dnnl::scratchpad_mode::library). +2. Explicit scratchpad, provided by the user. + See [mkldnn::scratchpad_mode::user](#dnnl::scratchpad_mode::user). + +The former mode matches the behavior of Intel MKL-DNN v0.x. It is kept for +user convenience and cases in which memory is not a concern. + +In the explicit scratchpad mode, a new `mkldnn_query_scratchpad_md` query will +return the amount of scratchpad memory needed for a primitive, and the user +will be responsible for allocating and providing the scratchpad memory to a +primitive at runtime. The explicit scratchpad mode should be *explicitly* +enabled by passing an attribute with `mkldnn::scratchpad_mode::user` to +primitive descriptors. + +@warning +[Scratchpad](@ref dev_guide_attributes_scratchpad) memory is not the same as +[workspace](@ref dev_guide_inference_and_training_aspects_workspace). + +With explicit scratchpad it is possible to make Intel MKL-DNN primitives +stateless and hence thread safe: the same primitive can be executed in multiple +independent threads as long as different threads use different scratchpads. + +However, if a user chooses implicit scratchpad mode, there is no thread-safety +guarantee. + +## 3. Simplified Execution Model + +This is the most notable change in the library. The main idea was to change the +execution API so that memory arguments are specified at primitive execution time +and not at primitive creation time. This leads to the following changes. + +### 3.1. Memory is not a primitive anymore + +In version 0.x, memory had a type of primitive. With the new API, memory becomes +a distinct data type. Moreover, a memory primitive descriptor becomes redundant +and has been dropped. The functions that use memory primitive descriptors now +take memory descriptor and (optionally) engine, if the latter cannot be +inferred. + +These changes bring new data types and functions, such as: + +~~~cpp + #define MKLDNN_NATIVE_HANDLE_ALLOCATE ((void *)-1) + #define MKLDNN_NATIVE_HANDLE_NONE ((void *)0) + + struct mkldnn_memory_t; // memory type, no more equal to mkldnn_primitive_t + + // create a memory + // native_handle can: + // - point to the user allocated memory, i.e. valid handle. In this case the + // library does not own allocated memory. + // - be MKLDNN_NATIVE_HANDLE_ALLOCATE to ask the library to allocate and + // attach memory. In this case the library owns allocated memory. + // - be MKLDNN_NATIVE_HANDLE_NONE to create mkldnn_memory w/o attached memory. + mkldnn_status_t mkldnn_memory_create(mkldnn_memory_t *mem, + const mkldnn_memory_desc_t *md, mkldnn_engine_t engine, + void *handle); +~~~ + +### 3.2. Operation primitives cannot be used as inputs (use memory instead) + +Version 0.x allowed passing an operation primitive as an input to another +primitive. For instance, a convolution primitive could be passed as an input to +a consequent ReLU. During the execution the ReLU primitive queried the +convolution for its output memory and used it as an input. + +In version 1.0, users are allowed to pass only memory type as inputs and outputs +for primitives. + +### 3.3. Remove the `mkldnn_primitive_at_t` type + +Another consequence is that `mkldnn_primitive_at_t`, which is logically +equivalent to `{primitive, output_index}`, becomes redundant. Previously the +type was used to specify the exact memory to use (if a primitive had several +outputs). + +### 3.4. Passing stream and input/output memories at primitive execution + +Finally, users are now able to directly run primitives by calling an `execute` +function instead of putting primitives into a stream and running the latter. +This change affects how primitives interact with streams and input/output +memory objects: with the new API they become arguments to be passed to the +primitive execution function. + +The change significantly simplifies primitive creation, which now requires a +primitive descriptor only: + +~~~cpp + mkldnn_status_t mkldnn_primitive_create(mkldnn_primitive_t *primitive, + const_mkldnn_primitive_desc_t *pd); +~~~ + +To remove the ambiguity in which order input and output memories need to be +passed, we introduced a map-like argument in which each memory argument is +paired with a tag indicating what kind of argument it is: destination, source, +weights, and so on. + +~~~cpp + // types + #define MKLDNN_ARG_SRC_0 1 + #define MKLDNN_ARG_SRC MKLDNN_ARG_SRC_0 + #define MKLDNN_ARG_FROM MKLDNN_ARG_SRC_0 + // ... + + // C API + typedef struct { + int arg; // MKLDNN_ARG_SRC, ... + mkldnn_memory_t memory; + } mkldnn_exec_arg_t; + + mkldnn_status_t mkldnn_primitive_execute(mkldnn_primitive_t prim, + mkldnn_stream_t stream, int nargs, const mkldnn_exec_arg_t *args); + + // C++ API + convolution_forward::execute(mkldnn::stream &stream, + const std::map &exec_args); + // ... other primitives ... + + + // example C, convolution forward w/ bias + mkldnn_exec_arg_t conv_exec_args[] = { + {MKLDNN_ARG_SRC, src_mem}, + {MKLDNN_ARG_WEIGHTS, weights_mem}, + {MKLDNN_ARG_BIAS, bias_mem}, + {MKLDNN_ARG_DST, dst_mem}, + }; + mkldnn_primitive_execute(conv_fwd, stream, 4, conv_exec_args); + + + // example C++, in-place eltwise + eltwise.execute(stream, {{MKLDNN_ARG_SRC, mem}, {MKLDNN_ARG_DST, mem}}); +~~~ + +### 3.5 Short summary + +The example below shows conceptual code transformations between versions. The +C++ API is used for brevity. + +#### Version 0.x: +~~~cpp + // create a convolution, specify all inputs and outputs + auto conv = convolution(conv_pd, + {src_mem, 0}, {wei_mem, 0}, dst_conv_mem); + + // create a relu (note that one of inputs is the convolution) + auto relu = relu(relu_pd, + {conv, 0}, dst_relu_mem); + + // create a stream, submit convolution and relu, and wait for the result + stream().submit({conv, relu}).wait(); +~~~ + +#### Version 1.0: +~~~cpp + // create convolution and relu. no inputs/outputs + auto conv = convolution(conv_pd); + auto relu = relu(relu_pd); + + // create stream (based on engine) + stream s(engine, 0); + + // execute the convolution with given inputs, outputs + conv.execute(s, { + {MKLDNN_ARG_SRC, src_mem}, + {MKLDNN_ARG_WEIGHTS, wei_mem}, + {MKLDNN_ARG_DST, dst_conv_mem}}); + + // execute the relu. cannot pass convolution as an input, only memory is allowed + relu.execute(s, { + {{MKLDNN_ARG_SRC, dst_conv_mem}, + {MKLDNN_ARG_DST, dst_relu_mem}}); + + s.wait(); // wait for async streams +~~~ + +## 4. Changes in Memory Description + +The way of describing memory format in version 0.x had multiple issues. From +the user's perspective, the main issues were: +- Some memory formats were missing. For example, the `iohw` format was not + available. +- There were multiple ambiguous ways to describe memory. For example, `oihw` + described memory in the same way as `nchw`, but these formats were different + (see [gh#153](https://github.com/oneapi-src/oneDNN/issues/153)). +- Support for custom formats was limited. +- Support for memory views was limited. + +There were more substantial issues from the library development perspective: +code bloat to support special cases, etc. + +We addressed the issues above by reworking memory descriptors. From the user's +perspective, the main changes are: +1. Memory descriptors support arbitrary strides for plain layouts. For + example, initializing a memory descriptor with `strides={h*w, o*h*w, w, 1}` + should be a valid way to define `iohw` format even if Intel MKL-DNN does not + support it explicitly. Functions to use: + - C++ API: [mkldnn::memory::desc::desc(dims, data_type, strides)](@ref dnnl::memory::desc::desc), + - C API: [mkldnn_memory_desc_init_by_strides()](@ref dnnl_memory_desc_init_by_strides). +2. Dimensions are of type `int64_t` instead of int, and the maximum number + of tensor dimensions is decreased from 16 to 12. The `mkldnn_strides_t` + is removed; use `mkldnn_dims_t` instead. +3. The `memory_desc_t.format` field is replaced with + `memory_desc_t.format_kind`, which also has different semantics. + +While the first two items are self-explanatory, the last one requires some +elaboration. + +In version 0.x, most memory formats could be described directly by using +appropriate format names (for example, `nchw`) that fully describe how data is +laid out in memory. However, Intel MKL-DNN also had the `blocked` memory format +and the corresponding `memory_desc_t.layout_desc.blocking_desc` structure, +which could describe a memory format in a unified fashion by specifying block +sizes and strides. The original idea was to use format tags like `nchw` during +memory descriptor initialization only, and always use the `blocked` format +internally. Unfortunately, that was never implemented. + +With the new design, Intel MKL-DNN starts distinguishing between the actual +memory format and convenience memory format tags that can be used to describe +memory format concisely. + +Users are still able to initialize memory descriptors with format tags like +`nchw` using [mkldnn::memory::desc::desc(dims, data_type, format_tag)](@ref dnnl::memory::desc::desc) +or [mkldnn_memory_desc_init_by_tag()](@ref dnnl_memory_desc_init_by_tag), +but the `memory_desc_t.format_kind` is set +to a canonicalized kind like `blocked`, and the format name is not recorded in +the memory descriptor structure. Initialization with strides will always result +in `blocked` format. The API also uses different types for memory format tags +and kinds to aid correctness. + +For more details, refer to the +[Memory descriptor article](https://github.com/oneapi-src/oneDNN/blob/rfc-api-changes-v1.0/doc/rfc/api-v1.0/rfc_memory_desc.md) +of the [RFC for v1.0](https://github.com/oneapi-src/oneDNN/pull/384). + +## 5. Changes in the Build System + +The build options were slightly changed in the new version of Intel MKL-DNN. +That was done mainly to avoid name collisions with other projects that include +Intel MKL-DNN as a subproject and to accommodate future extensions to the +library. The change are: + +| Old option | New option | Notes | +| :-- | :-- | :-- | +| WITH_EXAMPLE | MKLDNN_BUILD_EXAMPLES | | +| WITH_TEST | MKLDNN_BUILD_TESTS | | +| MKLDNN_THREADING | MKLDNN_CPU_RUNTIME | | +| MKLDNN_USE_MKL | N/A | Intel MKL-DNN does not use Intel MKL anymore | +| VTUNEROOT | N/A | Not required, as Intel MKL-DNN contains all the necessary code internally | + +By default, the `-Werror` flag is disabled. `MKLDNN_WERROR` controls the +behavior. + +For more information about build options, refer to @ref dev_guide_build_options. diff --git a/include/mkldnn.h b/include/mkldnn.h new file mode 100644 index 00000000000..27ce73b84b4 --- /dev/null +++ b/include/mkldnn.h @@ -0,0 +1,26 @@ +/******************************************************************************* +* Copyright 2019-2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +// Header file ensures the backwards compatibility with previous namings. + +#ifndef MKLDNN_H +#define MKLDNN_H + +#include "mkldnn_dnnl_mangling.h" + +#include "dnnl.h" + +#endif /* MKLDNN_H */ diff --git a/include/mkldnn.hpp b/include/mkldnn.hpp new file mode 100644 index 00000000000..c3940f9076c --- /dev/null +++ b/include/mkldnn.hpp @@ -0,0 +1,26 @@ +/******************************************************************************* +* Copyright 2019-2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +// Header file ensures the backwards compatibility with previous namings. + +#ifndef MKLDNN_HPP +#define MKLDNN_HPP + +#include "mkldnn_dnnl_mangling.h" + +#include "dnnl.hpp" + +#endif /* MKLDNN_HPP */ diff --git a/include/mkldnn_config.h b/include/mkldnn_config.h new file mode 100644 index 00000000000..799d0c80788 --- /dev/null +++ b/include/mkldnn_config.h @@ -0,0 +1,26 @@ +/******************************************************************************* +* Copyright 2019-2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +// Header file ensures the backwards compatibility with previous namings. + +#ifndef MKLDNN_CONFIG_H +#define MKLDNN_CONFIG_H + +#include "mkldnn_dnnl_mangling.h" + +#include "dnnl_config.h" + +#endif /* MKLDNN_CONFIG_H */ diff --git a/include/mkldnn_debug.h b/include/mkldnn_debug.h new file mode 100644 index 00000000000..b1863d11404 --- /dev/null +++ b/include/mkldnn_debug.h @@ -0,0 +1,26 @@ +/******************************************************************************* +* Copyright 2019-2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +// Header file ensures the backwards compatibility with previous namings. + +#ifndef MKLDNN_DEBUG_H +#define MKLDNN_DEBUG_H + +#include "mkldnn_dnnl_mangling.h" + +#include "dnnl_debug.h" + +#endif /* MKLDNN_DEBUG_H */ diff --git a/include/mkldnn_dnnl_mangling.h b/include/mkldnn_dnnl_mangling.h new file mode 100644 index 00000000000..d076ff26b91 --- /dev/null +++ b/include/mkldnn_dnnl_mangling.h @@ -0,0 +1,758 @@ +/******************************************************************************* +* Copyright 2019-2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +// Mangle mkldnn entities to dnnl ones to preserve source-code level backwards +// compatibility. The compatibility will be dropped in oneDNN v2.0. +// Please switch to the new names as soon as possible. + +#ifndef MKLDNN_DNNL_MANGLING_H +#define MKLDNN_DNNL_MANGLING_H + +#define MKLDNN_API DNNL_API +#define MKLDNN_ARG_BIAS DNNL_ARG_BIAS +#define MKLDNN_ARG_DIFF_BIAS DNNL_ARG_DIFF_BIAS +#define MKLDNN_ARG_DIFF_DST DNNL_ARG_DIFF_DST +#define MKLDNN_ARG_DIFF_DST_0 DNNL_ARG_DIFF_DST_0 +#define MKLDNN_ARG_DIFF_DST_1 DNNL_ARG_DIFF_DST_1 +#define MKLDNN_ARG_DIFF_DST_2 DNNL_ARG_DIFF_DST_2 +#define MKLDNN_ARG_DIFF_DST_ITER DNNL_ARG_DIFF_DST_ITER +#define MKLDNN_ARG_DIFF_DST_ITER_C DNNL_ARG_DIFF_DST_ITER_C +#define MKLDNN_ARG_DIFF_DST_LAYER DNNL_ARG_DIFF_DST_LAYER +#define MKLDNN_ARG_DIFF_SCALE_SHIFT DNNL_ARG_DIFF_SCALE_SHIFT +#define MKLDNN_ARG_DIFF_SRC DNNL_ARG_DIFF_SRC +#define MKLDNN_ARG_DIFF_SRC_0 DNNL_ARG_DIFF_SRC_0 +#define MKLDNN_ARG_DIFF_SRC_1 DNNL_ARG_DIFF_SRC_1 +#define MKLDNN_ARG_DIFF_SRC_2 DNNL_ARG_DIFF_SRC_2 +#define MKLDNN_ARG_DIFF_SRC_ITER DNNL_ARG_DIFF_SRC_ITER +#define MKLDNN_ARG_DIFF_SRC_ITER_C DNNL_ARG_DIFF_SRC_ITER_C +#define MKLDNN_ARG_DIFF_SRC_LAYER DNNL_ARG_DIFF_SRC_LAYER +#define MKLDNN_ARG_DIFF_WEIGHTS DNNL_ARG_DIFF_WEIGHTS +#define MKLDNN_ARG_DIFF_WEIGHTS_0 DNNL_ARG_DIFF_WEIGHTS_0 +#define MKLDNN_ARG_DIFF_WEIGHTS_1 DNNL_ARG_DIFF_WEIGHTS_1 +#define MKLDNN_ARG_DIFF_WEIGHTS_ITER DNNL_ARG_DIFF_WEIGHTS_ITER +#define MKLDNN_ARG_DIFF_WEIGHTS_LAYER DNNL_ARG_DIFF_WEIGHTS_LAYER +#define MKLDNN_ARG_DST DNNL_ARG_DST +#define MKLDNN_ARG_DST_0 DNNL_ARG_DST_0 +#define MKLDNN_ARG_DST_1 DNNL_ARG_DST_1 +#define MKLDNN_ARG_DST_2 DNNL_ARG_DST_2 +#define MKLDNN_ARG_DST_ITER DNNL_ARG_DST_ITER +#define MKLDNN_ARG_DST_ITER_C DNNL_ARG_DST_ITER_C +#define MKLDNN_ARG_DST_LAYER DNNL_ARG_DST_LAYER +#define MKLDNN_ARG_FROM DNNL_ARG_FROM +#define MKLDNN_ARG_MEAN DNNL_ARG_MEAN +#define MKLDNN_ARG_MULTIPLE_DST DNNL_ARG_MULTIPLE_DST +#define MKLDNN_ARG_MULTIPLE_SRC DNNL_ARG_MULTIPLE_SRC +#define MKLDNN_ARG_SCALE_SHIFT DNNL_ARG_SCALE_SHIFT +#define MKLDNN_ARG_SCRATCHPAD DNNL_ARG_SCRATCHPAD +#define MKLDNN_ARG_SRC DNNL_ARG_SRC +#define MKLDNN_ARG_SRC_0 DNNL_ARG_SRC_0 +#define MKLDNN_ARG_SRC_1 DNNL_ARG_SRC_1 +#define MKLDNN_ARG_SRC_2 DNNL_ARG_SRC_2 +#define MKLDNN_ARG_SRC_ITER DNNL_ARG_SRC_ITER +#define MKLDNN_ARG_SRC_ITER_C DNNL_ARG_SRC_ITER_C +#define MKLDNN_ARG_SRC_LAYER DNNL_ARG_SRC_LAYER +#define MKLDNN_ARG_TO DNNL_ARG_TO +#define MKLDNN_ARG_VARIANCE DNNL_ARG_VARIANCE +#define MKLDNN_ARG_WEIGHTS DNNL_ARG_WEIGHTS +#define MKLDNN_ARG_WEIGHTS_0 DNNL_ARG_WEIGHTS_0 +#define MKLDNN_ARG_WEIGHTS_1 DNNL_ARG_WEIGHTS_1 +#define MKLDNN_ARG_WEIGHTS_ITER DNNL_ARG_WEIGHTS_ITER +#define MKLDNN_ARG_WEIGHTS_LAYER DNNL_ARG_WEIGHTS_LAYER +#define MKLDNN_ARG_WORKSPACE DNNL_ARG_WORKSPACE +#define MKLDNN_CPU_RUNTIME DNNL_CPU_RUNTIME +#define MKLDNN_DEFINE_BITMASK_OPS DNNL_DEFINE_BITMASK_OPS +#define MKLDNN_GPU_RUNTIME DNNL_GPU_RUNTIME +#define MKLDNN_JIT_DUMP DNNL_JIT_DUMP +#define MKLDNN_MAX_NDIMS DNNL_MAX_NDIMS +#define MKLDNN_MEMORY_ALLOCATE DNNL_MEMORY_ALLOCATE +#define MKLDNN_MEMORY_NONE DNNL_MEMORY_NONE +#define MKLDNN_RNN_MAX_N_PARTS DNNL_RNN_MAX_N_PARTS +#define MKLDNN_RUNTIME_NONE DNNL_RUNTIME_NONE +#define MKLDNN_RUNTIME_OCL DNNL_RUNTIME_OCL +#define MKLDNN_RUNTIME_OMP DNNL_RUNTIME_OMP +#define MKLDNN_RUNTIME_SEQ DNNL_RUNTIME_SEQ +#define MKLDNN_RUNTIME_TBB DNNL_RUNTIME_TBB +#define MKLDNN_RUNTIME_SYCL DNNL_RUNTIME_SYCL +#define MKLDNN_WITH_SYCL DNNL_WITH_SYCL +#define MKLDNN_VERBOSE DNNL_VERBOSE +#define MKLDNN_VERSION_HASH DNNL_VERSION_HASH +#define MKLDNN_VERSION_MAJOR DNNL_VERSION_MAJOR +#define MKLDNN_VERSION_MINOR DNNL_VERSION_MINOR +#define MKLDNN_VERSION_PATCH DNNL_VERSION_PATCH +#define const_mkldnn_engine_t const_dnnl_engine_t +#define const_mkldnn_memory_t const_dnnl_memory_t +#define const_mkldnn_op_desc_t const_dnnl_op_desc_t +#define const_mkldnn_post_ops_t const_dnnl_post_ops_t +#define const_mkldnn_primitive_attr_t const_dnnl_primitive_attr_t +#define const_mkldnn_primitive_desc_iterator_t \ + const_dnnl_primitive_desc_iterator_t +#define const_mkldnn_primitive_desc_t const_dnnl_primitive_desc_t +#define const_mkldnn_primitive_t const_dnnl_primitive_t +#define const_mkldnn_stream_t const_dnnl_stream_t +#define mkldnn dnnl +#define mkldnn_ dnnl_ +#define mkldnn_ABc16a16b dnnl_ABc16a16b +#define mkldnn_ABc4a4b dnnl_ABc4a4b +#define mkldnn_ABc16b16a dnnl_ABc16b16a +#define mkldnn_ABc4b16a4b dnnl_ABc4b16a4b +#define mkldnn_ABc4b4a dnnl_ABc4b4a +#define mkldnn_ABc8a16b2a dnnl_ABc8a16b2a +#define mkldnn_ABc8a8b dnnl_ABc8a8b +#define mkldnn_ABc8b16a2b dnnl_ABc8b16a2b +#define mkldnn_ABc8b8a dnnl_ABc8b8a +#define mkldnn_ABcd16a16b dnnl_ABcd16a16b +#define mkldnn_ABcd16b16a dnnl_ABcd16b16a +#define mkldnn_ABcd2a8b8a2b dnnl_ABcd2a8b8a2b +#define mkldnn_ABcd32a32b dnnl_ABcd32a32b +#define mkldnn_ABcd4a8b8a4b dnnl_ABcd4a8b8a4b +#define mkldnn_ABcd4b16a4b dnnl_ABcd4b16a4b +#define mkldnn_OIhw16i16o4i dnnl_ABcd16b16a4b +#define mkldnn_OIhw16i16o2i dnnl_ABcd16b16a2b +#define mkldnn_ABcd4b4a dnnl_ABcd4b4a +#define mkldnn_ABcd4a4b dnnl_ABcd4a4b +#define mkldnn_ABcd8a16b2a dnnl_ABcd8a16b2a +#define mkldnn_ABcd8a8b dnnl_ABcd8a8b +#define mkldnn_ABcd8b16a2b dnnl_ABcd8b16a2b +#define mkldnn_ABcd8b8a dnnl_ABcd8b8a +#define mkldnn_ABcde16a16b dnnl_ABcde16a16b +#define mkldnn_ABcde16b16a dnnl_ABcde16b16a +#define mkldnn_ABcde4b4a dnnl_ABcde4b4a +#define mkldnn_ABcde4a4b dnnl_ABcde4a4b +#define mkldnn_ABcde8a16b2a dnnl_ABcde8a16b2a +#define mkldnn_ABcde8a8b dnnl_ABcde8a8b +#define mkldnn_ABcde8b16a2b dnnl_ABcde8b16a2b +#define mkldnn_ABcde4b16a4b dnnl_ABcde4b16a4b +#define mkldnn_ABcde8b8a dnnl_ABcde8b8a +#define mkldnn_Abc16a dnnl_Abc16a +#define mkldnn_Abc4a dnnl_Abc4a +#define mkldnn_Abcd16a dnnl_Abcd16a +#define mkldnn_Abcd4a dnnl_Abcd4a +#define mkldnn_Abcde16a dnnl_Abcde16a +#define mkldnn_Abcde4a dnnl_Abcde4a +#define mkldnn_Abcde8a dnnl_Abcde8a +#define mkldnn_Abcdef16a dnnl_Abcdef16a +#define mkldnn_Acb16a dnnl_Acb16a +#define mkldnn_Acb4a dnnl_Acb4a +#define mkldnn_Acb8a dnnl_Acb8a +#define mkldnn_Acdb16a dnnl_Acdb16a +#define mkldnn_Acdb32a dnnl_Acdb32a +#define mkldnn_Acdb4a dnnl_Acdb4a +#define mkldnn_Acdb8a dnnl_Acdb8a +#define mkldnn_Acdeb16a dnnl_Acdeb16a +#define mkldnn_Acdeb4a dnnl_Acdeb4a +#define mkldnn_Acdeb8a dnnl_Acdeb8a +#define mkldnn_BAc16a16b dnnl_BAc16a16b +#define mkldnn_BAc16b16a dnnl_BAc16b16a +#define mkldnn_BAc8a16b2a dnnl_BAc8a16b2a +#define mkldnn_BAcd16a16b dnnl_BAcd16a16b +#define mkldnn_BAcd16b16a dnnl_BAcd16b16a +#define mkldnn_BAcd8a16b2a dnnl_BAcd8a16b2a +#define mkldnn_BAcde16b16a dnnl_BAcde16b16a +#define mkldnn_BAcde16a16b dnnl_BAcde16a16b +#define mkldnn_BAcde8a16b2a dnnl_BAcde8a16b2a +#define mkldnn_Goidhw16g dnnl_Goidhw16g +#define mkldnn_Goihw16g dnnl_Goihw16g +#define mkldnn_Goihw8g dnnl_Goihw8g +#define mkldnn_Goiw16g dnnl_Goiw16g +#define mkldnn_IOdhw16i16o dnnl_IOdhw16i16o +#define mkldnn_IOdhw16o16i dnnl_IOdhw16o16i +#define mkldnn_IOdhw8o16i2o dnnl_IOdhw8o16i2o +#define mkldnn_IOhw16i16o dnnl_IOhw16i16o +#define mkldnn_IOhw16o16i dnnl_IOhw16o16i +#define mkldnn_IOhw8o16i2o dnnl_IOhw8o16i2o +#define mkldnn_IOw16i16o dnnl_IOw16i16o +#define mkldnn_IOw16o16i dnnl_IOw16o16i +#define mkldnn_IOw8o16i2o dnnl_IOw8o16i2o +#define mkldnn_NCdhw16n16c dnnl_NCdhw16n16c +#define mkldnn_NChw16n16c dnnl_NChw16n16c +#define mkldnn_NChw32n32c dnnl_NChw32n32c +#define mkldnn_NCw16n16c dnnl_NCw16n16c +#define mkldnn_OIdhw16i16o dnnl_OIdhw16i16o +#define mkldnn_OIdhw16o16i dnnl_OIdhw16o16i +#define mkldnn_OIdhw4i4o dnnl_OIdhw4i4o +#define mkldnn_OIdhw4o4i dnnl_OIdhw4o4i +#define mkldnn_OIdhw8i16o2i dnnl_OIdhw8i16o2i +#define mkldnn_OIdhw4i16o4i dnnl_OIdhw4i16o4i +#define mkldnn_OIdhw8i8o dnnl_OIdhw8i8o +#define mkldnn_OIdhw8o16i2o dnnl_OIdhw8o16i2o +#define mkldnn_OIdhw8o8i dnnl_OIdhw8o8i +#define mkldnn_OIhw16i16o dnnl_OIhw16i16o +#define mkldnn_OIhw16o16i dnnl_OIhw16o16i +#define mkldnn_OIhw2o8i8o2i dnnl_OIhw2o8i8o2i +#define mkldnn_OIhw4i16o4i dnnl_OIhw4i16o4i +#define mkldnn_OIhw4i4o dnnl_OIhw4i4o +#define mkldnn_OIhw4o4i dnnl_OIhw4o4i +#define mkldnn_OIhw4o8i8o4i dnnl_OIhw4o8i8o4i +#define mkldnn_OIhw8i16o2i dnnl_OIhw8i16o2i +#define mkldnn_OIhw8i8o dnnl_OIhw8i8o +#define mkldnn_OIhw8o16i2o dnnl_OIhw8o16i2o +#define mkldnn_OIhw8o8i dnnl_OIhw8o8i +#define mkldnn_OIw16i16o dnnl_OIw16i16o +#define mkldnn_OIw16o16i dnnl_OIw16o16i +#define mkldnn_OIw4i16o4i dnnl_OIw4i16o4i +#define mkldnn_OIw4i4o dnnl_OIw4i4o +#define mkldnn_OIw4o4i dnnl_OIw4o4i +#define mkldnn_OIw8i16o2i dnnl_OIw8i16o2i +#define mkldnn_OIw8i8o dnnl_OIw8i8o +#define mkldnn_OIw8o16i2o dnnl_OIw8o16i2o +#define mkldnn_OIw8o8i dnnl_OIw8o8i +#define mkldnn_Odhwi16o dnnl_Odhwi16o +#define mkldnn_Odhwi4o dnnl_Odhwi4o +#define mkldnn_Odhwi8o dnnl_Odhwi8o +#define mkldnn_Ohwi16o dnnl_Ohwi16o +#define mkldnn_Ohwi32o dnnl_Ohwi32o +#define mkldnn_Ohwi4o dnnl_Ohwi4o +#define mkldnn_Ohwi8o dnnl_Ohwi8o +#define mkldnn_Oidhw16o dnnl_Oidhw16o +#define mkldnn_Oidhw4o dnnl_Oidhw4o +#define mkldnn_Oihw16o dnnl_Oihw16o +#define mkldnn_Oihw4o dnnl_Oihw4o +#define mkldnn_Oiw16o dnnl_Oiw16o +#define mkldnn_Oiw4o dnnl_Oiw4o +#define mkldnn_Owi16o dnnl_Owi16o +#define mkldnn_Owi4o dnnl_Owi4o +#define mkldnn_Owi8o dnnl_Owi8o +#define mkldnn_a dnnl_a +#define mkldnn_aBCd16b16c dnnl_aBCd16b16c +#define mkldnn_aBCd16c16b dnnl_aBCd16c16b +#define mkldnn_aBCd4c16b4c dnnl_aBCd4c16b4c +#define mkldnn_aBCd4c4b dnnl_aBCd4c4b +#define mkldnn_aBCd4b4c dnnl_aBCd4b4c +#define mkldnn_aBCd8b16c2b dnnl_aBCd8b16c2b +#define mkldnn_aBCd8b8c dnnl_aBCd8b8c +#define mkldnn_aBCd8c16b2c dnnl_aBCd8c16b2c +#define mkldnn_aBCd8c8b dnnl_aBCd8c8b +#define mkldnn_aBCde16b16c dnnl_aBCde16b16c +#define mkldnn_aBCde16c16b dnnl_aBCde16c16b +#define mkldnn_aBCde2b8c8b2c dnnl_aBCde2b8c8b2c +#define mkldnn_aBCde2c8b4c dnnl_aBCde2c8b4c +#define mkldnn_gOIhw16i16o4i = dnnl_aBCde16c16b4c +#define mkldnn_gOIhw16i16o2i = dnnl_aBCde16c16b2c +#define mkldnn_aBCde4b4c dnnl_aBCde4b4c +#define mkldnn_aBCde4b8c8b4c dnnl_aBCde4b8c8b4c +#define mkldnn_aBCde4c16b4c dnnl_aBCde4c16b4c +#define mkldnn_aBCde4c4b dnnl_aBCde4c4b +#define mkldnn_aBCde8b16c2b dnnl_aBCde8b16c2b +#define mkldnn_aBCde8b8c dnnl_aBCde8b8c +#define mkldnn_aBCde8c16b2c dnnl_aBCde8c16b2c +#define mkldnn_aBCde8c8b dnnl_aBCde8c8b +#define mkldnn_aBCdef16b16c dnnl_aBCdef16b16c +#define mkldnn_aBCdef16c16b dnnl_aBCdef16c16b +#define mkldnn_aBCdef4c4b dnnl_aBCdef4c4b +#define mkldnn_aBCdef4b4c dnnl_aBCdef4b4c +#define mkldnn_aBCdef8b16c2b dnnl_aBCdef8b16c2b +#define mkldnn_aBCdef8b8c dnnl_aBCdef8b8c +#define mkldnn_aBCdef8c16b2c dnnl_aBCdef8c16b2c +#define mkldnn_aBCdef4c16b4c dnnl_aBCdef4c16b4c +#define mkldnn_aBCdef8c8b dnnl_aBCdef8c8b +#define mkldnn_aBc16b dnnl_aBc16b +#define mkldnn_aBc4b dnnl_aBc4b +#define mkldnn_aBc8b dnnl_aBc8b +#define mkldnn_aBcd16b dnnl_aBcd16b +#define mkldnn_aBcd4b dnnl_aBcd4b +#define mkldnn_aBcd8b dnnl_aBcd8b +#define mkldnn_aBcde16b dnnl_aBcde16b +#define mkldnn_aBcde4b dnnl_aBcde4b +#define mkldnn_aBcde8b dnnl_aBcde8b +#define mkldnn_aBcdef16b dnnl_aBcdef16b +#define mkldnn_aBcdef4b dnnl_aBcdef4b +#define mkldnn_aBdc16b dnnl_aBdc16b +#define mkldnn_aBdc4b dnnl_aBdc4b +#define mkldnn_aBdc8b dnnl_aBdc8b +#define mkldnn_aBdec16b dnnl_aBdec16b +#define mkldnn_aBdec32b dnnl_aBdec32b +#define mkldnn_aBdec4b dnnl_aBdec4b +#define mkldnn_aBdec8b dnnl_aBdec8b +#define mkldnn_aBdefc16b dnnl_aBdefc16b +#define mkldnn_aBdefc4b dnnl_aBdefc4b +#define mkldnn_aBdefc8b dnnl_aBdefc8b +#define mkldnn_aCBd16b16c dnnl_aCBd16b16c +#define mkldnn_aCBd16c16b dnnl_aCBd16c16b +#define mkldnn_aCBd8b16c2b dnnl_aCBd8b16c2b +#define mkldnn_aCBde16b16c dnnl_aCBde16b16c +#define mkldnn_aCBde16c16b dnnl_aCBde16c16b +#define mkldnn_aCBde8b16c2b dnnl_aCBde8b16c2b +#define mkldnn_aCBdef16c16b dnnl_aCBdef16c16b +#define mkldnn_aCBdef16b16c dnnl_aCBdef16b16c +#define mkldnn_aCBdef8b16c2b dnnl_aCBdef8b16c2b +#define mkldnn_ab dnnl_ab +#define mkldnn_abc dnnl_abc +#define mkldnn_abcd dnnl_abcd +#define mkldnn_abcde dnnl_abcde +#define mkldnn_abcdef dnnl_abcdef +#define mkldnn_abdec dnnl_abdec +#define mkldnn_acb dnnl_acb +#define mkldnn_acbde dnnl_acbde +#define mkldnn_acdb dnnl_acdb +#define mkldnn_acdeb dnnl_acdeb +#define mkldnn_alg_kind2str dnnl_alg_kind2str +#define mkldnn_alg_kind_t dnnl_alg_kind_t +#define mkldnn_alg_kind_undef dnnl_alg_kind_undef +#define mkldnn_any_engine dnnl_any_engine +#define mkldnn_ba dnnl_ba +#define mkldnn_bac dnnl_bac +#define mkldnn_bacd dnnl_bacd +#define mkldnn_backward dnnl_backward +#define mkldnn_backward_bias dnnl_backward_bias +#define mkldnn_backward_data dnnl_backward_data +#define mkldnn_backward_weights dnnl_backward_weights +#define mkldnn_batch_normalization dnnl_batch_normalization +#define mkldnn_batch_normalization_backward_desc_init \ + dnnl_batch_normalization_backward_desc_init +#define mkldnn_batch_normalization_desc_t dnnl_batch_normalization_desc_t +#define mkldnn_batch_normalization_forward_desc_init \ + dnnl_batch_normalization_forward_desc_init +#define mkldnn_bca dnnl_bca +#define mkldnn_bcda dnnl_bcda +#define mkldnn_bcdea dnnl_bcdea +#define mkldnn_bf16 dnnl_bf16 +#define mkldnn_bidirectional_concat dnnl_bidirectional_concat +#define mkldnn_bidirectional_sum dnnl_bidirectional_sum +#define mkldnn_blocked dnnl_blocked +#define mkldnn_blocking_desc_t dnnl_blocking_desc_t +#define mkldnn_cba dnnl_cba +#define mkldnn_cdba dnnl_cdba +#define mkldnn_cdeba dnnl_cdeba +#define mkldnn_chwn dnnl_chwn +#define mkldnn_cn dnnl_cn +#define mkldnn_concat dnnl_concat +#define mkldnn_concat_primitive_desc_create dnnl_concat_primitive_desc_create +#define mkldnn_config dnnl_config +#define mkldnn_convolution dnnl_convolution +#define mkldnn_convolution_auto dnnl_convolution_auto +#define mkldnn_convolution_backward_data_desc_init \ + dnnl_convolution_backward_data_desc_init +#define mkldnn_convolution_backward_weights_desc_init \ + dnnl_convolution_backward_weights_desc_init +#define mkldnn_convolution_desc_t dnnl_convolution_desc_t +#define mkldnn_convolution_direct dnnl_convolution_direct +#define mkldnn_convolution_forward_desc_init dnnl_convolution_forward_desc_init +#define mkldnn_convolution_winograd dnnl_convolution_winograd +#define mkldnn_cpu dnnl_cpu +#define mkldnn_data_type_t dnnl_data_type_t +#define mkldnn_data_type_undef dnnl_data_type_undef +#define mkldnn_decab dnnl_decab +#define mkldnn_deconvolution dnnl_deconvolution +#define mkldnn_deconvolution_backward_data_desc_init \ + dnnl_deconvolution_backward_data_desc_init +#define mkldnn_deconvolution_backward_weights_desc_init \ + dnnl_deconvolution_backward_weights_desc_init +#define mkldnn_deconvolution_desc_t dnnl_deconvolution_desc_t +#define mkldnn_deconvolution_direct dnnl_deconvolution_direct +#define mkldnn_deconvolution_forward_desc_init \ + dnnl_deconvolution_forward_desc_init +#define mkldnn_deconvolution_winograd dnnl_deconvolution_winograd +#define mkldnn_dhwio dnnl_dhwio +#define mkldnn_dilated_convolution_backward_data_desc_init \ + dnnl_dilated_convolution_backward_data_desc_init +#define mkldnn_dilated_convolution_backward_weights_desc_init \ + dnnl_dilated_convolution_backward_weights_desc_init +#define mkldnn_dilated_convolution_forward_desc_init \ + dnnl_dilated_convolution_forward_desc_init +#define mkldnn_dilated_deconvolution_backward_data_desc_init \ + dnnl_dilated_deconvolution_backward_data_desc_init +#define mkldnn_dilated_deconvolution_backward_weights_desc_init \ + dnnl_dilated_deconvolution_backward_weights_desc_init +#define mkldnn_dilated_deconvolution_forward_desc_init \ + dnnl_dilated_deconvolution_forward_desc_init +#define mkldnn_dim_t dnnl_dim_t +#define mkldnn_dims_t dnnl_dims_t +#define mkldnn_dt2str dnnl_dt2str +#define mkldnn_eltwise dnnl_eltwise +#define mkldnn_eltwise_abs dnnl_eltwise_abs +#define mkldnn_eltwise_backward_desc_init dnnl_eltwise_backward_desc_init +#define mkldnn_eltwise_bounded_relu dnnl_eltwise_bounded_relu +#define mkldnn_eltwise_desc_t dnnl_eltwise_desc_t +#define mkldnn_eltwise_elu dnnl_eltwise_elu +#define mkldnn_eltwise_exp dnnl_eltwise_exp +#define mkldnn_eltwise_forward_desc_init dnnl_eltwise_forward_desc_init +#define mkldnn_eltwise_gelu dnnl_eltwise_gelu +#define mkldnn_eltwise_linear dnnl_eltwise_linear +#define mkldnn_eltwise_logistic dnnl_eltwise_logistic +#define mkldnn_eltwise_relu dnnl_eltwise_relu +#define mkldnn_eltwise_soft_relu dnnl_eltwise_soft_relu +#define mkldnn_eltwise_sqrt dnnl_eltwise_sqrt +#define mkldnn_eltwise_square dnnl_eltwise_square +#define mkldnn_eltwise_swish dnnl_eltwise_swish +#define mkldnn_eltwise_tanh dnnl_eltwise_tanh +#define mkldnn_engine dnnl_engine +#define mkldnn_engine_create dnnl_engine_create +#define mkldnn_engine_create_ocl dnnl_ocl_interop_engine_create +#define mkldnn_engine_destroy dnnl_engine_destroy +#define mkldnn_engine_get_count dnnl_engine_get_count +#define mkldnn_engine_get_kind dnnl_engine_get_kind +#define mkldnn_engine_get_ocl_context dnnl_ocl_interop_engine_get_context +#define mkldnn_engine_get_ocl_device dnnl_ocl_interop_get_device +#define mkldnn_engine_kind2str dnnl_engine_kind2str +#define mkldnn_engine_kind_t dnnl_engine_kind_t +#define mkldnn_engine_t dnnl_engine_t +#define mkldnn_exec_arg_t dnnl_exec_arg_t +#define mkldnn_f16 dnnl_f16 +#define mkldnn_f32 dnnl_f32 +#define mkldnn_fmt_kind2str dnnl_fmt_kind2str +#define mkldnn_fmt_tag2str dnnl_fmt_tag2str +#define mkldnn_format_kind_any dnnl_format_kind_any +#define mkldnn_format_kind_rnn_packed dnnl_format_kind_rnn_packed +#define mkldnn_format_kind_t dnnl_format_kind_t +#define mkldnn_format_kind_undef dnnl_format_kind_undef +#define mkldnn_format_kind_wino dnnl_format_kind_wino +#define mkldnn_format_tag_any dnnl_format_tag_any +#define mkldnn_format_tag_last dnnl_format_tag_last +#define mkldnn_format_tag_t dnnl_format_tag_t +#define mkldnn_format_tag_undef dnnl_format_tag_undef +#define mkldnn_forward dnnl_forward +#define mkldnn_forward_inference dnnl_forward_inference +#define mkldnn_forward_scoring dnnl_forward_scoring +#define mkldnn_forward_training dnnl_forward_training +#define mkldnn_fuse_norm_relu dnnl_fuse_norm_relu +#define mkldnn_gIOdhw16i16o dnnl_gIOdhw16i16o +#define mkldnn_gIOdhw16o16i dnnl_gIOdhw16o16i +#define mkldnn_gIOdhw8o16i2o dnnl_gIOdhw8o16i2o +#define mkldnn_gIOhw16i16o dnnl_gIOhw16i16o +#define mkldnn_gIOhw16o16i dnnl_gIOhw16o16i +#define mkldnn_gIOhw8o16i2o dnnl_gIOhw8o16i2o +#define mkldnn_gIOw16i16o dnnl_gIOw16i16o +#define mkldnn_gIOw16o16i dnnl_gIOw16o16i +#define mkldnn_gIOw8o16i2o dnnl_gIOw8o16i2o +#define mkldnn_gOIdhw16i16o dnnl_gOIdhw16i16o +#define mkldnn_gOIdhw16o16i dnnl_gOIdhw16o16i +#define mkldnn_gOIdhw4i4o dnnl_gOIdhw4i4o +#define mkldnn_gOIdhw4o4i dnnl_gOIdhw4o4i +#define mkldnn_gOIdhw8i16o2i dnnl_gOIdhw8i16o2i +#define mkldnn_gOIdhw4i16o4i dnnl_gOIdhw4i16o4i +#define mkldnn_gOIdhw8i8o dnnl_gOIdhw8i8o +#define mkldnn_gOIdhw8o16i2o dnnl_gOIdhw8o16i2o +#define mkldnn_gOIdhw8o8i dnnl_gOIdhw8o8i +#define mkldnn_gOIhw16i16o dnnl_gOIhw16i16o +#define mkldnn_gOIhw16o16i dnnl_gOIhw16o16i +#define mkldnn_gOIhw2i8o4i dnnl_gOIhw2i8o4i +#define mkldnn_gOIhw2o8i8o2i dnnl_gOIhw2o8i8o2i +#define mkldnn_gOIhw4i16o4i dnnl_gOIhw4i16o4i +#define mkldnn_gOIhw4i4o dnnl_gOIhw4i4o +#define mkldnn_gOIhw4o4i dnnl_gOIhw4o4i +#define mkldnn_gOIhw4o8i8o4i dnnl_gOIhw4o8i8o4i +#define mkldnn_gOIhw8i16o2i dnnl_gOIhw8i16o2i +#define mkldnn_gOIhw8i8o dnnl_gOIhw8i8o +#define mkldnn_gOIhw8o16i2o dnnl_gOIhw8o16i2o +#define mkldnn_gOIhw8o8i dnnl_gOIhw8o8i +#define mkldnn_gOIw16i16o dnnl_gOIw16i16o +#define mkldnn_gOIw16o16i dnnl_gOIw16o16i +#define mkldnn_gOIw4i16o4i dnnl_gOIw4i16o4i +#define mkldnn_gOIw4i4o dnnl_gOIw4i4o +#define mkldnn_gOIw4o4i dnnl_gOIw4o4i +#define mkldnn_gOIw8i16o2i dnnl_gOIw8i16o2i +#define mkldnn_gOIw8i8o dnnl_gOIw8i8o +#define mkldnn_gOIw8o16i2o dnnl_gOIw8o16i2o +#define mkldnn_gOIw8o8i dnnl_gOIw8o8i +#define mkldnn_gOdhwi16o dnnl_gOdhwi16o +#define mkldnn_gOdhwi4o dnnl_gOdhwi4o +#define mkldnn_gOdhwi8o dnnl_gOdhwi8o +#define mkldnn_gOhwi16o dnnl_gOhwi16o +#define mkldnn_gOhwi32o dnnl_gOhwi32o +#define mkldnn_gOhwi4o dnnl_gOhwi4o +#define mkldnn_gOhwi8o dnnl_gOhwi8o +#define mkldnn_gOidhw16o dnnl_gOidhw16o +#define mkldnn_gOidhw4o dnnl_gOidhw4o +#define mkldnn_gOihw16o dnnl_gOihw16o +#define mkldnn_gOihw4o dnnl_gOihw4o +#define mkldnn_gOiw16o dnnl_gOiw16o +#define mkldnn_gOiw4o dnnl_gOiw4o +#define mkldnn_gOwi16o dnnl_gOwi16o +#define mkldnn_gOwi4o dnnl_gOwi4o +#define mkldnn_gOwi8o dnnl_gOwi8o +#define mkldnn_gemm dnnl_gemm +#define mkldnn_gemm_s8s8s32 dnnl_gemm_s8s8s32 +#define mkldnn_gemm_u8s8s32 dnnl_gemm_u8s8s32 +#define mkldnn_giohw dnnl_giohw +#define mkldnn_goidhw dnnl_goidhw +#define mkldnn_goihw dnnl_goihw +#define mkldnn_goiw dnnl_goiw +#define mkldnn_gpu dnnl_gpu +#define mkldnn_gru_backward_desc_init dnnl_gru_backward_desc_init +#define mkldnn_gru_forward_desc_init dnnl_gru_forward_desc_init +#define mkldnn_hwigo dnnl_hwigo +#define mkldnn_hwio dnnl_hwio +#define mkldnn_idhwo dnnl_idhwo +#define mkldnn_ihwo dnnl_ihwo +#define mkldnn_inner_product dnnl_inner_product +#define mkldnn_inner_product_backward_data_desc_init \ + dnnl_inner_product_backward_data_desc_init +#define mkldnn_inner_product_backward_weights_desc_init \ + dnnl_inner_product_backward_weights_desc_init +#define mkldnn_inner_product_desc_t dnnl_inner_product_desc_t +#define mkldnn_inner_product_forward_desc_init \ + dnnl_inner_product_forward_desc_init +#define mkldnn_invalid_arguments dnnl_invalid_arguments +#define mkldnn_io dnnl_io +#define mkldnn_iohw dnnl_iohw +#define mkldnn_iterator_ends dnnl_iterator_ends +#define mkldnn_iwo dnnl_iwo +#define mkldnn_layer_normalization dnnl_layer_normalization +#define mkldnn_layer_normalization_backward_desc_init \ + dnnl_layer_normalization_backward_desc_init +#define mkldnn_layer_normalization_desc_t dnnl_layer_normalization_desc_t +#define mkldnn_layer_normalization_forward_desc_init \ + dnnl_layer_normalization_forward_desc_init +#define mkldnn_lbr_gru dnnl_lbr_gru +#define mkldnn_lbr_gru_backward_desc_init dnnl_lbr_gru_backward_desc_init +#define mkldnn_lbr_gru_forward_desc_init dnnl_lbr_gru_forward_desc_init +#define mkldnn_ldgo dnnl_ldgo +#define mkldnn_ldgoi dnnl_ldgoi +#define mkldnn_ldgoi_p dnnl_ldgoi_p +#define mkldnn_ldigo dnnl_ldigo +#define mkldnn_ldigo_p dnnl_ldigo_p +#define mkldnn_ldnc dnnl_ldnc +#define mkldnn_lrn dnnl_lrn +#define mkldnn_lrn_across_channels dnnl_lrn_across_channels +#define mkldnn_lrn_backward_desc_init dnnl_lrn_backward_desc_init +#define mkldnn_lrn_desc_t dnnl_lrn_desc_t +#define mkldnn_lrn_forward_desc_init dnnl_lrn_forward_desc_init +#define mkldnn_lrn_within_channel dnnl_lrn_within_channel +#define mkldnn_lstm_backward_desc_init dnnl_lstm_backward_desc_init +#define mkldnn_lstm_forward_desc_init dnnl_lstm_forward_desc_init +#define mkldnn_md2dim_str dnnl_md2dim_str +#define mkldnn_md2fmt_str dnnl_md2fmt_str +#define mkldnn_memory dnnl_memory +#define mkldnn_memory_create dnnl_memory_create +#define mkldnn_memory_desc_equal dnnl_memory_desc_equal +#define mkldnn_memory_desc_get_size dnnl_memory_desc_get_size +#define mkldnn_memory_desc_init_by_strides dnnl_memory_desc_init_by_strides +#define mkldnn_memory_desc_init_by_tag dnnl_memory_desc_init_by_tag +#define mkldnn_memory_desc_init_submemory dnnl_memory_desc_init_submemory +#define mkldnn_memory_desc_t dnnl_memory_desc_t +#define mkldnn_memory_destroy dnnl_memory_destroy +#define mkldnn_memory_extra_desc_t dnnl_memory_extra_desc_t +#define mkldnn_memory_extra_flag_compensation_conv_s8s8 \ + dnnl_memory_extra_flag_compensation_conv_s8s8 +#define mkldnn_memory_extra_flag_none dnnl_memory_extra_flag_none +#define mkldnn_memory_extra_flag_scale_adjust \ + dnnl_memory_extra_flag_scale_adjust +#define mkldnn_memory_extra_flags_t dnnl_memory_extra_flags_t +#define mkldnn_memory_get_data_handle dnnl_memory_get_data_handle +#define mkldnn_memory_get_engine dnnl_memory_get_engine +#define mkldnn_memory_get_memory_desc dnnl_memory_get_memory_desc +#define mkldnn_memory_get_ocl_mem_object dnnl_ocl_interop_memory_get_mem_object +#define mkldnn_memory_map_data dnnl_memory_map_data +#define mkldnn_memory_set_data_handle dnnl_memory_set_data_handle +#define mkldnn_memory_set_ocl_mem_object dnnl_ocl_interop_memory_set_mem_object +#define mkldnn_memory_t dnnl_memory_t +#define mkldnn_memory_unmap_data dnnl_memory_unmap_data +#define mkldnn_nCdhw16c dnnl_nCdhw16c +#define mkldnn_nCdhw4c dnnl_nCdhw4c +#define mkldnn_nCdhw8c dnnl_nCdhw8c +#define mkldnn_nChw16c dnnl_nChw16c +#define mkldnn_nChw4c dnnl_nChw4c +#define mkldnn_nChw8c dnnl_nChw8c +#define mkldnn_nCw16c dnnl_nCw16c +#define mkldnn_nCw4c dnnl_nCw4c +#define mkldnn_nCw8c dnnl_nCw8c +#define mkldnn_nc dnnl_nc +#define mkldnn_ncdhw dnnl_ncdhw +#define mkldnn_nchw dnnl_nchw +#define mkldnn_ncw dnnl_ncw +#define mkldnn_ndhwc dnnl_ndhwc +#define mkldnn_nhwc dnnl_nhwc +#define mkldnn_normalization_flags2str dnnl_normalization_flags2str +#define mkldnn_normalization_flags_t dnnl_normalization_flags_t +#define mkldnn_not_required dnnl_not_required +#define mkldnn_nt dnnl_nt +#define mkldnn_ntc dnnl_ntc +#define mkldnn_nwc dnnl_nwc +#define mkldnn_odhwi dnnl_odhwi +#define mkldnn_ohwi dnnl_ohwi +#define mkldnn_oi dnnl_oi +#define mkldnn_oidhw dnnl_oidhw +#define mkldnn_oihw dnnl_oihw +#define mkldnn_oiw dnnl_oiw +#define mkldnn_op_desc_t dnnl_op_desc_t +#define mkldnn_out_of_memory dnnl_out_of_memory +#define mkldnn_owi dnnl_owi +#define mkldnn_packed_format_undef dnnl_packed_format_undef +#define mkldnn_pooling dnnl_pooling +#define mkldnn_pooling_avg dnnl_pooling_avg +#define mkldnn_pooling_avg_exclude_padding dnnl_pooling_avg_exclude_padding +#define mkldnn_pooling_avg_include_padding dnnl_pooling_avg_include_padding +#define mkldnn_pooling_backward_desc_init dnnl_pooling_backward_desc_init +#define mkldnn_pooling_desc_t dnnl_pooling_desc_t +#define mkldnn_pooling_forward_desc_init dnnl_pooling_forward_desc_init +#define mkldnn_pooling_max dnnl_pooling_max +#define mkldnn_post_ops dnnl_post_ops +#define mkldnn_post_ops_append_eltwise dnnl_post_ops_append_eltwise +#define mkldnn_post_ops_append_sum dnnl_post_ops_append_sum +#define mkldnn_post_ops_create dnnl_post_ops_create +#define mkldnn_post_ops_destroy dnnl_post_ops_destroy +#define mkldnn_post_ops_get_kind dnnl_post_ops_get_kind +#define mkldnn_post_ops_get_params_eltwise dnnl_post_ops_get_params_eltwise +#define mkldnn_post_ops_get_params_sum dnnl_post_ops_get_params_sum +#define mkldnn_post_ops_len dnnl_post_ops_len +#define mkldnn_post_ops_t dnnl_post_ops_t +#define mkldnn_prim_kind2str dnnl_prim_kind2str +#define mkldnn_primitive dnnl_primitive +#define mkldnn_primitive_attr dnnl_primitive_attr +#define mkldnn_primitive_attr_clone dnnl_primitive_attr_clone +#define mkldnn_primitive_attr_create dnnl_primitive_attr_create +#define mkldnn_primitive_attr_destroy dnnl_primitive_attr_destroy +#define mkldnn_primitive_attr_get_output_scales \ + dnnl_primitive_attr_get_output_scales +#define mkldnn_primitive_attr_get_post_ops dnnl_primitive_attr_get_post_ops +#define mkldnn_primitive_attr_get_scratchpad_mode \ + dnnl_primitive_attr_get_scratchpad_mode +#define mkldnn_primitive_attr_set_output_scales \ + dnnl_primitive_attr_set_output_scales +#define mkldnn_primitive_attr_set_post_ops dnnl_primitive_attr_set_post_ops +#define mkldnn_primitive_attr_set_rnn_data_qparams \ + dnnl_primitive_attr_set_rnn_data_qparams +#define mkldnn_primitive_attr_set_rnn_weights_qparams \ + dnnl_primitive_attr_set_rnn_weights_qparams +#define mkldnn_primitive_attr_set_scratchpad_mode \ + dnnl_primitive_attr_set_scratchpad_mode +#define mkldnn_primitive_attr_t dnnl_primitive_attr_t +#define mkldnn_primitive_create dnnl_primitive_create +#define mkldnn_primitive_desc dnnl_primitive_desc +#define mkldnn_primitive_desc_clone dnnl_primitive_desc_clone +#define mkldnn_primitive_desc_create dnnl_primitive_desc_create +#define mkldnn_primitive_desc_destroy dnnl_primitive_desc_destroy +#define mkldnn_primitive_desc_get_attr dnnl_primitive_desc_get_attr +#define mkldnn_primitive_desc_iterator dnnl_primitive_desc_iterator +#define mkldnn_primitive_desc_iterator_create \ + dnnl_primitive_desc_iterator_create +#define mkldnn_primitive_desc_iterator_destroy \ + dnnl_primitive_desc_iterator_destroy +#define mkldnn_primitive_desc_iterator_fetch dnnl_primitive_desc_iterator_fetch +#define mkldnn_primitive_desc_iterator_next dnnl_primitive_desc_iterator_next +#define mkldnn_primitive_desc_iterator_t dnnl_primitive_desc_iterator_t +#define mkldnn_primitive_desc_query dnnl_primitive_desc_query +#define mkldnn_primitive_desc_query_md dnnl_primitive_desc_query_md +#define mkldnn_primitive_desc_query_pd dnnl_primitive_desc_query_pd +#define mkldnn_primitive_desc_query_s32 dnnl_primitive_desc_query_s32 +#define mkldnn_primitive_desc_t dnnl_primitive_desc_t +#define mkldnn_primitive_destroy dnnl_primitive_destroy +#define mkldnn_primitive_execute dnnl_primitive_execute +#define mkldnn_primitive_get_primitive_desc dnnl_primitive_get_primitive_desc +#define mkldnn_primitive_kind_t dnnl_primitive_kind_t +#define mkldnn_primitive_t dnnl_primitive_t +#define mkldnn_prop_kind2str dnnl_prop_kind2str +#define mkldnn_prop_kind_t dnnl_prop_kind_t +#define mkldnn_prop_kind_undef dnnl_prop_kind_undef +#define mkldnn_query_batch_normalization_d dnnl_query_batch_normalization_d +#define mkldnn_query_convolution_d dnnl_query_convolution_d +#define mkldnn_query_deconvolution_d dnnl_query_deconvolution_d +#define mkldnn_query_diff_dst_md dnnl_query_diff_dst_md +#define mkldnn_query_diff_src_md dnnl_query_diff_src_md +#define mkldnn_query_diff_weights_md dnnl_query_diff_weights_md +#define mkldnn_query_dst_md dnnl_query_dst_md +#define mkldnn_query_eltwise_d dnnl_query_eltwise_d +#define mkldnn_query_engine dnnl_query_engine +#define mkldnn_query_gemm_d dnnl_query_gemm_d +#define mkldnn_query_impl_info_str dnnl_query_impl_info_str +#define mkldnn_query_inner_product_d dnnl_query_inner_product_d +#define mkldnn_query_layer_normalization_d dnnl_query_layer_normalization_d +#define mkldnn_query_lrn_d dnnl_query_lrn_d +#define mkldnn_query_memory_consumption_s64 dnnl_query_memory_consumption_s64 +#define mkldnn_query_num_of_inputs_s32 dnnl_query_num_of_inputs_s32 +#define mkldnn_query_num_of_outputs_s32 dnnl_query_num_of_outputs_s32 +#define mkldnn_query_op_d dnnl_query_op_d +#define mkldnn_query_pooling_d dnnl_query_pooling_d +#define mkldnn_query_primitive_kind dnnl_query_primitive_kind +#define mkldnn_query_rnn_d dnnl_query_rnn_d +#define mkldnn_query_scratchpad_engine dnnl_query_scratchpad_engine +#define mkldnn_query_scratchpad_md dnnl_query_scratchpad_md +#define mkldnn_query_shuffle_d dnnl_query_shuffle_d +#define mkldnn_query_softmax_d dnnl_query_softmax_d +#define mkldnn_query_some_d dnnl_query_some_d +#define mkldnn_query_some_md dnnl_query_some_md +#define mkldnn_query_src_md dnnl_query_src_md +#define mkldnn_query_t dnnl_query_t +#define mkldnn_query_time_estimate_f64 dnnl_query_time_estimate_f64 +#define mkldnn_query_undef dnnl_query_undef +#define mkldnn_query_weights_md dnnl_query_weights_md +#define mkldnn_query_workspace_md dnnl_query_workspace_md +#define mkldnn_reorder dnnl_reorder +#define mkldnn_reorder_primitive_desc_create dnnl_reorder_primitive_desc_create +#define mkldnn_rnn dnnl_rnn +#define mkldnn_rnn_desc_t dnnl_rnn_desc_t +#define mkldnn_rnn_direction2str dnnl_rnn_direction2str +#define mkldnn_rnn_direction_t dnnl_rnn_direction_t +#define mkldnn_rnn_flags2str dnnl_rnn_flags2str +#define mkldnn_rnn_flags_t dnnl_rnn_flags_t +#define mkldnn_rnn_flags_undef dnnl_rnn_flags_undef +#define mkldnn_rnn_packed_desc_t dnnl_rnn_packed_desc_t +#define mkldnn_rnn_packed_memory_format_t dnnl_rnn_packed_memory_format_t +#define mkldnn_runtime_error dnnl_runtime_error +#define mkldnn_s32 dnnl_s32 +#define mkldnn_s8 dnnl_s8 +#define mkldnn_scratchpad_mode2str dnnl_scratchpad_mode2str +#define mkldnn_scratchpad_mode_library dnnl_scratchpad_mode_library +#define mkldnn_scratchpad_mode_t dnnl_scratchpad_mode_t +#define mkldnn_scratchpad_mode_user dnnl_scratchpad_mode_user +#define mkldnn_set_jit_dump dnnl_set_jit_dump +#define mkldnn_set_verbose dnnl_set_verbose +#define mkldnn_sgemm dnnl_sgemm +#define mkldnn_shuffle dnnl_shuffle +#define mkldnn_shuffle_backward_desc_init dnnl_shuffle_backward_desc_init +#define mkldnn_shuffle_desc_t dnnl_shuffle_desc_t +#define mkldnn_shuffle_forward_desc_init dnnl_shuffle_forward_desc_init +#define mkldnn_softmax dnnl_softmax +#define mkldnn_softmax_backward_desc_init dnnl_softmax_backward_desc_init +#define mkldnn_softmax_desc_t dnnl_softmax_desc_t +#define mkldnn_softmax_forward_desc_init dnnl_softmax_forward_desc_init +#define mkldnn_status2str dnnl_status2str +#define mkldnn_status_t dnnl_status_t +#define mkldnn_stream dnnl_stream +#define mkldnn_stream_create dnnl_stream_create +#define mkldnn_stream_create_ocl dnnl_ocl_interop_stream_create +#define mkldnn_stream_default_flags dnnl_stream_default_flags +#define mkldnn_stream_destroy dnnl_stream_destroy +#define mkldnn_stream_flags_t dnnl_stream_flags_t +#define mkldnn_stream_get_ocl_command_queue \ + dnnl_ocl_interop_stream_get_command_queue +#define mkldnn_stream_in_order dnnl_stream_in_order +#define mkldnn_stream_out_of_order dnnl_stream_out_of_order +#define mkldnn_stream_t dnnl_stream_t +#define mkldnn_stream_wait dnnl_stream_wait +#define mkldnn_success dnnl_success +#define mkldnn_sum dnnl_sum +#define mkldnn_sum_primitive_desc_create dnnl_sum_primitive_desc_create +#define mkldnn_tn dnnl_tn +#define mkldnn_tnc dnnl_tnc +#define mkldnn_types dnnl_types +#define mkldnn_u8 dnnl_u8 +#define mkldnn_undefined_primitive dnnl_undefined_primitive +#define mkldnn_unidirectional dnnl_unidirectional +#define mkldnn_unidirectional_left2right dnnl_unidirectional_left2right +#define mkldnn_unidirectional_right2left dnnl_unidirectional_right2left +#define mkldnn_unimplemented dnnl_unimplemented +#define mkldnn_use_global_stats dnnl_use_global_stats +#define mkldnn_use_scaleshift dnnl_use_scaleshift +#define mkldnn_vanilla_gru dnnl_vanilla_gru +#define mkldnn_vanilla_lstm dnnl_vanilla_lstm +#define mkldnn_vanilla_rnn dnnl_vanilla_rnn +#define mkldnn_vanilla_rnn_backward_desc_init \ + dnnl_vanilla_rnn_backward_desc_init +#define mkldnn_vanilla_rnn_forward_desc_init dnnl_vanilla_rnn_forward_desc_init +#define mkldnn_version dnnl_version +#define mkldnn_version_t dnnl_version_t +#define mkldnn_wino_desc_t dnnl_wino_desc_t +#define mkldnn_wino_memory_format_t dnnl_wino_memory_format_t +#define mkldnn_wino_undef dnnl_wino_undef +#define mkldnn_wino_wei_OBaaIBOIio dnnl_wino_wei_OBaaIBOIio +#define mkldnn_wino_wei_aaOBiOo dnnl_wino_wei_aaOBiOo +#define mkldnn_wino_wei_aaOIoi dnnl_wino_wei_aaOIoi +#define mkldnn_wino_wei_aaOio dnnl_wino_wei_aaOio +#define mkldnn_wio dnnl_wio +#define mkldnn_x dnnl_x + +#endif /* MKLDNN_DNNL_MANGLING_H */ diff --git a/include/mkldnn_types.h b/include/mkldnn_types.h new file mode 100644 index 00000000000..fc1538ce6a2 --- /dev/null +++ b/include/mkldnn_types.h @@ -0,0 +1,26 @@ +/******************************************************************************* +* Copyright 2019-2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +// Header file ensures the backwards compatibility with previous namings. + +#ifndef MKLDNN_TYPES_H +#define MKLDNN_TYPES_H + +#include "mkldnn_dnnl_mangling.h" + +#include "dnnl_types.h" + +#endif /* MKLDNN_TYPES_H */ diff --git a/include/mkldnn_version.h b/include/mkldnn_version.h new file mode 100644 index 00000000000..8b427b2fcc2 --- /dev/null +++ b/include/mkldnn_version.h @@ -0,0 +1,26 @@ +/******************************************************************************* +* Copyright 2019-2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +// Header file ensures the backwards compatibility with previous namings. + +#ifndef MKLDNN_VERSION_H +#define MKLDNN_VERSION_H + +#include "mkldnn_dnnl_mangling.h" + +#include "dnnl_version.h" + +#endif /* MKLDNN_VERSION_H */ diff --git a/include/oneapi/dnnl/dnnl.h b/include/oneapi/dnnl/dnnl.h index 782637f885a..1285e4cf8a2 100644 --- a/include/oneapi/dnnl/dnnl.h +++ b/include/oneapi/dnnl/dnnl.h @@ -23,6 +23,7 @@ #include "oneapi/dnnl/dnnl_config.h" #include "oneapi/dnnl/dnnl_types.h" #include "oneapi/dnnl/dnnl_version.h" +#include #ifdef __cplusplus extern "C" { @@ -532,6 +533,15 @@ dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points( dnnl_primitive_attr_t attr, int arg, dnnl_dim_t count, int mask, const int32_t *zero_points); +dnnl_status_t DNNL_API dnnl_primitive_attr_set_output_compensations( + dnnl_primitive_attr_t attr, int count, int mask); + +dnnl_status_t DNNL_API dnnl_primitive_attr_set_input_zero_points( + dnnl_primitive_attr_t attr, int count, int mask); + +dnnl_status_t DNNL_API dnnl_primitive_attr_set_weights_zero_points( + dnnl_primitive_attr_t attr, int count, int mask); + /// Returns primitive attributes post-ops. /// /// @warning @@ -765,6 +775,14 @@ dnnl_status_t DNNL_API dnnl_post_ops_get_params_eltwise( const_dnnl_post_ops_t post_ops, int index, float *scale, dnnl_alg_kind_t *alg_kind, float *alpha, float *beta); +/** Appends DW convolution post operation to the @p post_ops with given parameters + * @p weights and @p bias. + * + * The kind of this post operation is #dnnl_convolution. + */ +dnnl_status_t DNNL_API dnnl_post_ops_append_dw_conv( + dnnl_post_ops_t post_ops, int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w, dnnl_data_type_t in_dt); + /// Appends a depthwise post-op convolution with stride 1. /// /// This post-op can only be fused with a 2D 1x1 convolution (convolution with @@ -956,6 +974,18 @@ dnnl_status_t DNNL_API dnnl_post_ops_append_prelu( dnnl_status_t DNNL_API dnnl_post_ops_get_params_prelu( const_dnnl_post_ops_t post_ops, int index, int *mask); +dnnl_status_t DNNL_API dnnl_post_ops_append_depthwise( + dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg, size_t offset_size, const size_t* offset); + +dnnl_status_t DNNL_API dnnl_post_ops_append_quantization( + dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg, + size_t per_channel_size, const bool* per_channel, + size_t all_default_size, const bool* all_default, + size_t offset_size, const size_t* offset); + +dnnl_status_t DNNL_API dnnl_post_ops_append_binarization( + dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg, const float* weights_data, const float* output_mask); + /// @} dnnl_api_attributes /// @} dnnl_api_primitives @@ -1244,6 +1274,9 @@ dnnl_status_t DNNL_API dnnl_memory_get_data_handle( dnnl_status_t DNNL_API dnnl_memory_set_data_handle( dnnl_memory_t memory, void *handle); +dnnl_status_t DNNL_API dnnl_memory_set_data_handle_no_pads_proc( + dnnl_memory_t memory, void *handle); + /// Sets the underlying memory buffer. /// /// @param memory Memory object. @@ -1255,6 +1288,9 @@ dnnl_status_t DNNL_API dnnl_memory_set_data_handle( dnnl_status_t DNNL_API dnnl_memory_set_data_handle_v2( dnnl_memory_t memory, void *handle, dnnl_stream_t stream); +dnnl_status_t DNNL_API dnnl_memory_set_data_handle_v2_no_pads_proc( + dnnl_memory_t memory, void *handle, dnnl_stream_t stream); + /// Destroys a memory object. /// /// @param memory Memory object to destroy. diff --git a/include/oneapi/dnnl/dnnl.hpp b/include/oneapi/dnnl/dnnl.hpp index cc052547227..77988f7116e 100644 --- a/include/oneapi/dnnl/dnnl.hpp +++ b/include/oneapi/dnnl/dnnl.hpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include "oneapi/dnnl/dnnl.h" @@ -40,7 +41,7 @@ // gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS, // Microsoft C++ Compiler does not provide an option to disable exceptions #ifndef DNNL_ENABLE_EXCEPTIONS -#if __cpp_exceptions || __EXCEPTIONS \ +#if defined(__cpp_exceptions) || defined(__EXCEPTIONS) \ || (defined(_MSC_VER) && !defined(__clang__)) #define DNNL_ENABLE_EXCEPTIONS 1 #else @@ -313,6 +314,9 @@ struct primitive : public handle { reduction = dnnl_reduction, /// A PReLU primitive. prelu = dnnl_prelu, + depthwise = dnnl_depthwise, + quantization = dnnl_quantization, + binarization = dnnl_binarization, }; using handle::handle; @@ -548,6 +552,12 @@ enum class algorithm { eltwise_round = dnnl_eltwise_round, /// Elementwise: hardswish eltwise_hardswish = dnnl_eltwise_hardswish, + /// Elementwise: hsigmoid + eltwise_hsigmoid = dnnl_eltwise_hsigmoid, + /// Elementwise: round_half_to_even + eltwise_round_half_to_even = dnnl_eltwise_round_half_to_even, + /// Elementwise: round_half_away_from_zero + eltwise_round_half_away_from_zero = dnnl_eltwise_round_half_away_from_zero, /// Elementwise: rectified linar unit (ReLU) (dst for backward) eltwise_relu_use_dst_for_bwd = dnnl_eltwise_relu_use_dst_for_bwd, /// Elementwise: hyperbolic tangent non-linearity (tanh) (dst for backward) @@ -611,6 +621,8 @@ enum class algorithm { binary_eq = dnnl_binary_eq, /// Binary not equal binary_ne = dnnl_binary_ne, + /// Binary prelu + binary_prelu = dnnl_binary_prelu, /// Nearest Neighbor resampling method resampling_nearest = dnnl_resampling_nearest, /// Linear (Bilinear, Trilinear) resampling method @@ -633,6 +645,13 @@ enum class algorithm { reduction_norm_lp_power_p_max = dnnl_reduction_norm_lp_power_p_max, /// Reduction using norm_lp_power_p_sum operation reduction_norm_lp_power_p_sum = dnnl_reduction_norm_lp_power_p_sum, + + depthwise_scale_shift = dnnl_depthwise_scale_shift, + depthwise_prelu = dnnl_depthwise_prelu, + + quantization_quantize_dequantize = dnnl_quantization_quantize_dequantize, + quantization_quantize = dnnl_quantization_quantize, + binarization_depthwise = dnnl_binarization_depthwise, }; /// Converts algorithm kind enum value from C++ API to C API type. @@ -1189,6 +1208,8 @@ struct memory : public handle { s8 = dnnl_s8, /// 8-bit unsigned integer. u8 = dnnl_u8, + /// 1-bit integer + bin = dnnl_bin }; /// Returns size of data type in bytes. @@ -1570,6 +1591,8 @@ struct memory : public handle { aBCd4b4c = dnnl_aBCd4b4c, ABcd8a16b2a = dnnl_ABcd8a16b2a, ABcd8a8b = dnnl_ABcd8a8b, + ABcd8a32b = dnnl_ABcd8a32b, + ABcd16a32b = dnnl_ABcd16a32b, ABcd8a4b = dnnl_ABcd8a4b, ABcd8a2b = dnnl_ABcd8a2b, /// 4D tensor blocked by 2nd dimension with block size 8 @@ -1679,6 +1702,8 @@ struct memory : public handle { BAcde16b16a = dnnl_BAcde16b16a, BAcde16a16b = dnnl_BAcde16a16b, aBdec32b = dnnl_aBdec32b, + Abcdef4a = dnnl_Abcdef4a, + Abcdef8a = dnnl_Abcdef8a, Abcdef16a = dnnl_Abcdef16a, Abcdef32a = dnnl_Abcdef32a, Acdb32a = dnnl_Acdb32a, @@ -1864,6 +1889,8 @@ struct memory : public handle { IOdhw16i16o = dnnl_IOdhw16i16o, gIOhw16i16o = dnnl_gIOhw16i16o, gOhwi32o = dnnl_gOhwi32o, + Goidhw4g = dnnl_Goidhw4g, + Goidhw8g = dnnl_Goidhw8g, Goidhw16g = dnnl_Goidhw16g, IOw16o16i = dnnl_IOw16o16i, OIw16i16o = dnnl_OIw16i16o, @@ -1922,6 +1949,8 @@ struct memory : public handle { OIhw8i8o = dnnl_OIhw8i8o, OIhw8o16i2o = dnnl_OIhw8o16i2o, OIhw8o8i = dnnl_OIhw8o8i, + OIhw8o32i = dnnl_OIhw8o32i, + OIhw16o32i = dnnl_OIhw16o32i, OIhw8o4i = dnnl_OIhw8o4i, OIhw2i8o4i = dnnl_OIhw2i8o4i, IOdhw16o16i = dnnl_IOdhw16o16i, @@ -2740,6 +2769,12 @@ struct memory : public handle { "could not set native handle of a memory object"); } + void set_data_handle_no_pads_proc(void *handle) const { + error::wrap_c_api( + dnnl_memory_set_data_handle_v2_no_pads_proc(get(), handle, nullptr), + "could not set native handle of a memory object"); + } + /// Maps a memory object and returns a host-side pointer to a memory /// buffer with a copy of its contents. /// @@ -3202,6 +3237,12 @@ struct post_ops : public handle { "could not append a binary post-op"); } + void append_dw_conv(int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w, dnnl_data_type_t in_dt) { + error::wrap_c_api(dnnl_post_ops_append_dw_conv(get(), + in_h, in_w, ker_h, ker_w, str_h, str_w, in_dt), + "could not append dw conv"); + } + /// Returns the parameters of a binary post-op. /// /// @param index Index of the binary post-op. @@ -3282,6 +3323,23 @@ struct post_ops : public handle { error::wrap_c_api(dnnl_post_ops_get_params_prelu(get(), index, &mask), "could not get parameters of a binary post-op"); } + + void append_depthwise(algorithm alg, const std::array& offset) { + error::wrap_c_api(dnnl_post_ops_append_depthwise(get(), convert_to_c(alg), offset.size(), offset.data()), + "could not append depthwise"); + } + + void append_quantization(algorithm alg, const std::array& per_channel, const std::array& all_default, + const std::array& offset) { + error::wrap_c_api(dnnl_post_ops_append_quantization(get(), convert_to_c(alg), per_channel.size(), per_channel.data(), + all_default.size(), all_default.data(), offset.size(), offset.data()), + "could not append quantization"); + } + + void append_binarization(algorithm alg, const float* weights_data, const float* output_mask) { + error::wrap_c_api(dnnl_post_ops_append_binarization(get(), convert_to_c(alg), weights_data, output_mask), + "could not append binarization"); + } }; /// @cond DO_NOT_DOCUMENT_THIS @@ -3523,6 +3581,24 @@ struct primitive_attr : public handle { "could not set zero points primitive attribute"); } + void set_output_compensations(dnnl_dim_t count, int mask) + { + error::wrap_c_api(dnnl_primitive_attr_set_output_compensations(get(), count, mask), + "could not set int output compensations"); + } + + void set_input_zero_points(dnnl_dim_t count, int mask) + { + error::wrap_c_api(dnnl_primitive_attr_set_input_zero_points(get(), count, mask), + "could not set int input zero_points"); + } + + void set_weights_zero_points(dnnl_dim_t count, int mask) + { + error::wrap_c_api(dnnl_primitive_attr_set_weights_zero_points(get(), count, mask), + "could not set int weights zero_points"); + } + /// Returns post-ops previously set via set_post_ops(). /// /// @returns Post-ops. diff --git a/include/oneapi/dnnl/dnnl_types.h b/include/oneapi/dnnl/dnnl_types.h index ae96f0f4020..b314f07070a 100644 --- a/include/oneapi/dnnl/dnnl_types.h +++ b/include/oneapi/dnnl/dnnl_types.h @@ -74,6 +74,8 @@ typedef enum { dnnl_s8 = 5, /// 8-bit unsigned integer. dnnl_u8 = 6, + /// 1-bit integer. + dnnl_bin = 7, } dnnl_data_type_t; /// Memory format kind @@ -277,6 +279,8 @@ typedef enum { dnnl_ABcd8a16b2a, dnnl_ABcd2b8a4b, dnnl_ABcd8a8b, + dnnl_ABcd8a32b, + dnnl_ABcd16a32b, dnnl_ABcd8a4b, /// 4D tensor blocked by 2nd dimension with block size 8 dnnl_aBcd8b, @@ -389,6 +393,8 @@ typedef enum { dnnl_aCBdef16c16b, dnnl_aBdefc4b, dnnl_aBdefc8b, + dnnl_Abcdef4a, + dnnl_Abcdef8a, dnnl_Abcdef16a, dnnl_Abcdef32a, dnnl_aBedc16b, @@ -962,6 +968,8 @@ typedef enum { dnnl_OIhw2i8o4i = dnnl_ABcd2b8a4b, dnnl_IOhw8o16i2o = dnnl_BAcd8a16b2a, dnnl_OIhw8o8i = dnnl_ABcd8a8b, + dnnl_OIhw8o32i = dnnl_ABcd8a32b, + dnnl_OIhw16o32i = dnnl_ABcd16a32b, dnnl_OIhw8o4i = dnnl_ABcd8a4b, dnnl_Owhi16o = dnnl_Adcb16a, @@ -1106,6 +1114,8 @@ typedef enum { dnnl_gIOdhw8o16i2o = dnnl_aCBdef8b16c2b, dnnl_gOIdhw8o8i = dnnl_aBCdef8b8c, dnnl_gOIdhw8o4i = dnnl_aBCdef8b4c, + dnnl_Goidhw4g = dnnl_Abcdef4a, + dnnl_Goidhw8g = dnnl_Abcdef8a, dnnl_Goidhw16g = dnnl_Abcdef16a, dnnl_Goidhw32g = dnnl_Abcdef32a, dnnl_gOIdhw2i4o2i = dnnl_aBCdef2c4b2c, @@ -1356,6 +1366,12 @@ typedef enum { dnnl_deconvolution, /// An element-wise primitive. dnnl_eltwise, + /// An depthwise-wise primitive. + dnnl_depthwise, + /// A quantization primitive. + dnnl_quantization, + /** A binatization primitive. */ + dnnl_binarization, /// A softmax primitive. dnnl_softmax, /// A pooling primitive. @@ -1454,6 +1470,12 @@ typedef enum { dnnl_eltwise_mish = 0x60, /// Eltwise: hardswish dnnl_eltwise_hardswish = 0x70, + /// Eltwise: hsigmoid + dnnl_eltwise_hsigmoid = 0x71, + /// Eltwise: round_half_to_even + dnnl_eltwise_round_half_to_even = 0x80, + /// Eltwise: round_half_away_from_zero + dnnl_eltwise_round_half_away_from_zero = 0x81, /// Eltwise: ReLU (dst for backward) dnnl_eltwise_relu_use_dst_for_bwd = 0x100, /// Eltwise: hyperbolic tangent non-linearity (tanh) (dst for backward) @@ -1518,6 +1540,8 @@ typedef enum { dnnl_binary_eq = 0x1fffa, /// Binary not equal dnnl_binary_ne = 0x1fffb, + /// Binary prelu + dnnl_binary_prelu = 0x1fffc, /// Nearest Neighbor Resampling Method dnnl_resampling_nearest = 0x2fff0, /// Linear Resampling Method @@ -1540,6 +1564,13 @@ typedef enum { dnnl_reduction_norm_lp_power_p_max, /// Reduction using lp norm without final pth-root dnnl_reduction_norm_lp_power_p_sum, + + dnnl_depthwise_scale_shift = 0x3fff0, + dnnl_depthwise_prelu = 0x3fff1, + + dnnl_quantization_quantize_dequantize = 0x4fff0, + dnnl_quantization_quantize = 0x4fff1, + dnnl_binarization_depthwise = 0x4fff2, } dnnl_alg_kind_t; /// Flags for normalization primitives. @@ -2965,6 +2996,9 @@ typedef const struct dnnl_stream *const_dnnl_stream_t; /// TBB runtime (CPU only) #define DNNL_RUNTIME_TBB 4u +/// TBB runtime with auto partitioning(CPU only) +#define DNNL_RUNTIME_TBB_AUTO 5u + /// Threadpool runtime (CPU only) #define DNNL_RUNTIME_THREADPOOL 8u @@ -3048,6 +3082,8 @@ typedef enum { /// Intel AMX with 8-bit integer and bfloat16 support dnnl_cpu_isa_avx512_core_amx = 0x3e7, + dnnl_cpu_isa_avx512_vpopcnt = 0x6e7, + /// Intel AVX2 and Intel Deep Learning Boost (Intel DL Boost) support dnnl_cpu_isa_avx2_vnni = 0x407, diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7a8fbe1caaf..71f0a98082c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -169,8 +169,73 @@ install(EXPORT ${LIB_EXPORT_NAME} NAMESPACE ${LIB_NAMESPACE} DESTINATION ${LIB_CONFIG_INSTALL_DIR}) + +# Do not create compat symlinks to libraries and CMake config file when a custom +# library name is specified +if(DNNL_LIBRARY_NAME MATCHES "(dnnl|dnnld)") + # Intel MKL-DNN compat cmake files + install(CODE "execute_process(COMMAND ${CMAKE_COMMAND} + -DDIR=\$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}/cmake + -P ${PROJECT_SOURCE_DIR}/cmake/gen_mkldnn_compat_cmakes.cmake)") + + # Intel MKL-DNN compat libraries + if (WIN32) + if (NOT MINGW) + set(prefix "") + set(ext ".lib") + else() + set(prefix "lib") + if (DNNL_LIBRARY_TYPE STREQUAL "SHARED") + set(ext ".dll.a") + else() + set(ext ".a") + endif() + endif() + add_custom_target(compat_libs ALL + ${CMAKE_COMMAND} -E copy + $/$ + $/${prefix}mkldnn${ext} + # Workaround for MSB8065 warning. + COMMAND ${CMAKE_COMMAND} -E touch "CMakeFiles/compat_libs" + DEPENDS ${DNNL_LIBRARY_NAME}) + install(FILES $/${prefix}mkldnn${ext} + DESTINATION ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}) + else() + if(DNNL_LIBRARY_TYPE STREQUAL "SHARED") + set_ternary(ext APPLE ".dylib" ".so") + set(vers ".${DNNL_VERSION_MAJOR};.${DNNL_VERSION_MAJOR}.${DNNL_VERSION_MINOR}") + else() + set(ext ".a") + set(vers "") + endif() + foreach(ver "" ${vers}) + set_ternary(ext_and_ver APPLE "${ver}${ext}" "${ext}${ver}") + get_property(lib_location TARGET ${DNNL_LIBRARY_NAME} PROPERTY LIBRARY_OUTPUT_DIRECTORY) + if(lib_location) + set(compat_link "${lib_location}/libmkldnn${ext_and_ver}") + else() + set(compat_link "${CMAKE_CURRENT_BINARY_DIR}/libmkldnn${ext_and_ver}") + endif() + if(CMAKE_HOST_SYSTEM_NAME STREQUAL Windows) + add_custom_command(OUTPUT ${compat_link} + COMMAND ${CMAKE_COMMAND} -E copy libdnnl${ext_and_ver} ${compat_link} + DEPENDS ${DNNL_LIBRARY_NAME}) + else() + add_custom_command(OUTPUT ${compat_link} + # to make the next command work fine + COMMAND ${CMAKE_COMMAND} -E remove -f ${compat_link} + COMMAND ${CMAKE_COMMAND} -E create_symlink libdnnl${ext_and_ver} ${compat_link} + DEPENDS ${DNNL_LIBRARY_NAME}) + endif() + add_custom_target(compat_libs${ver} ALL DEPENDS ${compat_link}) + install(FILES ${compat_link} DESTINATION ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}) + endforeach() + endif() +endif() + # Install custom find modules for transitive dependencies -if(DNNL_CPU_THREADING_RUNTIME STREQUAL "TBB") +set(LIB_CONFIG_INSTALL_DIR_COMPAT "${CMAKE_INSTALL_LIBDIR}/cmake/mkldnn") +if("${DNNL_CPU_THREADING_RUNTIME}" MATCHES "^(TBB|TBB_AUTO)$") if(WIN32) install(FILES "../cmake/win/TBBConfig.cmake" RENAME "FindTBB.cmake" DESTINATION ${LIB_CONFIG_INSTALL_DIR}) @@ -181,12 +246,16 @@ if(DNNL_CPU_THREADING_RUNTIME STREQUAL "TBB") install(FILES "../cmake/lnx/TBBConfig.cmake" RENAME "FindTBB.cmake" DESTINATION ${LIB_CONFIG_INSTALL_DIR}) endif() + install(FILES "${CMAKE_INSTALL_PREFIX}/${LIB_CONFIG_INSTALL_DIR}/FindTBB.cmake" + DESTINATION ${LIB_CONFIG_INSTALL_DIR_COMPAT}) endif() if(DNNL_GPU_RUNTIME STREQUAL "OCL") install(FILES "../cmake/FindOpenCL.cmake" DESTINATION ${LIB_CONFIG_INSTALL_DIR}) + install(FILES "../cmake/FindOpenCL.cmake" + DESTINATION ${LIB_CONFIG_INSTALL_DIR_COMPAT}) endif() if(DNNL_WITH_SYCL) diff --git a/src/common/binary.cpp b/src/common/binary.cpp index a593c19069b..ac7e3f394a1 100644 --- a/src/common/binary.cpp +++ b/src/common/binary.cpp @@ -34,7 +34,7 @@ status_t dnnl_binary_desc_init(binary_desc_t *binary_desc, alg_kind_t alg_kind, bool args_ok = true && !any_null(binary_desc, src0_md, src1_md, dst_md) && one_of(alg_kind, binary_add, binary_mul, binary_max, binary_min, binary_div, binary_sub, binary_ge, binary_gt, binary_le, - binary_lt, binary_eq, binary_ne); + binary_lt, binary_eq, binary_ne, binary_prelu); if (!args_ok) return invalid_arguments; auto bod = binary_desc_t(); diff --git a/src/common/c_types_map.hpp b/src/common/c_types_map.hpp index 6e2740f9a60..d71386744ea 100644 --- a/src/common/c_types_map.hpp +++ b/src/common/c_types_map.hpp @@ -91,6 +91,9 @@ const alg_kind_t eltwise_pow = dnnl_eltwise_pow; const alg_kind_t eltwise_gelu_tanh = dnnl_eltwise_gelu_tanh; const alg_kind_t eltwise_gelu_erf = dnnl_eltwise_gelu_erf; const alg_kind_t eltwise_hardswish = dnnl_eltwise_hardswish; +const alg_kind_t eltwise_hsigmoid = dnnl_eltwise_hsigmoid; +const alg_kind_t eltwise_round_half_to_even = dnnl_eltwise_round_half_to_even; +const alg_kind_t eltwise_round_half_away_from_zero = dnnl_eltwise_round_half_away_from_zero; const alg_kind_t eltwise_relu_use_dst_for_bwd = dnnl_eltwise_relu_use_dst_for_bwd; const alg_kind_t eltwise_tanh_use_dst_for_bwd @@ -126,6 +129,7 @@ const alg_kind_t binary_le = dnnl_binary_le; const alg_kind_t binary_lt = dnnl_binary_lt; const alg_kind_t binary_eq = dnnl_binary_eq; const alg_kind_t binary_ne = dnnl_binary_ne; +const alg_kind_t binary_prelu = dnnl_binary_prelu; const alg_kind_t resampling_nearest = dnnl_resampling_nearest; const alg_kind_t resampling_linear = dnnl_resampling_linear; const alg_kind_t reduction_max = dnnl_reduction_max; @@ -139,6 +143,11 @@ const alg_kind_t reduction_norm_lp_power_p_max = dnnl_reduction_norm_lp_power_p_max; const alg_kind_t reduction_norm_lp_power_p_sum = dnnl_reduction_norm_lp_power_p_sum; +const alg_kind_t depthwise_scale_shift = dnnl_depthwise_scale_shift; +const alg_kind_t depthwise_prelu = dnnl_depthwise_prelu; +const alg_kind_t quantization_quantize_dequantize = dnnl_quantization_quantize_dequantize; +const alg_kind_t quantization_quantize = dnnl_quantization_quantize; +const alg_kind_t binarization_depthwise = dnnl_binarization_depthwise; } // namespace alg_kind using data_type_t = dnnl_data_type_t; @@ -150,6 +159,7 @@ const data_type_t f32 = dnnl_f32; const data_type_t s32 = dnnl_s32; const data_type_t s8 = dnnl_s8; const data_type_t u8 = dnnl_u8; +const data_type_t bin = dnnl_bin; } // namespace data_type using fpmath_mode_t = dnnl_fpmath_mode_t; @@ -349,6 +359,8 @@ const format_tag_t aBCd4b4c = dnnl_aBCd4b4c; const format_tag_t ABcd8a16b2a = dnnl_ABcd8a16b2a; const format_tag_t BAcd8a16b2a = dnnl_BAcd8a16b2a; const format_tag_t ABcd8a8b = dnnl_ABcd8a8b; +const format_tag_t ABcd8a32b = dnnl_ABcd8a32b; +const format_tag_t ABcd16a32b = dnnl_ABcd16a32b; const format_tag_t ABcd8a4b = dnnl_ABcd8a4b; const format_tag_t ABcd8a2b = dnnl_ABcd8a2b; const format_tag_t aBcd8b = dnnl_aBcd8b; @@ -509,6 +521,8 @@ const format_tag_t ABcd40a32b = dnnl_ABcd40a32b; const format_tag_t ABcde40a32b = dnnl_ABcde40a32b; const format_tag_t BAcde16b16a = dnnl_BAcde16b16a; const format_tag_t aBdec32b = dnnl_aBdec32b; +const format_tag_t Abcdef4a = dnnl_Abcdef4a; +const format_tag_t Abcdef8a = dnnl_Abcdef8a; const format_tag_t Abcdef16a = dnnl_Abcdef16a; const format_tag_t Abcdef32a = dnnl_Abcdef32a; const format_tag_t Acdb32a = dnnl_Acdb32a; @@ -758,6 +772,8 @@ const format_tag_t IOhw16i16o = dnnl_IOhw16i16o; const format_tag_t Ohwi32o = dnnl_Ohwi32o; const format_tag_t gIOhw16i16o = dnnl_gIOhw16i16o; const format_tag_t gOhwi32o = dnnl_gOhwi32o; +const format_tag_t Goidhw4g = dnnl_Goidhw4g; +const format_tag_t Goidhw8g = dnnl_Goidhw8g; const format_tag_t Goidhw16g = dnnl_Goidhw16g; const format_tag_t IOw16o16i = dnnl_IOw16o16i; const format_tag_t IOw16i16o = dnnl_IOw16i16o; @@ -831,6 +847,8 @@ const format_tag_t OIhw8i8o = dnnl_OIhw8i8o; const format_tag_t OIhw8o16i2o = dnnl_OIhw8o16i2o; const format_tag_t IOhw8o16i2o = dnnl_IOhw8o16i2o; const format_tag_t OIhw8o8i = dnnl_OIhw8o8i; +const format_tag_t OIhw8o32i = dnnl_OIhw8o32i; +const format_tag_t OIhw16o32i = dnnl_OIhw16o32i; const format_tag_t OIhw8o4i = dnnl_OIhw8o4i; const format_tag_t Owhi16o = dnnl_Owhi16o; const format_tag_t Odhwi16o = dnnl_Odhwi16o; @@ -1191,6 +1209,7 @@ enum runtime_kind_t { dnnl_runtime_seq, dnnl_runtime_omp, dnnl_runtime_tbb, + dnnl_runtime_tbb_auto, dnnl_runtime_threadpool, dnnl_runtime_ocl, dnnl_runtime_sycl, @@ -1201,6 +1220,7 @@ const runtime_kind_t none = dnnl_runtime_none; const runtime_kind_t seq = dnnl_runtime_seq; const runtime_kind_t omp = dnnl_runtime_omp; const runtime_kind_t tbb = dnnl_runtime_tbb; +const runtime_kind_t tbb_auto = dnnl_runtime_tbb_auto; const runtime_kind_t threadpool = dnnl_runtime_threadpool; const runtime_kind_t ocl = dnnl_runtime_ocl; const runtime_kind_t sycl = dnnl_runtime_sycl; @@ -1231,6 +1251,9 @@ const primitive_kind_t logsoftmax = dnnl_logsoftmax; const primitive_kind_t matmul = dnnl_matmul; const primitive_kind_t resampling = dnnl_resampling; const primitive_kind_t reduction = dnnl_reduction; +const primitive_kind_t depthwise = dnnl_depthwise; +const primitive_kind_t quantization = dnnl_quantization; +const primitive_kind_t binarization = dnnl_binarization; // Internal only primitive kinds. const primitive_kind_t internal_only_start = (primitive_kind_t)(1 << 12); diff --git a/src/common/convolution.cpp b/src/common/convolution.cpp index 7d667fe65c3..f60d2cf8628 100644 --- a/src/common/convolution.cpp +++ b/src/common/convolution.cpp @@ -118,7 +118,10 @@ status_t conv_desc_init(convolution_desc_t *conv_desc, prop_kind_t prop_kind, int ker_range = 1 + (ker - 1) * (dil + 1); if (str < 1) return invalid_arguments; - consistency = consistency && dil >= 0 && pad_l >= 0 && pad_r + str > 0 + consistency = consistency + && dil >= 0 + && pad_l >= 0 + // && pad_r + str > 0 // TODO: [dmitrygo] Commented as WA to support dw conv fusing && (src - ker_range + pad_l + pad_r) / str + 1 == dst; } if (!consistency) return invalid_arguments; diff --git a/src/common/convolution_pd.hpp b/src/common/convolution_pd.hpp index a2b1880d27c..caabc1241a5 100644 --- a/src/common/convolution_pd.hpp +++ b/src/common/convolution_pd.hpp @@ -289,7 +289,7 @@ struct convolution_fwd_pd_t : public convolution_pd_t { int n_inputs() const override { return 2 + with_bias() + attr_post_op_dw_inputs() + n_binary_po_inputs() - + n_prelu_po_inputs(); + + n_prelu_po_inputs() + n_depthwise_po_inputs() + n_quantization_po_inputs(); } int n_outputs() const override { return 1; } @@ -319,8 +319,7 @@ struct convolution_fwd_pd_t : public convolution_pd_t { const auto &po = attr_.post_ops_; int conv = po.find(primitive_kind::convolution); if (conv == -1) return 0; - return po.entry_[conv].depthwise_conv.bias_dt == data_type::undef ? 1 - : 2; + return 2; } }; @@ -359,7 +358,9 @@ struct convolution_bwd_data_pd_t : public convolution_pd_t { return &glob_zero_md; } - int n_inputs() const override { return 2 + with_bias(); } + int n_inputs() const override { + return 2 + with_bias() + n_depthwise_po_inputs() + n_quantization_po_inputs(); + } int n_outputs() const override { return 1; } virtual bool support_bias() const { return false; } diff --git a/src/common/deconvolution_pd.hpp b/src/common/deconvolution_pd.hpp index 34e041d943e..d719abac14f 100644 --- a/src/common/deconvolution_pd.hpp +++ b/src/common/deconvolution_pd.hpp @@ -203,7 +203,7 @@ struct deconvolution_fwd_pd_t : public deconvolution_pd_t { } int n_inputs() const override { - return 2 + with_bias() + n_binary_po_inputs(); + return 2 + with_bias() + n_binary_po_inputs() + n_depthwise_po_inputs() + n_quantization_po_inputs(); } int n_outputs() const override { return 1; } diff --git a/src/common/dnnl_debug.cpp b/src/common/dnnl_debug.cpp index 56eb2f5d88e..e0fe11461da 100644 --- a/src/common/dnnl_debug.cpp +++ b/src/common/dnnl_debug.cpp @@ -40,6 +40,7 @@ const char *dnnl_runtime2str(unsigned runtime) { case DNNL_RUNTIME_SEQ: return "sequential"; case DNNL_RUNTIME_OMP: return "OpenMP"; case DNNL_RUNTIME_TBB: return "TBB"; + case DNNL_RUNTIME_TBB_AUTO: return "TBB_AUTO"; case DNNL_RUNTIME_OCL: return "OpenCL"; case DNNL_RUNTIME_THREADPOOL: return "threadpool"; #ifdef DNNL_WITH_SYCL diff --git a/src/common/dnnl_debug_autogenerated.cpp b/src/common/dnnl_debug_autogenerated.cpp index 17944a9a044..af64baa0c5e 100644 --- a/src/common/dnnl_debug_autogenerated.cpp +++ b/src/common/dnnl_debug_autogenerated.cpp @@ -43,6 +43,7 @@ const char *dnnl_dt2str(dnnl_data_type_t v) { if (v == dnnl_s32) return "s32"; if (v == dnnl_s8) return "s8"; if (v == dnnl_u8) return "u8"; + if (v == dnnl_bin) return "bin"; assert(!"unknown dt"); return "unknown dt"; } @@ -150,6 +151,8 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) { if (v == dnnl_ABcd8a16b2a) return "ABcd8a16b2a"; if (v == dnnl_ABcd2b8a4b) return "ABcd2b8a4b"; if (v == dnnl_ABcd8a8b) return "ABcd8a8b"; + if (v == dnnl_ABcd8a32b) return "ABcd8a32b"; + if (v == dnnl_ABcd16a32b) return "ABcd16a32b"; if (v == dnnl_ABcd8a4b) return "ABcd8a4b"; if (v == dnnl_aBcd8b) return "aBcd8b"; if (v == dnnl_aBCd4c8b2c) return "aBCd4c8b2c"; @@ -248,6 +251,8 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) { if (v == dnnl_aCBdef16c16b) return "aCBdef16c16b"; if (v == dnnl_aBdefc4b) return "aBdefc4b"; if (v == dnnl_aBdefc8b) return "aBdefc8b"; + if (v == dnnl_Abcdef4a) return "Abcdef4a"; + if (v == dnnl_Abcdef8a) return "Abcdef8a"; if (v == dnnl_Abcdef16a) return "Abcdef16a"; if (v == dnnl_Abcdef32a) return "Abcdef32a"; if (v == dnnl_aBedc16b) return "aBedc16b"; @@ -707,6 +712,8 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) { if (v == dnnl_OIhw8o16i2o) return "OIhw8o16i2o"; if (v == dnnl_OIhw2i8o4i) return "OIhw2i8o4i"; if (v == dnnl_IOhw8o16i2o) return "IOhw8o16i2o"; + if (v == dnnl_OIhw8o32i) return "OIhw8o23i"; + if (v == dnnl_OIhw16o32i) return "OIhw16o23i"; if (v == dnnl_OIhw8o8i) return "OIhw8o8i"; if (v == dnnl_OIhw8o4i) return "OIhw8o4i"; if (v == dnnl_Owhi16o) return "Owhi16o"; @@ -842,6 +849,8 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) { if (v == dnnl_gIOdhw8o16i2o) return "gIOdhw8o16i2o"; if (v == dnnl_gOIdhw8o8i) return "gOIdhw8o8i"; if (v == dnnl_gOIdhw8o4i) return "gOIdhw8o4i"; + if (v == dnnl_Goidhw4g) return "Goidhw4g"; + if (v == dnnl_Goidhw8g) return "Goidhw8g"; if (v == dnnl_Goidhw16g) return "Goidhw16g"; if (v == dnnl_Goidhw32g) return "Goidhw32g"; if (v == dnnl_gOIdhw2i4o2i) return "gOIdhw2i4o2i"; @@ -1077,6 +1086,8 @@ const char *dnnl_prim_kind2str(dnnl_primitive_kind_t v) { if (v == dnnl_pooling_v2) return "pooling_v2"; if (v == dnnl_reduction) return "reduction"; if (v == dnnl_prelu) return "prelu"; + if (v == dnnl_depthwise) return "depthwise"; + if (v == dnnl_quantization) return "quantization"; if (v == dnnl_primitive_kind_max) return "primitive_kind_max"; assert(!"unknown prim_kind"); return "unknown prim_kind"; @@ -1112,6 +1123,9 @@ const char *dnnl_alg_kind2str(dnnl_alg_kind_t v) { if (v == dnnl_eltwise_logsigmoid) return "eltwise_logsigmoid"; if (v == dnnl_eltwise_mish) return "eltwise_mish"; if (v == dnnl_eltwise_hardswish) return "eltwise_hardswish"; + if (v == dnnl_eltwise_hsigmoid) return "eltwise_hsigmoid"; + if (v == dnnl_eltwise_round_half_to_even) return "eltwise_round_half_to_even"; + if (v == dnnl_eltwise_round_half_away_from_zero) return "eltwise_round_half_away_from_zero"; if (v == dnnl_eltwise_relu_use_dst_for_bwd) return "eltwise_relu_use_dst_for_bwd"; if (v == dnnl_eltwise_tanh_use_dst_for_bwd) return "eltwise_tanh_use_dst_for_bwd"; if (v == dnnl_eltwise_elu_use_dst_for_bwd) return "eltwise_elu_use_dst_for_bwd"; @@ -1141,6 +1155,7 @@ const char *dnnl_alg_kind2str(dnnl_alg_kind_t v) { if (v == dnnl_binary_lt) return "binary_lt"; if (v == dnnl_binary_eq) return "binary_eq"; if (v == dnnl_binary_ne) return "binary_ne"; + if (v == dnnl_binary_prelu) return "binary_prelu"; if (v == dnnl_resampling_nearest) return "resampling_nearest"; if (v == dnnl_resampling_linear) return "resampling_linear"; if (v == dnnl_reduction_max) return "reduction_max"; @@ -1152,6 +1167,11 @@ const char *dnnl_alg_kind2str(dnnl_alg_kind_t v) { if (v == dnnl_reduction_norm_lp_sum) return "reduction_norm_lp_sum"; if (v == dnnl_reduction_norm_lp_power_p_max) return "reduction_norm_lp_power_p_max"; if (v == dnnl_reduction_norm_lp_power_p_sum) return "reduction_norm_lp_power_p_sum"; + if (v == dnnl_depthwise_scale_shift) return "depthwise_scale_shift"; + if (v == dnnl_depthwise_prelu) return "depthwise_prelu"; + if (v == dnnl_quantization_quantize_dequantize) return "quantization_quantize_dequantize"; + if (v == dnnl_quantization_quantize) return "quantization_quantize"; + if (v == dnnl_binarization_depthwise) return "binarization_depthwise"; assert(!"unknown alg_kind"); return "unknown alg_kind"; } diff --git a/src/common/dnnl_thread.cpp b/src/common/dnnl_thread.cpp index 991dafd8f5c..a3e725da224 100644 --- a/src/common/dnnl_thread.cpp +++ b/src/common/dnnl_thread.cpp @@ -29,15 +29,6 @@ namespace dnnl { namespace impl { -static int adjust_num_threads(int nthr, dim_t work_amount) { - if (nthr == 0) nthr = dnnl_get_current_num_threads(); -#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP - return (work_amount == 1 || omp_in_parallel()) ? 1 : nthr; -#else - return (int)std::min((dim_t)nthr, work_amount); -#endif -} - void parallel(int nthr, const std::function &f) { nthr = adjust_num_threads(nthr, INT64_MAX); #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_SEQ @@ -82,6 +73,9 @@ void parallel(int nthr, const std::function &f) { #endif }, tbb::static_partitioner()); +#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB_AUTO + tbb::parallel_for( + 0, nthr, [&](int ithr) { f(ithr, nthr); }); #elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL using namespace dnnl::impl::threadpool_utils; dnnl::threadpool_interop::threadpool_iface *tp = get_active_threadpool(); diff --git a/src/common/dnnl_thread.hpp b/src/common/dnnl_thread.hpp index 2520758f4ed..225d2777b12 100644 --- a/src/common/dnnl_thread.hpp +++ b/src/common/dnnl_thread.hpp @@ -47,7 +47,7 @@ inline void dnnl_thr_barrier() { #pragma omp barrier } -#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB +#elif (DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB || DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB_AUTO) #include "tbb/parallel_for.h" #include "tbb/task_arena.h" #define DNNL_THR_SYNC 0 @@ -162,8 +162,11 @@ inline int dnnl_get_current_num_threads() { #define OMP_GET_NUM_THREADS() 1 #endif -// MSVC still supports omp 2.0 only -#if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER) +/* MSVC still supports omp 2.0 only, + however VS2019 also now offers SIMD functionality + with the -openmp:experimental compilation switch that enables additional OpenMP features + not available when using the -openmp switch */ +#if defined(_MSC_VER) && (_MSC_VER < 1900) && !defined(__clang__) && !defined(__INTEL_COMPILER) #define collapse(x) #define PRAGMA_OMP_SIMD(...) #else @@ -182,6 +185,10 @@ inline int dnnl_get_current_num_threads() { #define simdlen(x) #endif // long simdlen if +#if defined(DNNL_ENABLE_ITT_TASKS) +#include "common/ittnotify.hpp" +#endif + namespace dnnl { namespace impl { @@ -328,11 +335,186 @@ void parallel_nd_in_omp(Args &&... args) { for_nd(omp_get_thread_num(), omp_get_num_threads(), utils::forward(args)...); #elif (DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB \ + || DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB_AUTO \ || DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL) assert(!"parallel_nd_in_omp() is not supported by this DNNL_CPU_RUNTIME"); #endif } +static inline int adjust_num_threads(int nthr, dim_t work_amount) { + if (nthr == 0) nthr = dnnl_get_current_num_threads(); +#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP + return (work_amount == 1 || omp_in_parallel()) ? 1 : nthr; +#else + return (int)std::min((dim_t)nthr, work_amount); +#endif +} + +template +void parallel_legacy(int nthr, F f) { + nthr = adjust_num_threads(nthr, INT64_MAX); +#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_SEQ + assert(nthr == 1); +f(0, 1); +#else +#if defined(DNNL_ENABLE_ITT_TASKS) + auto task_primitive_kind = itt::primitive_task_get_current_kind(); +bool itt_enable = itt::get_itt(itt::__itt_task_level_high); +#endif + if (nthr == 1) { + f(0, 1); + return; + } +#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP + #pragma omp parallel num_threads(nthr) +{ +int nthr_ = omp_get_num_threads(); +int ithr_ = omp_get_thread_num(); +assert(nthr_ == nthr); +#if defined(DNNL_ENABLE_ITT_TASKS) +if (ithr_ && itt_enable) itt::primitive_task_start(task_primitive_kind); +#endif +f(ithr_, nthr_); +#if defined(DNNL_ENABLE_ITT_TASKS) +if (ithr_ && itt_enable) itt::primitive_task_end(); +#endif +} +#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB + tbb::parallel_for( + 0, nthr, + [&](int ithr) { +#if defined(DNNL_ENABLE_ITT_TASKS) + bool mark_task = itt::primitive_task_get_current_kind() + == primitive_kind::undefined; + if (mark_task && itt_enable) + itt::primitive_task_start(task_primitive_kind); +#endif + f(ithr, nthr); +#if defined(DNNL_ENABLE_ITT_TASKS) + if (mark_task && itt_enable) itt::primitive_task_end(); +#endif + }, + tbb::static_partitioner()); +#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB_AUTO + tbb::parallel_for( +0, nthr, [&](int ithr) { f(ithr, nthr); }); +#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL +using namespace dnnl::impl::threadpool_utils; +dnnl::threadpool_interop::threadpool_iface *tp = get_active_threadpool(); +if (!tp || dnnl_in_parallel()) { +threadpool_utils::deactivate_threadpool(); +for (int ithr = 0; ithr < nthr; ithr++) { +f(ithr, nthr); +} +threadpool_utils::activate_threadpool(tp); +} else { +bool async = tp->get_flags() + & dnnl::threadpool_interop::threadpool_iface::ASYNCHRONOUS; +counting_barrier_t b; +if (async) b.init(nthr); +tp->parallel_for(nthr, [&, tp](int ithr, int nthr) { +bool is_master = threadpool_utils::get_active_threadpool() == tp; +if (!is_master) { + threadpool_utils::activate_threadpool(tp); +#if defined(DNNL_ENABLE_ITT_TASKS) + if (itt_enable) itt::primitive_task_start(task_primitive_kind); +#endif +} +f(ithr, nthr); +if (!is_master) { +#if defined(DNNL_ENABLE_ITT_TASKS) + if (itt_enable) itt::primitive_task_end(); +#endif + threadpool_utils::deactivate_threadpool(); +} +if (async) b.notify(); +}); +if (async) b.wait(); +} +#endif +#endif +} + +template +void for_nd_legacy(const int ithr, const int nthr, const T0 &D0, F f) { + T0 start {0}, end {0}; + balance211(D0, nthr, ithr, start, end); + for (T0 d0 = start; d0 < end; ++d0) + f(d0); +} + +template +void for_nd_legacy(const int ithr, const int nthr, const T0 &D0, const T1 &D1, + const T2 &D2, const T3 &D3, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3; + if (work_amount == 0) return; + size_t start {0}, end {0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0 {0}; + T1 d1 {0}; + T2 d2 {0}; + T3 d3 {0}; + utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1, d2, d3); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3); + } +} + +template +void for_nd_legacy(const int ithr, const int nthr, const T0 &D0, const T1 &D1, + const T2 &D2, const T3 &D3, const T4 &D4, const T5 &D5, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5; + if (work_amount == 0) return; + size_t start {0}, end {0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0 {0}; + T1 d1 {0}; + T2 d2 {0}; + T3 d3 {0}; + T4 d4 {0}; + T5 d5 {0}; + utils::nd_iterator_init( + start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1, d2, d3, d4, d5); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5); + } +} + +template +void parallel_nd_legacy(const T0 &D0, F f) { + const size_t work_amount = (size_t)D0; + int nthr = adjust_num_threads(dnnl_get_current_num_threads(), work_amount); + if (nthr) + parallel_legacy(nthr, [&](int ithr, int nthr) { for_nd_legacy(ithr, nthr, D0, f); }); +} + +template +void parallel_nd_legacy(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3; + int nthr = adjust_num_threads(dnnl_get_current_num_threads(), work_amount); + if (nthr) + parallel_legacy(nthr, [&](int ithr, int nthr) { + for_nd_legacy(ithr, nthr, D0, D1, D2, D3, f); + }); +} + +template +void parallel_nd_legacy(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, + const T4 &D4, const T5 &D5, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5; + int nthr = adjust_num_threads(dnnl_get_current_num_threads(), work_amount); + if (nthr) + parallel_legacy(nthr, [&](int ithr, int nthr) { + for_nd_legacy(ithr, nthr, D0, D1, D2, D3, D4, D5, f); + }); +} + } // namespace impl } // namespace dnnl diff --git a/src/common/dnnl_traits.hpp b/src/common/dnnl_traits.hpp index dcedac78b8b..54f0074afc6 100644 --- a/src/common/dnnl_traits.hpp +++ b/src/common/dnnl_traits.hpp @@ -66,6 +66,10 @@ struct prec_traits { typedef uint8_t type; }; +template <> struct prec_traits { + typedef uint8_t type; +}; + template <> struct data_traits { static constexpr data_type_t data_type = data_type::f16; diff --git a/src/common/eltwise.cpp b/src/common/eltwise.cpp index e21b937e834..f404d64dd02 100644 --- a/src/common/eltwise.cpp +++ b/src/common/eltwise.cpp @@ -38,7 +38,7 @@ status_t eltwise_desc_init(eltwise_desc_t *eltwise_desc, prop_kind_t prop_kind, backward_data) && IMPLICATION( prop_kind == backward_data, diff_data_desc != nullptr) - && IMPLICATION(alg_kind == eltwise_round, + && IMPLICATION(one_of(alg_kind, eltwise_round, eltwise_hsigmoid, eltwise_round_half_away_from_zero, eltwise_round_half_to_even), one_of(prop_kind, forward_training, forward_inference)) && math::is_eltwise_ok(data_desc->data_type, alg_kind, alpha, beta); if (!args_ok) return invalid_arguments; diff --git a/src/common/eltwise_pd.hpp b/src/common/eltwise_pd.hpp index 43c9d6fbc7f..9e4fe367699 100644 --- a/src/common/eltwise_pd.hpp +++ b/src/common/eltwise_pd.hpp @@ -137,7 +137,8 @@ struct eltwise_fwd_pd_t : public eltwise_pd_t { return one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu, eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_swish, eltwise_bounded_relu, eltwise_gelu_tanh, - eltwise_gelu_erf, eltwise_round, eltwise_hardswish) + eltwise_gelu_erf, eltwise_round, eltwise_hardswish, + eltwise_round_half_away_from_zero, eltwise_round_half_to_even) || one_of(alg, eltwise_relu_use_dst_for_bwd, eltwise_tanh_use_dst_for_bwd, eltwise_elu_use_dst_for_bwd, diff --git a/src/common/engine.hpp b/src/common/engine.hpp index a4017f39485..679a17955e2 100644 --- a/src/common/engine.hpp +++ b/src/common/engine.hpp @@ -157,6 +157,8 @@ inline runtime_kind_t get_default_runtime(engine_kind_t kind) { return runtime_kind::omp; #elif DNNL_CPU_RUNTIME == DNNL_RUNTIME_TBB return runtime_kind::tbb; +#elif DNNL_CPU_RUNTIME == DNNL_RUNTIME_TBB_AUTO + return runtime_kind::tbb_auto; #elif DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL return runtime_kind::threadpool; #elif DNNL_CPU_RUNTIME == DNNL_RUNTIME_SYCL diff --git a/src/common/ittnotify.cpp b/src/common/ittnotify.cpp index bda0a46b844..ab27bee8060 100644 --- a/src/common/ittnotify.cpp +++ b/src/common/ittnotify.cpp @@ -80,6 +80,8 @@ void primitive_task_start(primitive_kind_t kind) { CASE(pooling_v2), CASE(reduction), CASE(prelu), + CASE(depthwise), + CASE(quantization), }; #undef CASE int kind_idx = (int)kind; diff --git a/src/common/ittnotify/ittnotify_config.h b/src/common/ittnotify/ittnotify_config.h index 75d9bcb3dae..1ced7285ea3 100644 --- a/src/common/ittnotify/ittnotify_config.h +++ b/src/common/ittnotify/ittnotify_config.h @@ -183,6 +183,10 @@ # define ITT_ARCH_IA32E 2 #endif /* ITT_ARCH_IA32E */ +#ifndef ITT_ARCH_IA64 +# define ITT_ARCH_IA64 3 +#endif /* ITT_ARCH_IA64 */ + #ifndef ITT_ARCH_ARM # define ITT_ARCH_ARM 4 #endif /* ITT_ARCH_ARM */ diff --git a/src/common/math_utils.hpp b/src/common/math_utils.hpp index ce999514e6e..0ca40ca9919 100644 --- a/src/common/math_utils.hpp +++ b/src/common/math_utils.hpp @@ -246,7 +246,7 @@ inline U logistic_bwd_use_dst(T dd, T d) { template ::type> inline U soft_relu_fwd(T s) { - float exp_overflow_bound = 88.72283172607421875; + float exp_overflow_bound = 20.f; float in = (float)s; return in < exp_overflow_bound ? (U)(::log1pf(::expf(in))) : (U)in; } @@ -395,6 +395,31 @@ inline U hardswish_bwd(T dd, T s) { : s >= 3.f ? dd : 0.f); } +template ::type> +inline U hsigmoid_fwd(T s) { + float v = s + 3.0f; + v = v > 0.0f ? v : 0.0f; + v = v < 6.0f ? v : 6.0f; + return (U)(v / 6.0f); +} + +template ::type> +inline U round_half_to_even_fwd(T s) { + float r = ::roundf((float)s); + float d = (float)s - r; + float remainder = ::fmodf(r, 2.0f); + return ((d != 0.5f) && (d != -0.5f)) || (remainder == 0.0f) ? (U)r : + (U)((float)s + d); +} + +template ::type> +inline U round_half_away_from_zero_fwd(T s) { + return (U)(::roundf((float)s)); +} + inline bool is_eltwise_ok( data_type_t dt, alg_kind_t alg, float alpha, float beta) { using namespace alg_kind; @@ -407,7 +432,8 @@ inline bool is_eltwise_ok( eltwise_logsigmoid, eltwise_mish, eltwise_logistic, eltwise_exp, eltwise_gelu_tanh, eltwise_hardswish, eltwise_swish, eltwise_log, eltwise_clip, eltwise_clip_v2, - eltwise_pow, eltwise_gelu_erf, eltwise_round) + eltwise_pow, eltwise_gelu_erf, eltwise_round, eltwise_hsigmoid, + eltwise_round_half_away_from_zero, eltwise_round_half_to_even) && IMPLICATION(alg == eltwise_bounded_relu, alpha >= 0) && IMPLICATION( one_of(alg, eltwise_clip, eltwise_clip_v2), beta >= alpha) @@ -431,6 +457,43 @@ inline bool is_eltwise_ok( return eltwise_use_src || eltwise_use_dst; } +inline float get_bias(const char *bias, size_t offset, data_type_t data_type) { + if (!bias) return 0.0f; + +#define CASE(dt) \ + case dt: return (float)((const prec_traits
::type *)bias)[offset] + + switch (data_type) { + CASE(data_type::s8); + CASE(data_type::u8); + CASE(data_type::bf16); + CASE(data_type::s32); + CASE(data_type::f32); + default: assert(!"unimplemented"); + } + return 0; // never happens (should probably be a NaN) +#undef CASE +} + +inline float get_sum(char *sum, size_t offset, data_type_t data_type) +{ + if (!sum) + return 0.0f; + +#define CASE(dt) \ + case dt: return (float)((const prec_traits
::type *)sum)[offset] + + switch (data_type) { + CASE(data_type::s8); + CASE(data_type::u8); + CASE(data_type::s32); + CASE(data_type::f32); + default: assert(!"unimplemented"); + } + return 0; // never happens (should probably be a NaN) +#undef CASE +} + } // namespace math } // namespace impl } // namespace dnnl diff --git a/src/common/memory.cpp b/src/common/memory.cpp index 91a7967de12..0ef1052ec8d 100644 --- a/src/common/memory.cpp +++ b/src/common/memory.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include "oneapi/dnnl/dnnl.h" #include "oneapi/dnnl/dnnl.hpp" @@ -89,7 +90,7 @@ dnnl_memory::dnnl_memory(dnnl::impl::engine_t *engine, this->reset_memory_storage(std::move(memory_storage)); } -status_t dnnl_memory::set_data_handle(void *handle, stream_t *stream) { +status_t dnnl_memory::set_data_handle(void *handle, stream_t *stream, bool pads_zeroing) { using namespace dnnl::impl; void *old_handle; @@ -98,7 +99,10 @@ status_t dnnl_memory::set_data_handle(void *handle, stream_t *stream) { if (handle != old_handle) { CHECK(memory_storage_->set_data_handle(handle)); } - return status::success; + + memory_arg_t mem_arg = {this, true}; + exec_args_t args = {{0, mem_arg}}; + return pads_zeroing ? zero_pad(exec_ctx_t(stream, std::move(args))) : dnnl_success; } status_t dnnl_memory::reset_memory_storage( @@ -521,7 +525,8 @@ status_t dnnl_memory_create(memory_t **memory, const memory_desc_t *md, : memory_flags_t::use_runtime_ptr; void *handle_ptr = (handle == DNNL_MEMORY_ALLOCATE) ? nullptr : handle; auto _memory = new memory_t(engine, md, flags, handle_ptr); - if (_memory == nullptr) return out_of_memory; + if (_memory == nullptr) + return out_of_memory; if (_memory->memory_storage() == nullptr) { delete _memory; return out_of_memory; @@ -556,11 +561,24 @@ status_t dnnl_memory_set_data_handle(memory_t *memory, void *handle) { return dnnl_memory_set_data_handle_v2(memory, handle, nullptr); } +status_t dnnl_memory_set_data_handle_no_pads_proc(memory_t *memory, void *handle) { + return dnnl_memory_set_data_handle_v2_no_pads_proc(memory, handle, nullptr); +} + status_t dnnl_memory_set_data_handle_v2( memory_t *memory, void *handle, stream_t *stream) { if (any_null(memory)) return invalid_arguments; if (stream) stream->before_exec_hook(); - status_t status = memory->set_data_handle(handle, stream); + status_t status = memory->set_data_handle(handle, stream, true); + if (stream) stream->after_exec_hook(); + return status; +} + +status_t dnnl_memory_set_data_handle_v2_no_pads_proc( + memory_t *memory, void *handle, stream_t *stream) { + if (any_null(memory)) return invalid_arguments; + if (stream) stream->before_exec_hook(); + status_t status = memory->set_data_handle(handle, stream, false); if (stream) stream->after_exec_hook(); return status; } diff --git a/src/common/memory.hpp b/src/common/memory.hpp index 36c77582561..5e0791a4023 100644 --- a/src/common/memory.hpp +++ b/src/common/memory.hpp @@ -74,7 +74,7 @@ struct dnnl_memory : public dnnl::impl::c_compatible { } /** sets data handle */ - dnnl::impl::status_t set_data_handle(void *handle, dnnl_stream *stream); + dnnl::impl::status_t set_data_handle(void *handle, dnnl_stream *stream, bool pads_zeroing); /** zeros padding */ dnnl::impl::status_t zero_pad(const dnnl::impl::exec_ctx_t &ctx) const; diff --git a/src/common/memory_debug.hpp b/src/common/memory_debug.hpp index b964f85660b..d3c7fa740b3 100644 --- a/src/common/memory_debug.hpp +++ b/src/common/memory_debug.hpp @@ -45,7 +45,7 @@ static inline bool is_mem_debug() { // Static inline for optimization purposes when memory_debug is disabled static inline bool is_mem_debug_overflow() { if (is_mem_debug()) -#if (DNNL_ENABLE_MEM_DEBUG == DNNL_MEM_DEBUG_UNDERFLOW) +#if (defined(DNNL_ENABLE_MEM_DEBUG) && DNNL_ENABLE_MEM_DEBUG == DNNL_MEM_DEBUG_UNDERFLOW) return false; #else // Default to DNNL_MEM_DEBUG_OVERFLOW as buffer overflows are a diff --git a/src/common/memory_desc_wrapper.cpp b/src/common/memory_desc_wrapper.cpp index 8619afc8b87..c96d4d5b671 100644 --- a/src/common/memory_desc_wrapper.cpp +++ b/src/common/memory_desc_wrapper.cpp @@ -26,11 +26,10 @@ namespace dnnl { namespace impl { -status_t fill_blocked(memory_desc_t &md, std::initializer_list perm, - std::initializer_list inner_blks, - std::initializer_list inner_idxs) { +template +static status_t fill_blocked_impl(memory_desc_t &md, T&& perm, T&& inner_blks, T&& inner_idxs) { const bool ok = true && perm.size() == (size_t)md.ndims - && inner_blks.size() == inner_idxs.size(); + && inner_blks.size() == inner_idxs.size(); if (!ok) return status::invalid_arguments; md.offset0 = 0; @@ -81,6 +80,18 @@ status_t fill_blocked(memory_desc_t &md, std::initializer_list perm, return status::success; } +status_t fill_blocked(memory_desc_t &md, std::initializer_list perm, + std::initializer_list inner_blks, + std::initializer_list inner_idxs) { + return fill_blocked_impl(md, perm, inner_blks, inner_idxs); +} + +status_t fill_blocked(memory_desc_t &md, std::vector& perm, + std::vector& inner_blks, + std::vector& inner_idxs) { + return fill_blocked_impl(md, perm, inner_blks, inner_idxs); +} + void memory_desc_wrapper::compute_strides_compat(dims_t *strides_compat) const { if (ndims() == 0) return; @@ -125,14 +136,12 @@ void memory_desc_wrapper::compute_strides_compat(dims_t *strides_compat) const { utils::array_copy(strides_compat[1], inner_strides, ndims()); } -status_t memory_desc_wrapper::compute_blocking( - memory_desc_t &memory_desc, format_tag_t tag) { +template +status_t process_tag(F f, format_tag_t tag, Args&&... args) { using namespace format_tag; - if (memory_desc.ndims == 0) return status::invalid_arguments; - #define C(tag, ... /* perm, inner_blks, inner_idxs */) \ - case tag: return fill_blocked(memory_desc, __VA_ARGS__) + case tag: return f(std::forward(args)..., __VA_ARGS__) switch (tag) { C(a, {0}, {}, {}); @@ -339,6 +348,8 @@ status_t memory_desc_wrapper::compute_blocking( C(ABcd8a16b2a, {0, 1, 2, 3}, {8, 16, 2}, {0, 1, 0}); C(BAcd8a16b2a, {1, 0, 2, 3}, {8, 16, 2}, {0, 1, 0}); C(ABcd8a8b, {0, 1, 2, 3}, {8, 8}, {0, 1}); + C(ABcd8a32b, {0, 1, 2, 3}, {8, 32}, {0, 1}); + C(ABcd16a32b, {0, 1, 2, 3}, {16, 32}, {0, 1}); C(ABcd8a4b, {0, 1, 2, 3}, {8, 4}, {0, 1}); C(ABcd8a2b, {0, 1, 2, 3}, {8, 2}, {0, 1}); C(aBcd8b, {0, 1, 2, 3}, {8}, {1}); @@ -492,6 +503,8 @@ status_t memory_desc_wrapper::compute_blocking( C(aBdec32b, {0, 1, 3, 4, 2}, {32}, {1}); C(aCBdef16c16b, {0, 2, 1, 3, 4, 5}, {16, 16}, {2, 1}); C(aCBdef16b16c, {0, 2, 1, 3, 4, 5}, {16, 16}, {1, 2}); + C(Abcdef4a, {0, 1, 2, 3, 4, 5}, {4}, {0}); + C(Abcdef8a, {0, 1, 2, 3, 4, 5}, {8}, {0}); C(Abcdef16a, {0, 1, 2, 3, 4, 5}, {16}, {0}); C(Abcdef32a, {0, 1, 2, 3, 4, 5}, {32}, {0}); C(aCBd16c16b, {0, 2, 1, 3}, {16, 16}, {2, 1}); @@ -650,6 +663,28 @@ status_t memory_desc_wrapper::compute_blocking( return status::invalid_arguments; } +status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc, format_tag_t tag) { + using fill_blocked_t = status_t(memory_desc_t&, std::initializer_list, std::initializer_list, std::initializer_list); + if (memory_desc.ndims == 0) return status::invalid_arguments; + return process_tag(fill_blocked, tag, memory_desc); +} + +status_t memory_desc_wrapper::compute_blocking(format_tag_t tag, + std::vector &perm, + std::vector &inner_blks, + std::vector &inner_idxs) { + + auto extract_data = [&](std::initializer_list _perm, + std::initializer_list _inner_blks, + std::initializer_list _inner_idxs) -> status_t { + perm = {_perm.begin(), _perm.end()}; + inner_blks = {_inner_blks.begin(), _inner_blks.end()}; + inner_idxs = {_inner_idxs.begin(), _inner_idxs.end()}; + return status::success; + }; + return process_tag(extract_data, tag); +} + } // namespace impl } // namespace dnnl diff --git a/src/common/memory_desc_wrapper.hpp b/src/common/memory_desc_wrapper.hpp index e60b83291f2..5196f64f099 100644 --- a/src/common/memory_desc_wrapper.hpp +++ b/src/common/memory_desc_wrapper.hpp @@ -28,6 +28,10 @@ namespace dnnl { namespace impl { +status_t fill_blocked(memory_desc_t &md, std::vector &perm, + std::vector &inner_blks, + std::vector &inner_idxs); + /** thin wrapper class over \struct memory_desc_t which allows easy * manipulations with underlying C structure, which is taken by reference */ struct memory_desc_wrapper : public c_compatible { @@ -294,7 +298,8 @@ struct memory_desc_wrapper : public c_compatible { * following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */ /* TODO: revise */ bool similar_to(const memory_desc_wrapper &rhs, bool with_padding = true, - bool with_data_type = true, int dim_start = 0) const; + bool with_data_type = true, int dim_start = 0, bool use_weak_cmp = false, + bool check_off0 = false, uint64_t stride_mask = 0xffffffffffffffff) const; /** returns true if one memory can be reordered to another */ bool consistent_with(const memory_desc_wrapper &rhs) const; @@ -316,6 +321,21 @@ struct memory_desc_wrapper : public c_compatible { return format_tag::undef; } + /** returns matching tag (or undef if match is not found) with taking into + * account strides specified outside */ + template + dnnl_format_tag_t stride_relaxed_matches_any_of(const dims_t &strides, Tags... tags) const { + for (const auto &tag : {tags...}) + if (matches_tag(tag, strides)) return tag; + return format_tag::undef; + } + + template + dnnl_format_tag_t mb_stride_relaxed_match(Tags... tags) const { + const dims_t skip_mb_stride{-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + return stride_relaxed_matches_any_of(skip_mb_stride, tags...); + } + /* offset section */ /** returns physical offset by logical one. logical offset is represented by @@ -411,6 +431,11 @@ struct memory_desc_wrapper : public c_compatible { static status_t compute_blocking( memory_desc_t &memory_desc, format_tag_t tag); + static status_t compute_blocking(format_tag_t tag, + std::vector &perm, + std::vector &inner_blks, + std::vector &inner_idxs); + private: /* TODO: put logical_offset in utils */ template @@ -440,7 +465,7 @@ struct memory_desc_wrapper : public c_compatible { }; inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs, - bool with_padding, bool with_data_type, int dim_start) const { + bool with_padding, bool with_data_type, int dim_start, bool use_weak_cmp, bool check_off0, uint64_t stride_mask) const { using namespace utils; if (one_of(format_kind(), format_kind::undef, format_kind::any)) @@ -451,20 +476,38 @@ inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs, const auto &blk = blocking_desc(); const auto &r_blk = rhs.blocking_desc(); + auto custom_cpm = use_weak_cmp ? array_cmp_weak : array_cmp; + auto cmp_strides = [&]() { + if (0xffffffffffffffff == stride_mask) { + return custom_cpm(blk.strides + ds, r_blk.strides + ds, ndims() - ds); + } else { + for (int i = 0; i < ndims(); ++i) { + if (stride_mask & (1 << i)) { + if (blk.strides[i] != r_blk.strides[i] + && IMPLICATION(use_weak_cmp, (blk.strides[i] != DNNL_RUNTIME_DIM_VAL && r_blk.strides[i] != DNNL_RUNTIME_DIM_VAL))) { + return false; + } + } + } + } + return true; + }; + return ndims() == rhs.ndims() && dim_start <= ndims() /* guard */ && format_kind() == rhs.format_kind() && IMPLICATION(with_data_type, data_type() == rhs.data_type()) - && array_cmp(dims() + ds, rhs.dims() + ds, ndims() - ds) - && array_cmp(blk.strides + ds, r_blk.strides + ds, ndims() - ds) + && custom_cpm(dims() + ds, rhs.dims() + ds, ndims() - ds) + && cmp_strides() && blk.inner_nblks == r_blk.inner_nblks && array_cmp(blk.inner_blks, r_blk.inner_blks, blk.inner_nblks) && array_cmp(blk.inner_idxs, r_blk.inner_idxs, blk.inner_nblks) && IMPLICATION(with_padding, true - && array_cmp(padded_dims() + ds, + && custom_cpm(padded_dims() + ds, rhs.padded_dims() + ds, ndims() - ds) - && array_cmp(padded_offsets() + ds, - rhs.padded_offsets() + ds, ndims() - ds)); + && custom_cpm(padded_offsets() + ds, + rhs.padded_offsets() + ds, ndims() - ds)) + && IMPLICATION(check_off0, (offset0() == DNNL_RUNTIME_DIM_VAL || rhs.offset0() ==DNNL_RUNTIME_DIM_VAL || offset0() == rhs.offset0())); } inline bool memory_desc_wrapper::consistent_with( diff --git a/src/common/memory_tracking.hpp b/src/common/memory_tracking.hpp index 4f04c8aa89f..06568d169fb 100644 --- a/src/common/memory_tracking.hpp +++ b/src/common/memory_tracking.hpp @@ -272,6 +272,8 @@ enum { // even though they are not in alphabetical order key_nested, key_nested_multiple, + key_dw_conv_buffer, + key_dw_conv_padded_bias, }; enum { diff --git a/src/common/memory_zero_pad.cpp b/src/common/memory_zero_pad.cpp index fe059fad7ac..702ba7f9c4b 100644 --- a/src/common/memory_zero_pad.cpp +++ b/src/common/memory_zero_pad.cpp @@ -284,6 +284,7 @@ static status_t zero_pad(const memory_t *memory, const exec_ctx_t &ctx) { case s32: return typed_zero_pad(memory, ctx); case s8: return typed_zero_pad(memory, ctx); case u8: return typed_zero_pad(memory, ctx); + case bin: return typed_zero_pad(memory, ctx); default: assert(!"memory is undefined"); return unimplemented; } return unimplemented; diff --git a/src/common/nstl.hpp b/src/common/nstl.hpp index a868f1efa64..142c5d027e4 100644 --- a/src/common/nstl.hpp +++ b/src/common/nstl.hpp @@ -247,6 +247,10 @@ class vector : public c_compatible { } void clear() { _impl.clear(); } void push_back(const T &t) { _impl.push_back(t); } + template + void emplace_back(Args&&... args) { + _impl.emplace_back(std::forward(args)...); + } void resize(size_type count) { _impl.resize(count); } void reserve(size_type count) { _impl.reserve(count); } }; diff --git a/src/common/pooling.cpp b/src/common/pooling.cpp index 5bc28105239..6400881ada6 100644 --- a/src/common/pooling.cpp +++ b/src/common/pooling.cpp @@ -110,13 +110,13 @@ status_t pooling_desc_init(pooling_desc_type *pool_desc, prop_kind_t prop_kind, if ((src - ker_range + pad_l + pad_r) / str + 1 != dst) return invalid_arguments; - - // It's not allowed for pooling window to be totally placed outside - // of real source domain for pooling_avg_exclude_padding algorithm - // due to 0 / 0 ambiguity - if (alg_kind == pooling_avg_exclude_padding - && !(pad_l < ker_range && pad_r < ker_range && dil < src)) - return invalid_arguments; +// The check is disabled in order to support old behavior +// // It's not allowed for pooling window to be totally placed outside +// // of real source domain for pooling_avg_exclude_padding algorithm +// // due to 0 / 0 ambiguity +// if (alg_kind == pooling_avg_exclude_padding +// && !(pad_l < ker_range && pad_r < ker_range && dil < src)) +// return invalid_arguments; } *pool_desc = pd; diff --git a/src/common/pooling_pd.hpp b/src/common/pooling_pd.hpp index b3b0880f4a1..56bf98dd362 100644 --- a/src/common/pooling_pd.hpp +++ b/src/common/pooling_pd.hpp @@ -234,7 +234,7 @@ struct pooling_fwd_pd_t : public pooling_pd_t { : &glob_zero_md; } - int n_inputs() const override { return 1 + n_binary_po_inputs(); } + int n_inputs() const override { return 1 + n_binary_po_inputs() + n_depthwise_po_inputs() + n_quantization_po_inputs(); } int n_outputs() const override { return 1 + (!types::is_zero_md(workspace_md())); } diff --git a/src/common/primitive.hpp b/src/common/primitive.hpp index 88503e6d044..b94701d8aae 100644 --- a/src/common/primitive.hpp +++ b/src/common/primitive.hpp @@ -248,6 +248,9 @@ status_t primitive_execute( #define CTX_OUT_CLEAN_MEM(type, arg, status) \ static_cast(ctx.host_ptr(arg, true, &status)) +#define CTX_IN_BATCH(arg) \ + ctx.input(arg) ? ctx.input(arg)->md()->ndims != 0 ? ctx.input(arg)->md()->dims[0] : 0 : 0 + // dnnl_primitive is a user facing entity that has an alias primitive_iface_t // for internal use. // The primitive_iface_t is responsible for holding: diff --git a/src/common/primitive_attr.cpp b/src/common/primitive_attr.cpp index b43dd5af397..30e3dc3f931 100644 --- a/src/common/primitive_attr.cpp +++ b/src/common/primitive_attr.cpp @@ -56,6 +56,28 @@ status_t scales_t::set(dim_t count, int mask, const float *scales) { return status::success; } +template +status_t shifts_t::set(int count, int mask, const T *shifts) { + cleanup(); + + count_ = count; + mask_ = mask; + + if (count_ == 1) { + shifts_ = shifts_buf_; + utils::array_set(shifts_, shifts[0], shifts_buf_size); + } else { + shifts_ = (T *)impl::malloc(count_ * sizeof(*shifts_), 64); + if (shifts_ == nullptr) + return status::out_of_memory; + + for (int c = 0; c < count_; ++c) + shifts_[c] = shifts[c]; + } + + return status::success; +} + status_t arg_scales_t::set( int arg, dim_t count, int mask, const float *scales) { if (!check_arg(arg)) return status::invalid_arguments; @@ -135,6 +157,9 @@ bool primitive_attr_t::has_default_values(dnnl_primitive_attr::skip_mask_t mask, CHECK_MASK(smask_t::oscale, output_scales_); CHECK_MASK(smask_t::scales, scales_); CHECK_MASK(smask_t::zero_points, zero_points_); + CHECK_MASK(smask_t::input_zero_points, input_zero_points_); + CHECK_MASK(smask_t::weights_zero_points, weights_zero_points_); + CHECK_MASK(smask_t::output_compensations, output_compensations_); CHECK_MASK(smask_t::post_ops, post_ops_); CHECK_MASK(smask_t::rnn_data_qparams, rnn_data_qparams_); CHECK_MASK(smask_t::rnn_weights_qparams, rnn_weights_qparams_); @@ -260,7 +285,7 @@ status_t post_ops_t::append_binary( using namespace alg_kind; bool alg_ok = one_of(alg, binary_add, binary_mul, binary_max, binary_min, binary_div, binary_sub, binary_ge, binary_gt, binary_le, binary_lt, - binary_eq, binary_ne); + binary_eq, binary_ne, binary_prelu); if (!alg_ok) return invalid_arguments; if (!memory_desc_sanity_check(user_src1_desc)) return invalid_arguments; @@ -289,6 +314,79 @@ status_t post_ops_t::append_prelu(int mask) { return success; } +status_t post_ops_t::append_depthwise(alg_kind_t alg, size_t offset_size, const size_t* offset) { + using namespace dnnl::impl::alg_kind; + if (len() == post_ops_limit) return out_of_memory; + bool known_alg = one_of(alg, depthwise_scale_shift, depthwise_prelu); + if (!known_alg) + return invalid_arguments; + + entry_.emplace_back(); + auto &e = entry_.back(); + e.kind = primitive_kind::depthwise; + e.depthwise.alg = alg; + array_copy(e.depthwise.offset, offset, offset_size); + + return success; +} + +status_t post_ops_t::append_quantization(alg_kind_t alg, + size_t per_channel_size, const bool* per_channel, + size_t all_default_size, const bool* all_default, + size_t offset_size, const size_t* offset) { + using namespace dnnl::impl::alg_kind; + if (len() == post_ops_limit) return out_of_memory; + bool known_alg = one_of(alg, quantization_quantize_dequantize, quantization_quantize); + if (!known_alg) + return invalid_arguments; + + entry_.emplace_back(); + auto &e = entry_.back(); + e.kind = primitive_kind::quantization; + e.quantization.alg = alg; + + array_copy(e.quantization.per_channel, per_channel, per_channel_size); + array_copy(e.quantization.all_default, all_default, all_default_size); + array_copy(e.quantization.offset, offset, offset_size); + + return success; +} + +status_t post_ops_t::append_binarization(alg_kind_t alg, const float* weights_data, const float* output_mask_data) { + using namespace dnnl::impl::alg_kind; + if (len() == post_ops_limit) return out_of_memory; + bool known_alg = one_of(alg, binarization_depthwise); + if (!known_alg) + return invalid_arguments; + + entry_.emplace_back(); + auto &e = entry_.back(); + e.kind = primitive_kind::binarization; + e.binarization.alg = alg; + e.binarization.weights_data = weights_data; + e.binarization.output_mask_data = output_mask_data; + + return success; +} + +status_t post_ops_t::append_dw_conv(int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w, + dnnl::impl::data_type_t in_dt) { + if (len() == post_ops_limit) return out_of_memory; + + entry_.emplace_back(); + auto &e = entry_.back(); + e.kind = primitive_kind::convolution; + e.depthwise_conv_old.in_h = in_h; + e.depthwise_conv_old.in_w = in_w; + e.depthwise_conv_old.ker_h = ker_h; + e.depthwise_conv_old.ker_w = ker_w; + e.depthwise_conv_old.str_h = str_h; + e.depthwise_conv_old.str_w = str_w; + e.depthwise_conv_old.in_dt = in_dt; + + return success; +} + bool post_ops_t::defined() const { for (int idx = 0; idx < len(); ++idx) { auto kind = entry_[idx].kind; @@ -300,11 +398,16 @@ bool post_ops_t::defined() const { || is_runtime_value(e.beta)) return false; } else if (kind == primitive_kind::convolution) { - const auto &c = entry_[idx].depthwise_conv; - if (c.scales && is_runtime_value(*(c.scales))) return false; + // convolution is always defined } else if (utils::one_of(kind, primitive_kind::binary, primitive_kind::prelu)) { // binary is always defined + } else if (kind == primitive_kind::depthwise) { + // depthwise is always defined + } else if (kind == primitive_kind::quantization) { + // quantization is always defined + } else if (kind == primitive_kind::binarization) { + // binarization is always defined } else { assert(!"unreachable"); } @@ -363,7 +466,7 @@ status_t primitive_attr_t::set_scratchpad_mode( scratchpad_mode_t scratchpad_mode) { using namespace dnnl::impl::scratchpad_mode; - const bool ok = one_of(scratchpad_mode, library, user); + const bool ok = one_of(scratchpad_mode, scratchpad_mode::library, scratchpad_mode::user); if (!ok) return invalid_arguments; scratchpad_mode_ = scratchpad_mode; @@ -484,6 +587,33 @@ status_t dnnl_primitive_attr_set_zero_points(primitive_attr_t *attr, int arg, return attr->zero_points_.set(arg, count, mask, zero_points); } +status_t dnnl_primitive_attr_set_output_compensations(primitive_attr_t *attr, + int count, int mask) { + bool ok = !any_null(attr) && count > 0 && mask >= 0; + if (!ok) + return invalid_arguments; + + return attr->output_compensations_.set(count, mask); +} + +status_t dnnl_primitive_attr_set_input_zero_points(primitive_attr_t *attr, + int count, int mask) { + bool ok = !any_null(attr) && count > 0 && mask >= 0; + if (!ok) + return invalid_arguments; + + return attr->input_zero_points_.set(count, mask); +} + +status_t dnnl_primitive_attr_set_weights_zero_points(primitive_attr_t *attr, + int count, int mask) { + bool ok = !any_null(attr) && count > 0 && mask >= 0; + if (!ok) + return invalid_arguments; + + return attr->weights_zero_points_.set(count, mask); +} + status_t dnnl_primitive_attr_get_post_ops( const primitive_attr_t *attr, const post_ops_t **post_ops) { if (any_null(attr, post_ops)) return invalid_arguments; @@ -702,6 +832,45 @@ status_t dnnl_post_ops_get_params_prelu( return success; } +status_t dnnl_post_ops_append_depthwise(dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg, size_t offset_size, const size_t* offset) { + if (post_ops == nullptr || offset == nullptr) return invalid_arguments; + + if (offset_size != 2) + return invalid_arguments; + + return post_ops->append_depthwise(alg, offset_size, offset); +} + +status_t dnnl_post_ops_append_quantization(post_ops_t *post_ops, alg_kind_t kind, + size_t per_channel_size, const bool* per_channel, + size_t all_default_size, const bool* all_default, + size_t offset_size, const size_t* offset) { + if (post_ops == nullptr || per_channel == nullptr || all_default == nullptr || offset == nullptr) + return invalid_arguments; + + if (per_channel_size != all_default_size || all_default_size != offset_size || offset_size != 6) + return invalid_arguments; + + return post_ops->append_quantization(kind, per_channel_size, per_channel, all_default_size, all_default, offset_size, offset); +} + +status_t dnnl_post_ops_append_binarization(post_ops_t *post_ops, alg_kind_t kind, const float* weights_data, + const float* output_mask_data) { + if (post_ops == nullptr) + return invalid_arguments; + + return post_ops->append_binarization(kind, weights_data, output_mask_data); +} + +status_t dnnl_post_ops_append_dw_conv(post_ops_t *post_ops, + int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w, + dnnl::impl::data_type_t in_dt) { + if (post_ops == nullptr) + return invalid_arguments; + + return post_ops->append_dw_conv(in_h, in_w, ker_h, ker_w, str_h, str_w, in_dt); +} + status_t dnnl_primitive_attr_set_rnn_data_qparams( primitive_attr_t *attr, const float scale, const float shift) { if (attr == nullptr) return invalid_arguments; @@ -769,3 +938,7 @@ status_t DNNL_API dnnl_primitive_attr_set_rnn_tparams( return attr->rnn_tparams_.set(mode, ngates, scales, cscale); } + +template struct dnnl::impl::shifts_t; +template struct dnnl::impl::shifts_t; +template struct dnnl::impl::shifts_t; \ No newline at end of file diff --git a/src/common/primitive_attr.hpp b/src/common/primitive_attr.hpp index 86b4b268ceb..75171f1b991 100644 --- a/src/common/primitive_attr.hpp +++ b/src/common/primitive_attr.hpp @@ -175,6 +175,58 @@ struct scales_t : public c_compatible { DNNL_DISALLOW_COPY_AND_ASSIGN(scales_t); }; +template +struct shifts_t: public c_compatible { + shifts_t(): count_(1), mask_(0), shifts_(shifts_buf_) + { set(0); } + + ~shifts_t() { cleanup(); } + + bool operator==(const shifts_t &rhs) const { + bool ret = count_ == rhs.count_ && mask_ == rhs.mask_ + && !utils::any_null(shifts_, rhs.shifts_) + && defined() == rhs.defined() + && IMPLICATION(defined(), + utils::array_cmp(shifts_, rhs.shifts_, count_)); + return ret; + } + + bool has_default_values() const { + for (int c = 0; c < count_; ++c) { + if(shifts_[c] != 0) return false; + } + return true; + } + + bool defined() const { return !is_runtime_value(shifts_[0]); } + + status_t set(int count, int mask, const T *zero_points); + status_t set(T single_zero_point) { return this->set(1, 0, &single_zero_point); } + + status_t copy_from(const shifts_t &other) { + return set(other.count_, other.mask_, other.shifts_); + } + + dim_t count_; + int mask_; + T *shifts_; + +private: + enum { shifts_buf_size = 16 }; + T shifts_buf_[shifts_buf_size]; + + void cleanup() { + if (shifts_ != shifts_buf_ && shifts_ != nullptr) + impl::free(shifts_); + + count_ = 1; + mask_ = 0; + shifts_ = shifts_buf_; + } + + DNNL_DISALLOW_COPY_AND_ASSIGN(shifts_t); +}; + struct arg_scales_t : public c_compatible { arg_scales_t() { for (const auto &sa : {DNNL_ARG_SRC_0, DNNL_ARG_SRC_1}) { @@ -326,6 +378,26 @@ struct zero_points_t : public c_compatible { } }; +struct legacy_zero_points_t : public c_compatible { + bool operator==(const legacy_zero_points_t &rhs) const { + return count_ == rhs.count_ && mask_ == rhs.mask_; + } + + bool has_default_values() const { + return count_ == 0 && mask_ == 0; + } + + status_t set(dim_t count, int mask) { + count_ = count; + mask_ = mask; + + return status::success; + } + + dim_t count_ = 0; + int mask_ = 0; +}; + } // namespace impl } // namespace dnnl @@ -379,6 +451,52 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible { int mask; }; + struct depthwise_t { + enum depthwise_fields { + scales, + shifts, + + fields_count + }; + + dnnl::impl::alg_kind_t alg; + size_t offset[fields_count]; + }; + + struct quantization_t { + enum quantization_fields { + crop_low, + crop_high, + inp_scale, + inp_shift, + output_scale, + output_shift, + + fields_count + }; + + dnnl::impl::alg_kind_t alg; + bool per_channel[fields_count]; + bool all_default[fields_count]; + size_t offset[fields_count]; + }; + + struct binarization_t { + dnnl::impl::alg_kind_t alg; + const float* weights_data; + const float* output_mask_data; + }; + + struct depthwise_conv_old_t { + int in_h; + int in_w; + int ker_h; + int ker_w; + int str_h; + int str_w; + dnnl::impl::data_type_t in_dt; + }; + dnnl::impl::primitive_kind_t kind = dnnl::impl::primitive_kind::undefined; union { @@ -389,8 +507,12 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible { } sum; eltwise_t eltwise; depthwise_conv_t depthwise_conv; + depthwise_conv_old_t depthwise_conv_old; binary_t binary; prelu_t prelu; + depthwise_t depthwise; + quantization_t quantization; + binarization_t binarization; }; bool is_eltwise(bool require_scale_one = false) const { @@ -428,6 +550,21 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible { return kind == dnnl::impl::primitive_kind::prelu; } + bool is_depthwise() const { + using namespace dnnl::impl; + return kind == primitive_kind::depthwise; + } + + bool is_quantization() const { + using namespace dnnl::impl; + return kind == primitive_kind::quantization; + } + + bool is_binarization() const { + using namespace dnnl::impl; + return kind == primitive_kind::binarization; + } + dnnl::impl::status_t set_depthwise_scales(const float *scales); bool operator==(const entry_t &rhs) const { @@ -448,25 +585,33 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible { && sum.dt == rhs.sum.dt; break; case primitive_kind::convolution: - // Depthwise Only - ret = depthwise_conv.stride == rhs.depthwise_conv.stride - && depthwise_conv.wei_dt - == rhs.depthwise_conv.wei_dt - && depthwise_conv.bias_dt - == rhs.depthwise_conv.bias_dt - && depthwise_conv.dst_dt - == rhs.depthwise_conv.dst_dt - && depthwise_conv.count == rhs.depthwise_conv.count - && depthwise_conv.mask == rhs.depthwise_conv.mask; - if (!ret) break; - - // only call memcmp with valid pointers - if (depthwise_conv.count == 0) break; - ret = !utils::any_null(depthwise_conv.scales, - rhs.depthwise_conv.scales) - && !std::memcmp(depthwise_conv.scales, - rhs.depthwise_conv.scales, - sizeof(float) * depthwise_conv.count); + // todo: [antonvor] uncomment when new behavior of dw convolution fusing from oneDNN 1.6 will be supported +// // Depthwise Only +// ret = depthwise_conv.stride == rhs.depthwise_conv.stride +// && depthwise_conv.wei_dt +// == rhs.depthwise_conv.wei_dt +// && depthwise_conv.bias_dt +// == rhs.depthwise_conv.bias_dt +// && depthwise_conv.dst_dt +// == rhs.depthwise_conv.dst_dt +// && depthwise_conv.count == rhs.depthwise_conv.count +// && depthwise_conv.mask == rhs.depthwise_conv.mask; +// if (!ret) break; +// +// // only call memcmp with valid pointers +// if (depthwise_conv.count == 0) break; +// ret = !utils::any_null(depthwise_conv.scales, +// rhs.depthwise_conv.scales) +// && !std::memcmp(depthwise_conv.scales, +// rhs.depthwise_conv.scales, +// sizeof(float) * depthwise_conv.count); + ret = depthwise_conv_old.in_h == rhs.depthwise_conv_old.in_h + && depthwise_conv_old.in_w == rhs.depthwise_conv_old.in_w + && depthwise_conv_old.ker_h == rhs.depthwise_conv_old.ker_h + && depthwise_conv_old.ker_w == rhs.depthwise_conv_old.ker_w + && depthwise_conv_old.str_h == rhs.depthwise_conv_old.str_h + && depthwise_conv_old.str_w == rhs.depthwise_conv_old.str_w + && depthwise_conv_old.in_dt == rhs.depthwise_conv_old.in_dt; break; case primitive_kind::binary: ret = binary.alg == rhs.binary.alg @@ -476,6 +621,21 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible { case primitive_kind::prelu: ret = prelu.mask == rhs.prelu.mask; break; + case primitive_kind::depthwise: + ret = depthwise.alg == rhs.depthwise.alg + && array_cmp(depthwise.offset, rhs.depthwise.offset, depthwise.fields_count); + break; + case primitive_kind::quantization: + ret = quantization.alg == rhs.quantization.alg + && array_cmp(quantization.per_channel, rhs.quantization.per_channel, quantization.fields_count) + && array_cmp(quantization.all_default, rhs.quantization.all_default, quantization.fields_count) + && array_cmp(quantization.offset, rhs.quantization.offset, quantization.fields_count); + break; + case primitive_kind::binarization: + ret = depthwise.alg == rhs.depthwise.alg + && binarization.weights_data == rhs.binarization.weights_data + && binarization.output_mask_data == rhs.binarization.output_mask_data; + break; default: assert(!"unsupported post_op"); } return ret; @@ -489,9 +649,10 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible { private: void clear() { - if (is_convolution() && depthwise_conv.count - && depthwise_conv.scales) - dnnl::impl::free(depthwise_conv.scales); + // todo: [antonvor] uncomment when new behavior of dw convolution fusing from oneDNN 1.6 will be supported +// if (is_convolution() && depthwise_conv.count +// && depthwise_conv.scales) +// dnnl::impl::free(depthwise_conv.scales); depthwise_conv.scales = nullptr; return; } @@ -502,9 +663,10 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible { // else if(is_relu()) {} seems to be unreliable. memcpying for now. dnnl::impl::utils::array_copy( (char *)this, (char *)&other, sizeof(*this)); - if (other.is_convolution()) { - return set_depthwise_scales(other.depthwise_conv.scales); - } + // todo: [antonvor] uncomment when new behavior of dw convolution fusing from oneDNN 1.6 will be supported +// if (other.is_convolution()) { +// return set_depthwise_scales(other.depthwise_conv.scales); +// } return dnnl::impl::status::success; } }; @@ -524,6 +686,15 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible { dnnl::impl::status_t append_binary(dnnl::impl::alg_kind_t alg, const dnnl::impl::memory_desc_t *user_src1_desc); dnnl::impl::status_t append_prelu(int mask); + dnnl::impl::status_t append_depthwise(dnnl::impl::alg_kind_t alg, size_t offset_size, const size_t* offset); + dnnl::impl::status_t append_quantization(dnnl::impl::alg_kind_t alg, + size_t per_channel_size, const bool* per_channel, + size_t all_default_size, const bool* all_default, + size_t offset_size, const size_t* offset); + dnnl::impl::status_t append_binarization(dnnl::impl::alg_kind_t alg, const float* weights_data, + const float* output_mask_data); + dnnl::impl::status_t append_dw_conv(int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w, + dnnl::impl::data_type_t in_dt); int find(dnnl::impl::primitive_kind_t kind, int start = 0, int stop = -1) const { @@ -543,6 +714,16 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible { return dst_dt; } + int count(dnnl::impl::primitive_kind_t kind, int start = 0, + int stop = -1) const { + if (stop == -1) stop = len(); + stop = dnnl::impl::nstl::min(stop, len()); + int cnt = 0; + for (int idx = start; idx < stop; ++idx) + if (entry_[idx].kind == kind) cnt++; + return cnt; + } + bool defined() const; int len() const { return (int)entry_.size(); } bool has_default_values() const { return len() == 0; } @@ -621,6 +802,9 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible { CHECK(rnn_weights_projection_qparams_.copy_from( other.rnn_weights_projection_qparams_)); CHECK(rnn_tparams_.copy_from(other.rnn_tparams_)); + input_zero_points_ = (other.input_zero_points_); + weights_zero_points_ = (other.weights_zero_points_); + output_compensations_ = (other.output_compensations_); return status::success; } @@ -640,7 +824,10 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible { rnn_weights_qparams = 1u << 8, rnn_tparams = 1u << 9, sum_dt = 1u << 10, - rnn_weights_projection_qparams = 1u << 11 + rnn_weights_projection_qparams = 1u << 11, + input_zero_points = 1 << 12, + weights_zero_points = 1 << 13, + output_compensations = 1 << 14 }; /** Returns true if the attributes have default values. @@ -662,7 +849,10 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible { && rnn_weights_qparams_ == rhs.rnn_weights_qparams_ && rnn_weights_projection_qparams_ == rhs.rnn_weights_projection_qparams_ - && rnn_tparams_ == rhs.rnn_tparams_; + && rnn_tparams_ == rhs.rnn_tparams_ + && input_zero_points_ == rhs.input_zero_points_ + && weights_zero_points_ == rhs.weights_zero_points_ + && output_compensations_ == rhs.output_compensations_; return ret; } @@ -705,6 +895,10 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible { dnnl::impl::scales_t rnn_weights_projection_qparams_; dnnl::impl::rnn_tparams_t rnn_tparams_; + dnnl::impl::legacy_zero_points_t input_zero_points_; + dnnl::impl::legacy_zero_points_t weights_zero_points_; + dnnl::impl::legacy_zero_points_t output_compensations_; + dnnl_primitive_attr &operator=(const dnnl_primitive_attr &other) = delete; }; diff --git a/src/common/primitive_desc.cpp b/src/common/primitive_desc.cpp index 8dfbfb4d0d8..6fcc3436910 100644 --- a/src/common/primitive_desc.cpp +++ b/src/common/primitive_desc.cpp @@ -43,6 +43,12 @@ int primitive_desc_t::n_binary_po_inputs() const { int primitive_desc_t::n_prelu_po_inputs() const { return po_inputs(attr()->post_ops_, primitive_kind::prelu); } +int primitive_desc_t::n_depthwise_po_inputs() const { + return po_inputs(attr()->post_ops_, primitive_kind::depthwise); +} +int primitive_desc_t::n_quantization_po_inputs() const { + return po_inputs(attr()->post_ops_, primitive_kind::quantization); +} status_t dnnl_primitive_desc::create_primitive_iface( std::pair &primitive_iface) const { diff --git a/src/common/primitive_desc.hpp b/src/common/primitive_desc.hpp index 138caf50459..92606eb1ebc 100644 --- a/src/common/primitive_desc.hpp +++ b/src/common/primitive_desc.hpp @@ -81,6 +81,15 @@ struct primitive_desc_t : public c_compatible { if ((arg & DNNL_ARG_ATTR_ZERO_POINTS) && !attr()->zero_points_.defined(arg)) return arg_usage_t::input; + if ((arg & (DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC)) + && !attr()->input_zero_points_.has_default_values()) + return arg_usage_t::input; + if ((arg & (DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS)) + && !attr()->weights_zero_points_.has_default_values()) + return arg_usage_t::input; + if ((arg & (DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST)) + && !attr()->output_compensations_.has_default_values()) + return arg_usage_t::input; if ((arg == (DNNL_ARG_ATTR_INPUT_SCALES | DNNL_ARG_SRC_0)) && !attr()->scales_.get(DNNL_ARG_SRC_0).defined()) return arg_usage_t::input; @@ -91,12 +100,16 @@ struct primitive_desc_t : public c_compatible { return arg_usage_t::output; for (int idx = 0; idx < attr()->post_ops_.len(); ++idx) { using namespace primitive_kind; - if (post_op_has_proper_input( - attr(), binary, idx, arg, DNNL_ARG_SRC_1) - || post_op_has_proper_input( - attr(), prelu, idx, arg, DNNL_ARG_WEIGHTS)) + if (post_op_has_proper_input(attr(), binary, idx, arg, DNNL_ARG_SRC_1) || + post_op_has_proper_input(attr(), depthwise, idx, arg, DNNL_ARG_SRC_1) || + post_op_has_proper_input(attr(), quantization, idx, arg, DNNL_ARG_SRC_1) || + post_op_has_proper_input(attr(), prelu, idx, arg, DNNL_ARG_WEIGHTS)) return arg_usage_t::input; } + if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)) + return arg_usage_t::input; + if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS)) + return arg_usage_t::input; return arg_usage_t::unused; } @@ -210,6 +223,8 @@ struct primitive_desc_t : public c_compatible { virtual int n_outputs() const { return 0; } int n_binary_po_inputs() const; int n_prelu_po_inputs() const; + int n_depthwise_po_inputs() const; + int n_quantization_po_inputs() const; // The `hint_mds(bool is_hint)` returns a vector of memory descriptors // that might affect the equality of primitive descriptors for backward pass. // diff --git a/src/common/primitive_hashing.cpp b/src/common/primitive_hashing.cpp index 5220e885ea0..6f4cbdc1c84 100644 --- a/src/common/primitive_hashing.cpp +++ b/src/common/primitive_hashing.cpp @@ -119,185 +119,6 @@ bool key_t::operator==(const key_t &rhs) const { return true; } -// Combine hash of each memory_desc_t data member -size_t get_md_hash(const memory_desc_t &md) { - size_t seed = 0; - seed = get_array_hash(seed, md.dims, md.ndims); - seed = hash_combine(seed, static_cast(md.data_type)); - seed = get_array_hash(seed, md.padded_dims, md.ndims); - seed = get_array_hash(seed, md.padded_offsets, md.ndims); - seed = hash_combine(seed, md.offset0); - seed = hash_combine(seed, static_cast(md.format_kind)); - // format desc - switch (md.format_kind) { - case format_kind::undef: - case format_kind::any: break; - case format_kind::blocked: - for (int i = 0; i < md.ndims; i++) { - if (md.dims[i] == 1 && md.padded_dims[i] == 1) continue; - seed = hash_combine(seed, md.format_desc.blocking.strides[i]); - } - seed = hash_combine(seed, md.format_desc.blocking.inner_nblks); - seed = get_array_hash(seed, md.format_desc.blocking.inner_blks, - md.format_desc.blocking.inner_nblks); - seed = get_array_hash(seed, md.format_desc.blocking.inner_idxs, - md.format_desc.blocking.inner_nblks); - break; - case format_kind::wino: - seed = hash_combine(seed, - static_cast(md.format_desc.wino_desc.wino_format)); - seed = hash_combine(seed, md.format_desc.wino_desc.r); - seed = hash_combine(seed, md.format_desc.wino_desc.alpha); - seed = hash_combine(seed, md.format_desc.wino_desc.ic); - seed = hash_combine(seed, md.format_desc.wino_desc.oc); - seed = hash_combine(seed, md.format_desc.wino_desc.ic_block); - seed = hash_combine(seed, md.format_desc.wino_desc.oc_block); - seed = hash_combine(seed, md.format_desc.wino_desc.ic2_block); - seed = hash_combine(seed, md.format_desc.wino_desc.oc2_block); - seed = hash_combine(seed, md.format_desc.wino_desc.adj_scale); - seed = hash_combine(seed, md.format_desc.wino_desc.size); - break; - case format_kind::rnn_packed: - seed = hash_combine(seed, - static_cast(md.format_desc.rnn_packed_desc.format)); - seed = hash_combine(seed, md.format_desc.rnn_packed_desc.n_parts); - seed = hash_combine(seed, md.format_desc.rnn_packed_desc.n); - seed = hash_combine(seed, md.format_desc.rnn_packed_desc.ldb); - { - int n_parts = md.format_desc.rnn_packed_desc.n_parts; - seed = get_array_hash( - seed, md.format_desc.rnn_packed_desc.parts, n_parts); - seed = get_array_hash(seed, - md.format_desc.rnn_packed_desc.part_pack_size, n_parts); - seed = get_array_hash(seed, - md.format_desc.rnn_packed_desc.pack_part, n_parts); - } - seed = hash_combine( - seed, md.format_desc.rnn_packed_desc.offset_compensation); - seed = hash_combine(seed, md.format_desc.rnn_packed_desc.size); - break; - default: assert(!"unknown format_kind"); - } - - if (md.extra.flags != dnnl_memory_extra_flag_none) { - seed = hash_combine(seed, md.extra.flags); - if (md.extra.flags - & (dnnl_memory_extra_flag_compensation_conv_s8s8 - | dnnl_memory_extra_flag_rnn_u8s8_compensation)) { - seed = hash_combine(seed, md.extra.compensation_mask); - } - - if (md.extra.flags & dnnl_memory_extra_flag_scale_adjust) { - seed = hash_combine(seed, md.extra.scale_adjust); - } - - if (md.extra.flags - & dnnl_memory_extra_flag_compensation_conv_asymmetric_src) { - seed = hash_combine(seed, md.extra.asymm_compensation_mask); - } - } - // Combined hash for a memory descriptor - return seed; -} - -// Combine hash of each primitive_attr_t data member -size_t get_attr_hash(const primitive_attr_t &attr) { - size_t seed = 0; - // scratchpad_mode - seed = hash_combine(seed, static_cast(attr.scratchpad_mode_)); - // fpmath_mode - seed = hash_combine(seed, static_cast(attr.fpmath_mode_)); - - if (!attr.output_scales_.has_default_values()) { - // output_scales: mask - seed = hash_combine(seed, attr.output_scales_.mask_); - // output_scales: count - seed = hash_combine(seed, attr.output_scales_.count_); - // output_scales: scales[:] - seed = get_array_hash( - seed, attr.output_scales_.scales_, attr.output_scales_.count_); - } else if (!attr.scales_.has_default_values()) { - // go through scales for all arguments - for (const auto &p : attr.scales_.scales_) { - seed = hash_combine(seed, p.second.mask_); - seed = hash_combine(seed, p.second.count_); - seed = get_array_hash(seed, p.second.scales_, p.second.count_); - } - } - // zero_points - for (int arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) - if (!attr.zero_points_.has_default_values(arg)) { - dim_t count = 0; - int mask = 0; - const int *zero_points = nullptr; - attr.zero_points_.get(arg, &count, &mask, &zero_points); - // zero_points: count - seed = hash_combine(seed, count); - // zero_points: mask - seed = hash_combine(seed, mask); - // zero_points: zero_points[:] - seed = get_array_hash(seed, zero_points, count); - } - // post_ops: entry[:] - for (int i = 0; i < attr.post_ops_.len(); i++) { - const auto &entry = attr.post_ops_.entry_[i]; - switch (entry.kind) { - case primitive_kind::eltwise: - seed = hash_combine( - seed, static_cast(entry.eltwise.alg)); - seed = hash_combine(seed, entry.eltwise.scale); - seed = hash_combine(seed, entry.eltwise.alpha); - seed = hash_combine(seed, entry.eltwise.beta); - break; - case primitive_kind::sum: - seed = hash_combine(seed, entry.sum.scale); - seed = hash_combine(seed, static_cast(entry.sum.dt)); - break; - case primitive_kind::convolution: - seed = hash_combine( - seed, static_cast(entry.depthwise_conv.stride)); - seed = hash_combine( - seed, static_cast(entry.depthwise_conv.wei_dt)); - seed = hash_combine(seed, - static_cast(entry.depthwise_conv.bias_dt)); - seed = hash_combine( - seed, static_cast(entry.depthwise_conv.dst_dt)); - if (entry.depthwise_conv.scales) { - seed = hash_combine(seed, entry.depthwise_conv.mask); - seed = hash_combine(seed, entry.depthwise_conv.count); - seed = get_array_hash(seed, entry.depthwise_conv.scales, - entry.depthwise_conv.count); - } - break; - case primitive_kind::binary: - seed = hash_combine( - seed, static_cast(entry.binary.alg)); - seed = hash_combine( - seed, get_md_hash(entry.binary.user_src1_desc)); - break; - case primitive_kind::prelu: - seed = hash_combine( - seed, static_cast(entry.prelu.mask)); - break; - default: assert(!"unknown post_op"); - } - } - // rnn_data_qparams: scale, shift - seed = hash_combine(seed, attr.rnn_data_qparams_.scale_); - seed = hash_combine(seed, attr.rnn_data_qparams_.shift_); - if (!attr.rnn_weights_qparams_.has_default_values()) { - // rnn_weights_qparams: mask - seed = hash_combine(seed, attr.rnn_weights_qparams_.mask_); - // rnn_weights_qparams: count - seed = hash_combine(seed, attr.rnn_weights_qparams_.count_); - // rnn_weights_qparams: scales[:] - seed = get_array_hash(seed, attr.rnn_weights_qparams_.scales_, - attr.rnn_weights_qparams_.count_); - } - // Combined hash for attributes - return seed; -} - // Functions that compute hash for different op_descs size_t get_desc_hash(const concat_desc_t &desc) { size_t seed = 0; diff --git a/src/common/primitive_hashing.hpp b/src/common/primitive_hashing.hpp index d26e26bb49f..1f6707d0d2d 100644 --- a/src/common/primitive_hashing.hpp +++ b/src/common/primitive_hashing.hpp @@ -20,10 +20,7 @@ #include #include -#include "c_types_map.hpp" -#include "oneapi/dnnl/dnnl.h" -#include "primitive_attr.hpp" -#include "type_helpers.hpp" +#include "primitive_hashing_utils.hpp" #ifdef DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE #include "engine_id.hpp" @@ -75,8 +72,6 @@ struct key_t { std::thread::id thread_id_; }; -size_t get_md_hash(const memory_desc_t &md); -size_t get_attr_hash(const primitive_attr_t &attr); size_t get_desc_hash(const concat_desc_t &desc); size_t get_desc_hash(const batch_normalization_desc_t &desc); size_t get_desc_hash(const binary_desc_t &desc); @@ -99,31 +94,6 @@ size_t get_desc_hash(const softmax_desc_t &desc); size_t get_desc_hash(const sum_desc_t &desc); size_t get_desc_hash(const zero_pad_desc_t &desc); -template -size_t get_array_hash(size_t seed, const T *v, int size) { - for (int i = 0; i < size; i++) { - seed = hash_combine(seed, v[i]); - } - return seed; -} - -template <> -inline size_t get_array_hash( - size_t seed, const memory_desc_t *v, int size) { - for (int i = 0; i < size; i++) { - seed = hash_combine(seed, get_md_hash(v[i])); - } - return seed; -} - -template <> -inline size_t get_array_hash(size_t seed, const float *v, int size) { - for (int i = 0; i < size; i++) { - seed = hash_combine(seed, float2int(v[i])); - } - return seed; -} - } // namespace primitive_hashing } // namespace impl } // namespace dnnl diff --git a/src/common/primitive_hashing_utils.cpp b/src/common/primitive_hashing_utils.cpp new file mode 100644 index 00000000000..59c4b6482ff --- /dev/null +++ b/src/common/primitive_hashing_utils.cpp @@ -0,0 +1,215 @@ +/******************************************************************************* +* Copyright 2019-2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "utils.hpp" +#include "primitive_hashing_utils.hpp" + +namespace dnnl { +namespace impl { +namespace primitive_hashing { + +// Combine hash of each memory_desc_t data member +size_t get_md_hash(const memory_desc_t &md) { + size_t seed = 0; + seed = get_array_hash(seed, md.dims, md.ndims); + seed = hash_combine(seed, static_cast(md.data_type)); + seed = get_array_hash(seed, md.padded_dims, md.ndims); + seed = get_array_hash(seed, md.padded_offsets, md.ndims); + seed = hash_combine(seed, md.offset0); + seed = hash_combine(seed, static_cast(md.format_kind)); + // format desc + switch (md.format_kind) { + case format_kind::undef: + case format_kind::any: break; + case format_kind::blocked: + for (int i = 0; i < md.ndims; i++) { + if (md.dims[i] == 1 && md.padded_dims[i] == 1) continue; + seed = hash_combine(seed, md.format_desc.blocking.strides[i]); + } + seed = hash_combine(seed, md.format_desc.blocking.inner_nblks); + seed = get_array_hash(seed, md.format_desc.blocking.inner_blks, + md.format_desc.blocking.inner_nblks); + seed = get_array_hash(seed, md.format_desc.blocking.inner_idxs, + md.format_desc.blocking.inner_nblks); + break; + case format_kind::wino: + seed = hash_combine(seed, + static_cast(md.format_desc.wino_desc.wino_format)); + seed = hash_combine(seed, md.format_desc.wino_desc.r); + seed = hash_combine(seed, md.format_desc.wino_desc.alpha); + seed = hash_combine(seed, md.format_desc.wino_desc.ic); + seed = hash_combine(seed, md.format_desc.wino_desc.oc); + seed = hash_combine(seed, md.format_desc.wino_desc.ic_block); + seed = hash_combine(seed, md.format_desc.wino_desc.oc_block); + seed = hash_combine(seed, md.format_desc.wino_desc.ic2_block); + seed = hash_combine(seed, md.format_desc.wino_desc.oc2_block); + seed = hash_combine(seed, md.format_desc.wino_desc.adj_scale); + seed = hash_combine(seed, md.format_desc.wino_desc.size); + break; + case format_kind::rnn_packed: + seed = hash_combine(seed, + static_cast(md.format_desc.rnn_packed_desc.format)); + seed = hash_combine(seed, md.format_desc.rnn_packed_desc.n_parts); + seed = hash_combine(seed, md.format_desc.rnn_packed_desc.n); + seed = hash_combine(seed, md.format_desc.rnn_packed_desc.ldb); + { + int n_parts = md.format_desc.rnn_packed_desc.n_parts; + seed = get_array_hash( + seed, md.format_desc.rnn_packed_desc.parts, n_parts); + seed = get_array_hash(seed, + md.format_desc.rnn_packed_desc.part_pack_size, n_parts); + seed = get_array_hash(seed, + md.format_desc.rnn_packed_desc.pack_part, n_parts); + } + seed = hash_combine( + seed, md.format_desc.rnn_packed_desc.offset_compensation); + seed = hash_combine(seed, md.format_desc.rnn_packed_desc.size); + break; + default: assert(!"unknown format_kind"); + } + + if (md.extra.flags != dnnl_memory_extra_flag_none) { + seed = hash_combine(seed, md.extra.flags); + if (md.extra.flags + & (dnnl_memory_extra_flag_compensation_conv_s8s8 + | dnnl_memory_extra_flag_rnn_u8s8_compensation)) { + seed = hash_combine(seed, md.extra.compensation_mask); + } + + if (md.extra.flags & dnnl_memory_extra_flag_scale_adjust) { + seed = hash_combine(seed, md.extra.scale_adjust); + } + + if (md.extra.flags + & dnnl_memory_extra_flag_compensation_conv_asymmetric_src) { + seed = hash_combine(seed, md.extra.asymm_compensation_mask); + } + } + // Combined hash for a memory descriptor + return seed; +} + +// Combine hash of each primitive_attr_t data member +size_t get_attr_hash(const primitive_attr_t &attr) { + size_t seed = 0; + // scratchpad_mode + seed = hash_combine(seed, static_cast(attr.scratchpad_mode_)); + // fpmath_mode + seed = hash_combine(seed, static_cast(attr.fpmath_mode_)); + + if (!attr.output_scales_.has_default_values()) { + // output_scales: mask + seed = hash_combine(seed, attr.output_scales_.mask_); + // output_scales: count + seed = hash_combine(seed, attr.output_scales_.count_); + // output_scales: scales[:] + seed = get_array_hash( + seed, attr.output_scales_.scales_, attr.output_scales_.count_); + } else if (!attr.scales_.has_default_values()) { + // go through scales for all arguments + for (const auto &p : attr.scales_.scales_) { + seed = hash_combine(seed, p.second.mask_); + seed = hash_combine(seed, p.second.count_); + seed = get_array_hash(seed, p.second.scales_, p.second.count_); + } + } + // zero_points + for (int arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) + if (!attr.zero_points_.has_default_values(arg)) { + dim_t count = 0; + int mask = 0; + const int *zero_points = nullptr; + attr.zero_points_.get(arg, &count, &mask, &zero_points); + // zero_points: count + seed = hash_combine(seed, count); + // zero_points: mask + seed = hash_combine(seed, mask); + // zero_points: zero_points[:] + seed = get_array_hash(seed, zero_points, count); + } + // post_ops: entry[:] + seed = get_post_op_hash(seed, attr.post_ops_); + // rnn_data_qparams: scale, shift + seed = hash_combine(seed, attr.rnn_data_qparams_.scale_); + seed = hash_combine(seed, attr.rnn_data_qparams_.shift_); + if (!attr.rnn_weights_qparams_.has_default_values()) { + // rnn_weights_qparams: mask + seed = hash_combine(seed, attr.rnn_weights_qparams_.mask_); + // rnn_weights_qparams: count + seed = hash_combine(seed, attr.rnn_weights_qparams_.count_); + // rnn_weights_qparams: scales[:] + seed = get_array_hash(seed, attr.rnn_weights_qparams_.scales_, + attr.rnn_weights_qparams_.count_); + } + // Combined hash for attributes + return seed; +} + +// Combine hash of each post_ops::entry_ +size_t get_post_op_hash(size_t seed, const post_ops_t &post_ops) { + for (int i = 0; i < post_ops.len(); i++) { + const auto &entry = post_ops.entry_[i]; + switch (entry.kind) { + case primitive_kind::eltwise: + seed = hash_combine( + seed, static_cast(entry.eltwise.alg)); + seed = hash_combine(seed, entry.eltwise.scale); + seed = hash_combine(seed, entry.eltwise.alpha); + seed = hash_combine(seed, entry.eltwise.beta); + break; + case primitive_kind::sum: + seed = hash_combine(seed, entry.sum.scale); + seed = hash_combine(seed, static_cast(entry.sum.dt)); + break; + case primitive_kind::convolution: + seed = hash_combine(seed, static_cast(entry.depthwise_conv_old.in_h)); + seed = hash_combine(seed, static_cast(entry.depthwise_conv_old.in_w)); + seed = hash_combine(seed, static_cast(entry.depthwise_conv_old.ker_h)); + seed = hash_combine(seed, static_cast(entry.depthwise_conv_old.ker_w)); + seed = hash_combine(seed, static_cast(entry.depthwise_conv_old.str_h)); + seed = hash_combine(seed, static_cast(entry.depthwise_conv_old.str_w)); + seed = hash_combine(seed, static_cast(entry.depthwise_conv_old.in_dt)); + break; + case primitive_kind::binary: + seed = hash_combine( + seed, static_cast(entry.binary.alg)); + seed = hash_combine( + seed, get_md_hash(entry.binary.user_src1_desc)); + break; + case primitive_kind::prelu: + seed = hash_combine( + seed, static_cast(entry.prelu.mask)); + break; + case primitive_kind::depthwise: + seed = hash_combine(seed, static_cast(entry.depthwise.alg)); + seed = get_array_hash(seed, entry.depthwise.offset, entry.depthwise.fields_count); + break; + case primitive_kind::quantization: + seed = hash_combine(seed, static_cast(entry.quantization.alg)); + seed = get_array_hash(seed, entry.quantization.per_channel, entry.quantization.fields_count); + seed = get_array_hash(seed, entry.quantization.all_default, entry.quantization.fields_count); + seed = get_array_hash(seed, entry.quantization.offset, entry.quantization.fields_count); + break; + default: assert(!"unknown post_op"); + } + } + + return seed; +} + +} // namespace primitive_hashing +} // namespace impl +} // namespace dnnl \ No newline at end of file diff --git a/src/common/primitive_hashing_utils.hpp b/src/common/primitive_hashing_utils.hpp new file mode 100644 index 00000000000..361faee16f0 --- /dev/null +++ b/src/common/primitive_hashing_utils.hpp @@ -0,0 +1,67 @@ +/******************************************************************************* +* Copyright 2019-2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef COMMON_PRIMITIVE_HASHING_UTILS_HPP +#define COMMON_PRIMITIVE_HASHING_UTILS_HPP + +#include "c_types_map.hpp" +#include "oneapi/dnnl/dnnl.h" +#include "primitive_attr.hpp" +#include "type_helpers.hpp" + +namespace dnnl { +namespace impl { +namespace primitive_hashing { + +size_t get_md_hash(const memory_desc_t &md); +size_t get_attr_hash(const primitive_attr_t &attr); +size_t get_post_op_hash(size_t seed, const post_ops_t &post_ops); + +template +size_t get_array_hash(size_t seed, const T *v, int size) { + for (int i = 0; i < size; i++) { + seed = hash_combine(seed, v[i]); + } + return seed; +} + +template <> +inline size_t get_array_hash( + size_t seed, const memory_desc_t *v, int size) { + for (int i = 0; i < size; i++) { + seed = hash_combine(seed, get_md_hash(v[i])); + } + return seed; +} + +template <> +inline size_t get_array_hash(size_t seed, const float *v, int size) { + for (int i = 0; i < size; i++) { + seed = hash_combine(seed, float2int(v[i])); + } + return seed; +} + +template +size_t get_vector_hash(size_t seed, const std::vector &vec) { + return get_array_hash(seed, vec.data(), vec.size()); +} + +} // namespace primitive_hashing +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/common/reorder.cpp b/src/common/reorder.cpp index 2c7cb108c75..aaef6197e2f 100644 --- a/src/common/reorder.cpp +++ b/src/common/reorder.cpp @@ -69,7 +69,8 @@ status_t reorder_primitive_desc_create(std::shared_ptr &pd, auto s_mdw = memory_desc_wrapper(*src_md); auto d_mdw = memory_desc_wrapper(*dst_md); - if (!s_mdw.consistent_with(d_mdw)) return invalid_arguments; + if (!s_mdw.consistent_with(d_mdw)) + return invalid_arguments; if (attr == nullptr) attr = &default_attr(); diff --git a/src/common/tag_traits.hpp b/src/common/tag_traits.hpp index 50880e57ba0..7ccba7513e0 100644 --- a/src/common/tag_traits.hpp +++ b/src/common/tag_traits.hpp @@ -58,6 +58,7 @@ enum class inner_blk_t { _4c4b, _8a8b, _8b8a, + _8a32b, _8b8c, _8c8b, _16a16b, @@ -133,7 +134,7 @@ constexpr int AB_or_BC_blk_off(int x0, int x1) { using ib = inner_blk_t; static_assert( utils::one_of(f, ib::_4a4b, ib::_4b4a, ib::_4b4c, ib::_4c4b, - ib::_8a8b, ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_16a16b, + ib::_8a8b, ib::_8a32b, ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_16a16b, ib::_16b64a, ib::_16b32a, ib::_16b16a, ib::_16b16c, ib::_16c16b, ib::_32a32b, ib::_16a2b, ib::_16a4b, ib::_16b2c, ib::_16b4c, ib::_2c8b4c, ib::_8a16b2a, @@ -155,9 +156,9 @@ constexpr int AB_or_BC_blk_off(int x0, int x1) { return false ? 0 : (f == ib::_4a4b || f == ib::_4b4c) ? 4 * x0 + x1 : (f == ib::_4b4a || f == ib::_4c4b) ? 4 * x1 + x0 - : (f == ib::_8a8b || f == ib::_8b8c) ? 8 * x0 + x1 + : (f == ib::_8a8b || f == ib::_8a32b || f == ib::_8b8c) ? 8 * x0 + x1 : (f == ib::_8b8a || f == ib::_8c8b) ? 8 * x1 + x0 - : (f == ib::_16a16b || f == ib::_16b16c) ? 16 * x0 + x1 + : (f == ib::_16a16b || f == ib::_16a32b || f == ib::_16b16c) ? 16 * x0 + x1 : (f == ib::_16b64a) ? 64 * x1 + x0 : (f == ib::_16b32a) ? 32 * x1 + x0 : (f == ib::_16b16a || f == ib::_16c16b) ? 16 * x1 + x0 @@ -378,6 +379,8 @@ DECL_TRAITS(ABcde16b48a2b, _AB, _16b48a2b, 5); DECL_TRAITS(ABcde16b64a2b, _AB, _16b64a2b, 5); DECL_TRAITS(ABcd8a16b2a, _AB, _8a16b2a, 4); DECL_TRAITS(ABcd8a8b, _AB, _8a8b, 4); +DECL_TRAITS(ABcd8a32b, _AB, _8a32b, 4); +DECL_TRAITS(ABcd16a32b, _AB, _16a32b, 4); DECL_TRAITS(aBcd8b, _B, _8b, 4); DECL_TRAITS(ABcd8b16a2b, _AB, _8b16a2b, 4); DECL_TRAITS(ABcd8b32a2b, _AB, _8b32a2b, 4); @@ -473,6 +476,8 @@ DECL_TRAITS(aBCde2b8c8b2c, _BC, _2b8c8b2c, 5); DECL_TRAITS(aBdec32b, _B, _32b, 5); DECL_TRAITS(aCBdef16c16b, _BC, _16c16b, 6); DECL_TRAITS(aCBdef16b16c, _BC, _16b16c, 6); +DECL_TRAITS(Abcdef4a, _A, _4a, 6); +DECL_TRAITS(Abcdef8a, _A, _8a, 6); DECL_TRAITS(Abcdef16a, _A, _16a, 6); DECL_TRAITS(aCBd16c16b, _BC, _16c16b, 4); DECL_TRAITS(aCBde16c16b, _BC, _16c16b, 4); diff --git a/src/common/type_helpers.hpp b/src/common/type_helpers.hpp index 72c2bbbfc70..151a32779e4 100644 --- a/src/common/type_helpers.hpp +++ b/src/common/type_helpers.hpp @@ -88,6 +88,7 @@ inline size_t data_type_size(data_type_t data_type) { case s32: return sizeof(prec_traits::type); case s8: return sizeof(prec_traits::type); case u8: return sizeof(prec_traits::type); + case bin: return sizeof(prec_traits::type); case data_type::undef: default: assert(!"unknown data_type"); } @@ -184,10 +185,7 @@ inline bool blocking_desc_is_equal(const memory_desc_t &lhs_md, && array_cmp(lhs.inner_idxs, rhs.inner_idxs, lhs.inner_nblks); if (ignore_strides) return equal; - // Check the strides. - // Note: for dimensions of size `1` the stride doesn't really matter. for (int d = 0; d < lhs_md.ndims; ++d) { - if (lhs_md.dims[d] == 1 && lhs_md.padded_dims[d] == 1) continue; equal = equal && lhs.strides[d] == rhs.strides[d]; } @@ -821,7 +819,7 @@ inline bool memory_desc_sanity_check(int ndims, const dims_t dims, if (ndims == 0) return true; bool ok = dims != nullptr && 0 < ndims && ndims <= DNNL_MAX_NDIMS - && utils::one_of(data_type, f16, bf16, f32, s32, s8, u8); + && utils::one_of(data_type, f16, bf16, f32, s32, s8, u8, bin); if (!ok) return false; bool has_runtime_dims = false; diff --git a/src/common/utils.hpp b/src/common/utils.hpp index 60ce60972c5..2ebc51a4271 100644 --- a/src/common/utils.hpp +++ b/src/common/utils.hpp @@ -213,6 +213,12 @@ inline void array_set(T *arr, const U &val, size_t size) { arr[i] = static_cast(val); } +inline bool array_cmp_weak(const dnnl_dim_t *a1, const dnnl_dim_t *a2, size_t size) { + for (size_t i = 0; i < size; ++i) + if (a1[i] != a2[i] && a1[i] != DNNL_RUNTIME_DIM_VAL && a2[i] != DNNL_RUNTIME_DIM_VAL) return false; + return true; +} + namespace product_impl { template struct int2type {}; @@ -629,11 +635,17 @@ struct setting_t { // Copyright 2005-2014 Daniel James. // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE or copy at http://www.boost.org/LICENSE_1_0.txt) -template +template ::value , int>::type = 0> static size_t hash_combine(size_t seed, const T &v) { return seed ^= std::hash {}(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); } +template ::value , int>::type = 0> +static size_t hash_combine(size_t seed, const T &v) { + using underlying_t = typename std::underlying_type::type; + return hash_combine(seed, static_cast(v)); +} + inline int float2int(float x) { return utils::bit_cast(x); } @@ -709,7 +721,7 @@ struct set_once_before_first_get_setting_t { inline bool is_native_runtime(runtime_kind_t kind) { return utils::one_of(kind, runtime_kind::seq, runtime_kind::omp, - runtime_kind::tbb, runtime_kind::threadpool); + runtime_kind::tbb, runtime_kind::tbb_auto, runtime_kind::threadpool); } } // namespace impl diff --git a/src/common/verbose.cpp b/src/common/verbose.cpp index 8bba501fd3c..22b7fee4ee4 100644 --- a/src/common/verbose.cpp +++ b/src/common/verbose.cpp @@ -421,15 +421,17 @@ std::ostream &operator<<(std::ostream &ss, const primitive_attr_t *attr) { if (s.dt != data_type::undef) ss << ":" << s.dt; } break; case primitive_kind::convolution: { - using namespace data_type; - const auto &c = e.depthwise_conv; - ss << delim << "dw_k3s" << c.stride << "p1"; - if (c.wei_dt == s8 || c.dst_dt != f32) - ss << ":" << c.dst_dt; - if (c.count > 0 && c.wei_dt == s8) { - ss << ":" << c.mask; - if (c.mask == 0) ss << ":" << c.scales[0]; - } +// using namespace data_type; +// const auto &c = e.depthwise_conv; +// ss << delim << "dw_k3s" << c.stride << "p1"; +// if (c.wei_dt == s8 || c.dst_dt != f32) +// ss << ":" << c.dst_dt; +// if (c.count > 0 && c.wei_dt == s8) { +// ss << ":" << c.mask; +// if (c.mask == 0) ss << ":" << c.scales[0]; +// } + const char *alg_str = "depthwise_conv_old"; + ss << delim << alg_str; } break; case primitive_kind::eltwise: { const post_ops_t::entry_t::eltwise_t &ew = e.eltwise; @@ -454,6 +456,14 @@ std::ostream &operator<<(std::ostream &ss, const primitive_attr_t *attr) { ss << delim << "prelu" << ":" << ep.mask; } break; + case primitive_kind::depthwise: { + const post_ops_t::entry_t::depthwise_t &dw = e.depthwise; + ss << delim << dw.alg; + } break; + case primitive_kind::quantization: { + const post_ops_t::entry_t::quantization_t &qt = e.quantization; + ss << delim << qt.alg; + } break; default: assert(!"unsupported post op primitive kind!"); break; } delim = attr_delim; diff --git a/src/cpu/binary_injector_utils.cpp b/src/cpu/binary_injector_utils.cpp index 41b267bea70..5ad87992c4b 100644 --- a/src/cpu/binary_injector_utils.cpp +++ b/src/cpu/binary_injector_utils.cpp @@ -30,7 +30,7 @@ std::vector prepare_binary_args(const post_ops_t &post_ops, unsigned idx = first_arg_idx_offset; for (const auto &post_op : post_ops.entry_) { - if (post_op.is_binary()) { + if (post_op.is_binary() || post_op.is_depthwise() || post_op.is_quantization()) { post_ops_binary_rhs_arg_vec.emplace_back(CTX_IN_MEM(const void *, DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1)); } diff --git a/src/cpu/cpu_convolution_list.cpp b/src/cpu/cpu_convolution_list.cpp index 7ab42a55346..cec3c033bad 100644 --- a/src/cpu/cpu_convolution_list.cpp +++ b/src/cpu/cpu_convolution_list.cpp @@ -34,6 +34,7 @@ #include "cpu/x64/gemm_bf16_convolution.hpp" #include "cpu/x64/ip_convolution.hpp" #include "cpu/x64/jit_avx2_1x1_convolution.hpp" +#include "cpu/x64/jit_avx2_1x1_convolution_with_dw_conv.hpp" #include "cpu/x64/jit_avx2_convolution.hpp" #include "cpu/x64/jit_avx512_common_1x1_convolution.hpp" #include "cpu/x64/jit_avx512_common_convolution.hpp" @@ -53,8 +54,10 @@ #include "cpu/x64/jit_sse41_1x1_convolution.hpp" #include "cpu/x64/jit_sse41_convolution.hpp" #include "cpu/x64/jit_uni_dw_convolution.hpp" +#include "cpu/x64/jit_uni_fork_dw_convolution.hpp" #include "cpu/x64/jit_uni_x8s8s32x_1x1_convolution.hpp" #include "cpu/x64/jit_uni_x8s8s32x_convolution.hpp" +#include "cpu/x64/jit_uni_planar_convolution.hpp" using namespace dnnl::impl::cpu::x64; #elif DNNL_AARCH64 #include "cpu/aarch64/jit_sve_512_1x1_convolution.hpp" @@ -84,16 +87,22 @@ const std::map> impl_list_map RE CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) CPU_INSTANCE_X64(ip_convolution_fwd_t) CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) + CPU_INSTANCE_AVX512(jit_avx512_common_planar_convolution_fwd_t) CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) CPU_INSTANCE_AVX512(jit_avx512_common_dw_convolution_fwd_t) + CPU_INSTANCE_AVX512(jit_avx512_common_fork_dw_convolution_fwd_t) CPU_INSTANCE_AVX512(jit_avx512_common_1x1_convolution_fwd_f32_t) CPU_INSTANCE_AVX512(jit_avx512_core_f32_wino_conv_2x3_fwd_t) CPU_INSTANCE_AVX512(jit_avx512_core_f32_wino_conv_4x3_fwd_t) CPU_INSTANCE_AVX512(jit_avx512_common_convolution_winograd_fwd_t) CPU_INSTANCE_AVX512(jit_avx512_common_convolution_fwd_t) + CPU_INSTANCE_AVX2(jit_avx2_planar_convolution_fwd_t) CPU_INSTANCE_AVX2(jit_avx2_dw_convolution_fwd_t) + CPU_INSTANCE_AVX2(jit_avx2_fork_dw_convolution_fwd_t) + CPU_INSTANCE_AVX2(jit_avx2_1x1_convolution_with_dw_conv_fwd_t) CPU_INSTANCE_AVX2(jit_avx2_1x1_convolution_fwd_t) CPU_INSTANCE_SSE41(jit_sse41_dw_convolution_fwd_t) + CPU_INSTANCE_SSE41(jit_sse41_fork_dw_convolution_fwd_t) CPU_INSTANCE_SSE41(jit_sse41_1x1_convolution_fwd_t) CPU_INSTANCE_AVX2(jit_avx2_convolution_fwd_t) CPU_INSTANCE_SSE41(jit_sse41_convolution_fwd_t) @@ -118,6 +127,7 @@ const std::map> impl_list_map RE CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) CPU_INSTANCE_AVX512(jit_uni_dw_convolution_fwd_t) + CPU_INSTANCE_AVX512(jit_uni_fork_dw_convolution_fwd_t) CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_fwd_t) CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_fwd_t) CPU_INSTANCE_AVX512(gemm_bf16_convolution_fwd_t) @@ -134,6 +144,7 @@ const std::map> impl_list_map RE CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) CPU_INSTANCE_AVX512(jit_uni_dw_convolution_fwd_t) + CPU_INSTANCE_AVX512(jit_uni_fork_dw_convolution_fwd_t) CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_fwd_t) CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_fwd_t) CPU_INSTANCE_AVX512(gemm_bf16_convolution_fwd_t) @@ -145,13 +156,14 @@ const std::map> impl_list_map RE {{backward_data, f32, f32, f32}, REG_BWD_D_PK({ CPU_INSTANCE_X64(ip_convolution_bwd_data_t) CPU_INSTANCE_AVX512(jit_avx512_common_dw_convolution_bwd_data_t) + CPU_INSTANCE_AVX512(jit_avx512_common_fork_dw_convolution_bwd_data_t) CPU_INSTANCE_AVX512(jit_avx512_common_1x1_convolution_bwd_data_f32_t) CPU_INSTANCE_AVX512(jit_avx512_core_f32_wino_conv_4x3_bwd_data_t) CPU_INSTANCE_AVX512(jit_avx512_common_convolution_winograd_bwd_data_t) CPU_INSTANCE_AVX512(jit_avx512_common_convolution_bwd_data_t) - CPU_INSTANCE_AVX2(jit_avx2_dw_convolution_bwd_data_t) + CPU_INSTANCE_AVX2(jit_avx2_fork_dw_convolution_bwd_data_t) CPU_INSTANCE_AVX2(jit_avx2_1x1_convolution_bwd_data_t) - CPU_INSTANCE_SSE41(jit_sse41_dw_convolution_bwd_data_t) + CPU_INSTANCE_SSE41(jit_sse41_fork_dw_convolution_bwd_data_t) CPU_INSTANCE_AVX2(jit_avx2_convolution_bwd_data_t) CPU_INSTANCE_AARCH64(jit_sve_512_dw_convolution_bwd_data_t) CPU_INSTANCE_AARCH64(jit_sve_512_1x1_convolution_bwd_data_f32_t) @@ -164,6 +176,7 @@ const std::map> impl_list_map RE CPU_INSTANCE_X64(ip_convolution_bwd_data_t) CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_bwd_data_t) CPU_INSTANCE_AVX512(jit_uni_dw_convolution_bwd_data_t) + CPU_INSTANCE_AVX512(jit_uni_fork_dw_convolution_bwd_data_t) CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_bwd_data_t) CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_bwd_data_t) CPU_INSTANCE_AVX512(gemm_bf16_convolution_bwd_data_t) @@ -173,7 +186,7 @@ const std::map> impl_list_map RE {{backward_data, bf16, bf16, bf16}, REG_BWD_D_PK({ CPU_INSTANCE_X64(ip_convolution_bwd_data_t) CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_bwd_data_t) - CPU_INSTANCE_AVX512(jit_uni_dw_convolution_bwd_data_t) + CPU_INSTANCE_AVX512(jit_uni_fork_dw_convolution_bwd_data_t) CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_bwd_data_t) CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_bwd_data_t) CPU_INSTANCE_AVX512(gemm_bf16_convolution_bwd_data_t) diff --git a/src/cpu/cpu_inner_product_pd.hpp b/src/cpu/cpu_inner_product_pd.hpp index 7554af4a81f..0d6742d1a3b 100644 --- a/src/cpu/cpu_inner_product_pd.hpp +++ b/src/cpu/cpu_inner_product_pd.hpp @@ -193,8 +193,8 @@ struct cpu_inner_product_fwd_pd_t : public inner_product_fwd_pd_t { /* with batch = 1, no transpose to use the faster gemv kernels */ /* otherwise, we transpose the weights to improve efficiency of * no-copy kernels */ - if (MB() > 1 && transpose_leading_dim(OC(), IC_total())) - transpose_md(weights_md_); +// if (MB() > 1 && transpose_leading_dim(OC(), IC_total())) +// transpose_md(weights_md_); return status::success; }; diff --git a/src/cpu/cpu_pooling_list.cpp b/src/cpu/cpu_pooling_list.cpp index 37c4e09d4ae..ce9e9cbbcd0 100644 --- a/src/cpu/cpu_pooling_list.cpp +++ b/src/cpu/cpu_pooling_list.cpp @@ -54,16 +54,18 @@ const std::map> impl_list_map REG_P CPU_INSTANCE(nchw_pooling_fwd_t) CPU_INSTANCE(nhwc_pooling_fwd_t) CPU_INSTANCE(nhwc_pooling_fwd_t) - CPU_INSTANCE(ref_pooling_fwd_t) - CPU_INSTANCE(ref_pooling_fwd_t) + CPU_INSTANCE(ref_pooling_fwd_t) + CPU_INSTANCE(ref_pooling_fwd_t) /* int */ CPU_INSTANCE_X64(jit_uni_i8i8_pooling_fwd_t) CPU_INSTANCE_X64(jit_uni_i8i8_pooling_fwd_t) CPU_INSTANCE_X64(jit_uni_i8i8_pooling_fwd_t) CPU_INSTANCE_AARCH64(jit_uni_i8i8_pooling_fwd_t) - CPU_INSTANCE(ref_pooling_fwd_t) - CPU_INSTANCE(ref_pooling_fwd_t) - CPU_INSTANCE(ref_pooling_fwd_t) + CPU_INSTANCE(ref_pooling_fwd_t) + CPU_INSTANCE(ref_pooling_fwd_t) + CPU_INSTANCE(ref_pooling_fwd_t) + CPU_INSTANCE(ref_pooling_fwd_t) + CPU_INSTANCE(ref_pooling_fwd_t) nullptr, }}, {{backward}, REG_BWD_PK({ diff --git a/src/cpu/cpu_primitive.hpp b/src/cpu/cpu_primitive.hpp index 256acf969cd..b9686738f40 100644 --- a/src/cpu/cpu_primitive.hpp +++ b/src/cpu/cpu_primitive.hpp @@ -55,6 +55,20 @@ if (zero_points_ptr == nullptr) return status::invalid_arguments; \ MAYBE_UNUSED(zero_points_ptr); +#define DEFINE_INPUT_ZERO_POINTS_BUFFER(input_zero_points_ptr, jcp) \ + const uint8_t *input_zero_points_ptr = nullptr; \ + if (jcp.with_input_zp) { \ + input_zero_points_ptr = CTX_IN_MEM(const uint8_t *, DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC); \ + if (input_zero_points_ptr == nullptr) return status::invalid_arguments; \ + } + +#define DEFINE_OUTPUT_COMPENSATION_BUFFER(output_compensation_ptr, jcp) \ + const int32_t *output_compensation_ptr = nullptr; \ + if (jcp.with_input_zp) { \ + output_compensation_ptr = CTX_IN_MEM(const int32_t *, DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST); \ + if (output_compensation_ptr == nullptr) return status::invalid_arguments; \ + } + #define ASSIGN_INPUT_SCALE_VALUE(scale, mem_arg) \ if (pd()->attr()->scales_.get(mem_arg).defined()) { \ scale = pd()->attr()->scales_.get(mem_arg).scales_; \ diff --git a/src/cpu/cpu_softmax_list.cpp b/src/cpu/cpu_softmax_list.cpp index 65617978901..2ff11ce2e14 100644 --- a/src/cpu/cpu_softmax_list.cpp +++ b/src/cpu/cpu_softmax_list.cpp @@ -22,6 +22,7 @@ #if DNNL_X64 #include "cpu/x64/jit_uni_softmax.hpp" +#include "cpu/x64/jit_uni_fork_softmax.hpp" using namespace dnnl::impl::cpu::x64; #elif DNNL_AARCH64 #include "cpu/aarch64/jit_uni_softmax.hpp" @@ -45,6 +46,9 @@ const std::map> impl_list_map REG_S CPU_INSTANCE_X64(jit_uni_softmax_fwd_t) CPU_INSTANCE_X64(jit_uni_softmax_fwd_t) CPU_INSTANCE_X64(jit_uni_softmax_fwd_t) + CPU_INSTANCE_X64(jit_uni_fork_softmax_fwd_t) + CPU_INSTANCE_X64(jit_uni_fork_softmax_fwd_t) + CPU_INSTANCE_X64(jit_uni_fork_softmax_fwd_t) CPU_INSTANCE_AARCH64(jit_uni_softmax_fwd_t) CPU_INSTANCE_AARCH64(jit_uni_softmax_bwd_t) CPU_INSTANCE_AARCH64_ACL(acl_softmax_fwd_t) diff --git a/src/cpu/dw_convolution_utils.hpp b/src/cpu/dw_convolution_utils.hpp index 23d581eee59..bb01662ab4c 100644 --- a/src/cpu/dw_convolution_utils.hpp +++ b/src/cpu/dw_convolution_utils.hpp @@ -39,6 +39,10 @@ inline status_t get_depthwise_conv_desc(convolution_desc_t &cd_dw, || !attr_1x1.post_ops_.entry_[dw_po_index].is_convolution()) return status::invalid_arguments; + // todo: [AV] remove this check when we use original oneDNN dw conv fusing + if (attr_1x1.post_ops_.entry_[dw_po_index].is_convolution()) + return status::unimplemented; + // Create new attributes with scales from depthwise post-op and copy // post-ops after depthwise post-op. auto &dw_po = attr_1x1.post_ops_.entry_[dw_po_index].depthwise_conv; diff --git a/src/cpu/gemm/f32/ref_gemm_f32.cpp b/src/cpu/gemm/f32/ref_gemm_f32.cpp index e7d69f01727..944df461e3c 100644 --- a/src/cpu/gemm/f32/ref_gemm_f32.cpp +++ b/src/cpu/gemm/f32/ref_gemm_f32.cpp @@ -38,7 +38,10 @@ template void copy_A( bool isTransA, dim_t K, const data_t *A, const dim_t lda, data_t *ws) { for (dim_t k = 0; k < K; k++) { +#if !defined(_MSC_VER) + // Compilation with '#pragma omp simd' in this place on VS2019 to lead to fatal error C1001 PRAGMA_OMP_SIMD() +#endif for (dim_t i = 0; i < unroll_factor::m; i++) { ws[i] = isTransA ? A[i * lda + k] : A[i + k * lda]; } diff --git a/src/cpu/gemm_convolution.cpp b/src/cpu/gemm_convolution.cpp index b9758ffb933..e4cc6e07c0e 100644 --- a/src/cpu/gemm_convolution.cpp +++ b/src/cpu/gemm_convolution.cpp @@ -23,6 +23,7 @@ #include "common/type_helpers.hpp" #include "common/utils.hpp" #include "cpu/gemm_convolution.hpp" +#include "cpu/x64/injectors/jit_uni_postops_injector.hpp" namespace dnnl { namespace impl { @@ -51,13 +52,18 @@ status_t gemm_convolution_fwd_t::execute_forward_nspc( auto bia_base = CTX_IN_MEM(const data_t *, DNNL_ARG_BIAS); auto dst_base = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + + const auto post_ops_binary_rhs_arg_vec + = x64::binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); + auto scratchpad = ctx.get_scratchpad_grantor(); const conv_gemm_conf_t &jcp = pd()->jcp_; std::atomic st(status::success); parallel(jcp.nthr, [&](const int ithr, const int nthr) { status_t st_thr = execute_forward_thr_nspc(ctx, ithr, nthr, src_base, - wei_base, bia_base, dst_base, scratchpad); + wei_base, bia_base, dst_base, scratchpad, MB, post_ops_binary_rhs_arg_vec); if (st_thr != status::success) st = st_thr; }); @@ -67,7 +73,8 @@ status_t gemm_convolution_fwd_t::execute_forward_nspc( status_t gemm_convolution_fwd_t::execute_forward_thr_nspc(const exec_ctx_t &ctx, const int ithr, const int nthr, const data_t *src_base, const data_t *wei_base, const data_t *bia_base, data_t *dst_base, - const memory_tracking::grantor_t &scratchpad) const { + const memory_tracking::grantor_t &scratchpad, int MB, + const std::vector& post_ops_binary_rhs_arg_vec) const { const conv_gemm_conf_t &jcp = pd()->jcp_; // Src Format: mb-spatial-groups-input_channels @@ -98,9 +105,9 @@ status_t gemm_convolution_fwd_t::execute_forward_thr_nspc(const exec_ctx_t &ctx, const dim_t nb_oh = div_up(jcp.oh, jcp.oh_block); const dim_t nb_ow = div_up(jcp.ow, jcp.ow_block); // threads share work across mini-batch, groups, and blocked width/height - const dim_t work_amount = jcp.mb * jcp.ngroups * nb_oh * nb_ow; + const dim_t work_amount = MB * jcp.ngroups * nb_oh * nb_ow; balance211(work_amount, nthr, ithr, start, end); - nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow); + nd_iterator_init(start, n, MB, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow); if (jcp.im2col_sz && is_problem_3d) { // jit_gemm_convolution_utils::im2col_dt_3d() requires external @@ -151,68 +158,21 @@ status_t gemm_convolution_fwd_t::execute_forward_thr_nspc(const exec_ctx_t &ctx, &LDC); if (st != status::success) return st; - if (jcp.with_bias || jcp.with_eltwise || jcp.with_binary) { - parallel(0, [&](int ithr, int nthr) { - dim_t start, end; - balance211(N * jcp.oc, nthr, ithr, start, end); - - const size_t first_oc = start % jcp.oc; - const size_t last_oc = (end - 1) % jcp.oc; - const size_t first_os = start / jcp.oc; - const size_t last_os = (end - 1) / jcp.oc; - - for (size_t os = first_os; os <= last_os; ++os) { - const size_t start_oc = (os == first_os) ? first_oc : 0; - const size_t end_oc - = (os == last_os) ? last_oc : jcp.oc - 1; - - const data_t *__restrict bia_arr - = bia_base ? bia_base + g * jcp.oc : nullptr; - data_t *__restrict dst_arr = dst + os * dst_os_stride; - - if (jcp.with_bias) { - PRAGMA_OMP_SIMD() - for (size_t oc = start_oc; oc <= end_oc; oc++) { - dst_arr[oc] += bia_arr[oc]; - } - } + if (pp_kernel_) { + const size_t first_oc = g * jcp.oc; + const size_t last_oc = jcp.oc; + const size_t first_os = 0; + const size_t last_os = N; - if (jcp.with_eltwise || jcp.with_binary) { - bool fast_relu_done = false; - if (jcp.with_eltwise && jcp.post_ops.len() == 1) { - // fast branch for ReLU case - const auto &eltwise - = jcp.post_ops.entry_.back().eltwise; - - if (eltwise.alg == alg_kind::eltwise_relu) { - const auto alpha = eltwise.alpha; - const auto scale = eltwise.scale; - PRAGMA_OMP_SIMD() - for (size_t oc = start_oc; oc <= end_oc; - oc++) { - if (dst_arr[oc] < 0) - dst_arr[oc] *= alpha; - dst_arr[oc] *= scale; - } - fast_relu_done = true; - } - } - if (!fast_relu_done) { - ref_post_ops_t::args_t args; - args.ctx = &ctx; - args.dst_md = pd()->dst_md(); - - for (size_t oc = start_oc; oc <= end_oc; oc++) { - args.l_offset = (g * jcp.oc + oc) * jcp.os; - post_ops_->execute(dst_arr[oc], args); - } - } - } - } - }); + const data_t* bias = bia_base ? bia_base + g * jcp.oc: nullptr; + + for (size_t os = first_os; os < last_os; ++os) { + data_t* dst_local = dst + os * dst_os_stride; + (*pp_kernel_)(dst_local, bias, 1, first_oc, last_oc, 1, post_ops_binary_rhs_arg_vec); + } } } - nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow); + nd_iterator_step(n, MB, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow); } return status::success; } @@ -224,14 +184,29 @@ status_t gemm_convolution_fwd_t::execute_forward_ncsp( auto bias = CTX_IN_MEM(const data_t *, DNNL_ARG_BIAS); auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + + const auto post_ops_binary_rhs_arg_vec + = x64::binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); + auto col = ctx.get_scratchpad_grantor().get(key_conv_gemm_col); const conv_gemm_conf_t &jcp = this->pd()->jcp_; - const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id; + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + + const size_t src_mb_stride = src_d.blk_off(1); + const size_t src_grp_stride = src_d.blk_off(0, 1) * jcp.ic; + + const size_t dst_mb_stride = dst_d.blk_off(1); + const size_t dst_grp_stride = dst_d.blk_off(0, 1) * jcp.oc; + const size_t weights_oc_size = jcp.ic * jcp.ks; const size_t weights_g_size = weights_oc_size * jcp.oc; const bool is_problem_3d = pd()->ndims() == 5; + src += src_d.off_l(0); + dst += dst_d.off_l(0); assert(IMPLICATION(is_problem_3d, jcp.os_block == jcp.os && jcp.ic_block == jcp.ic @@ -251,7 +226,7 @@ status_t gemm_convolution_fwd_t::execute_forward_ncsp( auto inner_ker = [&](int spatial, const im_pos_t &curr, im_pos_t &prev, im_pos_t &step, const im_pos_t &end) { const data_t *_src - = src + (curr.n * jcp.ngroups + curr.g) * src_step; + = src + curr.n * src_mb_stride + curr.g * src_grp_stride; step.oc = nstl::min( jcp.oc_block, nstl::min(jcp.oc, end.oc) - curr.oc); step.sp = nstl::min(jcp.os_block, @@ -272,10 +247,9 @@ status_t gemm_convolution_fwd_t::execute_forward_ncsp( const data_t one = 1.0; const dim_t M = jcp.os * jcp.od; - const size_t dst_step = jcp.oc * M; const dim_t m = step.sp; const dim_t LDA = jcp.im2col_sz ? m : M; - data_t *_dst = dst + (curr.n * jcp.ngroups + curr.g) * dst_step + data_t *_dst = dst + curr.n * dst_mb_stride + curr.g * dst_grp_stride + curr.oc * M + curr.od * jcp.os + curr.sp; const dim_t K = step.ic * jcp.ks; const dim_t LDB = jcp.ic * jcp.ks; @@ -293,61 +267,8 @@ status_t gemm_convolution_fwd_t::execute_forward_ncsp( &LDA, _weights, &LDB, &beta, _dst, &M); if (st != status::success) return st; - if (curr.ic == jcp.ic - step.ic) { - // TODO: for "outer threading" we have parallel section within - // outermost "parallel". It is not good. Consider to use - // "parallel" here with number of threads passed as parameter - const int oc_start = curr.g * jcp.oc + curr.oc; - if (jcp.with_eltwise || jcp.with_binary) { - bool fast_relu_done = false; - if (jcp.with_eltwise && jcp.post_ops.len() == 1) { - // fast branch for ReLU case - const auto &eltwise - = jcp.post_ops.entry_.back().eltwise; - if (eltwise.alg == alg_kind::eltwise_relu) { - parallel_nd(step.oc, [&](dim_t oc) { - data_t b = jcp.with_bias ? bias[oc_start + oc] - : 0; - data_t *d_ = _dst + oc * M; - PRAGMA_OMP_SIMD() - for (int oS = 0; oS < m; ++oS) { - d_[oS] += b; - if (d_[oS] < 0) d_[oS] *= eltwise.alpha; - d_[oS] *= eltwise.scale; - } - }); - fast_relu_done = true; - } - } - if (!fast_relu_done) { - parallel_nd(step.oc, [&](dim_t oc) { - data_t b = jcp.with_bias ? bias[oc_start + oc] : 0; - data_t *d_ = _dst + oc * M; - - ref_post_ops_t::args_t args; - args.ctx = &ctx; - args.dst_md = pd()->dst_md(); - args.l_offset = d_ - dst; - - PRAGMA_OMP_SIMD() - for (int oS = 0; oS < m; ++oS) { - d_[oS] += b; - post_ops_->execute(d_[oS], args); - args.l_offset++; - } - }); - } - - } else if (jcp.with_bias) { - parallel_nd(step.oc, [&](dim_t oc) { - data_t b = bias[oc_start + oc]; - data_t *d_ = _dst + oc * M; - PRAGMA_OMP_SIMD() - for (int oS = 0; oS < m; ++oS) { - d_[oS] += b; - } - }); - } + if (pp_kernel_ && curr.ic == jcp.ic - step.ic) { + (*pp_kernel_)(_dst, bias, m, curr.g * jcp.oc + curr.oc, step.oc, M, post_ops_binary_rhs_arg_vec); } return status::success; @@ -356,11 +277,11 @@ status_t gemm_convolution_fwd_t::execute_forward_ncsp( end.ic = jcp.ic; if (!is_problem_3d) { - dim_t sp_work = jcp.mb * jcp.ngroups * jcp.od * jcp.os; + dim_t sp_work = MB * jcp.ngroups * jcp.od * jcp.os; balance2D(nthr, ithr, sp_work, start.sp, end.sp, jcp.oc, start.oc, end.oc, dim_t(jcp.nthr_oc)); } else { - dim_t sp_work = jcp.mb * jcp.ngroups * jcp.od; + dim_t sp_work = MB * jcp.ngroups * jcp.od; balance2D(nthr, ithr, sp_work, start.sp, end.sp, jcp.oc, start.oc, end.oc, dim_t(jcp.nthr_oc)); start.sp *= jcp.os; @@ -377,7 +298,7 @@ status_t gemm_convolution_fwd_t::execute_forward_ncsp( for (curr.ic = 0; curr.ic < jcp.ic; curr.ic += step.ic) for (int spatial = start.sp; spatial < end.sp; spatial += step.sp) { - nd_iterator_init(spatial, curr.n, jcp.mb, curr.g, + nd_iterator_init(spatial, curr.n, MB, curr.g, jcp.ngroups, curr.od, jcp.od, curr.sp, jcp.os); for (curr.oc = start.oc; curr.oc < end.oc; curr.oc += step.oc) { @@ -391,7 +312,7 @@ status_t gemm_convolution_fwd_t::execute_forward_ncsp( } else if (jcp.loop_order == gemm_loop_lrb) for (int spatial = start.sp; spatial < end.sp; spatial += step.sp) { - nd_iterator_init(spatial, curr.n, jcp.mb, curr.g, jcp.ngroups, + nd_iterator_init(spatial, curr.n, MB, curr.g, jcp.ngroups, curr.od, jcp.od, curr.sp, jcp.os); for (curr.ic = 0; curr.ic < jcp.ic; curr.ic += step.ic) for (curr.oc = start.oc; curr.oc < end.oc; @@ -419,13 +340,18 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_nspc( auto bia_base = CTX_IN_MEM(const data_t *, DNNL_ARG_BIAS); auto diff_src_base = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC); + const auto post_ops_binary_rhs_arg_vec + = x64::binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); + + auto MB = CTX_IN_BATCH(DNNL_ARG_DIFF_DST); + auto scratchpad = ctx.get_scratchpad_grantor(); const conv_gemm_conf_t &jcp = pd()->jcp_; std::atomic st(status::success); parallel(jcp.nthr, [&](const int ithr, const int nthr) { status_t st_thr = execute_backward_data_thr_nspc(ithr, nthr, - diff_dst_base, wei_base, bia_base, diff_src_base, scratchpad); + diff_dst_base, wei_base, bia_base, diff_src_base, scratchpad, MB, post_ops_binary_rhs_arg_vec); if (st_thr != status::success) st = st_thr; }); @@ -435,7 +361,8 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_nspc( status_t gemm_convolution_bwd_data_t::execute_backward_data_thr_nspc( const int ithr, const int nthr, const data_t *diff_dst_base, const data_t *wei_base, const data_t *bia_base, data_t *diff_src_base, - const memory_tracking::grantor_t &scratchpad) const { + const memory_tracking::grantor_t &scratchpad, int MB, + const std::vector& post_ops_binary_rhs_arg_vec) const { const conv_gemm_conf_t &jcp = pd()->jcp_; // Diff_dst Format: mb-spatial-groups-output_channels @@ -453,7 +380,9 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_thr_nspc( const size_t diff_src_os_stride = jcp.ngroups * jcp.ic; // threads share work across mini-batch and groups - const dim_t work_amount = jcp.ngroups * jcp.mb; + const dim_t work_amount = jcp.ngroups * MB; + + const auto &p = pd()->attr()->post_ops_; data_t *__restrict col = scratchpad.get(key_conv_gemm_col) + (ptrdiff_t)ithr * jcp.im2col_sz; @@ -467,7 +396,7 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_thr_nspc( dim_t start = 0, end = 0; balance211(work_amount, nthr, ithr, start, end); - nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups); + nd_iterator_init(start, n, MB, g, jcp.ngroups); for (dim_t iwork = start; iwork < end; ++iwork) { const data_t *__restrict diff_dst = diff_dst_base @@ -503,7 +432,32 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_thr_nspc( } }); } - nd_iterator_step(n, jcp.mb, g, jcp.ngroups); + if (p.len() > 0) { + std::size_t post_ops_data_idx = 0; + int depthwise_inj_idx = 0; + for (int i = 0; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + auto depthwise_base = reinterpret_cast(post_ops_binary_rhs_arg_vec[post_ops_data_idx]); + auto depthwise_weights = depthwise_base + post_op.depthwise.offset[post_op.depthwise.scales]; + auto depthwise_bias = post_op.depthwise.alg == alg_kind::depthwise_scale_shift + ? depthwise_base + post_op.depthwise.offset[post_op.depthwise.shifts] + : nullptr; + + parallel_nd(static_cast(jcp.is) * jcp.id, [&](size_t is) { + data_t *__restrict diff_src_arr + = diff_src + is * diff_src_os_stride; + for (int ic = 0; ic < jcp.ic; ic++) { + diff_src_arr[ic] = depthwise_injectors[depthwise_inj_idx]->compute_scalar(diff_src_arr[ic], + depthwise_weights + g * jcp.ic + ic, depthwise_bias + g * jcp.ic + ic); + } + }); + post_ops_data_idx++; + depthwise_inj_idx++; + } + } + } + nd_iterator_step(n, MB, g, jcp.ngroups); } return status::success; } @@ -514,22 +468,34 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_ncsp( auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC); + const auto post_ops_binary_rhs_arg_vec + = x64::binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); + + auto MB = CTX_IN_BATCH(DNNL_ARG_DIFF_DST); + auto col = ctx.get_scratchpad_grantor().get(key_conv_gemm_col); const conv_gemm_conf_t &jcp = this->pd()->jcp_; const dim_t M = jcp.os * jcp.od; - const size_t src_step = (size_t)jcp.ic * jcp.ih * jcp.iw * jcp.id; - const size_t dst_step = (size_t)jcp.oc * M; + const size_t src_step_to_clean = jcp.ic * jcp.ih * jcp.iw * jcp.id; + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const size_t src_step = diff_src_d.blk_off(1) / jcp.ngroups; + const size_t dst_step = diff_dst_d.blk_off(1) / jcp.ngroups; + diff_src += diff_src_d.off_l(0); + diff_dst += diff_dst_d.off_l(0); const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks; const dim_t m = jcp.os_block; const dim_t K = jcp.oc; const dim_t N = jcp.ic * jcp.ks; - const dim_t work_amount = (size_t)jcp.ngroups * jcp.mb; + const dim_t work_amount = (size_t)jcp.ngroups * MB; const bool is_problem_3d = pd()->ndims() == 5; + const auto &p = pd()->attr()->post_ops_; + std::atomic st(status::success); parallel(jcp.nthr, [&](const int ithr, const int nthr) { data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; @@ -537,14 +503,14 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_ncsp( dim_t g {0}, n {0}; dim_t start = 0, end = 0; balance211(work_amount, nthr, ithr, start, end); - nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb); + nd_iterator_init(start, g, jcp.ngroups, n, MB); for (dim_t iwork = start; iwork < end; ++iwork) { data_t *_diff_src = diff_src + (n * jcp.ngroups + g) * src_step; if (is_problem_3d && jcp.im2col_sz > 0) { // jit_gemm_convolution_utils::col2im_3d() assumes that the // accumulator is initialized by zeroes - for (size_t i = 0; i < src_step; i++) + for (size_t i = 0; i < src_step_to_clean; i++) _diff_src[i] = (data_t)0; } @@ -577,7 +543,32 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_ncsp( } } } - nd_iterator_step(g, jcp.ngroups, n, jcp.mb); + if (p.len() > 0) { + std::size_t post_ops_data_idx = 0; + int depthwise_inj_idx = 0; + for (int i = 0; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + auto depthwise_base = reinterpret_cast(post_ops_binary_rhs_arg_vec[post_ops_data_idx]); + auto depthwise_weights = depthwise_base + post_op.depthwise.offset[post_op.depthwise.scales]; + auto depthwise_bias = post_op.depthwise.alg == alg_kind::depthwise_scale_shift + ? depthwise_base + post_op.depthwise.offset[post_op.depthwise.shifts] + : nullptr; + parallel_nd(jcp.ic, [&](const int ic) { + for (int id = 0; id < jcp.id; ++id) { + data_t *d_ = _diff_src + ic * jcp.id * jcp.is + id * jcp.is; + for (int iS = 0; iS < jcp.is; ++iS) { + d_[iS] = depthwise_injectors[depthwise_inj_idx]->compute_scalar(d_[iS], + depthwise_weights + g * jcp.ic + ic, depthwise_bias + g * jcp.ic + ic); + } + } + }); + post_ops_data_idx++; + depthwise_inj_idx++; + } + } + } + nd_iterator_step(g, jcp.ngroups, n, MB); } }); diff --git a/src/cpu/gemm_convolution.hpp b/src/cpu/gemm_convolution.hpp index 5b359072edc..37a3d1809be 100644 --- a/src/cpu/gemm_convolution.hpp +++ b/src/cpu/gemm_convolution.hpp @@ -26,6 +26,8 @@ #include "cpu/gemm_convolution_utils.hpp" #include "cpu/primitive_attr_postops.hpp" +#include "ref_depthwise_injector.hpp" + namespace dnnl { namespace impl { namespace cpu { @@ -50,7 +52,6 @@ struct gemm_convolution_fwd_t : public primitive_t { primitive_attr_t::skip_mask_t::post_ops, f32) && post_ops_ok(); if (!ok) return status::unimplemented; - auto scratchpad = scratchpad_registry().registrar(); return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, *desc(), src_md_, weights_md_, dst_md_, bias_md_, attr_, @@ -61,21 +62,24 @@ struct gemm_convolution_fwd_t : public primitive_t { protected: bool post_ops_ok() const { + using namespace dnnl::impl::primitive_kind; auto const &po = attr()->post_ops_; - auto is_eltwise - = [&](int idx) { return po.entry_[idx].is_eltwise(); }; - auto is_sum = [&](int idx) { return po.entry_[idx].is_sum(); }; - auto is_binary - = [&](int idx) { return po.entry_[idx].is_binary(); }; - - for (int idx = 0; idx < po.len(); idx++) { - bool ok = utils::one_of(true, is_sum(idx), is_binary(idx), - is_eltwise(idx)) - && IMPLICATION(is_sum(idx), idx == 0); - if (!ok) return false; - } - return true; + auto all_post_ops_supported = [&]() { + bool ok = true; + + for (int i = 0; i < po.len(); i++) { + ok = ok && utils::one_of(po.entry_[i].kind, sum, eltwise, depthwise, quantization); + } + return ok; + }; + auto contain = [&](dnnl::impl::primitive_kind_t kind) { return po.find(kind) != -1; }; + auto position = [&](dnnl::impl::primitive_kind_t kind) { return po.find(kind); }; + auto count = [&](dnnl::impl::primitive_kind_t kind) { return po.count(kind); }; + + return all_post_ops_supported() && + count(primitive_kind::sum) <= 1 && + IMPLICATION(contain(primitive_kind::sum), position(primitive_kind::sum) == 0); } }; @@ -83,13 +87,18 @@ struct gemm_convolution_fwd_t : public primitive_t { : primitive_t(apd), post_ops_(nullptr) {} status_t init(engine_t *engine) override { + const auto &post_ops = pd()->attr()->post_ops_; const data_t one = 1.0, zero = 0.0; const auto &jcp = pd()->jcp_; beta_ = jcp.with_sum ? one : zero; - if (jcp.with_eltwise || jcp.with_binary) - CHECK(safe_ptr_assign(post_ops_, new ref_post_ops_t(jcp.post_ops))); - return status::success; + bool has_bias = pd()->with_bias(); + bool has_post_ops = post_ops.len() > 0; + bool has_scale = !pd()->attr()->output_scales_.has_default_values(); + postops_in_ip_ = has_bias || has_post_ops || has_scale; + + CHECK(safe_ptr_assign(pp_kernel_, pp_kernel_t::create(pd(), pd()->jcp_))); + return (pp_kernel_) ? pp_kernel_->create_kernel() : status::success; } typedef typename prec_traits::type data_t; @@ -105,9 +114,13 @@ struct gemm_convolution_fwd_t : public primitive_t { status_t execute_forward_thr_nspc(const exec_ctx_t &ctx, const int ithr, const int nthr, const data_t *src_base, const data_t *wei_base, const data_t *bia_base, data_t *dst_base, - const memory_tracking::grantor_t &scratchpad) const; + const memory_tracking::grantor_t &scratchpad, int MB, + const std::vector& post_ops_binary_rhs_arg_vec) const; const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + using pp_kernel_t = gemm_convolution_utils::pp_kernel_t; + std::unique_ptr pp_kernel_; + bool postops_in_ip_; data_t beta_; std::unique_ptr post_ops_; @@ -127,7 +140,8 @@ struct gemm_convolution_bwd_data_t : public primitive_t { && set_default_alg_kind(alg_kind::convolution_direct) && expect_data_types(data_type::f32, data_type::f32, data_type::undef, data_type::f32, data_type::f32) - && !has_zero_dim_memory() && attr()->has_default_values(); + && !has_zero_dim_memory() + && is_supported_post_ops(); if (!ok) return status::unimplemented; auto scratchpad = scratchpad_registry().registrar(); @@ -137,9 +151,42 @@ struct gemm_convolution_bwd_data_t : public primitive_t { } conv_gemm_conf_t jcp_; + + protected: + virtual bool is_supported_post_ops() const { + const auto &p = this->attr()->post_ops_; + if (p.len() > 1) + return false; + + auto all_post_ops_supported = [&]() { + bool ok = true; + + for (int i = 0; i < p.len(); i++) { + ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::depthwise); + } + return ok; + }; + + return all_post_ops_supported(); + } }; - gemm_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {} + + gemm_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) { + const auto &post_ops = pd()->attr()->post_ops_; + for (int i = 0; i < post_ops.len(); i++) { + auto &post_op = post_ops.entry_[i]; + if (post_op.is_depthwise()) { + depthwise_injectors.push_back(new ref_depthwise_scalar_fwd_t(post_op.depthwise.alg)); + } + } + } + + ~gemm_convolution_bwd_data_t() { + for (auto inj : depthwise_injectors) + delete inj; + depthwise_injectors.clear(); + } typedef typename prec_traits::type data_t; @@ -155,9 +202,12 @@ struct gemm_convolution_bwd_data_t : public primitive_t { status_t execute_backward_data_thr_nspc(const int ithr, const int nthr, const data_t *diff_dst_base, const data_t *wei_base, const data_t *bia_base, data_t *diff_src_base, - const memory_tracking::grantor_t &scratchpad) const; + const memory_tracking::grantor_t &scratchpad, int MB, + const std::vector& post_ops_binary_rhs_arg_vec) const; const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + + nstl::vector depthwise_injectors; }; struct gemm_convolution_bwd_weights_t : public primitive_t { diff --git a/src/cpu/gemm_convolution_utils.cpp b/src/cpu/gemm_convolution_utils.cpp index 6a397c971ca..b48907a4ccb 100644 --- a/src/cpu/gemm_convolution_utils.cpp +++ b/src/cpu/gemm_convolution_utils.cpp @@ -22,6 +22,10 @@ #include "common/type_helpers.hpp" #include "common/utils.hpp" #include "cpu/gemm_convolution_utils.hpp" + +#include "ref_eltwise.hpp" +#include "ref_depthwise_injector.hpp" + #if DNNL_X64 #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" #endif @@ -29,6 +33,7 @@ #include "cpu/platform.hpp" #if DNNL_X64 +#include "cpu/x64/jit_gemm_convolution_utils.hpp" #include "cpu/x64/cpu_isa_traits.hpp" #endif @@ -50,6 +55,152 @@ single_gemm_conv_chunk_desc_t::single_gemm_conv_chunk_desc_t(dim_t d_off, , w_off_(w_off) , w_size_(w_size) {} +namespace gemm_convolution_utils { + +struct ref_pp_kernel_t : pp_kernel_t { + ref_pp_kernel_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) + : pp_kernel_t(pd, jcp) { + for (int i = 0; i < post_ops_.len(); i++) { + auto &post_op = post_ops_.entry_[i]; + if (post_op.is_eltwise()) { + ref_eltwise_injectors_.push_back(new ref_eltwise_scalar_fwd_t(post_op.eltwise)); + } else if (post_op.is_depthwise()) { + ref_depthwise_injectors_.push_back(new ref_depthwise_scalar_fwd_t( + post_op.depthwise.alg)); + } + } + } + ~ref_pp_kernel_t() { + for (auto impl : ref_eltwise_injectors_) + delete impl; + ref_eltwise_injectors_.clear(); + for (auto impl : ref_depthwise_injectors_) + delete impl; + ref_depthwise_injectors_.clear(); + } + + virtual void operator()(float *dst, const float *bias, const int len, const int oc_start, const int oc_work, const int oc_stride, + const std::vector& post_ops_binary_rhs_arg_vec) const override; + +private: + nstl::vector ref_eltwise_injectors_; + nstl::vector ref_depthwise_injectors_; +}; + +void ref_pp_kernel_t::operator()(float *dst, const float *bias, const int len,const int oc_start, const int oc_work, const int oc_stride, + const std::vector& post_ops_binary_rhs_arg_vec) const { + // TODO: for "outer threading" we have parallel section within + // outermost "parallel". It is not good. Consider to use + // "parallel" here with number of threads passed as parameter + const auto &p = post_ops_; + bool need_bias = do_bias_; + if (p.len() > 0) { + std::size_t post_ops_data_idx = 0; + int eltwise_inj_idx = 0; + int depthwise_inj_idx = 0; + + for (int i = 0; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + // todo: sum? + if (post_op.is_eltwise()) { + parallel_nd(oc_work, [&](const int oc) { + float b = need_bias ? bias[oc_start + oc] : 0; + float *d_ = dst + oc * oc_stride; + for (int oS = 0; oS < len; ++oS) { + d_[oS] += b; + d_[oS] = ref_eltwise_injectors_[eltwise_inj_idx]->compute_scalar(d_[oS]); + } + }); + + eltwise_inj_idx++; + need_bias = false; + } else if (post_op.is_depthwise()) { + auto depthwise_base = reinterpret_cast(post_ops_binary_rhs_arg_vec[post_ops_data_idx]); + auto depthwise_weights = depthwise_base + post_op.depthwise.offset[post_op.depthwise.scales]; + auto depthwise_bias = depthwise_base + post_op.depthwise.offset[post_op.depthwise.shifts]; + + parallel_nd(oc_work, [&](const int oc) { + float b = need_bias ? bias[oc_start + oc] : 0; + float *d_ = dst + oc * oc_stride; + for (int oS = 0; oS < len; ++oS) { + d_[oS] += b; + d_[oS] = ref_depthwise_injectors_[depthwise_inj_idx]->compute_scalar(d_[oS], + depthwise_weights + oc_start + oc, + depthwise_bias + oc_start + oc); + } + }); + + post_ops_data_idx++; + depthwise_inj_idx++; + need_bias = false; + } else if (post_op.is_quantization()) { + auto quant = post_op.quantization; + auto quantization_base = reinterpret_cast(post_ops_binary_rhs_arg_vec[post_ops_data_idx]); + auto pcl = quantization_base + post_op.quantization.offset[quant.crop_low]; + auto pch = quantization_base + post_op.quantization.offset[quant.crop_high]; + auto pisc = quantization_base + post_op.quantization.offset[quant.inp_scale]; + auto pish = quantization_base + post_op.quantization.offset[quant.inp_shift]; + auto posc = quantization_base + post_op.quantization.offset[quant.output_scale]; + auto posh = quantization_base + post_op.quantization.offset[quant.output_shift]; + + parallel_nd(oc_work, [&](const int oc) { + float b = need_bias ? bias[oc_start + oc] : 0; + float *d_ = dst + oc * oc_stride; + + int cl_idx = !quant.per_channel[quant.crop_low] ? 0 : oc_start + oc; + int ch_idx = !quant.per_channel[quant.crop_high] ? 0 : oc_start + oc; + int isc_idx = !quant.per_channel[quant.inp_scale] ? 0 : oc_start + oc; + int ish_idx = !quant.per_channel[quant.inp_shift] ? 0 : oc_start + oc; + int osc_idx = !quant.per_channel[quant.output_scale] ? 0 : oc_start + oc; + int osh_idx = !quant.per_channel[quant.output_shift] ? 0 : oc_start + oc; + + PRAGMA_OMP_SIMD() + for (int oS = 0; oS < len; ++oS) { + d_[oS] += b; + + d_[oS] = nstl::min(pch[ch_idx], nstl::max(pcl[cl_idx], d_[oS])); + d_[oS] = d_[oS] * pisc[isc_idx] + pish[ish_idx]; + d_[oS] = roundf(d_[oS]); + d_[oS] = d_[oS] * posc[osc_idx] + posh[osh_idx]; + } + }); + + post_ops_data_idx++; + need_bias = false; + } + } + } + + if (need_bias) { + parallel_nd(oc_work, [&](const int oc) { + float b = bias[oc_start + oc]; + float *d_ = dst + oc * oc_stride; + PRAGMA_OMP_SIMD() + for (int oS = 0; oS < len; ++oS) { + d_[oS] += b; + } + }); + } +} + +// Interface section + +pp_kernel_t::pp_kernel_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) + : do_bias_(pd->with_bias()), post_ops_(pd->attr()->post_ops_) {} + +pp_kernel_t *pp_kernel_t::create( + const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) { +#if DNNL_X64 + auto *res + = x64::gemm_convolution_utils::jit_pp_kernel_create(pd, jcp); + if (res) return res; +#endif + + return new ref_pp_kernel_t(pd, jcp); +} + +} // namespace gemm_convolution_utils + namespace jit_gemm_convolution_utils { template @@ -276,7 +427,7 @@ template void transpose_dt(const conv_gemm_conf_t &jcp, /* col[kd][kh][kw][g][ic][od][oh][ow] <-- im2col_dt_3d(im[id][ih][iw][g][ic]) */ template void im2col_dt_3d(const conv_gemm_conf_t &jcp, const void *__restrict _imtr, - orig_col_dt *__restrict _col, dim_t od) { + orig_col_dt *__restrict _col, dim_t od, const uint8_t *__restrict input_zp) { // For performance reasons, use uint16_t as a proxy for bfloat16_t using im_dt = typename utils::conditional::data_type == bf16, @@ -306,15 +457,18 @@ void im2col_dt_3d(const conv_gemm_conf_t &jcp, const void *__restrict _imtr, const dim_t IHW = jcp.ih * jcp.iw; const dim_t OHW = jcp.oh * jcp.ow; + bool with_input_zp = input_zp != nullptr; + if (sd == 1 && sh == 1 && sw == 1 && dd == 1 && dh == 1 && dw == 1) - parallel_nd(jcp.kd, jcp.kh, jcp.kw, jcp.ic, + parallel_nd_legacy(jcp.kd, jcp.kh, jcp.kw, jcp.ic, [&](dim_t kd, dim_t kh, dim_t kw, dim_t ic) { col_dt *__restrict col_loc = col + kd * col_kd_s + kh * col_kh_s + kw * col_kw_s + ic * col_ic_s; const dim_t id = od - fp + kd; if (id < 0 || id >= jcp.id) { + col_dt izp = with_input_zp ? (col_dt)input_zp[ic] : shift; for (ptrdiff_t i = 0; i < OHW; i++) - col_loc[i] = shift; + col_loc[i] = izp; return; } const im_dt *__restrict imtr_loc @@ -336,14 +490,15 @@ void im2col_dt_3d(const conv_gemm_conf_t &jcp, const void *__restrict _imtr, } }); else if (sd == 2 && sh == 2 && sw == 2 && dd == 1 && dh == 1 && dw == 1) - parallel_nd(jcp.kd, jcp.kh, jcp.kw, jcp.ic, + parallel_nd_legacy(jcp.kd, jcp.kh, jcp.kw, jcp.ic, [&](dim_t kd, dim_t kh, dim_t kw, dim_t ic) { col_dt *__restrict col_loc = col + kd * col_kd_s + kh * col_kh_s + kw * col_kw_s + ic * col_ic_s; const dim_t id = od * 2 - fp + kd; if (id < 0 || id >= jcp.id) { + col_dt izp = with_input_zp ? (col_dt)input_zp[ic] : shift; for (ptrdiff_t i = 0; i < OHW; i++) - col_loc[i] = shift; + col_loc[i] = izp; return; } const im_dt *__restrict imtr_loc @@ -367,14 +522,15 @@ void im2col_dt_3d(const conv_gemm_conf_t &jcp, const void *__restrict _imtr, } }); else - parallel_nd(jcp.kd, jcp.kh, jcp.kw, jcp.ic, + parallel_nd_legacy(jcp.kd, jcp.kh, jcp.kw, jcp.ic, [&](dim_t kd, dim_t kh, dim_t kw, dim_t ic) { col_dt *__restrict col_loc = col + kd * col_kd_s + kh * col_kh_s + kw * col_kw_s + ic * col_ic_s; const dim_t id = od * sd - fp + kd * dd; if (id < 0 || id >= jcp.id) { + col_dt izp = with_input_zp ? (col_dt)input_zp[ic] : shift; for (ptrdiff_t i = 0; i < OHW; i++) - col_loc[i] = shift; + col_loc[i] = izp; return; } const im_dt *__restrict imtr_loc @@ -401,13 +557,13 @@ void im2col_dt_3d(const conv_gemm_conf_t &jcp, const void *__restrict _imtr, } template void im2col_dt_3d(const conv_gemm_conf_t &jcp, - const void *__restrict im, uint8_t *__restrict col, dim_t od); + const void *__restrict im, uint8_t *__restrict col, dim_t od, const uint8_t *__restrict input_zp); template void im2col_dt_3d(const conv_gemm_conf_t &jcp, - const void *__restrict im, uint8_t *__restrict col, dim_t od); + const void *__restrict im, uint8_t *__restrict col, dim_t od, const uint8_t *__restrict input_zp); template void im2col_dt_3d(const conv_gemm_conf_t &jcp, - const void *__restrict im, float *__restrict col, dim_t od); + const void *__restrict im, float *__restrict col, dim_t od, const uint8_t *__restrict input_zp); template void im2col_dt_3d(const conv_gemm_conf_t &jcp, - const void *__restrict im, bfloat16_t *__restrict col, dim_t od); + const void *__restrict im, bfloat16_t *__restrict col, dim_t od, const uint8_t *__restrict input_zp); /* col[ic][kh][kw][oh][ow] <-- im2col(im[ic][ih][iw]) */ template @@ -516,7 +672,7 @@ void im2col(const conv_gemm_conf_t &jcp, const data_type_t *__restrict im, // Generated code is more optimized for stride_w == 1 // because innermost loop is by width if (sw == 1) - parallel_nd(cb, jcp.kh, jcp.kw, oh_range, + parallel_nd_legacy(cb, jcp.kh, jcp.kw, oh_range, [&](dim_t ic, dim_t kh, dim_t kw, dim_t ohr) { const dim_t oh = ohr + oh_begin; const dim_t ih = oh * sh - tp + kh * dh; @@ -541,7 +697,7 @@ void im2col(const conv_gemm_conf_t &jcp, const data_type_t *__restrict im, } }); else - parallel_nd(cb, jcp.kh, jcp.kw, oh_range, + parallel_nd_legacy(cb, jcp.kh, jcp.kw, oh_range, [&](dim_t ic, dim_t kh, dim_t kw, dim_t ohr) { const dim_t oh = ohr + oh_begin; const dim_t ih = oh * sh - tp + kh * dh; @@ -580,7 +736,7 @@ template void im2col(const conv_gemm_conf_t &jcp, template void im2col_dt(const conv_gemm_conf_t &jcp, const void *__restrict _im, void *__restrict _imtr, orig_col_dt *__restrict _col, dim_t hs, - dim_t hb, dim_t ws, dim_t wb) { + dim_t hb, dim_t ws, dim_t wb, const uint8_t *__restrict input_zp) { // For performance reasons, use uint16_t as a proxy for bfloat16_t using im_dt = typename utils::conditional::data_type == bf16, @@ -603,6 +759,8 @@ void im2col_dt(const conv_gemm_conf_t &jcp, const void *__restrict _im, const dim_t tp = jcp.t_pad; const dim_t lp = jcp.l_pad; + bool with_input_zp = input_zp != nullptr; + if (jcp.outer_threading && sh == 1 && sw == 1 && dh == 1 && dw == 1) { /* im[ih][iw][ic] --> imtr[ic][ih][iw] --> col[kh][kw][ic][oh][ow] */ const dim_t hp = hs - tp; @@ -646,61 +804,103 @@ void im2col_dt(const conv_gemm_conf_t &jcp, const void *__restrict _im, const dim_t ow_start = saturate(dim_t(0), wb, ow_kw); const dim_t ow_end = saturate(dim_t(0), wb, ow_kw + iwb); for (dim_t ic = 0; ic < jcp.ic; ic++) { + uint8_t izp = with_input_zp ? input_zp[ic] : (uint8_t) 0; const ptrdiff_t col_idx_ic = col_idx_kw + ic * col_ic_str; const dim_t imtr_idx_ic = ic * imtr_ic_stride - imtr_shift; for (dim_t oh = 0; oh < oh_start; oh++) { const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb; - for (dim_t ow = 0; ow < wb; ++ow) - col[col_idx_oh + ow] = shift; + if (with_input_zp) { + for (dim_t ow = 0; ow < wb; ++ow) + col[col_idx_oh + ow] = izp; + } else { + for (dim_t ow = 0; ow < wb; ++ow) + col[col_idx_oh + ow] = shift; + } } for (dim_t oh = oh_start; oh < oh_end; oh++) { const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb; const ptrdiff_t imtr_idx_oh = imtr_idx_ic + oh * iwb; - for (dim_t ow = 0; ow < ow_start; ++ow) - col[col_idx_oh + ow] = shift; - for (dim_t ow = ow_start; ow < ow_end; ++ow) - col[col_idx_oh + ow] - = imtr[imtr_idx_oh + ow] + shift; - for (dim_t ow = ow_end; ow < wb; ++ow) - col[col_idx_oh + ow] = shift; + if (with_input_zp) { + for (dim_t ow = 0; ow < ow_start; ++ow) + col[col_idx_oh + ow] = izp; + for (dim_t ow = ow_start; ow < ow_end; ++ow) + col[col_idx_oh + ow] + = imtr[imtr_idx_oh + ow]; + for (dim_t ow = ow_end; ow < wb; ++ow) + col[col_idx_oh + ow] = izp; + } else { + for (dim_t ow = 0; ow < ow_start; ++ow) + col[col_idx_oh + ow] = shift; + for (dim_t ow = ow_start; ow < ow_end; ++ow) + col[col_idx_oh + ow] + = imtr[imtr_idx_oh + ow] + shift; + for (dim_t ow = ow_end; ow < wb; ++ow) + col[col_idx_oh + ow] = shift; + } } for (dim_t oh = oh_end; oh < hb; oh++) { const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb; - for (dim_t ow = 0; ow < wb; ++ow) - col[col_idx_oh + ow] = shift; + if (with_input_zp) { + for (dim_t ow = 0; ow < wb; ++ow) + col[col_idx_oh + ow] = izp; + } else { + for (dim_t ow = 0; ow < wb; ++ow) + col[col_idx_oh + ow] = shift; + } } } } } } else { - parallel_nd(jcp.kh, jcp.kw, jcp.ic, hb, + parallel_nd_legacy(jcp.kh, jcp.kw, jcp.ic, hb, [&](dim_t kh, dim_t kw, dim_t ic, dim_t oh) { const dim_t hp = tp - kh * dh; const dim_t ih = (oh + hs) * sh - hp; const ptrdiff_t col_idx_base = (((kh * jcp.kw + kw) * jcp.ic + ic) * hb + oh) * wb; + uint8_t izp = with_input_zp ? input_zp[ic] : (uint8_t) 0; if (ih < 0 || ih >= jcp.ih) - for (dim_t ow = 0; ow < wb; ow++) - col[col_idx_base + ow] = shift; + if (with_input_zp) { + for (dim_t ow = 0; ow < wb; ow++) + col[col_idx_base + ow] = izp; + } else { + for (dim_t ow = 0; ow < wb; ow++) + col[col_idx_base + ow] = shift; + } else { const dim_t wp = lp - kw * dw; const dim_t ow_start = saturate(dim_t(0), wb, div_up(wp, sw) - ws); const dim_t ow_end = saturate( dim_t(0), wb, div_up(jcp.iw + wp, sw) - ws); - for (dim_t ow = 0; ow < ow_start; ow++) - col[col_idx_base + ow] = shift; - const dim_t iw_base = ws * sw - wp; - const ptrdiff_t im_idx_base = ih * im_ih_stride + ic; - for (dim_t ow = ow_start; ow < ow_end; ow++) { - const dim_t iw = iw_base + ow * sw; - const ptrdiff_t im_idx - = im_idx_base + iw * im_iw_stride; - col[col_idx_base + ow] = im[im_idx] + shift; + if (with_input_zp) { + for (dim_t ow = 0; ow < ow_start; ow++) + col[col_idx_base + ow] = izp; + const dim_t iw_base = ws * sw - wp; + const ptrdiff_t im_idx_base = ih * im_ih_stride + ic; + for (dim_t ow = ow_start; ow < ow_end; ow++) { + const dim_t iw = iw_base + ow * sw; + const ptrdiff_t im_idx + = im_idx_base + iw * im_iw_stride; + col[col_idx_base + ow] = im[im_idx]; + } + for (dim_t ow = ow_end; ow < wb; ow++) + col[col_idx_base + ow] = izp; + } else { + for (dim_t ow = 0; ow < ow_start; ow++) + col[col_idx_base + ow] = shift; + const dim_t iw_base = ws * sw - wp; + const ptrdiff_t im_idx_base = ih * im_ih_stride + ic; + for (dim_t ow = ow_start; ow < ow_end; ow++) { + const dim_t iw = iw_base + ow * sw; + const ptrdiff_t im_idx + = im_idx_base + iw * im_iw_stride; + col[col_idx_base + ow] = im[im_idx] + shift; + } + for (dim_t ow = ow_end; ow < wb; ow++) + col[col_idx_base + ow] = shift; } - for (dim_t ow = ow_end; ow < wb; ow++) - col[col_idx_base + ow] = shift; } }); } @@ -708,17 +908,17 @@ void im2col_dt(const conv_gemm_conf_t &jcp, const void *__restrict _im, template void im2col_dt(const conv_gemm_conf_t &jcp, const void *__restrict im, void *__restrict imtr, - uint8_t *__restrict col, dim_t hs, dim_t hb, dim_t ws, dim_t wb); + uint8_t *__restrict col, dim_t hs, dim_t hb, dim_t ws, dim_t wb, const uint8_t *__restrict input_zp); template void im2col_dt(const conv_gemm_conf_t &jcp, const void *__restrict im, void *__restrict imtr, - uint8_t *__restrict col, dim_t hs, dim_t hb, dim_t ws, dim_t wb); + uint8_t *__restrict col, dim_t hs, dim_t hb, dim_t ws, dim_t wb, const uint8_t *__restrict input_zp); template void im2col_dt(const conv_gemm_conf_t &jcp, const void *__restrict im, void *__restrict imtr, float *__restrict col, - dim_t hs, dim_t hb, dim_t ws, dim_t wb); + dim_t hs, dim_t hb, dim_t ws, dim_t wb, const uint8_t *__restrict input_zp); template void im2col_dt(const conv_gemm_conf_t &jcp, const void *__restrict im, void *__restrict imtr, - bfloat16_t *__restrict col, dim_t hs, dim_t hb, dim_t ws, dim_t wb); + bfloat16_t *__restrict col, dim_t hs, dim_t hb, dim_t ws, dim_t wb, const uint8_t *__restrict input_zp); /* im[id][ih][iw][ic] <-- col2im_dt_3d(col[od][oh][ow][kd][kh][kw][ic]) */ template @@ -1084,16 +1284,14 @@ status_t init_conf(conv_gemm_conf_t &jcp, CHECK(memory_desc_init_by_tag(src_md, desired_src_tag)); src_tag = desired_src_tag; } else { - src_tag = memory_desc_matches_one_of_tag( - src_md, nwc, nhwc, ndhwc, ncw, nchw, ncdhw); + src_tag = src_d.mb_stride_relaxed_match(nwc, nhwc, ndhwc, ncw, nchw, ncdhw); } if (dst_d.format_kind() == format_kind::any) { CHECK(memory_desc_init_by_tag(dst_md, desired_dst_tag)); dst_tag = desired_dst_tag; } else { - dst_tag = memory_desc_matches_one_of_tag( - dst_md, nwc, nhwc, ndhwc, ncw, nchw, ncdhw); + dst_tag = dst_d.mb_stride_relaxed_match(nwc, nhwc, ndhwc, ncw, nchw, ncdhw); } if (src_tag == format_tag::undef || dst_tag == format_tag::undef) @@ -1138,6 +1336,21 @@ status_t init_conf(conv_gemm_conf_t &jcp, const bool is_bwd_w = jcp.prop_kind == backward_weights; const bool is_fwd = !is_bwd_d && !is_bwd_w; + jcp.with_input_zp = !attr.input_zero_points_.has_default_values(); + if (jcp.with_input_zp) { + if (attr.input_zero_points_.count_ != 1 && attr.input_zero_points_.count_ != jcp.ic * jcp.ngroups) + return status::unimplemented; + + if (attr.output_compensations_.count_ != jcp.oc * jcp.ngroups) + return status::unimplemented; + } + + jcp.with_weights_zp = !attr.weights_zero_points_.has_default_values(); + if (jcp.with_weights_zp) { + if (attr.weights_zero_points_.count_ != 1 && attr.weights_zero_points_.count_ != jcp.oc * jcp.ngroups) + return status::unimplemented; + } + bool is_int8_conv = (is_fwd ? utils::one_of(src_d.data_type(), s8, u8) : utils::one_of(dst_d.data_type(), s8, u8)) && weights_d.data_type() == s8; @@ -1164,6 +1377,8 @@ status_t init_conf(conv_gemm_conf_t &jcp, jcp.with_binary = binary_ind != -1; const int sum_ind = jcp.post_ops.find(primitive_kind::sum); jcp.with_sum = sum_ind != -1; + const int depthwise_ind = jcp.post_ops.find(primitive_kind::depthwise); + jcp.with_depthwise = depthwise_ind != -1; bool is_bf16_conv = false || (is_fwd @@ -2114,7 +2329,8 @@ status_t init_conf(conv_gemm_conf_t &jcp, if (size) scratchpad.book(key_conv_gemm_zp_src_comp, size); } - if (scratchpad.size() > scratchpad_limit) return status::unimplemented; + // [WA] Disabled condition to prevent fallback on ref convolution implementation +// if (scratchpad.size() > scratchpad_limit) return status::unimplemented; return status::success; } diff --git a/src/cpu/gemm_convolution_utils.hpp b/src/cpu/gemm_convolution_utils.hpp index 3e54deacf1a..424f377423a 100644 --- a/src/cpu/gemm_convolution_utils.hpp +++ b/src/cpu/gemm_convolution_utils.hpp @@ -43,6 +43,7 @@ struct conv_gemm_conf_t { bool with_bias; bool with_eltwise; bool with_binary; + bool with_depthwise; bool with_sum; post_ops_t post_ops; bool is_nspc; @@ -68,6 +69,9 @@ struct conv_gemm_conf_t { data_type_t sum_data_type; size_t dst_os_stride; size_t scale_idx_mult; + + bool with_input_zp; + bool with_weights_zp; }; struct single_gemm_conv_chunk_desc_t { @@ -83,6 +87,28 @@ struct single_gemm_conv_chunk_desc_t { dim_t w_size_ = 0; }; +namespace gemm_convolution_utils { + +struct pp_kernel_t { + static pp_kernel_t *create( + const convolution_pd_t *pd, const conv_gemm_conf_t &jcp); + + virtual ~pp_kernel_t() = default; + + virtual void operator()(float *dst, const float *bias, const int len, const int oc_start, const int oc_work, const int oc_stride, + const std::vector& post_ops_binary_rhs_arg_vec) const = 0; + + virtual status_t create_kernel() { return status::success; } + +protected: + pp_kernel_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp); + + bool do_bias_ = false; + post_ops_t post_ops_; +}; + +} // namespace gemm_convolution_utils + namespace jit_gemm_convolution_utils { template void im2col_3d(const conv_gemm_conf_t &jcp, const data_type_t *im, @@ -94,7 +120,7 @@ void transpose_dt(const conv_gemm_conf_t &jcp, const T *__restrict im, template void im2col_dt_3d(const conv_gemm_conf_t &jcp, const void *__restrict im, - col_dt *__restrict col, dim_t od); + col_dt *__restrict col, dim_t od, const uint8_t *__restrict input_zp = nullptr); template void im2col(const conv_gemm_conf_t &jcp, const data_type_t *__restrict im, @@ -103,7 +129,7 @@ void im2col(const conv_gemm_conf_t &jcp, const data_type_t *__restrict im, template void im2col_dt(const conv_gemm_conf_t &jcp, const void *__restrict im, void *__restrict imtr, col_dt *__restrict col, dim_t hs, dim_t hb, - dim_t ws, dim_t wb); + dim_t ws, dim_t wb, const uint8_t *__restrict input_zp = nullptr); template void col2im_dt( diff --git a/src/cpu/gemm_inner_product.cpp b/src/cpu/gemm_inner_product.cpp index 8a073e0d1bd..4052229b7a0 100644 --- a/src/cpu/gemm_inner_product.cpp +++ b/src/cpu/gemm_inner_product.cpp @@ -41,7 +41,8 @@ status_t gemm_inner_product_fwd_t::execute_forward( = binary_injector_utils::prepare_binary_args( this->pd()->attr()->post_ops_, ctx); - const dim_t MB = pd()->MB(); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const dim_t OC = pd()->OC(); const dim_t IC = pd()->IC_total_padded(); @@ -86,7 +87,8 @@ status_t gemm_inner_product_bwd_data_t::execute_backward_data( auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC); - const dim_t MB = pd()->MB(); + auto MB = CTX_IN_BATCH(DNNL_ARG_DIFF_DST); + const dim_t OC = pd()->OC(); const dim_t IC = pd()->IC_total_padded(); diff --git a/src/cpu/gemm_x8s8s32x_convolution.cpp b/src/cpu/gemm_x8s8s32x_convolution.cpp index 861e7e38215..21ab1a7a11a 100644 --- a/src/cpu/gemm_x8s8s32x_convolution.cpp +++ b/src/cpu/gemm_x8s8s32x_convolution.cpp @@ -120,6 +120,11 @@ status_t gemm_x8s8s32x_convolution_fwd_t::execute_forward( = binary_injector_utils::prepare_binary_args( this->pd()->attr()->post_ops_, ctx); + DEFINE_INPUT_ZERO_POINTS_BUFFER(input_zp_base, jcp); + DEFINE_OUTPUT_COMPENSATION_BUFFER(output_compensation_base, jcp); + + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + auto scratchpad = ctx.get_scratchpad_grantor(); assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1)); @@ -133,7 +138,8 @@ status_t gemm_x8s8s32x_convolution_fwd_t::execute_forward( parallel(jcp.nthr, [&](const int ithr, const int nthr) { status_t st_thr = execute_forward_thr(ithr, nthr, src_base, wei_base, bia_base, dst_base, zp, scratchpad, - post_ops_binary_rhs_arg_vec.data(), ctx); + post_ops_binary_rhs_arg_vec.data(), ctx, MB, + input_zp_base, output_compensation_base); if (st_thr != status::success) st = st_thr; }); @@ -153,7 +159,8 @@ status_t gemm_x8s8s32x_convolution_fwd_t::execute_forward_thr(const int ithr, const char *bia_base, void *dst_base, const zero_point_call_params_t &zp, const memory_tracking::grantor_t &scratchpad, - const void *post_ops_binary_rhs_arg_vec, const exec_ctx_t &ctx) const { + const void *post_ops_binary_rhs_arg_vec, const exec_ctx_t &ctx, int MB, + const uint8_t *input_zp_base, const int32_t *output_compensation_base) const { const conv_gemm_conf_t &jcp = this->pd()->jcp_; @@ -182,16 +189,11 @@ status_t gemm_x8s8s32x_convolution_fwd_t::execute_forward_thr(const int ithr, + (ptrdiff_t)ithr * jcp.oh_block * jcp.ow_block * jcp.oc; const int32_t *_wei_comp - = jcp.signed_input ? get_wei_comp(wei_base, wei_md) : nullptr; + = jcp.signed_input ? get_wei_comp(wei_base, wei_md) : + jcp.with_input_zp ? output_compensation_base : nullptr; - const bool should_apply_zp_src_comp_pad = jcp.zp.src_exists - && jit_gemm_convolution_utils::padding_exists(jcp); - const bool should_apply_zp_src_comp_pad_jit_pp - = should_apply_zp_src_comp_pad - && gemm_x8s8s32x_convolution_utils::mayiuse_jit_pp_kernel(); - const bool should_apply_zp_src_comp_outside_pp - = should_apply_zp_src_comp_pad - && !gemm_x8s8s32x_convolution_utils::mayiuse_jit_pp_kernel(); + const bool should_apply_zp_src_comp_pad_jit_pp = false; + const bool should_apply_zp_src_comp_outside_pp = false; dim_t g {0}, n {0}, ohb {0}, owb {0}; dim_t start = 0, end = 0; @@ -203,11 +205,11 @@ status_t gemm_x8s8s32x_convolution_fwd_t::execute_forward_thr(const int ithr, const dim_t nb_oh = div_up(jcp.oh, jcp.oh_block); const dim_t nb_ow = div_up(jcp.ow, jcp.ow_block); - const dim_t work_amount = jcp.ngroups * jcp.mb * nb_oh * nb_ow; + const dim_t work_amount = jcp.ngroups * MB * nb_oh * nb_ow; balance211(work_amount, nthr, ithr, start, end); - nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow); + nd_iterator_init(start, n, MB, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow); const uint8_t shift = jcp.signed_input ? 128 : 0; - parallel_nd(jcp.im2col_sz, [&](ptrdiff_t i) { col[i] = shift; }); + parallel_nd_legacy(jcp.im2col_sz, [&](ptrdiff_t i) { col[i] = shift; }); status_t st = status::success; @@ -227,6 +229,11 @@ status_t gemm_x8s8s32x_convolution_fwd_t::execute_forward_thr(const int ithr, for (int od = 0; od < jcp.od; od++) { const auto dst_off = n * dst_mb_stride + g * dst_g_stride + ((od * jcp.oh + oh) * jcp.ow + ow) * jcp.dst_os_stride; + + const uint8_t *__restrict input_zp = nullptr; + if (jcp.with_input_zp) + input_zp = input_zp_base + g * jcp.ic; + char *__restrict dst = (char *)dst_base + types::data_type_size(dst_md.data_type()) * dst_off; if (jcp.im2col_sz) { @@ -234,20 +241,20 @@ status_t gemm_x8s8s32x_convolution_fwd_t::execute_forward_thr(const int ithr, case data_type::s8: { if (is_problem_3d) jit_gemm_convolution_utils::im2col_dt_3d(jcp, imtr, col, od); + uint8_t>(jcp, imtr, col, od, input_zp); else jit_gemm_convolution_utils::im2col_dt(jcp, src, imtr, col, oh, h_step, - ow, w_step); + ow, w_step, input_zp); } break; case data_type::u8: { if (is_problem_3d) jit_gemm_convolution_utils::im2col_dt_3d(jcp, imtr, col, od); + uint8_t>(jcp, imtr, col, od, input_zp); else jit_gemm_convolution_utils::im2col_dt(jcp, src, imtr, col, oh, h_step, - ow, w_step); + ow, w_step, input_zp); } break; default: assert(!"unsupported data type"); break; } @@ -265,10 +272,10 @@ status_t gemm_x8s8s32x_convolution_fwd_t::execute_forward_thr(const int ithr, const float onef = 1.f, zerof = 0.f; const char *__restrict src_od = src + od * jcp.oh * jcp.ow * jcp.ngroups * jcp.ic; - st = gemm_s8x8s32("N", BT, jcp.signed_input ? "C" : "F", &M, &N, &K, + st = gemm_s8x8s32("N", BT, (jcp.signed_input || jcp.with_input_zp) ? "C" : "F", &M, &N, &K, &onef, wei, &LDA, &off_a, jcp.im2col_sz ? col : (uint8_t *)src_od, &LDB, &off_b, - &zerof, acc, &M, jcp.signed_input ? wei_comp : &off_c); + &zerof, acc, &M, (jcp.signed_input || jcp.with_input_zp) ? wei_comp : &off_c); if (st != status::success) return st; @@ -297,7 +304,7 @@ status_t gemm_x8s8s32x_convolution_fwd_t::execute_forward_thr(const int ithr, *pd()->dst_md(), chunk_desc); }); } - nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow); + nd_iterator_step(n, MB, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow); } return st; @@ -310,6 +317,8 @@ status_t gemm_x8s8s32x_convolution_bwd_data_t::execute_backward_data( auto bia_base = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); auto diff_src_base = CTX_OUT_MEM(char *, DNNL_ARG_DIFF_SRC); + auto MB = CTX_IN_BATCH(DNNL_ARG_DIFF_DST); + auto scratchpad = ctx.get_scratchpad_grantor(); const conv_gemm_conf_t &jcp = this->pd()->jcp_; @@ -318,7 +327,7 @@ status_t gemm_x8s8s32x_convolution_bwd_data_t::execute_backward_data( parallel(jcp.nthr, [&](const int ithr, const int nthr) { status_t st_thr = execute_backward_data_thr(ithr, nthr, diff_dst_base, - wei_base, bia_base, diff_src_base, scratchpad); + wei_base, bia_base, diff_src_base, scratchpad, MB); if (st_thr != status::success) st = st_thr; }); @@ -329,7 +338,7 @@ status_t gemm_x8s8s32x_convolution_bwd_data_t::execute_backward_data( status_t gemm_x8s8s32x_convolution_bwd_data_t::execute_backward_data_thr( const int ithr, const int nthr, const char *diff_dst_base, const int8_t *wei_base, const char *bia_base, char *diff_src_base, - const memory_tracking::grantor_t &scratchpad) const { + const memory_tracking::grantor_t &scratchpad, int MB) const { const conv_gemm_conf_t &jcp = this->pd()->jcp_; const auto diff_dst_md = memory_desc_wrapper(pd()->diff_dst_md()); @@ -350,7 +359,7 @@ status_t gemm_x8s8s32x_convolution_bwd_data_t::execute_backward_data_thr( /* scale_idx_mult = 1 for per_oc scales and 0, otherwise */ const int scale_idx_mult = pd()->attr()->output_scales_.mask_ == (1 << 1); const float *__restrict scales = pd()->attr()->output_scales_.scales_; - const dim_t work_amount = jcp.ngroups * jcp.mb; + const dim_t work_amount = jcp.ngroups * MB; int *__restrict col = scratchpad.get(key_conv_gemm_col) + (ptrdiff_t)ithr * jcp.im2col_sz; @@ -361,7 +370,7 @@ status_t gemm_x8s8s32x_convolution_bwd_data_t::execute_backward_data_thr( dim_t start = 0, end = 0; balance211(work_amount, nthr, ithr, start, end); - nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups); + nd_iterator_init(start, n, MB, g, jcp.ngroups); for (dim_t iwork = start; iwork < end; ++iwork) { const int8_t *__restrict wei = wei_base + g * wei_g_stride; @@ -424,7 +433,7 @@ status_t gemm_x8s8s32x_convolution_bwd_data_t::execute_backward_data_thr( diff_src_md.data_type(), d, diff_src_loc, ic); } }); - nd_iterator_step(n, jcp.mb, g, jcp.ngroups); + nd_iterator_step(n, MB, g, jcp.ngroups); } return status::success; diff --git a/src/cpu/gemm_x8s8s32x_convolution.hpp b/src/cpu/gemm_x8s8s32x_convolution.hpp index 42daae12bce..b8face77b73 100644 --- a/src/cpu/gemm_x8s8s32x_convolution.hpp +++ b/src/cpu/gemm_x8s8s32x_convolution.hpp @@ -66,19 +66,23 @@ struct gemm_x8s8s32x_convolution_fwd_t : public primitive_t { && attr()->has_default_values(skip_mask_t::oscale | skip_mask_t::zero_points_runtime | skip_mask_t::post_ops - | skip_mask_t::sum_dt, + | skip_mask_t::sum_dt + | primitive_attr_t::skip_mask_t::input_zero_points + | primitive_attr_t::skip_mask_t::output_compensations + | primitive_attr_t::skip_mask_t::sum_dt, dst_type) - && attr()->post_ops_.check_sum_consistent_dt(dst_type) - && output_scales_mask_ok() && zero_points_valid(attr()); +// && attr()->post_ops_.check_sum_consistent_dt(dst_type) + && output_scales_mask_ok() && zero_points_valid(attr()) + && post_ops_ok(); if (!ok) return status::unimplemented; auto scratchpad = scratchpad_registry().registrar(); CHECK(jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, *desc(), src_md_, weights_md_, dst_md_, bias_md_, attr_, dnnl_get_max_threads())); - if (!gemm_x8s8s32x_convolution_utils::post_ops_ok( - attr()->post_ops_, &dst_md_)) - return status::unimplemented; +// if (!gemm_x8s8s32x_convolution_utils::post_ops_ok( +// attr()->post_ops_, &dst_md_)) +// return status::unimplemented; return status::success; } @@ -89,6 +93,22 @@ struct gemm_x8s8s32x_convolution_fwd_t : public primitive_t { const auto &mask = attr()->output_scales_.mask_; return mask == 0 || mask == 1 << 1; } + + bool post_ops_ok() const { + using namespace dnnl::impl::primitive_kind; + auto const &po = attr()->post_ops_; + + auto all_post_ops_supported = [&]() { + bool ok = true; + + for (int i = 0; i < po.len(); i++) { + ok = ok && utils::one_of(po.entry_[i].kind, sum, eltwise, depthwise, quantization); + } + return ok; + }; + + return all_post_ops_supported(); + } }; gemm_x8s8s32x_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} @@ -110,7 +130,8 @@ struct gemm_x8s8s32x_convolution_fwd_t : public primitive_t { void *dst_base, const zero_point_call_params_t &zp, const memory_tracking::grantor_t &scratchpad, const void *post_ops_binary_rhs_arg_vec, - const exec_ctx_t &ctx) const; + const exec_ctx_t &ctx, int MB, + const uint8_t *input_zp_base, const int32_t *output_compensation_base) const; using pp_ker_t = gemm_x8s8s32x_convolution_utils::pp_ker_t; std::unique_ptr pp_ker_; @@ -173,7 +194,7 @@ struct gemm_x8s8s32x_convolution_bwd_data_t : public primitive_t { status_t execute_backward_data_thr(const int ithr, const int nthr, const char *diff_dst_base, const int8_t *wei_base, const char *bia_base, char *diff_src_base, - const memory_tracking::grantor_t &scratchpad) const; + const memory_tracking::grantor_t &scratchpad, int MB) const; const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } }; diff --git a/src/cpu/gemm_x8s8s32x_convolution_utils.cpp b/src/cpu/gemm_x8s8s32x_convolution_utils.cpp index 9dbf0f00336..8cee3d639ab 100644 --- a/src/cpu/gemm_x8s8s32x_convolution_utils.cpp +++ b/src/cpu/gemm_x8s8s32x_convolution_utils.cpp @@ -40,14 +40,28 @@ template struct ref_pp_ker_t : pp_ker_t { ref_pp_ker_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) : pp_ker_t(pd, jcp) { - if (jcp.with_eltwise || jcp.with_binary) { - ref_post_ops_.reset(new ref_post_ops_t(jcp.post_ops)); + for (int i = 0; i < post_ops_.len(); i++) { + auto &post_op = post_ops_.entry_[i]; + if (post_op.is_eltwise()) { + ref_eltwise_injectors_.push_back(new ref_eltwise_scalar_fwd_t(post_op.eltwise)); + } else if (post_op.is_depthwise()) { + ref_depthwise_injectors_.push_back(new ref_depthwise_scalar_fwd_t( + post_op.depthwise.alg)); + } } } + ~ref_pp_ker_t() { + for (auto impl : ref_eltwise_injectors_) + delete impl; + ref_eltwise_injectors_.clear(); + for (auto impl : ref_depthwise_injectors_) + delete impl; + ref_depthwise_injectors_.clear(); + } using acc_data_t = pp_ker_t::acc_data_t; - void operator()(void *dst, const acc_data_t *acc, const char *bias, + void operator()(void *dst, acc_data_t *acc, const char *bias, const float *scales, float sum_scale, float signed_scale, int g, size_t start, size_t end, const zero_point_call_params_t &zp, const void *post_ops_binary_rhs_arg_vec, const void *dst_orig, @@ -55,72 +69,173 @@ struct ref_pp_ker_t : pp_ker_t { const single_gemm_conv_chunk_desc_t &chunk_desc) const override; private: - std::unique_ptr ref_post_ops_; + nstl::vector ref_eltwise_injectors_; + nstl::vector ref_depthwise_injectors_; }; template -void ref_pp_ker_t::operator()(void *void_dst, const acc_data_t *acc, - const char *bias, const float *scales, float sum_scale, +void ref_pp_ker_t::operator()(void *void_dst, acc_data_t *acc, const char *bias, const float *scales, float sum_scale, float signed_scale, int g, size_t start, size_t end, const zero_point_call_params_t &zp, - const void * /* post_ops_binary_rhs_arg_vec */, + const void * post_ops_binary_rhs_arg_vec, const void * /* dst_orig */, const exec_ctx_t &ctx, const memory_desc_t &dst_md, const single_gemm_conv_chunk_desc_t &chunk_desc) const { if (end <= start) return; - assert(data_traits::data_type == jcp_.dst_data_type); - - const lldiv_t dv_start = std::div((long long)start, (long long)jcp_.oc); - const lldiv_t dv_end = std::div((long long)(end - 1), (long long)jcp_.oc); - const size_t first_oc = dv_start.rem; - const size_t last_oc = dv_end.rem; - const size_t first_os = dv_start.quot; - const size_t last_os = dv_end.quot; - const int32_t zp_dst_val = jcp_.zp.dst_exists ? *(zp.dst) : 0; + assert(data_traits::data_type == dst_data_type_); + dst_data_t *dst = (dst_data_t *)void_dst; - ref_post_ops_t::args_t args; - args.ctx = &ctx; - args.dst_md = &dst_md; + const size_t first_oc = start % OC_; + const size_t last_oc = (end - 1) % OC_; + const size_t first_os = start / OC_; + const size_t last_os = (end - 1) / OC_; + if (post_ops_.len() == 0) { + for (size_t os = first_os; os <= last_os; os++) { + const size_t start_oc = (os == first_os) ? first_oc : 0; + const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1; + for (size_t oc = start_oc; oc <= end_oc; oc++) { + const size_t acc_off = os * jcp_.oc + oc; + const size_t dst_off = os * dst_os_stride_ + oc; - for (size_t os = first_os; os <= last_os; os++) { - const size_t start_oc = (os == first_os) ? first_oc : 0; - const size_t end_oc = (os == last_os) ? last_oc : jcp_.oc - 1; - for (size_t oc = start_oc; oc <= end_oc; oc++) { - const size_t acc_off = os * jcp_.oc + oc; - const size_t dst_off = os * jcp_.dst_os_stride + oc; + float d = (float) (acc[acc_off]); + if (jcp_.signed_input) d *= signed_scale; - int32_t data_s32 = acc[acc_off]; + if (do_bias_) + d += math::get_bias(bias, g * jcp_.oc + oc, bias_data_type_); - if (jcp_.zp.src_exists) { - const auto oc_offset = g * jcp_.oc + oc; - data_s32 += zp.src_comp[oc_offset]; + d *= scales[(g * jcp_.oc + oc) * scale_idx_mult_]; + dst[dst_off] = qz_a1b0()(d); } + } + } else { + float* acc_fp = reinterpret_cast(acc); - float data = static_cast(data_s32); + auto load = [&](int idx, size_t oc, size_t os, size_t acc_off, size_t dst_off) { + float d; + if (idx == 0) { + d = (float) (acc[acc_off]); - if (jcp_.signed_input) data *= signed_scale; + if (jcp_.signed_input) + d *= signed_scale; - if (jcp_.with_bias) { - const float b = io::load_float_value( - jcp_.bias_data_type, bias, g * jcp_.oc + oc); - data += b; - } + if (do_bias_) + d += math::get_bias(bias, g * jcp_.oc + oc, + bias_data_type_); - data *= scales[(g * jcp_.oc + oc) * jcp_.scale_idx_mult]; - if (jcp_.with_sum) - data += sum_scale - * io::load_float_value( - jcp_.sum_data_type, void_dst, dst_off); - if (jcp_.with_eltwise || jcp_.with_binary) { - args.l_offset = (g * jcp_.oc + oc) * jcp_.os; - ref_post_ops_->execute(data, args); + d *= scales[(g * jcp_.oc + oc) * scale_idx_mult_]; + } else { + d = acc_fp[acc_off]; } - if (jcp_.zp.dst_exists) data += zp_dst_val; + return d; + }; + + auto store = [&](int idx, float d, size_t acc_off, size_t dst_off) { + if (idx == post_ops_.len() - 1) + dst[dst_off] = qz_a1b0()(d); + else + acc_fp[acc_off] = d; + }; + + auto post_ops_data_ptrs = reinterpret_cast(post_ops_binary_rhs_arg_vec); + std::size_t post_ops_data_idx = 0; + int eltwise_inj_idx = 0; + int depthwise_inj_idx = 0; + for (int i = 0; i < post_ops_.len(); i++) { + auto &post_op = post_ops_.entry_[i]; + if (post_op.is_eltwise()) { + for (size_t os = first_os; os <= last_os; os++) { + const size_t start_oc = (os == first_os) ? first_oc : 0; + const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1; + for (size_t oc = start_oc; oc <= end_oc; oc++) { + const size_t acc_off = os * jcp_.oc + oc; + const size_t dst_off = os * this->dst_os_stride_ + oc; + + float d = load(i, oc, os, acc_off, dst_off); + + d = ref_eltwise_injectors_[eltwise_inj_idx]->compute_scalar(d); + + store(i, d, acc_off, dst_off); + } + } + eltwise_inj_idx++; + } else if (post_op.is_depthwise()) { + for (size_t os = first_os; os <= last_os; os++) { + const size_t start_oc = (os == first_os) ? first_oc : 0; + const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1; + for (size_t oc = start_oc; oc <= end_oc; oc++) { + const size_t acc_off = os * jcp_.oc + oc; + const size_t dst_off = os * this->dst_os_stride_ + oc; + + auto depthwise_base = post_ops_data_ptrs[post_ops_data_idx]; + auto depthwise_weights = depthwise_base + post_op.depthwise.offset[post_op.depthwise.scales]; + auto depthwise_bias = depthwise_base + post_op.depthwise.offset[post_op.depthwise.shifts]; + + float d = load(i, oc, os, acc_off, dst_off); + + d = ref_depthwise_injectors_[depthwise_inj_idx]->compute_scalar(d, depthwise_weights + g * jcp_.oc + oc, + depthwise_bias + g * jcp_.oc + oc); + + store(i, d, acc_off, dst_off); + + } + } + post_ops_data_idx++; + depthwise_inj_idx++; + } else if (post_op.is_quantization()) { + for (size_t os = first_os; os <= last_os; os++) { + const size_t start_oc = (os == first_os) ? first_oc : 0; + const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1; + for (size_t oc = start_oc; oc <= end_oc; oc++) { + const size_t acc_off = os * jcp_.oc + oc; + const size_t dst_off = os * this->dst_os_stride_ + oc; + + auto quant = post_op.quantization; + auto quantization_base = post_ops_data_ptrs[post_ops_data_idx]; + auto pcl = quantization_base + post_op.quantization.offset[quant.crop_low]; + auto pch = quantization_base + post_op.quantization.offset[quant.crop_high]; + auto pisc = quantization_base + post_op.quantization.offset[quant.inp_scale]; + auto pish = quantization_base + post_op.quantization.offset[quant.inp_shift]; + auto posc = quantization_base + post_op.quantization.offset[quant.output_scale]; + auto posh = quantization_base + post_op.quantization.offset[quant.output_shift]; + + float d = load(i, oc, os, acc_off, dst_off); + + int cl_idx = !quant.per_channel[quant.crop_low] ? 0 : g * jcp_.oc + oc; + int ch_idx = !quant.per_channel[quant.crop_high] ? 0 : g * jcp_.oc + oc; + int isc_idx = !quant.per_channel[quant.inp_scale] ? 0 : g * jcp_.oc + oc; + int ish_idx = !quant.per_channel[quant.inp_shift] ? 0 : g * jcp_.oc + oc; + int osc_idx = !quant.per_channel[quant.output_scale] ? 0 : g * jcp_.oc + oc; + int osh_idx = !quant.per_channel[quant.output_shift] ? 0 : g * jcp_.oc + oc; + + d = nstl::min(pch[ch_idx], nstl::max(pcl[cl_idx], d)); + d = d * pisc[isc_idx] + pish[ish_idx]; + d = roundf(d); + d = d * posc[osc_idx] + posh[osh_idx]; + + store(i, d, acc_off, dst_off); - io::store_float_value(jcp_.dst_data_type, data, void_dst, dst_off); + } + } + post_ops_data_idx++; + } else if (post_op.is_sum()) { + for (size_t os = first_os; os <= last_os; os++) { + const size_t start_oc = (os == first_os) ? first_oc : 0; + const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1; + for (size_t oc = start_oc; oc <= end_oc; oc++) { + const size_t acc_off = os * jcp_.oc + oc; + const size_t dst_off = os * this->dst_os_stride_ + oc; + + float d = load(i, oc, os, acc_off, dst_off); + + d += post_op.sum.scale * math::get_sum((char *) dst, dst_off, post_op.sum.dt); + + store(i, d, acc_off, dst_off); + } + } + } } } } @@ -128,7 +243,26 @@ void ref_pp_ker_t::operator()(void *void_dst, const acc_data_t *acc, // Interface section pp_ker_t::pp_ker_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) - : jcp_(jcp) {} + : jcp_(jcp) + , post_ops_(pd->attr()->post_ops_) + , OC_(jcp_.oc) +{ + const auto dst_md = memory_desc_wrapper(pd->dst_md()); + + dst_os_stride_ = dst_md.blocking_desc().strides[pd->ndims() - 1]; + dst_data_type_ = dst_md.data_type(); + + do_scale_ = !pd->attr()->output_scales_.has_default_values(); + if (do_scale_) { + scale_idx_mult_ = (pd->attr()->output_scales_.mask_ == (1 << 1)); + } + + do_bias_ = pd->with_bias(); + if (do_bias_) { + bias_data_type_ = pd->desc()->bias_desc.data_type; + assert(bias_data_type_ != data_type::undef); + } +} pp_ker_t *pp_ker_t::create( const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) { @@ -148,30 +282,6 @@ pp_ker_t *pp_ker_t::create( return nullptr; } -bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d) { -#if DNNL_X64 - return x64::gemm_x8s8s32x_convolution_utils::post_ops_ok(post_ops, dst_d); -#endif - return std::all_of(post_ops.entry_.cbegin(), post_ops.entry_.cend(), - [](const dnnl_post_ops::entry_t &post_op) { - return post_op.is_eltwise() || post_op.is_sum() - || post_op.is_binary(); - }); -} - -bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_t *dst_d) { - const auto dst_md = memory_desc_wrapper(dst_d); - return post_ops_ok(post_ops, &dst_md); -} - -bool mayiuse_jit_pp_kernel() noexcept { -#if DNNL_X64 - return x64::gemm_x8s8s32x_convolution_utils::mayiuse_jit_pp_kernel(); -#else - return false; -#endif -} - } // namespace gemm_x8s8s32x_convolution_utils } // namespace cpu } // namespace impl diff --git a/src/cpu/gemm_x8s8s32x_convolution_utils.hpp b/src/cpu/gemm_x8s8s32x_convolution_utils.hpp index 12c1698f324..c707d5b699f 100644 --- a/src/cpu/gemm_x8s8s32x_convolution_utils.hpp +++ b/src/cpu/gemm_x8s8s32x_convolution_utils.hpp @@ -34,24 +34,31 @@ struct pp_ker_t { typedef typename prec_traits::type acc_data_t; - virtual void operator()(void *dst, const acc_data_t *acc, const char *bias, + virtual void operator()(void *dst, acc_data_t *acc, const char *bias, const float *scales, float sum_scale, float signed_scale, int g, size_t start, size_t end, const zero_point_call_params_t &zp, const void *post_ops_binary_rhs_arg_vec, const void *dst_orig, const exec_ctx_t &ctx, const memory_desc_t &dst_md, const single_gemm_conv_chunk_desc_t &chunk_desc) const = 0; + size_t dst_os_stride_; + virtual status_t create_kernel() { return status::success; } protected: pp_ker_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp); const conv_gemm_conf_t &jcp_; -}; + const post_ops_t &post_ops_; + size_t OC_; -bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d); -bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_t *dst_d); -bool mayiuse_jit_pp_kernel() noexcept; + bool do_bias_ = false; + bool do_scale_ = false; + size_t scale_idx_mult_ = 0; + + data_type_t bias_data_type_ = data_type::undef; + data_type_t dst_data_type_ = data_type::undef; +}; } // namespace gemm_x8s8s32x_convolution_utils } // namespace cpu diff --git a/src/cpu/gemm_x8s8s32x_inner_product.cpp b/src/cpu/gemm_x8s8s32x_inner_product.cpp index eba10ded0a3..94f9f9c8e7e 100644 --- a/src/cpu/gemm_x8s8s32x_inner_product.cpp +++ b/src/cpu/gemm_x8s8s32x_inner_product.cpp @@ -42,7 +42,8 @@ status_t gemm_x8s8s32x_inner_product_fwd_t::execute_forward( = binary_injector_utils::prepare_binary_args( this->pd()->attr()->post_ops_, ctx); - const dim_t MB = pd()->MB(); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const dim_t OC = pd()->OC(); const dim_t IC = pd()->IC(); diff --git a/src/cpu/nchw_pooling.cpp b/src/cpu/nchw_pooling.cpp index 792f97c5892..35575ad2f5a 100644 --- a/src/cpu/nchw_pooling.cpp +++ b/src/cpu/nchw_pooling.cpp @@ -40,14 +40,20 @@ template status_t nchw_pooling_fwd_t::execute_forward( const exec_ctx_t &ctx) const { const auto alg = pd()->desc()->alg_kind; - const auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); + const auto src_ = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const memory_desc_wrapper ws_d(pd()->workspace_md()); + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef; - const dim_t MB = pd()->MB(); + auto src = src_ + src_d.off_l(0); + dst += dst_d.off_l(0); + const dim_t C = pd()->OC(); const dim_t OD = pd()->OD(); const dim_t OH = pd()->OH(); @@ -64,10 +70,11 @@ status_t nchw_pooling_fwd_t::execute_forward( const dim_t padF = pd()->padFront(); const dim_t padT = pd()->padT(); const dim_t padL = pd()->padL(); + const dim_t padB = pd()->padB(); + const dim_t padR = pd()->padR(); + const dim_t padBack = pd()->padBack(); - const auto apply_offset = [](int index, int offset) { - return (index > offset) ? index - offset : 0; - }; + const bool do_post_ops = pd()->attr()->post_ops_.len() > 0; const auto set_ws = [=](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow, dim_t value) { @@ -88,6 +95,12 @@ status_t nchw_pooling_fwd_t::execute_forward( const auto ker_max = [=](data_t *d, dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) { + bool is_initialized = false; + + const auto src_offset = (size_t)IW * IH * ID * C * mb + (size_t)IW * IH * ID * c; + const auto local_src = &src[src_offset]; + const auto IWH = (size_t)IW * IH; + for_(dim_t kd = 0; kd < KD; ++kd) for_(dim_t kh = 0; kh < KH; ++kh) for (dim_t kw = 0; kw < KW; ++kw) { @@ -99,80 +112,128 @@ status_t nchw_pooling_fwd_t::execute_forward( if (ih < 0 || ih >= IH) continue; if (iw < 0 || iw >= IW) continue; - const auto src_offset = (size_t)IW * IH * ID * C * mb - + (size_t)IW * IH * ID * c + (size_t)IW * IH * id - + (size_t)IW * ih + (size_t)iw; - const auto &s = src[src_offset]; - if (s > d[0]) { + const auto local_src_offset = IWH * id + (size_t)IW * ih + (size_t)iw; + const auto s = local_src[local_src_offset]; + if (!is_initialized) { d[0] = s; set_ws(mb, c, od, oh, ow, kd * KH * KW + kh * KW + kw); + is_initialized = true; + } else { + if (s > d[0]) { + d[0] = s; + set_ws(mb, c, od, oh, ow, kd * KH * KW + kh * KW + kw); + } } } }; const auto ker_avg = [=](data_t *d, dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) { - const auto id_start = apply_offset(od * SD, padF); - const auto ih_start = apply_offset(oh * SH, padT); - const auto iw_start = apply_offset(ow * SW, padL); - const auto id_end = min(od * SD - padF + KD, ID); - const auto ih_end = min(oh * SH - padT + KH, IH); - const auto iw_end = min(ow * SW - padL + KW, IW); - - const auto num_summands = (alg == alg_kind::pooling_avg_include_padding) + auto id_start = od*SD - padF; + auto ih_start = oh*SH - padT; + auto iw_start = ow*SW - padL; + auto id_end = nstl::min(od*SD - padF + KD, ID + padBack); + auto ih_end = nstl::min(oh*SH - padT + KH, IH + padB); + auto iw_end = nstl::min(ow*SW - padL + KW, IW + padR); + + auto num_summands = (alg == alg_kind::pooling_avg_include_padding) ? KD * KW * KH : (id_end - id_start) * (ih_end - ih_start) * (iw_end - iw_start); float d_val = 0; + id_start = nstl::max(id_start, dim_t(0)); + ih_start = nstl::max(ih_start, dim_t(0)); + iw_start = nstl::max(iw_start, dim_t(0)); + + id_end = nstl::min(id_end, ID); + ih_end = nstl::min(ih_end, IH); + iw_end = nstl::min(iw_end, IW); + + if (alg == alg_kind::pooling_avg_exclude_padding) + num_summands = (id_end - id_start)*(ih_end - ih_start)*(iw_end - iw_start); + if (num_summands == 0) return d_val; + + const auto src_offset = (size_t)IW * IH * ID * C * mb + (size_t)IW * IH * ID * c + (size_t)iw_start; + const auto IWH = (size_t)IW * IH; + const dim_t iw_range = iw_end - iw_start; + for_(dim_t id = id_start; id < id_end; ++id) - for_(dim_t ih = ih_start; ih < ih_end; ++ih) - for (dim_t iw = iw_start; iw < iw_end; ++iw) { - const auto src_offset = (size_t)IW * IH * ID * C * mb - + (size_t)IW * IH * ID * c + (size_t)IW * IH * id - + (size_t)IW * ih + (size_t)iw; - d_val += src[src_offset]; + for_(dim_t ih = ih_start; ih < ih_end; ++ih) { + auto local_src_offset = src_offset + IWH * id + (size_t) IW * ih; + const auto tmp_src = &src[local_src_offset]; + for (dim_t iw = 0; iw < iw_range; ++iw) { + d_val += tmp_src[iw]; + } } return d_val / num_summands; }; if (alg == alg_kind::pooling_max) { - parallel_nd(MB, C, OD, OH, OW, - [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) { - const size_t dst_offset = (size_t)OW * OH * OD * C * mb - + (size_t)OW * OH * OD * c + (size_t)OW * OH * od - + (size_t)OW * oh + (size_t)ow; - data_t *d = &dst[dst_offset]; - d[0] = numeric_limits::lowest(); - set_ws(mb, c, od, oh, ow, 0); - ker_max(d, mb, c, od, oh, ow); - - ref_post_ops_t::args_t args; - args.ctx = &ctx; - args.l_offset = dst_offset; - args.dst_md = pd()->dst_md(); - ref_post_ops_.execute(dst[dst_offset], args); - dst[dst_offset] - = saturate_and_round(dst[dst_offset]); - }); + if (do_post_ops) { + parallel_nd(MB, C, OD, OH, OW, + [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) { + const size_t dst_offset = (size_t)OW * OH * OD * C * mb + + (size_t)OW * OH * OD * c + (size_t)OW * OH * od + + (size_t)OW * oh + (size_t)ow; + data_t *d = &dst[dst_offset]; + d[0] = numeric_limits::lowest(); + set_ws(mb, c, od, oh, ow, 0); + ker_max(d, mb, c, od, oh, ow); + + ref_post_ops_t::args_t args; + args.ctx = &ctx; + args.l_offset = dst_offset; + args.dst_md = pd()->dst_md(); + ref_post_ops_.execute(dst[dst_offset], args); + dst[dst_offset] + = saturate_and_round(dst[dst_offset]); + }); + } else { + parallel_nd(MB, C, OD, OH, OW, + [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) { + const size_t dst_offset = (size_t)OW * OH * OD * C * mb + + (size_t)OW * OH * OD * c + (size_t)OW * OH * od + + (size_t)OW * oh + (size_t)ow; + data_t *d = &dst[dst_offset]; + d[0] = numeric_limits::lowest(); + set_ws(mb, c, od, oh, ow, 0); + ker_max(d, mb, c, od, oh, ow); + + dst[dst_offset] + = saturate_and_round(dst[dst_offset]); + }); + } } else { - parallel_nd(MB, C, OD, OH, OW, - [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) { - const size_t dst_offset = (size_t)OW * OH * OD * C * mb - + (size_t)OW * OH * OD * c + (size_t)OW * OH * od - + (size_t)OW * oh + (size_t)ow; - data_t *d = &dst[dst_offset]; - d[0] = 0; - auto res = ker_avg(d, mb, c, od, oh, ow); - - ref_post_ops_t::args_t args; - args.ctx = &ctx; - args.l_offset = dst_offset; - args.dst_md = pd()->dst_md(); - ref_post_ops_.execute(res, args); - d[0] = saturate_and_round(res); - }); + if (do_post_ops) { + parallel_nd(MB, C, OD, OH, OW, + [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) { + const size_t dst_offset = (size_t)OW * OH * OD * C * mb + + (size_t)OW * OH * OD * c + (size_t)OW * OH * od + + (size_t)OW * oh + (size_t)ow; + data_t *d = &dst[dst_offset]; + d[0] = 0; + auto res = ker_avg(d, mb, c, od, oh, ow); + + ref_post_ops_t::args_t args; + args.ctx = &ctx; + args.l_offset = dst_offset; + args.dst_md = pd()->dst_md(); + ref_post_ops_.execute(res, args); + d[0] = saturate_and_round(res); + }); + } else { + parallel_nd(MB, C, OD, OH, OW, + [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) { + const size_t dst_offset = (size_t)OW * OH * OD * C * mb + + (size_t)OW * OH * OD * c + (size_t)OW * OH * od + + (size_t)OW * oh + (size_t)ow; + data_t *d = &dst[dst_offset]; + d[0] = 0; + d[0] = saturate_and_round(ker_avg(d, mb, c, od, oh, ow)); + }); + } } return status::success; @@ -189,6 +250,8 @@ status_t nchw_pooling_fwd_t::execute_forward( auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE); memory_desc_wrapper dst_d(pd()->dst_md()); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + auto scratchpad = ctx.get_scratchpad_grantor(); float *bf16cvt_wsp = scratchpad.template get( memory_tracking::names::key_pool_src_bf16cvt); @@ -196,7 +259,6 @@ status_t nchw_pooling_fwd_t::execute_forward( const memory_desc_wrapper ws_d(pd()->workspace_md()); const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef; - const dim_t MB = pd()->MB(); const dim_t C = pd()->OC(); const dim_t OD = pd()->OD(); const dim_t OH = pd()->OH(); @@ -213,15 +275,16 @@ status_t nchw_pooling_fwd_t::execute_forward( const dim_t padF = pd()->padFront(); const dim_t padT = pd()->padT(); const dim_t padL = pd()->padL(); + const dim_t padB = pd()->padB(); + const dim_t padR = pd()->padR(); + const dim_t padBack = pd()->padBack(); const size_t simd_w = 16; const size_t src_size = MB * C * ID * IH * IW; const size_t blocked_size = src_size / simd_w; const size_t tail_size = src_size % simd_w; - auto apply_offset = [=](int index, int offset) { - return (index > offset) ? index - offset : 0; - }; + const bool do_post_ops = pd()->attr()->post_ops_.len() > 0; auto set_ws = [=](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow, dim_t value) { @@ -242,6 +305,12 @@ status_t nchw_pooling_fwd_t::execute_forward( auto ker_max = [=](float *d, dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) { + bool is_initialized = false; + + const auto src_offset = (size_t)IW * IH * ID * C * mb + (size_t)IW * IH * ID * c; + const auto local_src = &bf16cvt_wsp[src_offset]; + const auto IWH = (size_t)IW * IH; + for_(dim_t kd = 0; kd < KD; ++kd) for_(dim_t kh = 0; kh < KH; ++kh) for (dim_t kw = 0; kw < KW; ++kw) { @@ -253,39 +322,57 @@ status_t nchw_pooling_fwd_t::execute_forward( if (ih < 0 || ih >= IH) continue; if (iw < 0 || iw >= IW) continue; - auto src_offset = (size_t)IW * IH * ID * C * mb - + (size_t)IW * IH * ID * c + (size_t)IW * IH * id - + (size_t)IW * ih + (size_t)iw; - auto &s = bf16cvt_wsp[src_offset]; + const auto local_src_offset = IWH * id + (size_t)IW * ih + (size_t)iw; + const auto s = local_src[local_src_offset]; - if (s > d[0]) { + if (!is_initialized) { d[0] = s; - set_ws(mb, c, od, oh, ow, kd * KH * KW + kh * KW + kw); + set_ws(mb, c, od, oh, ow, kd*KH*KW + kh*KW + kw); + is_initialized = true; + } else { + if (s > d[0]) { + d[0] = s; + set_ws(mb, c, od, oh, ow, kd * KH * KW + kh * KW + kw); + } } } }; auto ker_avg = [=](float *d, dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) { - auto id_start = apply_offset(od * SD, padF); - auto ih_start = apply_offset(oh * SH, padT); - auto iw_start = apply_offset(ow * SW, padL); - auto id_end = min(od * SD - padF + KD, ID); - auto ih_end = min(oh * SH - padT + KH, IH); - auto iw_end = min(ow * SW - padL + KW, IW); + auto id_start = od*SD - padF; + auto ih_start = oh*SH - padT; + auto iw_start = ow*SW - padL; + auto id_end = nstl::min(od*SD - padF + KD, ID + padBack); + auto ih_end = nstl::min(oh*SH - padT + KH, IH + padB); + auto iw_end = nstl::min(ow*SW - padL + KW, IW + padR); - auto num_summands = (alg == alg_kind::pooling_avg_include_padding) - ? KD * KW * KH - : (id_end - id_start) * (ih_end - ih_start) - * (iw_end - iw_start); + // case alg == pooling_avg_include_padding + auto num_summands = (id_end - id_start)*(ih_end - ih_start)*(iw_end - iw_start); + + id_start = nstl::max(id_start, dim_t(0)); + ih_start = nstl::max(ih_start, dim_t(0)); + iw_start = nstl::max(iw_start, dim_t(0)); + + id_end = nstl::min(id_end, ID); + ih_end = nstl::min(ih_end, IH); + iw_end = nstl::min(iw_end, IW); + + if (alg == alg_kind::pooling_avg_exclude_padding) + num_summands = (id_end - id_start)*(ih_end - ih_start)*(iw_end - iw_start); + if (num_summands == 0) return; + + const auto src_offset = (size_t)IW * IH * ID * C * mb + (size_t)IW * IH * ID * c + (size_t)iw_start; + const auto IWH = (size_t)IW * IH; + const dim_t iw_range = iw_end - iw_start; for_(dim_t id = id_start; id < id_end; ++id) - for_(dim_t ih = ih_start; ih < ih_end; ++ih) - for (dim_t iw = iw_start; iw < iw_end; ++iw) { - auto src_offset = (size_t)IW * IH * ID * C * mb - + (size_t)IW * IH * ID * c + (size_t)IW * IH * id - + (size_t)IW * ih + (size_t)iw; - d[0] += bf16cvt_wsp[src_offset]; + for_(dim_t ih = ih_start; ih < ih_end; ++ih) { + auto local_src_offset = src_offset + IWH * id + (size_t)IW * ih; + const auto tmp_src = &bf16cvt_wsp[local_src_offset]; + for (dim_t iw = 0; iw < iw_range; ++iw) { + d[0] += tmp_src[iw]; + } } d[0] = out_round((float)d[0] / num_summands); @@ -298,40 +385,68 @@ status_t nchw_pooling_fwd_t::execute_forward( cvt_bfloat16_to_float(&bf16cvt_wsp[blocked_size * simd_w], &src[blocked_size * simd_w], tail_size); if (alg == alg_kind::pooling_max) { - parallel_nd(MB, C, OD, OH, OW, - [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) { - size_t dst_offset = (size_t)OW * OH * OD * C * mb - + (size_t)OW * OH * OD * c + (size_t)OW * OH * od - + (size_t)OW * oh + (size_t)ow; - float d_fp32 = numeric_limits::lowest(); - - set_ws(mb, c, od, oh, ow, 0); - - ker_max(&d_fp32, mb, c, od, oh, ow); - - ref_post_ops_t::args_t args; - args.ctx = &ctx; - args.l_offset = dst_offset; - args.dst_md = pd()->dst_md(); - ref_post_ops_.execute(d_fp32, args); - - dst[dst_offset] = static_cast(d_fp32); - }); + if (do_post_ops) { + parallel_nd(MB, C, OD, OH, OW, + [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) { + size_t dst_offset = (size_t)OW * OH * OD * C * mb + + (size_t)OW * OH * OD * c + (size_t)OW * OH * od + + (size_t)OW * oh + (size_t)ow; + float d_fp32 = numeric_limits::lowest(); + + set_ws(mb, c, od, oh, ow, 0); + + ker_max(&d_fp32, mb, c, od, oh, ow); + + ref_post_ops_t::args_t args; + args.ctx = &ctx; + args.l_offset = dst_offset; + args.dst_md = pd()->dst_md(); + ref_post_ops_.execute(d_fp32, args); + + dst[dst_offset] = static_cast(d_fp32); + }); + } else { + parallel_nd(MB, C, OD, OH, OW, + [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) { + size_t dst_offset = (size_t)OW * OH * OD * C * mb + + (size_t)OW * OH * OD * c + (size_t)OW * OH * od + + (size_t)OW * oh + (size_t)ow; + float d_fp32 = numeric_limits::lowest(); + + set_ws(mb, c, od, oh, ow, 0); + + ker_max(&d_fp32, mb, c, od, oh, ow); + + dst[dst_offset] = static_cast(d_fp32); + }); + } } else { - parallel_nd(MB, C, OD, OH, OW, - [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) { - size_t dst_offset = (size_t)OW * OH * OD * C * mb - + (size_t)OW * OH * OD * c + (size_t)OW * OH * od - + (size_t)OW * oh + (size_t)ow; - float d_fp32 = 0.0f; - ker_avg(&d_fp32, mb, c, od, oh, ow); - ref_post_ops_t::args_t args; - args.ctx = &ctx; - args.l_offset = dst_offset; - args.dst_md = pd()->dst_md(); - ref_post_ops_.execute(d_fp32, args); - dst[dst_offset] = static_cast(d_fp32); - }); + if (do_post_ops) { + parallel_nd(MB, C, OD, OH, OW, + [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) { + size_t dst_offset = (size_t)OW * OH * OD * C * mb + + (size_t)OW * OH * OD * c + (size_t)OW * OH * od + + (size_t)OW * oh + (size_t)ow; + float d_fp32 = 0.0f; + ker_avg(&d_fp32, mb, c, od, oh, ow); + ref_post_ops_t::args_t args; + args.ctx = &ctx; + args.l_offset = dst_offset; + args.dst_md = pd()->dst_md(); + ref_post_ops_.execute(d_fp32, args); + dst[dst_offset] = static_cast(d_fp32); + }); + } else { + parallel_nd(MB, C, OD, OH, OW, + [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) { + size_t dst_offset = (size_t)OW * OH * OD * C * mb + + (size_t)OW * OH * OD * c + (size_t)OW * OH * od + + (size_t)OW * oh + (size_t)ow; + float d_fp32 = 0.0f; + ker_avg(&d_fp32, mb, c, od, oh, ow); + dst[dst_offset] = static_cast(d_fp32); + }); + } } return status::success; diff --git a/src/cpu/nchw_pooling.hpp b/src/cpu/nchw_pooling.hpp index 7f699aa05bb..b970c0365a8 100644 --- a/src/cpu/nchw_pooling.hpp +++ b/src/cpu/nchw_pooling.hpp @@ -47,8 +47,8 @@ struct nchw_pooling_fwd_t : public primitive_t { const bool ok = is_fwd() && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, - alg_kind::pooling_avg_include_padding, alg_kind::pooling_avg_exclude_padding) + && memory_desc_wrapper(dst_md()).is_dense(false) && utils::everyone_is( d_type, src_md()->data_type, dst_md()->data_type) && platform::has_data_type_support(d_type) diff --git a/src/cpu/nhwc_pooling.cpp b/src/cpu/nhwc_pooling.cpp index 8a59f03e64b..7c2869ee5a2 100644 --- a/src/cpu/nhwc_pooling.cpp +++ b/src/cpu/nhwc_pooling.cpp @@ -163,11 +163,12 @@ status_t nhwc_pooling_fwd_t::execute_forward( auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const memory_desc_wrapper MEM_D(src)(pd()->src_md()); const memory_desc_wrapper MEM_D(dst)(pd()->dst_md()); const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md()); - const dim_t MB = pd()->MB(); const dim_t OC = pd()->OC(); const dim_t OD = pd()->OD(); const dim_t OH = pd()->OH(); @@ -324,6 +325,8 @@ status_t nhwc_pooling_fwd_t::execute_forward( auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + auto scratchpad = ctx.get_scratchpad_grantor(); float *const bf16cvt_src_wsp = scratchpad.template get( memory_tracking::names::key_pool_src_bf16cvt); @@ -334,7 +337,6 @@ status_t nhwc_pooling_fwd_t::execute_forward( const memory_desc_wrapper MEM_D(dst)(pd()->dst_md()); const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md()); - const dim_t MB = pd()->MB(); const dim_t OC = pd()->OC(); const dim_t OD = pd()->OD(); const dim_t OH = pd()->OH(); diff --git a/src/cpu/platform.cpp b/src/cpu/platform.cpp index 7d7c054c872..415bdf30c3a 100644 --- a/src/cpu/platform.cpp +++ b/src/cpu/platform.cpp @@ -15,7 +15,7 @@ * limitations under the License. *******************************************************************************/ -#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL +#if (defined(DNNL_CPU_THREADING_RUNTIME) && DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL) #include #if defined(_WIN32) diff --git a/src/cpu/primitive_attr_postops.cpp b/src/cpu/primitive_attr_postops.cpp index 7b594ab2b12..c5e704b3bbd 100644 --- a/src/cpu/primitive_attr_postops.cpp +++ b/src/cpu/primitive_attr_postops.cpp @@ -40,6 +40,7 @@ float compute_binary_scalar(alg_kind_t alg, float x, float y) { case binary_lt: return x < y; case binary_eq: return x == y; case binary_ne: return x != y; + case binary_prelu: return x >= 0 ? x : x * y; default: assert(!"not supported operation!"); return NAN; } } @@ -70,6 +71,9 @@ float compute_eltwise_scalar_fwd( case eltwise_logsigmoid: d = logsigmoid_fwd(s); break; case eltwise_mish: d = mish_fwd(s); break; case eltwise_hardswish: d = hardswish_fwd(s); break; + case eltwise_hsigmoid: d = hsigmoid_fwd(s); break; + case eltwise_round_half_away_from_zero: d = round_half_away_from_zero_fwd(s); break; + case eltwise_round_half_to_even: d = round_half_to_even_fwd(s); break; case eltwise_relu_use_dst_for_bwd: d = relu_fwd(s, alpha); break; case eltwise_tanh_use_dst_for_bwd: d = tanh_fwd(s); break; case eltwise_elu_use_dst_for_bwd: d = elu_fwd(s, alpha); break; @@ -136,7 +140,7 @@ ref_binary_scalar_t::ref_binary_scalar_t(alg_kind_t alg) : alg_(alg) { alg_kind::binary_min, alg_kind::binary_mul, alg_kind::binary_div, alg_kind::binary_sub, alg_kind::binary_ge, alg_kind::binary_gt, alg_kind::binary_le, alg_kind::binary_lt, alg_kind::binary_eq, - alg_kind::binary_ne)); + alg_kind::binary_ne, alg_kind::binary_prelu)); } ref_binary_scalar_t::ref_binary_scalar_t( @@ -156,6 +160,7 @@ ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t( eltwise_mish, eltwise_logistic, eltwise_exp, eltwise_gelu_tanh, eltwise_swish, eltwise_log, eltwise_clip, eltwise_clip_v2, eltwise_pow, eltwise_gelu_erf, eltwise_round, eltwise_hardswish, + eltwise_hsigmoid, eltwise_round_half_away_from_zero, eltwise_round_half_to_even, eltwise_relu_use_dst_for_bwd, eltwise_tanh_use_dst_for_bwd, eltwise_elu_use_dst_for_bwd, eltwise_sqrt_use_dst_for_bwd, eltwise_logistic_use_dst_for_bwd, eltwise_exp_use_dst_for_bwd, @@ -179,6 +184,8 @@ ref_post_ops_t::ref_post_ops_t(const post_ops_t &po, bool skip_sum) eltwise_po_.emplace_back(e.eltwise); } else if (po_.contain(primitive_kind::binary, idx)) { binary_po_.emplace_back(e.binary); + } else if (po_.contain(primitive_kind::depthwise, idx)) { + depthwise_po_.emplace_back(e.depthwise.alg); } } } @@ -247,11 +254,12 @@ dim_t get_binary_src1_off(const memory_desc_t &src1_md, const dim_t l_offset, } // namespace -status_t ref_post_ops_t::execute(float &res, const args_t &args) const { +status_t ref_post_ops_t::execute(float &res, const args_t &args, const size_t oc) const { if (po_.len() == 0) return status::success; auto it_eltwise_po = eltwise_po_.begin(); auto it_binary_po = binary_po_.begin(); + auto it_depthwise_po = depthwise_po_.begin(); for (auto idx = 0; idx < po_.len(); ++idx) { const auto &e = po_.entry_[idx]; switch (e.kind) { @@ -299,6 +307,46 @@ status_t ref_post_ops_t::execute(float &res, const args_t &args) const { const auto &weights_value = prelu_weights[off]; res = weights_value * res; } break; + case primitive_kind::depthwise: { + const exec_ctx_t &ctx = *args.ctx; + auto depthwise_base = CTX_IN_MEM(const float *, (DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1)); + auto depthwise_weights = depthwise_base + e.depthwise.offset[e.depthwise.scales]; + auto depthwise_bias = depthwise_base + e.depthwise.offset[e.depthwise.shifts]; + + res = it_depthwise_po->compute_scalar(res, depthwise_weights + oc, depthwise_bias + oc); + + ++it_depthwise_po; + } break; + case primitive_kind::quantization: { + bool do_dequantization = e.quantization.alg == alg_kind::quantization_quantize_dequantize; + bool do_rounding = do_dequantization || args.dst_md->data_type == dnnl_f32 || idx != po_.len() - 1; + + auto quant = e.quantization; + const exec_ctx_t &ctx = *args.ctx; + auto quantization_base = CTX_IN_MEM(const float *, (DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1)); + const auto pcl = quantization_base + quant.offset[quant.crop_low]; + const auto pch = quantization_base + quant.offset[quant.crop_high]; + const auto pisc = quantization_base + quant.offset[quant.inp_scale]; + const auto pish = quantization_base + quant.offset[quant.inp_shift]; + const auto posc = quantization_base + quant.offset[quant.output_scale]; + const auto posh = quantization_base + quant.offset[quant.output_shift]; + + int cl_idx = !quant.per_channel[quant.crop_low] ? 0 : oc; + int ch_idx = !quant.per_channel[quant.crop_high] ? 0 : oc; + int isc_idx = !quant.per_channel[quant.inp_scale] ? 0 : oc; + int ish_idx = !quant.per_channel[quant.inp_shift] ? 0 : oc; + int osc_idx = !quant.per_channel[quant.output_scale] ? 0 : oc; + int osh_idx = !quant.per_channel[quant.output_shift] ? 0 : oc; + + res = nstl::min(pch[ch_idx], nstl::max(pcl[cl_idx], res)); + res = res * pisc[isc_idx] + pish[ish_idx]; + + if (do_rounding) + res = roundf(res); + + if (do_dequantization) + res = res * posc[osc_idx] + posh[osh_idx]; + } break; default: assert(!"unsupported post op primitive kind!"); } } diff --git a/src/cpu/primitive_attr_postops.hpp b/src/cpu/primitive_attr_postops.hpp index 68b30e82a62..7103b8f1f3a 100644 --- a/src/cpu/primitive_attr_postops.hpp +++ b/src/cpu/primitive_attr_postops.hpp @@ -22,6 +22,8 @@ #include "common/primitive.hpp" #include "common/primitive_attr.hpp" +#include "ref_depthwise_injector.hpp" + namespace dnnl { namespace impl { namespace cpu { @@ -69,7 +71,7 @@ struct ref_post_ops_t { virtual ~ref_post_ops_t() = default; - status_t execute(float &res, const args_t &args = args_t()) const; + status_t execute(float &res, const args_t &args = args_t(), const size_t oc = 0) const; private: const post_ops_t &po_; @@ -79,6 +81,7 @@ struct ref_post_ops_t { std::vector eltwise_po_; std::vector binary_po_; + std::vector depthwise_po_; }; } // namespace cpu diff --git a/src/cpu/ref_batch_normalization.cpp b/src/cpu/ref_batch_normalization.cpp index 0f8f22d8e79..3a8810dc4d8 100644 --- a/src/cpu/ref_batch_normalization.cpp +++ b/src/cpu/ref_batch_normalization.cpp @@ -94,8 +94,10 @@ status_t ref_batch_normalization_fwd_t::execute_forward( auto ws = CTX_OUT_CLEAN_MEM(uint8_t *, DNNL_ARG_WORKSPACE, status); CHECK(status); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const auto ndims = data_d.ndims(); - const auto N = pd()->MB(); + const auto N = MB; const auto C = pd()->C(); const auto D = pd()->D(); const auto H = pd()->H(); diff --git a/src/cpu/ref_convolution.cpp b/src/cpu/ref_convolution.cpp index b97c7b76942..be73f04bc45 100644 --- a/src/cpu/ref_convolution.cpp +++ b/src/cpu/ref_convolution.cpp @@ -38,6 +38,8 @@ status_t ref_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { auto dst = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DST, status); CHECK(status); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const memory_desc_wrapper src_d(pd()->src_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); @@ -46,7 +48,6 @@ status_t ref_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { const bool with_groups = pd()->with_groups(); const auto G = pd()->G(); - const auto MB = pd()->MB(); const auto OD = pd()->OD(); const auto OH = pd()->OH(); const auto OW = pd()->OW(); @@ -227,6 +228,8 @@ status_t ref_convolution_bwd_data_t::execute_backward_data( auto diff_src = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DIFF_SRC, status); CHECK(status); + auto MB = CTX_IN_BATCH(DNNL_ARG_DIFF_DST); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); @@ -234,7 +237,6 @@ status_t ref_convolution_bwd_data_t::execute_backward_data( const bool with_groups = pd()->with_groups(); const auto G = pd()->G(); - const auto MB = pd()->MB(); const auto OD = pd()->OD(); const auto OH = pd()->OH(); const auto OW = pd()->OW(); @@ -387,6 +389,8 @@ status_t ref_convolution_bwd_data_t::execute_backward_data( return ds; }; + const auto &p = pd()->attr()->post_ops_; + parallel_nd(G, MB, IC, ID, IH, IW, [&](dim_t g, dim_t mb, dim_t ic, dim_t id, dim_t ih, dim_t iw) { float ds = 0; @@ -396,6 +400,21 @@ status_t ref_convolution_bwd_data_t::execute_backward_data( else ds += ker(g, mb, ic, id, ih, iw); + size_t post_ops_data_idx = 0; + int depthwise_inj_idx = 0; + for (int i = 0; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + auto depthwise_base = CTX_IN_MEM(const float *, (DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1)); + auto depthwise_weights = depthwise_base + post_op.depthwise.offset[post_op.depthwise.scales]; + auto depthwise_bias = depthwise_base + post_op.depthwise.offset[post_op.depthwise.shifts]; + + ds = depthwise_injectors[depthwise_inj_idx]->compute_scalar(ds, depthwise_weights + g * IC + ic, depthwise_bias + g * IC + ic); + post_ops_data_idx++; + depthwise_inj_idx++; + } + } + const auto diff_src_off = ref_conv_utils::get_data_off( diff_src_d, ndims, mb, g * IC + ic, id, ih, iw); io::store_float_value( diff --git a/src/cpu/ref_convolution.hpp b/src/cpu/ref_convolution.hpp index fa4f7fed8bd..a10bb92d546 100644 --- a/src/cpu/ref_convolution.hpp +++ b/src/cpu/ref_convolution.hpp @@ -27,6 +27,8 @@ #include "cpu/cpu_convolution_pd.hpp" #include "cpu/primitive_attr_postops.hpp" +#include "ref_depthwise_injector.hpp" + namespace dnnl { namespace impl { namespace cpu { @@ -125,7 +127,9 @@ struct ref_convolution_bwd_data_t : public primitive_t { && utils::one_of(diff_dst_type, f32, bf16) && diff_dst_type == wei_type && IMPLICATION(diff_dst_type == f32, diff_src_type == f32) - && set_default_formats() && attr()->has_default_values(); + && set_default_formats() + && attr()->has_default_values(primitive_attr_t::skip_mask_t::post_ops) + && is_supported_post_ops(); return ok ? status::success : status::unimplemented; } @@ -139,9 +143,41 @@ struct ref_convolution_bwd_data_t : public primitive_t { : utils::pick(ndims() - 3, oiw, oihw, oidhw); return set_default_formats_common(dat_tag, wei_tag, dat_tag); } + + bool is_supported_post_ops() const { + const auto &p = this->attr()->post_ops_; + if (p.len() > 1) + return false; + + auto all_post_ops_supported = [&]() { + bool ok = true; + + for (int i = 0; i < p.len(); i++) { + ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::depthwise); + } + return ok; + }; + + return all_post_ops_supported(); + } }; - ref_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {} + ref_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) { + const auto &post_ops = pd()->attr()->post_ops_; + + for (int i = 0; i < post_ops.len(); i++) { + auto &post_op = post_ops.entry_[i]; + if (post_op.is_depthwise()) { + depthwise_injectors.push_back(new ref_depthwise_scalar_fwd_t(post_op.depthwise.alg)); + } + } + } + + ~ref_convolution_bwd_data_t() { + for (auto inj : depthwise_injectors) + delete inj; + depthwise_injectors.clear(); + } status_t execute(const exec_ctx_t &ctx) const override { return execute_backward_data(ctx); @@ -150,6 +186,8 @@ struct ref_convolution_bwd_data_t : public primitive_t { private: status_t execute_backward_data(const exec_ctx_t &ctx) const; const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + + nstl::vector depthwise_injectors; }; struct ref_convolution_bwd_weights_t : public primitive_t { diff --git a/src/cpu/ref_deconvolution.cpp b/src/cpu/ref_deconvolution.cpp index 8a6bf1f1204..0296b3d9626 100644 --- a/src/cpu/ref_deconvolution.cpp +++ b/src/cpu/ref_deconvolution.cpp @@ -37,8 +37,9 @@ void ref_deconvolution_fwd_t::compute_fwd_bias_common(const exec_ctx_t &ctx, const memory_desc_wrapper dst_d(pd()->dst_md()); const memory_desc_wrapper bias_d(pd()->weights_md(1)); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const auto G = pd()->G(); - const auto MB = pd()->MB(); const auto OH = pd()->OH(); const auto OW = pd()->OW(); const auto OD = pd()->OD(); @@ -64,7 +65,8 @@ void ref_deconvolution_fwd_t::compute_fwd_bias_ncdhw(const exec_ctx_t &ctx, const memory_desc_wrapper dst_d(pd()->dst_md()); const memory_desc_wrapper bias_d(pd()->weights_md(1)); - const auto MB = pd()->MB(); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const auto OC = pd()->OC(); const auto SP = pd()->OW() * pd()->OH() * pd()->OD(); @@ -87,7 +89,8 @@ void ref_deconvolution_fwd_t::compute_fwd_bias_ndhwc(const exec_ctx_t &ctx, const memory_desc_wrapper dst_d(pd()->dst_md()); const memory_desc_wrapper bias_d(pd()->weights_md(1)); - const auto MB = pd()->MB(); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const auto OC = pd()->OC(); const auto SP = pd()->OW() * pd()->OH() * pd()->OD(); @@ -111,7 +114,8 @@ void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc(const exec_ctx_t &ctx, const memory_desc_wrapper dst_d(pd()->dst_md()); const memory_desc_wrapper bias_d(pd()->weights_md(1)); - const auto MB = pd()->MB(); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const auto OC = pd()->OC(); const auto SP = pd()->OW() * pd()->OH() * pd()->OD(); const auto stride_mb = dst_d.blocking_desc().strides[0]; @@ -178,7 +182,7 @@ status_t ref_deconvolution_fwd_t::compute_ref_attrs(const exec_ctx_t &ctx, const memory_desc_wrapper dst_d(pd()->dst_md()); - const auto MB = pd()->MB(); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); const auto OH = pd()->OH(); const auto OW = pd()->OW(); const auto OD = pd()->OD(); diff --git a/src/cpu/ref_depthwise_injector.cpp b/src/cpu/ref_depthwise_injector.cpp new file mode 100644 index 00000000000..86c42f17c96 --- /dev/null +++ b/src/cpu/ref_depthwise_injector.cpp @@ -0,0 +1,79 @@ +/******************************************************************************* +* Copyright 2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "ref_depthwise_injector.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { + +using namespace alg_kind; +using namespace math; + +template inline T scale_shift_fwd(T s_val, T w_val, T b_val) { + return s_val*w_val + b_val; +} + +template inline T prelu_fwd(T s_val, T w_val) { + return s_val >= 0 ? s_val : s_val*w_val; +} + +union float_raw { + float f; + unsigned short i[2]; +}; + +static float bf16tof32(bfloat16_t bf16) { + union float_raw t = { 0 }; + t.i[1] = bf16; + t.i[0] = 0; + return t.f; +} + +static bfloat16_t f32tobf16(float f32) { + union float_raw t = { 0 }; + t.f = f32; + return t.i[1]; +} + +inline bfloat16_t bf16_scale_shift_fwd(bfloat16_t s_val, bfloat16_t w_val, bfloat16_t b_val) { + return f32tobf16(bf16tof32(s_val) * bf16tof32(w_val) + bf16tof32(b_val)); +} + +inline bfloat16_t bf16_prelu_fwd(bfloat16_t s_val, bfloat16_t w_val) { + return s_val >= 0 ? s_val : f32tobf16(bf16tof32(s_val) * bf16tof32(w_val)); +} + +ref_depthwise_scalar_fwd_t::ref_depthwise_scalar_fwd_t(const alg_kind_t alg_) + : alg(alg_) { + using namespace alg_kind; + + assert(utils::one_of(alg, depthwise_scale_shift, depthwise_prelu)); +} + +float ref_depthwise_scalar_fwd_t::compute_scalar(float s, const float* weights, const float* bias) const { + switch (alg) { + case depthwise_scale_shift: return scale_shift_fwd(s, *weights, *bias); + case depthwise_prelu: return prelu_fwd(s, *weights); + default: assert(!"unknown depthwise alg_kind"); + } + + return 0.0f; +} + +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/ref_depthwise_injector.hpp b/src/cpu/ref_depthwise_injector.hpp new file mode 100644 index 00000000000..1a56e28cdc2 --- /dev/null +++ b/src/cpu/ref_depthwise_injector.hpp @@ -0,0 +1,40 @@ +/******************************************************************************* +* Copyright 2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef REF_DEPTHWISE_INJECTOR_HPP +#define REF_DEPTHWISE_INJECTOR_HPP + +#include "common/primitive.hpp" +#include "common/primitive_attr.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { + +struct ref_depthwise_scalar_fwd_t { +public: + explicit ref_depthwise_scalar_fwd_t(alg_kind_t alg); + float compute_scalar(float s, const float* weights, const float* bias) const; + +private: + alg_kind_t alg; +}; + +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/cpu/ref_inner_product.cpp b/src/cpu/ref_inner_product.cpp index cf6bc4e5cdb..dc7963b9bb0 100644 --- a/src/cpu/ref_inner_product.cpp +++ b/src/cpu/ref_inner_product.cpp @@ -35,13 +35,14 @@ status_t ref_inner_product_fwd_t::execute_forward(const exec_ctx_t &ctx) const { auto dst = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DST, status); CHECK(status); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const memory_desc_wrapper src_d(pd()->src_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); const memory_desc_wrapper bias_d(pd()->weights_md(1)); const auto ndims = pd()->ndims(); - const auto MB = pd()->MB(); const auto OC = pd()->OC(); const auto IC = pd()->IC(); diff --git a/src/cpu/ref_lrn.cpp b/src/cpu/ref_lrn.cpp index 0e2faaed541..e2acc31b8c9 100644 --- a/src/cpu/ref_lrn.cpp +++ b/src/cpu/ref_lrn.cpp @@ -67,6 +67,8 @@ status_t ref_lrn_fwd_t::execute_forward(const exec_ctx_t &ctx) const { auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status); CHECK(status); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const memory_desc_wrapper data_d(pd()->src_md()); const dim_t C = pd()->C(); @@ -143,7 +145,6 @@ status_t ref_lrn_fwd_t::execute_forward(const exec_ctx_t &ctx) const { d[0] = static_cast(s * fast_negative_powf(sum, beta)); }; - const dim_t MB = pd()->MB(); if (tag == nChw16c || tag == nChw8c) { parallel_nd(MB, utils::div_up(C, blksize), H, W, [&](dim_t mb, dim_t c_blk, dim_t h, dim_t w) { diff --git a/src/cpu/ref_pooling.cpp b/src/cpu/ref_pooling.cpp index fb84b56dbbe..5ed32355cbf 100644 --- a/src/cpu/ref_pooling.cpp +++ b/src/cpu/ref_pooling.cpp @@ -43,16 +43,15 @@ static inline dim_t get_offset(const memory_desc_wrapper &mdw, dim_t n, dim_t c, using namespace nstl; -template -status_t ref_pooling_fwd_t::execute_forward( +template +status_t ref_pooling_fwd_t::execute_forward( const exec_ctx_t &ctx) const { - status_t status = status::success; - auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); - auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status); - CHECK(status); - auto ws = CTX_OUT_CLEAN_MEM(unsigned char *, DNNL_ARG_WORKSPACE, status); - CHECK(status); + auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); + auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST); + auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE); + + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); const memory_desc_wrapper src_d(pd()->src_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); @@ -62,7 +61,6 @@ status_t ref_pooling_fwd_t::execute_forward( if (ws) assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); const auto alg = pd()->desc()->alg_kind; - const dim_t MB = pd()->MB(); const dim_t OC = pd()->OC(); const dim_t OD = pd()->OD(); const dim_t OH = pd()->OH(); @@ -82,6 +80,9 @@ status_t ref_pooling_fwd_t::execute_forward( const dim_t DD = pd()->KDD(); const dim_t DH = pd()->KDH(); const dim_t DW = pd()->KDW(); + const dim_t padB = pd()->padB(); + const dim_t padR = pd()->padR(); + const dim_t padBack = pd()->padBack(); auto set_ws = [=](dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow, dim_t value) { @@ -99,6 +100,7 @@ status_t ref_pooling_fwd_t::execute_forward( auto ker_max = [=](float &d, dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow) { + bool is_initialized = false; set_ws(mb, oc, od, oh, ow, 0); for (dim_t kd = 0; kd < KD; ++kd) { const dim_t id = od * SD - padF + kd * (DD + 1); @@ -112,9 +114,15 @@ status_t ref_pooling_fwd_t::execute_forward( const auto off = get_offset(src_d, mb, oc, id, ih, iw); auto s = src[off]; - if (s > d) { + if (!is_initialized) { d = s; - set_ws(mb, oc, od, oh, ow, (kd * KH + kh) * KW + kw); + set_ws(mb, oc, od, oh, ow, kd * KH * KW + kh*KW + kw); + is_initialized = true; + } else { + if (s > d) { + d = s; + set_ws(mb, oc, od, oh, ow, kd * KH * KW + kh * KW + kw); + } } } } @@ -138,17 +146,25 @@ status_t ref_pooling_fwd_t::execute_forward( } } } - int num_summands; - if (alg == alg_kind::pooling_avg_include_padding) - num_summands = KW * KH * KD; - else { - auto id_start = od * SD - padF; - auto ih_start = oh * SH - padT; - auto iw_start = ow * SW - padL; - auto id_end = od * SD - padF + (KD - 1) * DD + KD; - auto ih_end = oh * SH - padT + (KH - 1) * DH + KH; - auto iw_end = ow * SW - padL + (KW - 1) * DW + KW; + auto id_start = od*SD - padF; + auto ih_start = oh*SH - padT; + auto iw_start = ow*SW - padL; + auto id_end = nstl::min(od*SD - padF + KD, ID + padBack); + auto ih_end = nstl::min(oh*SH - padT + KH, IH + padB); + auto iw_end = nstl::min(ow*SW - padL + KW, IW + padR); + + // case alg == pooling_avg_include_padding + auto num_summands = (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start); + + id_start = nstl::max(id_start, dim_t(0)); + ih_start = nstl::max(ih_start, dim_t(0)); + iw_start = nstl::max(iw_start, dim_t(0)); + id_end = nstl::min(id_end, ID); + ih_end = nstl::min(ih_end, IH); + iw_end = nstl::min(iw_end, IW); + + if (alg == alg_kind::pooling_avg_exclude_padding) { auto id_start_excluded = id_start < 0 ? (0 - id_start - 1) / (DD + 1) + 1 : 0; auto ih_start_excluded @@ -163,16 +179,43 @@ status_t ref_pooling_fwd_t::execute_forward( = iw_end > IW ? (iw_end - IW - 1) / (DW + 1) + 1 : 0; num_summands = (KD - id_start_excluded - id_end_excluded) - * (KH - ih_start_excluded - ih_end_excluded) - * (KW - iw_start_excluded - iw_end_excluded); + * (KH - ih_start_excluded - ih_end_excluded) + * (KW - iw_start_excluded - iw_end_excluded); } + if (num_summands == 0) return; + d /= num_summands; + + const auto &p = pd()->attr()->post_ops_; + for (int i = 0; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_quantization()) { + auto quant = post_op.quantization; + auto quantization_base = CTX_IN_MEM(const float *, (DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1)); + const auto crop_low_data = quantization_base + quant.offset[quant.crop_low]; + const auto crop_high_data = quantization_base + quant.offset[quant.crop_high]; + const auto inp_scale_data = quantization_base + quant.offset[quant.inp_scale]; + const auto inp_shift_data = quantization_base + quant.offset[quant.inp_shift]; + const auto output_scale_data = quantization_base + quant.offset[quant.output_scale]; + const auto output_shift_data = quantization_base + quant.offset[quant.output_shift]; + + float cl = crop_low_data[!quant.per_channel[quant.crop_low] ? 0 : oc]; + float ch = crop_high_data[!quant.per_channel[quant.crop_high] ? 0 : oc]; + float isc = inp_scale_data[!quant.per_channel[quant.inp_scale] ? 0 : oc]; + float ish = inp_shift_data[!quant.per_channel[quant.inp_shift] ? 0 : oc]; + float osc = output_scale_data[!quant.per_channel[quant.output_scale] ? 0 : oc]; + float osh = output_shift_data[!quant.per_channel[quant.output_shift] ? 0 : oc]; + + d = nstl::min(ch, nstl::max(cl, d)); + d = d * isc + ish; + d = roundf(d); + d = d * osc + osh; + } + } }; const bool is_max_pool = alg == alg_kind::pooling_max; - float base_res - = is_max_pool ? (float)numeric_limits::lowest() : 0.f; using ker_t = std::function; ker_t kernel = is_max_pool ? (ker_t)ker_max : (ker_t)ker_avg; @@ -182,7 +225,7 @@ status_t ref_pooling_fwd_t::execute_forward( auto data_p_off = get_offset(dst_d, mb, oc, od, oh, ow); auto data_l_off = (((mb * OC + oc) * OD + od) * OH + oh) * OW + ow; - float res = base_res; + float res = 0.f; kernel(res, mb, oc, od, oh, ow); ref_post_ops_t::args_t args; @@ -191,7 +234,7 @@ status_t ref_pooling_fwd_t::execute_forward( args.dst_md = pd()->dst_md(); ref_post_ops->execute(res, args); - dst[data_p_off] = cpu::saturate_and_round(res); + dst[data_p_off] = cpu::saturate_and_round(res); }); return status::success; @@ -345,11 +388,13 @@ status_t ref_pooling_bwd_t::execute_backward( return status::success; } -template struct ref_pooling_fwd_t; -template struct ref_pooling_fwd_t; -template struct ref_pooling_fwd_t; -template struct ref_pooling_fwd_t; -template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; template struct ref_pooling_bwd_t; template struct ref_pooling_bwd_t; diff --git a/src/cpu/ref_pooling.hpp b/src/cpu/ref_pooling.hpp index 497ddd2fc98..8e4e7f75327 100644 --- a/src/cpu/ref_pooling.hpp +++ b/src/cpu/ref_pooling.hpp @@ -33,7 +33,7 @@ namespace dnnl { namespace impl { namespace cpu { -template +template struct ref_pooling_fwd_t : public primitive_t { struct pd_t : public cpu_pooling_fwd_pd_t { using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; @@ -43,13 +43,14 @@ struct ref_pooling_fwd_t : public primitive_t { status_t init(engine_t *engine) { using sm = primitive_attr_t::skip_mask_t; - bool ok = platform::has_data_type_support(data_type) + bool ok = platform::has_data_type_support(src_type) && platform::has_data_type_support(dst_type) && set_default_params() == status::success && is_fwd() - && utils::everyone_is( - data_type, src_md()->data_type, dst_md()->data_type) + && utils::everyone_is(src_type, src_md()->data_type) + && utils::everyone_is(dst_type, dst_md()->data_type) && desc()->accum_data_type == acc_type && attr()->has_default_values(sm::post_ops) - && attr_.set_default_formats(dst_md(0)) == status::success; + && attr_.set_default_formats(dst_md(0)) == status::success + && is_supported_post_ops(); if (!ok) return status::unimplemented; bool is_training = desc_.prop_kind == prop_kind::forward_training; @@ -58,6 +59,24 @@ struct ref_pooling_fwd_t : public primitive_t { return status::success; } + + virtual bool is_supported_post_ops() const { + const auto &p = this->attr()->post_ops_; + + auto all_post_ops_supported = [&]() { + bool ok = true; + + for (int i = 0; i < p.len(); i++) { + ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::quantization); + } + return ok; + }; + + return all_post_ops_supported() && + IMPLICATION(p.len() > 0, (desc()->alg_kind == dnnl_pooling_avg_include_padding || desc()->alg_kind == dnnl_pooling_avg_exclude_padding) && + src_type != data_type::bf16); + + } }; ref_pooling_fwd_t(const pd_t *apd) : primitive_t(apd) {} @@ -69,7 +88,8 @@ struct ref_pooling_fwd_t : public primitive_t { return status::success; } - using data_t = typename prec_traits::type; + using src_data_t = typename prec_traits::type; + using dst_data_t = typename prec_traits::type; using acc_data_t = typename prec_traits::type; status_t execute(const exec_ctx_t &ctx) const override { diff --git a/src/cpu/ref_softmax.cpp b/src/cpu/ref_softmax.cpp index e09fdd68b42..488a7e69976 100644 --- a/src/cpu/ref_softmax.cpp +++ b/src/cpu/ref_softmax.cpp @@ -47,8 +47,10 @@ status_t ref_softmax_fwd_t::execute_forward_dense( const auto zero_padding = has_padding && !is_inplace; const auto axis = pd()->axis(); const auto axis_blk_size = data_d.padded_dims()[axis] - data_d.dims()[axis]; + auto real_src_md = ctx.input(DNNL_ARG_SRC)->md(); + auto outer_size = utils::array_product(real_src_md->dims, axis); - parallel_nd(outer_size_, [&](dim_t ou) { + parallel_nd(outer_size, [&](dim_t ou) { const data_t *src_data = src + ou * ou_stride; data_t *dst_data = dst + ou * ou_stride; float space_max = -FLT_MAX; diff --git a/src/cpu/reorder/cpu_reorder.cpp b/src/cpu/reorder/cpu_reorder.cpp index 474876eab57..c1869fd18e3 100644 --- a/src/cpu/reorder/cpu_reorder.cpp +++ b/src/cpu/reorder/cpu_reorder.cpp @@ -22,22 +22,24 @@ namespace impl { namespace cpu { /* regular reorders */ -std::map regular_impl_list_map { +const std::map regular_impl_list_map { {{f32, bf16, 0}, ®ular_f32_bf16_impl_list_map}, {{f32, f16, 0}, ®ular_f32_f16_impl_list_map}, {{f32, f32, 0}, ®ular_f32_f32_impl_list_map}, {{f32, s32, 0}, ®ular_f32_s32_impl_list_map}, {{f32, s8, 0}, ®ular_f32_s8_impl_list_map}, {{f32, u8, 0}, ®ular_f32_u8_impl_list_map}, + {{f32, bin, 0}, ®ular_f32_bin_impl_list_map}, {{bf16, data_type::undef, 0}, ®ular_bf16_impl_list_map}, {{f16, data_type::undef, 0}, ®ular_f16_impl_list_map}, {{s32, data_type::undef, 0}, ®ular_s32_impl_list_map}, {{s8, data_type::undef, 0}, ®ular_s8_impl_list_map}, {{u8, data_type::undef, 0}, ®ular_u8_impl_list_map}, + {{bin, data_type::undef, 0}, ®ular_bin_impl_list_map}, }; /* conv reorders w/ compensation */ -std::map comp_s8s8_impl_list_map { +const std::map comp_s8s8_impl_list_map { {{f32, s8, 0}, &comp_f32_s8_impl_list_map}, {{bf16, s8, 0}, &comp_bf16_s8_impl_list_map}, {{s8, s8, 0}, &comp_s8_s8_impl_list_map}, @@ -49,16 +51,19 @@ const impl_list_item_t *cpu_engine_impl_list_t::get_reorder_implementation_list( const bool do_comp_s8s8 = dst_md->extra.flags & (memory_extra_flags::compensation_conv_s8s8 | memory_extra_flags::compensation_conv_asymmetric_src); - auto &map = do_comp_s8s8 ? comp_s8s8_impl_list_map : regular_impl_list_map; - const impl_list_map_t *p_impl_list = (const impl_list_map_t *)map[dt_pair]; + const auto &map = do_comp_s8s8 ? comp_s8s8_impl_list_map : regular_impl_list_map; static const impl_list_item_t empty_list[] = {nullptr}; - if (!p_impl_list) { + + auto iter = map.find(dt_pair); + if (iter == map.end()) { dt_pair.dst_dt = data_type::undef; - p_impl_list = (const impl_list_map_t *)map[dt_pair]; - if (!p_impl_list) return empty_list; + iter = map.find(dt_pair); + if (iter == map.end()) return empty_list; } + const impl_list_map_t *p_impl_list = (const impl_list_map_t *)iter->second; + reorder_impl_key_t key {dt_pair.src_dt, dt_pair.dst_dt, src_md->ndims}; { diff --git a/src/cpu/reorder/cpu_reorder.hpp b/src/cpu/reorder/cpu_reorder.hpp index c8fe57b67a5..b3bd2fc96c7 100644 --- a/src/cpu/reorder/cpu_reorder.hpp +++ b/src/cpu/reorder/cpu_reorder.hpp @@ -72,11 +72,13 @@ extern const impl_list_map_t regular_f32_f32_impl_list_map; extern const impl_list_map_t regular_f32_s32_impl_list_map; extern const impl_list_map_t regular_f32_s8_impl_list_map; extern const impl_list_map_t regular_f32_u8_impl_list_map; +extern const impl_list_map_t regular_f32_bin_impl_list_map; extern const impl_list_map_t regular_bf16_impl_list_map; extern const impl_list_map_t regular_f16_impl_list_map; extern const impl_list_map_t regular_s32_impl_list_map; extern const impl_list_map_t regular_s8_impl_list_map; extern const impl_list_map_t regular_u8_impl_list_map; +extern const impl_list_map_t regular_bin_impl_list_map; /* conv reorders w/ compensation */ extern const impl_list_map_t comp_f32_s8_impl_list_map; @@ -93,6 +95,10 @@ extern const impl_list_map_t comp_s8_s8_impl_list_map; REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy) \ REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy_except_dim_0) +#define REG_SR_BIDIR(idt, ifmt, odt, ofmt) \ + REG_SR(idt, ifmt, odt, ofmt, fmt_order::keep) \ + REG_SR(idt, ifmt, odt, ofmt, fmt_order::reverse) + // clang-format on #if defined(__INTEL_COMPILER) || (defined(__GNUC__) && !defined(__clang__)) diff --git a/src/cpu/reorder/cpu_reorder_regular_bin.cpp b/src/cpu/reorder/cpu_reorder_regular_bin.cpp new file mode 100644 index 00000000000..e494e377ef2 --- /dev/null +++ b/src/cpu/reorder/cpu_reorder_regular_bin.cpp @@ -0,0 +1,44 @@ +/******************************************************************************* +* Copyright 2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu/reorder/cpu_reorder.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { + +// clang-format off + +const impl_list_map_t regular_bin_impl_list_map REG_REORDER_P({ + // bin -> + {{bin, data_type::undef, 4}, { + REG_SR_DIRECT_COPY(bin, bin) + + REG_SR(bin, any, bin, OIhw8o32i, fmt_order::keep) + + REG_SR(bin, any, bin, OIhw16o32i, fmt_order::keep) + + REG_SR_BIDIR(u8, any, u8, nChw8c) + + nullptr, + }}, +}); + +// clang-format on + +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_bin.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_bin.cpp new file mode 100644 index 00000000000..79fadd2c94c --- /dev/null +++ b/src/cpu/reorder/cpu_reorder_regular_f32_bin.cpp @@ -0,0 +1,40 @@ +/******************************************************************************* +* Copyright 2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu/reorder/cpu_reorder.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { + +// clang-format off + +const impl_list_map_t regular_f32_bin_impl_list_map REG_REORDER_P({ + // f32 -> bin + {{f32, bin, 4}, { + REG_SR_BIDIR(f32, nchw, bin, nhwc) + + REG_SR_BIDIR(f32, nhwc, bin, nhwc) + + nullptr, + }}, +}); + +// clang-format on + +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/reorder/cpu_reorder_regular_u8.cpp b/src/cpu/reorder/cpu_reorder_regular_u8.cpp index e4421f124f6..e68b87ac5d7 100644 --- a/src/cpu/reorder/cpu_reorder_regular_u8.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_u8.cpp @@ -20,6 +20,13 @@ namespace dnnl { namespace impl { namespace cpu { +#ifdef __INTEL_COMPILER +/* Enable direct copy primitives for non-icc compilers, but place it after the jitted ones */ +#define REG_FAST_DIRECT_COPY_AFTER_JIT(sdt, ddt) +#else +#define REG_FAST_DIRECT_COPY_AFTER_JIT(sdt, ddt) REG_SR_DIRECT_COPY(sdt, ddt) +#endif + // clang-format off const impl_list_map_t regular_u8_impl_list_map REG_REORDER_P({ @@ -36,6 +43,13 @@ const impl_list_map_t regular_u8_impl_list_map REG_REORDER_P({ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + // Allow direct-copy primitives for non-intel compilers, but with a lower priority than the jitted impl + REG_FAST_DIRECT_COPY_AFTER_JIT(u8, f32) + REG_FAST_DIRECT_COPY_AFTER_JIT(u8, s32) + REG_FAST_DIRECT_COPY_AFTER_JIT(u8, bf16) + REG_FAST_DIRECT_COPY_AFTER_JIT(u8, s8) + REG_FAST_DIRECT_COPY_AFTER_JIT(u8, u8) + REG_SR(u8, any, f32, any, fmt_order::any, spec::reference) REG_SR(u8, any, s32, any, fmt_order::any, spec::reference) REG_SR(u8, any, bf16, any, fmt_order::any, spec::reference) diff --git a/src/cpu/reorder/simple_reorder.hpp b/src/cpu/reorder/simple_reorder.hpp index 88749511cf2..e525a73033b 100644 --- a/src/cpu/reorder/simple_reorder.hpp +++ b/src/cpu/reorder/simple_reorder.hpp @@ -464,8 +464,6 @@ struct simple_reorder_implattr()->output_scales_.mask_ + 1)); const float *scales = pd->attr()->output_scales_.scales_; @@ -540,20 +538,731 @@ struct simple_reorder_impl +struct simple_reorder_impl::type> { + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { + using namespace data_type; + + if (input_d.has_runtime_dims_or_strides()) return false; + + return order_keep && input_d.matches_tag(tag_i) + && output_d.matches_tag(tag_o) && input_d.data_type() == f32 + && output_d.data_type() == bf16 && attr->has_default_values(); + } + + static size_t get_scratchpad_size(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d) { + const dim_t blksize = 16; + return sizeof(float) * blksize * blksize * dnnl_get_max_threads(); + } + + static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { + DECLARE_COMMON_PARAMS(); + using namespace format_tag; + + static constexpr bool w_groups = tag_i == goihw; + const dim_t blksize = 16; + const int sblk = 2; + + const auto &plain_d = input_d; + const auto &dims = input_d.dims(); + const auto &pdims = output_d.padded_dims(); + + const dim_t G = w_groups ? dims[0] : 1; + const dim_t OC = dims[w_groups + 0]; + const dim_t NB_OC = pdims[w_groups + 0] / blksize; + const dim_t IC = dims[w_groups + 1]; + const dim_t NB_IC = pdims[w_groups + 1] / blksize; + const dim_t H = dims[w_groups + 2]; + const dim_t W = dims[w_groups + 3]; + + const size_t wsp_size = blksize * blksize; + float *wspace = scratchpad.template get( + memory_tracking::names::key_reorder_space); + + auto index = [&](dim_t ic, dim_t oc) -> dim_t { + if (utils::one_of(tag_o, gOIhw16i16o, OIhw16i16o)) + return (ic * blksize + oc); + else if (utils::one_of(tag_o, gOIhw8i16o2i, OIhw8i16o2i)) + return ((ic / sblk) * blksize * sblk + sblk * oc + ic % sblk); + else if (utils::one_of(tag_o, gOIhw8o16i2o, gIOhw8o16i2o, + OIhw8o16i2o, IOhw8o16i2o)) + return ((oc / sblk) * blksize * sblk + sblk * ic + oc % sblk); + else + assert(!"Invalid weight format"); + return dim_t(0); + }; + + auto ker = [&](const data_t *inp, data_t *out, + const dim_t curr_oc_block, const dim_t oc_block, + const dim_t curr_ic_block, const dim_t ic_block) { + dim_t ic = 0; + for (ic = 0; ic < curr_ic_block; ++ic) { + dim_t oc = 0; + for (oc = 0; oc < curr_oc_block; ++oc) { + const auto plain_off + = oc * plain_d.blocking_desc().strides[w_groups + 0] + + ic + * plain_d.blocking_desc() + .strides[w_groups + 1]; + out[index(ic, oc)] = inp[plain_off]; + } + for (/* continue */; oc < oc_block; ++oc) { + out[index(ic, oc)] = (data_t)0; + } + } + for (/* continue */; ic < ic_block; ++ic) { + for (dim_t oc = 0; oc < oc_block; ++oc) { + out[index(ic, oc)] = (data_t)0; + } + } + }; + + constexpr int i_mult = blksize; + constexpr int o_mult = 1; + + parallel_nd_ext(0, G, NB_OC, NB_IC, H, W, + [&](int ithr, int, dim_t g, dim_t O, dim_t I, dim_t h, + dim_t w) { + float *_wspace = wspace + wsp_size * ithr; + auto i = &input[input_d.blk_off( + g, i_mult * O, i_mult * I, h, w)]; + auto o = &output[output_d.blk_off( + g, o_mult * O, o_mult * I, h, w)]; + const dim_t oc_block = nstl::min(blksize, OC - O * blksize); + const dim_t ic_block = nstl::min(blksize, IC - I * blksize); + ker(i, _wspace, oc_block, blksize, ic_block, blksize); + cvt_float_to_bfloat16(o, _wspace, wsp_size); + }); + + return status::success; + } +}; + +template +struct simple_reorder_impl::type> { + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { + using namespace data_type; + + if (input_d.has_runtime_dims_or_strides()) return false; + + return input_d.mb_stride_relaxed_match(tag_i) + && output_d.mb_stride_relaxed_match(tag_o) + && input_d.data_type() == f32 && output_d.data_type() == bf16 + && attr->has_default_values(); + } + + static size_t get_scratchpad_size(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d) { + constexpr int ndims = tag_traits::ndims; + const size_t blksize = 16; + const size_t W = input_d.dims()[ndims - 1]; + return sizeof(float) * blksize * W * dnnl_get_max_threads(); + } - if (zero_padding_needed) { - PRAGMA_OMP_SIMD() - for (int off = g_block; off < blksize; off++) - out[off] = 0; + static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { + DECLARE_COMMON_PARAMS(); + + const dim_t blksize = 16; + const dim_t ndims = tag_traits::ndims; + + const auto &flat_d = input_d; + const auto &dims = input_d.dims(); + const auto &pdims = output_d.padded_dims(); + + const dim_t C = dims[1]; + const dim_t H = ndims == 3 ? 1 : dims[ndims - 2]; + const dim_t W = dims[ndims - 1]; + + const dim_t wsp_size = W * blksize; + float *wspace = scratchpad.template get( + memory_tracking::names::key_reorder_space); + + auto ker = [&](const data_t *i, data_t *o, + const dim_t curr_c_block, const dim_t c_block) { + for (dim_t w = 0; w < W; ++w) { + dim_t c = 0; + for (c = 0; c < curr_c_block; ++c) { + const ptrdiff_t flat_off = 0 + + c * flat_d.blocking_desc().strides[1] + + w * flat_d.blocking_desc().strides[ndims - 1]; + o[w * blksize + c] = i[flat_off]; + } + for (/* continue */; c < c_block; ++c) { + o[w * blksize + c] = (data_t)0; + } + } + }; + + constexpr int i_c_mult = blksize; + constexpr int o_c_mult = 1; + + parallel_nd_ext(0, dims[0], pdims[1] / blksize, H, + [&](int ithr, int, dim_t n, dim_t nb_c, dim_t h) { + float *_wspace = wspace + wsp_size * ithr; + auto i = &input[input_d.blk_off(n, i_c_mult * nb_c, h)]; + auto o = &output[output_d.blk_off(n, o_c_mult * nb_c, h)]; + const dim_t c_block + = nstl::min(blksize, C - nb_c * blksize); + ker(i, _wspace, c_block, blksize); + cvt_float_to_bfloat16(o, _wspace, wsp_size); + }); + + return status::success; + } +}; + +/* reorders with tail support */ + +template +struct simple_reorder_impl::type> { + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { + return simple_fmt_check(order_keep, tag_i, tag_o, input_d, output_d) + && simple_attr_check(attr, false, true); + } + + GET_SCRATCHPAD_SIZE_ZERO(); + + static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { + DECLARE_COMMON_PARAMS(); + using namespace format_tag; + + constexpr int is_1d = utils::one_of(tag_i, nCw4c, nCw8c); + constexpr int is_3d = utils::one_of(tag_i, nCdhw4c, nCdhw8c); + + constexpr dim_t blksize_i + = tag_traits::inner_blks == ib::_4b ? 4 : 8; + constexpr dim_t blksize_16 = 16; + + constexpr dim_t ic_mult = order_keep ? blksize_16 / blksize_i : 1; + constexpr dim_t oc_mult = order_keep ? 1 : blksize_16 / blksize_i; + + const auto &dims = input_d.dims(); + const auto &pdims + = order_keep ? output_d.padded_dims() : input_d.padded_dims(); + + const auto &d_i = order_keep ? input_d : output_d; + const auto stride_C_in_blk_i = d_i.blocking_desc().strides[1]; + + const dim_t C = dims[1]; + const dim_t D = is_3d ? dims[2] : 1; + const dim_t H = is_1d ? 1 : dims[2 + is_3d]; + const dim_t W = dims[3 + is_3d - is_1d]; + + auto ker = [&](const data_t *i, data_t *o, + const int block) { + const int nb = utils::div_up(block, blksize_i); + if (alpha == 1.0 && beta == 0.0) { + for (int b = 0; b < nb; ++b) { + const ptrdiff_t i_off + = b * (order_keep ? stride_C_in_blk_i : blksize_i); + const ptrdiff_t o_off + = b * (order_keep ? blksize_i : stride_C_in_blk_i); + const int block_i + = nstl::min(blksize_i, block - b * blksize_i); + for (int c = 0; c < block_i; ++c) { + o[o_off + c] = _qz_a1b0()(i[i_off + c]); } } + } else { + for (int b = 0; b < nb; ++b) { + const ptrdiff_t i_off + = b * (order_keep ? stride_C_in_blk_i : blksize_i); + const ptrdiff_t o_off + = b * (order_keep ? blksize_i : stride_C_in_blk_i); + const int block_i + = nstl::min(blksize_i, block - b * blksize_i); + for (int c = 0; c < block_i; ++c) { + o[o_off + c] = _qz()( + i[i_off + c], o[o_off + c], alpha, beta); + } + } + } + }; + +#define data_blk_off(md, n, c, d, h, w) \ + (is_1d ? (md).blk_off(n, c, w) \ + : is_3d ? (md).blk_off(n, c, d, h, w) : (md).blk_off(n, c, h, w)) + + parallel_nd(dims[0], pdims[1] / blksize_16, D, H, W, + [&](dim_t n, dim_t nb_c, dim_t d, dim_t h, dim_t w) { + auto i = &input[data_blk_off( + input_d, n, ic_mult * nb_c, d, h, w)]; + auto o = &output[data_blk_off( + output_d, n, oc_mult * nb_c, d, h, w)]; + const int block + = nstl::min(blksize_16, C - nb_c * blksize_16); + ker(i, o, block); + }); + +#undef data_blk_off + + return status::success; + } +}; + +#define PLAIN_TO_BLOCKED_IS_APPLICABLE() \ + static bool is_applicable(const memory_desc_wrapper &input_d, \ + const memory_desc_wrapper &output_d, \ + const primitive_attr_t *attr) { \ + return !input_d.has_runtime_dims_or_strides() \ + && simple_attr_check(attr, false, true) \ + && (order_keep ? output_d.matches_tag(tag_o) \ + && input_d.is_plain() \ + : input_d.matches_tag(tag_o) \ + && output_d.is_plain()); \ + } + +template +struct simple_reorder_impl::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { + return simple_fmt_check(order_keep, tag_i, tag_o, input_d, output_d) && + simple_attr_check(attr, false, false); + } + + GET_SCRATCHPAD_SIZE_ZERO(); + + static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { + DECLARE_COMMON_PARAMS(); + + const auto &dims = input_d.dims(); + const int C = dims[1]; + const int H = dims[2]; + const int W = dims[3]; + + int nbits = 8; + const int CB = utils::div_up(C, nbits); + + auto ker = [&](const data_t *i, data_t *o) { + for (int cb = 0; cb < CB; ++cb) { + uint8_t bin_val = 0x00; + for (int c = cb * nbits, shift = 0; c < std::min(C, (cb + 1) * nbits); c++, shift++) { + const ptrdiff_t flat_off = c * input_d.blocking_desc().strides[1]; + + auto bit = uint8_t((i[flat_off] > 0) ? 0x01 : 0x00); + bin_val |= (bit << shift); + } + + o[cb] = bin_val; } + }; + + parallel_nd(dims[0], H, W, + [&](int n, int h, int w) { + auto iidx = input_d.blk_off(n, 0, h, w); + auto oidx = output_d.blk_off(n, 0, h, w); + + auto i = &input[iidx]; + auto o = &output[oidx / nbits]; + ker(i, o); }); return status::success; } }; +template +struct simple_reorder_impl::type> +{ + PLAIN_TO_BLOCKED_IS_APPLICABLE(); + + GET_SCRATCHPAD_SIZE_ZERO(); + + static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { + DECLARE_COMMON_PARAMS(); + + static constexpr bool w_groups = false; + constexpr int blksize_o = tag_o == format_tag::OIhw8o32i ? 8 : 16; + constexpr int blksize_i = 32; + + const auto &dims = input_d.dims(); + const auto &pdims = order_keep + ? output_d.padded_dims() + : input_d.padded_dims(); + + const int G = w_groups ? dims[0] : 1; + const int OC = dims[w_groups + 0]; + const int NB_OC = pdims[w_groups + 0] / blksize_o; + const int IC = dims[w_groups + 1]; + const int NB_IC = pdims[w_groups + 1] / blksize_i; + const int H = dims[w_groups + 2]; + const int W = dims[w_groups + 3]; + + constexpr int i_mult_o = blksize_o; + constexpr int i_mult_i = blksize_i; + constexpr int nbits = 8; + + auto extract_bit = [](uint8_t val, uint8_t bit) -> uint8_t { + return (uint8_t) ((val >> bit) & 0x0001); + }; + + parallel_nd(G, NB_OC, NB_IC, H, W, + [&](int g, int nb_oc, int nb_ic, int h, int w) { + const int oc_block = nstl::min(blksize_o, OC - nb_oc * blksize_o); + const int ic_block = nstl::min(blksize_i, IC - nb_ic * blksize_i); + + for (int oc = 0; oc < oc_block; ++oc) { + for (int icb = 0; icb < utils::div_up(ic_block, nbits); ++icb) { + + uint8_t bin_val = 0x00; + for (int ic = icb*nbits, shift = 0; ic < std::min(IC, (icb + 1)*nbits); ic++, shift++) { + size_t iidx = (i_mult_o * nb_oc + oc) * input_d.blocking_desc().strides[0] + + (i_mult_i * nb_ic + ic) * input_d.blocking_desc().strides[1] + + h * input_d.blocking_desc().strides[2] + + w; + + uint8_t bit = extract_bit(input[iidx / nbits], (uint8_t)(iidx % nbits)); + bin_val |= (bit << shift); + } + + size_t oidx = output_d.blk_off(g, nb_oc, nb_ic, h, w) + oc * blksize_i + icb * nbits; + output[oidx / nbits] = bin_val; + + } + } + }); + + return status::success; + } +}; + +template +struct simple_reorder_impl::block_dims == bd::_A + || tag_traits::block_dims == bd::_B) + && tag_traits::ndims >= 3 + && tag_traits::ndims <= 6>::type> { + PLAIN_TO_BLOCKED_IS_APPLICABLE(); + + GET_SCRATCHPAD_SIZE_ZERO(); + + static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { + DECLARE_COMMON_PARAMS(); + + const auto &flat_d = order_keep ? input_d : output_d; + const auto &block_d = order_keep ? output_d : input_d; + const dims_t &dims = input_d.dims(); + const dims_t &pdims = block_d.padded_dims(); + + const int ndims = tag_traits::ndims; + const int blk_idx = tag_traits::block_dims == bd::_A ? 0 : 1; + + const dim_t H0 = dims[0]; + const dim_t H1 = dims[1]; + const dim_t M0 = ndims == 6 ? dims[ndims - 4] : 1; + const dim_t M1 = ndims >= 5 ? dims[ndims - 3] : 1; + const dim_t M2 = ndims >= 4 ? dims[ndims - 2] : 1; + const dim_t L = dims[ndims - 1]; + const dim_t l_blk_stride = block_d.blocking_desc().strides[ndims - 1]; + const dim_t l_flat_stride = flat_d.blocking_desc().strides[ndims - 1]; + const dim_t blk_flat_stride = flat_d.blocking_desc().strides[blk_idx]; + using namespace data_type; + using namespace utils; + + constexpr int blksize = false + ? 0 + : one_of(tag_traits::inner_blks, ib::_4a, ib::_4b) + ? 4 + : one_of(tag_traits::inner_blks, ib::_8a, + ib::_8b) + ? 8 + : 16; + + constexpr bool f32bf16 + = one_of(type_i, f32, bf16) && one_of(type_o, f32, bf16); + + auto wrap_qz_a1b0 = [=](data_t &out, data_t inp) { + if (f32bf16) + out = inp; + else + out = _qz_a1b0()(inp); + }; + + auto wrap_qz = [=](data_t &out, data_t inp, float alpha, + float beta) { + if (f32bf16) + out = alpha * inp + (beta ? beta * out : 0); + else + out = _qz()(inp, out, alpha, beta); + }; + + auto ker = [&](const data_t *i, data_t *o, int block) { + if (alpha == 1.0 && beta == 0.0) { + for (int l = 0; l < L; ++l) { + for (int blk = 0; blk < block; ++blk) { + const dim_t flat_off + = blk * blk_flat_stride + l * l_flat_stride; + const dim_t blk_offset = l * l_blk_stride + blk; + if (order_keep) { + wrap_qz_a1b0(o[blk_offset], i[flat_off]); + } else { + wrap_qz_a1b0(o[flat_off], i[blk_offset]); + } + } + } + } else { + for (int l = 0; l < L; ++l) { + for (int blk = 0; blk < block; ++blk) { + const dim_t flat_off + = blk * blk_flat_stride + l * l_flat_stride; + const dim_t blk_offset = l * l_blk_stride + blk; + if (order_keep) + wrap_qz(o[blk_offset], i[flat_off], alpha, beta); + else + wrap_qz(o[flat_off], i[blk_offset], alpha, beta); + } + } + } + }; + +#define off(md, h0, h1, m0, m1, m2) \ + (ndims >= 6 ? (md).blk_off(h0, h1, m0, m1, m2) \ + : ndims >= 5 ? (md).blk_off(h0, h1, m1, m2) \ + : ndims >= 4 \ + ? (md).blk_off(h0, h1, m2) \ + : /* ndims >= 3 ? */ (md).blk_off(h0, h1)) + + constexpr int i_mult = order_keep ? blksize : 1; + constexpr int o_mult = order_keep ? 1 : blksize; + + if (blk_idx == 0) { + const dim_t BH0 = pdims[0] / blksize; + parallel_nd(BH0, H1, M0, M1, M2, + [&](dim_t bh0, dim_t h1, dim_t m0, dim_t m1, dim_t m2) { + auto i = &input[off( + input_d, bh0 * i_mult, h1, m0, m1, m2)]; + auto o = &output[off( + output_d, bh0 * o_mult, h1, m0, m1, m2)]; + const int block + = nstl::min(blksize, H0 - bh0 * blksize); + ker(i, o, block); + }); + } else if (blk_idx == 1) { + const dim_t BH1 = pdims[1] / blksize; + parallel_nd(H0, BH1, M0, M1, M2, + [&](dim_t h0, dim_t bh1, dim_t m0, dim_t m1, dim_t m2) { + auto i = &input[off( + input_d, h0, bh1 * i_mult, m0, m1, m2)]; + auto o = &output[off( + output_d, h0, bh1 * o_mult, m0, m1, m2)]; + const int block + = nstl::min(blksize, H1 - bh1 * blksize); + ker(i, o, block); + }); + } else { + assert(!"unimplemented"); + } + +#undef off + + return status::success; + } +}; + +template +struct simple_reorder_impl::block_dims == bd::_AB + || tag_traits::block_dims == bd::_BC) + && IMPLICATION(tag_traits::block_dims == bd::_AB, + tag_traits::ndims >= 3 + && tag_traits::ndims <= 5) + && IMPLICATION(tag_traits::block_dims == bd::_BC, + tag_traits::ndims >= 4 + && tag_traits::ndims <= 6)>::type> { + PLAIN_TO_BLOCKED_IS_APPLICABLE(); + + GET_SCRATCHPAD_SIZE_ZERO(); + + static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { + DECLARE_COMMON_PARAMS(); + + const auto &flat_d = order_keep ? input_d : output_d; + const auto &dims = input_d.dims(); + const auto &pdims + = order_keep ? output_d.padded_dims() : input_d.padded_dims(); + + constexpr int ndims = tag_traits::ndims; + + static constexpr bool with_g = tag_traits::block_dims == bd::_BC; + const dim_t G = with_g ? dims[0] : 1; + + const dim_t H0 = dims[0 + with_g]; + const dim_t H1 = dims[1 + with_g]; + + const dim_t M0 = ndims >= 5 + with_g ? dims[ndims - 3] : 1; + const dim_t M1 = ndims >= 4 + with_g ? dims[ndims - 2] : 1; + const dim_t M2 = ndims >= 3 + with_g ? dims[ndims - 1] : 1; + + const dim_t h0_flat_stride = flat_d.blocking_desc().strides[with_g + 0]; + const dim_t h1_flat_stride = flat_d.blocking_desc().strides[with_g + 1]; + using namespace data_type; + using namespace utils; + + constexpr dim_t blksize_0 = false + ? 0 + : one_of(tag_traits::inner_blks, ib::_4b4a, ib::_4b4c, + ib::_4c4b) + ? 4 + : one_of(tag_traits::inner_blks, ib::_8a8b, + ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_2c8b4c) + ? 8 + : one_of(tag_traits::inner_blks, + ib::_16a16b, ib::_16b16a, ib::_16b16c, + ib::_16c16b, ib::_8a16b2a, + ib::_4b16a4b, ib::_8b16a2b, + ib::_8b16c2b, ib::_4c16b4c, + ib::_8c16b2c) + ? 16 + : -1; + + constexpr dim_t blksize_1 + = one_of(tag_traits::inner_blks, ib::_8a8b, ib::_8b8a, + ib::_8b8c, ib::_8c8b, ib::_2c8b4c) + ? 8 + : one_of(tag_traits::inner_blks, ib::_16a16b, + ib::_16b16a, ib::_16b16c, ib::_16c16b, ib::_8a16b2a, + ib::_4b16a4b, ib::_8b16a2b, ib::_8b16c2b, + ib::_4c16b4c, ib::_8c16b2c) + ? 16 + : one_of(tag_traits::inner_blks, ib::_4b4a, + ib::_4b4c, ib::_4c4b) + ? 4 + : -1; + + const dim_t NB_H0 = pdims[0 + with_g] / blksize_0; + const dim_t NB_H1 = pdims[1 + with_g] / blksize_1; + + constexpr bool f32bf16 + = one_of(type_i, f32, bf16) && one_of(type_o, f32, bf16); + + auto wrap_qz_a1b0 = [=](data_t &out, data_t inp) { + if (f32bf16) + out = inp; + else + out = _qz_a1b0()(inp); + }; + + auto wrap_qz = [=](data_t &out, data_t inp, float alpha, + float beta) { + if (f32bf16) + out = alpha * inp + (beta ? beta * out : 0); + else + out = _qz()(inp, out, alpha, beta); + }; + + auto ker = [&](const data_t *i, data_t *o, + const int block_h0, const int block_h1) { +#define blk_off AB_or_BC_blk_off::inner_blks> + if (alpha == 1.0 && beta == 0.0) { + for (int h0 = 0; h0 < block_h0; ++h0) { + for (int h1 = 0; h1 < block_h1; ++h1) { + const dim_t flat_off + = h0 * h0_flat_stride + h1 * h1_flat_stride; + if (order_keep) + wrap_qz_a1b0(o[blk_off(h0, h1)], i[flat_off]); + else + wrap_qz_a1b0(o[flat_off], i[blk_off(h0, h1)]); + } + } + } else { + for (int h0 = 0; h0 < block_h0; ++h0) { + for (int h1 = 0; h1 < block_h1; ++h1) { + const dim_t flat_off + = h0 * h0_flat_stride + h1 * h1_flat_stride; + if (order_keep) + wrap_qz(o[blk_off(h0, h1)], i[flat_off], alpha, + beta); + else + wrap_qz(o[flat_off], i[blk_off(h0, h1)], alpha, + beta); + } + } + } + +#undef blk_off + }; + + constexpr int i_mult_0 = order_keep ? blksize_0 : 1; + constexpr int o_mult_0 = order_keep ? 1 : blksize_0; + + constexpr int i_mult_1 = order_keep ? blksize_1 : 1; + constexpr int o_mult_1 = order_keep ? 1 : blksize_1; + +#define off(md, g, h0, h1, m0, m1, m2) \ + (ndims >= 5 + with_g ? (md).blk_off(g, h0, h1, m0, m1, m2) \ + : ndims >= 4 + with_g \ + ? (md).blk_off(g, h0, h1, m1, m2) \ + : /* ndims >= 3 + with_g ? */ (md) \ + .blk_off(g, h0, h1, m2)) + + parallel_nd(G, NB_H0, NB_H1, M0, M1, M2, + [&](dim_t g, dim_t nb_h0, dim_t nb_h1, dim_t m0, dim_t m1, + dim_t m2) { + auto i = &input[off(input_d, g, i_mult_0 * nb_h0, + i_mult_1 * nb_h1, m0, m1, m2)]; + auto o = &output[off(output_d, g, o_mult_0 * nb_h0, + o_mult_1 * nb_h1, m0, m1, m2)]; + const int block_h0 + = nstl::min(blksize_0, H0 - nb_h0 * blksize_0); + const int block_h1 + = nstl::min(blksize_1, H1 - nb_h1 * blksize_1); + ker(i, o, block_h0, block_h1); + }); + +#undef off + return status::success; + } +}; + /* generic and direct-copy reorders */ template @@ -805,7 +1514,6 @@ struct simple_reorder_implattr()->output_scales_.mask_; diff --git a/src/cpu/rnn/ref_rnn.hpp b/src/cpu/rnn/ref_rnn.hpp index 03a13febfac..145a8c2d14b 100644 --- a/src/cpu/rnn/ref_rnn.hpp +++ b/src/cpu/rnn/ref_rnn.hpp @@ -58,7 +58,7 @@ void gates_reduction(const rnn_utils::rnn_conf_t &rnn, const gates_t *ws_gates_, // @todo block k on simd-width to enable vectorization in // parallel_nd path #if DNNL_CPU_RUNTIME == DNNL_RUNTIME_OMP && _OPENMP >= 201307 \ - && __INTEL_COMPILER < 1910 + && defined __INTEL_COMPILER && __INTEL_COMPILER < 1910 #pragma omp parallel for simd collapse(2) for (int i = 0; i < rnn.n_gates; i++) for (int k = 0; k < rnn.dhc; k++) @@ -238,6 +238,10 @@ struct _ref_rnn_common_t : public primitive_t { using namespace rnn_utils; #if DNNL_X64 using namespace x64; + + // WA: Brgemm implementation has perf degradation for RNN node + return status::unimplemented; + const alg_kind_t cell_kind = this->desc()->cell_kind; const data_type_t src_layer_dt diff --git a/src/cpu/simple_concat.cpp b/src/cpu/simple_concat.cpp index e7b2ce826c6..8882a487043 100644 --- a/src/cpu/simple_concat.cpp +++ b/src/cpu/simple_concat.cpp @@ -77,7 +77,7 @@ status_t simple_concat_t::execute(const exec_ctx_t &ctx) const { for (int a = 0; a < num_arrs; ++a) { const data_t *i = &iptrs[a][0]; data_t *o = &optrs[a][0]; - parallel_nd(nelems_to_copy[a], [&](dim_t e) { o[e] = i[e]; }); + parallel_nd_legacy(nelems_to_copy[a], [&](dim_t e) { o[e] = i[e]; }); } return status::success; } @@ -94,7 +94,7 @@ status_t simple_concat_t::execute(const exec_ctx_t &ctx) const { const auto L1_size = platform::get_per_core_cache_size(1); UNUSED(L1_size); // for Windows - parallel_nd(phys_dims[0], phys_dims[1], phys_dims[2], phys_dims[3], + parallel_nd_legacy(phys_dims[0], phys_dims[1], phys_dims[2], phys_dims[3], phys_dims[4], num_arrs, [&](dim_t n0, dim_t n1, dim_t n2, dim_t n3, dim_t n4, dim_t a) { // check if zero memory diff --git a/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp b/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp index 3c8a2bbed1f..16d01571bd9 100644 --- a/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp +++ b/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp @@ -921,17 +921,10 @@ void jit_brgemm_kernel_t::store_accumulators_apply_post_ops( const bool dq2ps_required = brg.is_int8 && IMPLICATION(alpha_or_beta_applicable, beta_uses_vadd); - if (brg.with_bias) { mov(reg_aux_bias, ptr[rsp + reg_aux_bias_offs_]); } for_(int bd = 0; bd < bd_block; bd++) for (int ld = 0; ld < ld_block2; ld++) { auto zmm = accm(ld_block2, bd, ld); if (dq2ps_required) vcvtdq2ps(zmm, zmm); - if (brg.with_bias) { - auto zmm_bias = zmm_tmp_1(); - auto ptr_bias = ptr[reg_aux_bias + bias_offset(ld)]; - cvt2ps(brg.dt_bias, zmm_bias, ptr_bias, true, false, k_mask); - vaddps(zmm, zmm, zmm_bias); - } } if (brg.zp_type_a != brgemm_broadcast_t::none) { @@ -980,6 +973,19 @@ void jit_brgemm_kernel_t::store_accumulators_apply_post_ops( } } } + + if (brg.with_bias) { mov(reg_aux_bias, ptr[rsp + reg_aux_bias_offs_]); } + for_(int bd = 0; bd < bd_block; bd++) + for (int ld = 0; ld < ld_block2; ld++) { + auto zmm = accm(ld_block2, bd, ld); + if (brg.with_bias) { + auto zmm_bias = zmm_tmp_1(); + auto ptr_bias = ptr[reg_aux_bias + bias_offset(ld)]; + cvt2ps(brg.dt_bias, zmm_bias, ptr_bias, true, false, k_mask); + vaddps(zmm, zmm, zmm_bias); + } + } + if (brg.with_scales) { mov(reg_aux_scales, ptr[rsp + reg_aux_scales_offs_]); for (int bd = 0; bd < bd_block; bd++) { diff --git a/src/cpu/x64/cpu_isa_traits.hpp b/src/cpu/x64/cpu_isa_traits.hpp index e96e14d3661..4f69f868e1e 100644 --- a/src/cpu/x64/cpu_isa_traits.hpp +++ b/src/cpu/x64/cpu_isa_traits.hpp @@ -31,7 +31,7 @@ #define XBYAK_NO_OP_NAMES /* in order to make selinux happy memory that would be marked with X-bit should * be obtained with mmap */ -#define XBYAK_USE_MMAP_ALLOCATOR +//#define XBYAK_USE_MMAP_ALLOCATOR #define XBYAK_NO_EXCEPTION #if defined(_MSC_VER) && !defined(__INTEL_COMPILER) /* turn off `size_t to other-type implicit casting` warning @@ -67,6 +67,7 @@ enum cpu_isa_bit_t : unsigned { amx_int8_bit = 1u << 10, amx_bf16_bit = 1u << 11, avx_vnni_bit = 1u << 12, + avx512_vpopcnt_bit = 1u << 13, // Fill in hints from most significant bit to least significant bit prefer_ymm_bit = 1u << (cpu_isa_total_bits - 1), @@ -122,6 +123,7 @@ enum cpu_isa_t : unsigned { avx512_core_bf16_amx_int8 = avx512_core_bf16 | amx_int8, avx512_core_bf16_amx_bf16 = avx512_core_bf16 | amx_bf16, avx512_core_amx = avx512_core_bf16 | amx_int8 | amx_bf16, + avx512_vpopcnt = avx512_vpopcnt_bit, // NOTES: 1. isa_all by default has no isa specific hints isa_all = ~0u & ~cpu_isa_hints_utils::hints_mask, }; @@ -275,6 +277,13 @@ struct cpu_isa_traits { static constexpr const char *user_option_env = "avx512_core_amx"; }; +template <> +struct cpu_isa_traits { + static constexpr dnnl_cpu_isa_t user_option_val + = dnnl_cpu_isa_avx512_vpopcnt; + static constexpr const char *user_option_env = "AVX512_VPOPCNT"; +}; + inline const Xbyak::util::Cpu &cpu() { const static Xbyak::util::Cpu cpu_; return cpu_; @@ -340,6 +349,7 @@ static inline bool mayiuse(const cpu_isa_t cpu_isa, bool soft = false) { case avx512_core_amx: return mayiuse(avx512_core_bf16_amx_int8, soft) && mayiuse(avx512_core_bf16_amx_bf16, soft); + case avx512_vpopcnt: return cpu().has(Cpu::tAVX512_VPOPCNTDQ); case isa_any: return true; case isa_all: return false; } diff --git a/src/cpu/x64/gemm_bf16_convolution.cpp b/src/cpu/x64/gemm_bf16_convolution.cpp index 6f45b93f805..b5f18386303 100644 --- a/src/cpu/x64/gemm_bf16_convolution.cpp +++ b/src/cpu/x64/gemm_bf16_convolution.cpp @@ -72,11 +72,15 @@ void cvt_acc_to_dst(const conv_gemm_conf_t &jcp, size_t g_start, size_t g_end, template gemm_bf16_convolution_fwd_t::pp_ker_t::pp_ker_t(const pd_t *pd) : jcp_(pd->jcp_) + , post_ops_(pd->attr()->post_ops_) , do_sum_(dst_data_type != data_type::f32 && jcp_.with_sum) , max_data_reg_idx_(31) , max_unroll_(12) , compute_reg_step_(1) - , data_reg_base_idx_(0) { + , data_reg_base_idx_(0) + , attr_(pd->attr()) + , jit_eltwise_injectors_(0) +{ using namespace types; using namespace Xbyak; @@ -84,29 +88,15 @@ gemm_bf16_convolution_fwd_t::pp_ker_t::pp_ker_t(const pd_t *pd) // bf16 is not supported return; - const auto &post_ops = jcp_.post_ops; - if (jcp_.with_eltwise || jcp_.with_binary) { -#define PARAM_OFF(field) offsetof(ker_args, field) - static constexpr bool preserve_gpr = true; - static constexpr bool preserve_vmm = true; - static constexpr size_t helper_vmm_idx = 31; - static constexpr size_t tail_size = 1; - static constexpr bool use_exact_tail_scalar_bcast = false; - const binary_injector::rhs_arg_static_params_t rhs_arg_static_params { - helper_vmm_idx, reserved_eltwise_gpr, r14, preserve_gpr, - preserve_vmm, PARAM_OFF(post_ops_binary_rhs_arg_vec), - PARAM_OFF(dst_orig), memory_desc_wrapper(pd->dst_md()), - tail_size, kreg_rem_mask, use_exact_tail_scalar_bcast}; - const binary_injector::static_params_t binary_static_params { - this->reg_param, rhs_arg_static_params}; - static constexpr bool save_state = true; - const eltwise_injector::static_params_t eltwise_static_params { - save_state, reserved_eltwise_gpr, reserved_eltwise_maskr}; - - postops_injector_ = utils::make_unique< - injector::jit_uni_postops_injector_t>( - this, post_ops, binary_static_params, eltwise_static_params); -#undef PARAM_OFF + bool do_depthwise_ = false; + for (int i = 0; i < post_ops_.len(); i++) { + auto& post_op = post_ops_.entry_[i]; + if (post_op.is_eltwise()) { + jit_eltwise_injectors_.push_back(new jit_uni_eltwise_injector_f32(this, + post_op.eltwise, true, reserved_eltwise_gpr, reserved_eltwise_maskr)); + } else if (post_op.is_depthwise()) { + do_depthwise_ = true; + } } if (do_sum_) { @@ -116,6 +106,9 @@ gemm_bf16_convolution_fwd_t::pp_ker_t::pp_ker_t(const pd_t *pd) if (jcp_.with_bias) vreg_bias = Zmm(data_reg_base_idx_++); + if (do_depthwise_) + vreg_dw = Zmm(data_reg_base_idx_++); + vlen_ = cpu_isa_traits::vlen / sizeof(float); isa_ = mayiuse(avx512_core_bf16) ? avx512_core_bf16 @@ -132,25 +125,6 @@ gemm_bf16_convolution_fwd_t::pp_ker_t::pp_ker_t(const pd_t *pd) = (max_data_reg_idx_ - data_reg_base_idx_ + 1) / compute_reg_step_; } -template -void gemm_bf16_convolution_fwd_t::pp_ker_t::apply_postops( - const bool apply_mask, const size_t out_offset, const int vmm_idx) { -#define PARAM_OFF(x) offsetof(ker_args, x) - if (jcp_.with_eltwise || jcp_.with_binary) { - if (jcp_.with_binary) { - binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; - rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_dst); - rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace( - vmm_idx, out_offset * sizeof(dst_data_t)); - if (apply_mask) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); - - postops_injector_->compute_vector(vmm_idx, rhs_arg_params); - } else - postops_injector_->compute_vector(vmm_idx); - } -#undef PARAM_OFF -} - template void gemm_bf16_convolution_fwd_t::pp_ker_t::generate() { using namespace Xbyak; @@ -171,6 +145,8 @@ void gemm_bf16_convolution_fwd_t::pp_ker_t::generate() { mov(reg_len, ptr[reg_param + PARAM_OFF(spatial_length)]); mov(reg_oc_iter, ptr[reg_param + PARAM_OFF(oc_work)]); + mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]); + if (jcp_.with_binary) { // zero initialize binary post_ops offset accumulator (store on stack) const auto binary_post_op_acc_off_reg = reg_tmp; @@ -180,7 +156,6 @@ void gemm_bf16_convolution_fwd_t::pp_ker_t::generate() { if (do_sum_) vbroadcastss(vreg_sum_scale, ptr[reg_param + PARAM_OFF(sum_scale)]); -#undef PARAM_OFF // Load accumulated value, apply sum (if any), bias (if any) // and relu (if any); then convert to destination type and store @@ -217,7 +192,48 @@ void gemm_bf16_convolution_fwd_t::pp_ker_t::generate() { vfmadd231ps(vreg_dst(idx), vreg_prev_dst(idx), vreg_sum_scale); } - apply_postops(apply_mask, offset, vreg_dst_idx(idx)); + if (jcp_.with_depthwise) { + push(reg_post_ops_data); + mov(reg_post_ops_data, ptr[reg_param + PARAM_OFF(post_ops_binary_rhs_arg_vec)]); + } + + int eltwise_inj_idx = 0; + std::size_t post_ops_data_offset = 0; + const auto& p = attr_->post_ops_; + for (int i = 0; i < p.len(); i++) { + auto& post_op = p.entry_[i]; + if (post_op.is_eltwise()) { + jit_eltwise_injectors_[eltwise_inj_idx]->compute_vector(vreg_dst_idx(idx)); + eltwise_inj_idx++; + } else if (post_op.is_depthwise()) { + mov(reg_dw, ptr[reg_post_ops_data + post_ops_data_offset]); + lea(reg_dw, ptr[reg_dw + reg_oc_offset]); + + switch (post_op.depthwise.alg) { + case alg_kind::depthwise_scale_shift: { + vbroadcastss(vreg_dw, ptr[reg_dw + post_op.depthwise.offset[post_op.depthwise.scales] * sizeof(float)]); + vmulps(vreg_dst(idx), vreg_dst(idx), vreg_dw); + vbroadcastss(vreg_dw, ptr[reg_dw + post_op.depthwise.offset[post_op.depthwise.shifts] * sizeof(float)]); + vaddps(vreg_dst(idx), vreg_dst(idx), vreg_dw); + break; + } + case alg_kind::depthwise_prelu: { + vpxord(vreg_dw, vreg_dw, vreg_dw); + vcmpps(kmask, vreg_dst(idx), vreg_dw, _cmp_lt_os); + vbroadcastss(vreg_dw, ptr[reg_dw + post_op.depthwise.offset[post_op.depthwise.scales] * sizeof(float)]); + vmulps(vreg_dst(idx) | kmask, vreg_dst(idx), vreg_dw); + break; + } + default: assert(!"unsupported depthwise algorithm"); + } + + post_ops_data_offset += sizeof(float*); + } + } + + if (jcp_.with_depthwise) { + pop(reg_post_ops_data); + } if (dst_data_type == data_type::bf16) { // TODO: implement store by zmm registers for bf16 @@ -293,6 +309,8 @@ void gemm_bf16_convolution_fwd_t::pp_ker_t::generate() { if (jcp_.with_binary) inc(EVEX_compress_addr(rsp, reg_binary_post_op_acc_off)); + add(reg_oc_offset, sizeof(float)); + dec(reg_oc_iter); jnz(oc_loop, T_NEAR); // oc_loop end @@ -302,14 +320,17 @@ void gemm_bf16_convolution_fwd_t::pp_ker_t::generate() { postamble(); - if (jcp_.with_eltwise) postops_injector_->prepare_table(); + for (auto& inj : jit_eltwise_injectors_) + inj->prepare_table(); + +#undef PARAM_OFF } // operator () specialized for nspc format template void gemm_bf16_convolution_fwd_t::pp_ker_t::operator()( dst_data_t *dst, const acc_data_t *acc, const acc_data_t *bias, - float sum_scale, size_t oc_work, + float sum_scale, size_t oc_work, size_t g_offset, const void *post_ops_binary_rhs_arg_vec, const void *dst_orig, const size_t g_oc_offset) { @@ -322,6 +343,7 @@ void gemm_bf16_convolution_fwd_t::pp_ker_t::operator()( args.acc_stride_in_bytes = sizeof(acc_data_t); args.spatial_length = 1; args.oc_work = oc_work; + args.oc_offset = g_offset * sizeof(float); args.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; args.dst_orig = dst_orig; @@ -333,7 +355,7 @@ void gemm_bf16_convolution_fwd_t::pp_ker_t::operator()( template void gemm_bf16_convolution_fwd_t::pp_ker_t::operator()( dst_data_t *dst, const acc_data_t *acc, const acc_data_t *bias, - float sum_scale, size_t dst_stride_in_elements, + size_t g_offset, size_t start_oc, float sum_scale, size_t dst_stride_in_elements, size_t acc_stride_in_elements, size_t sp_len, size_t oc_len, const void *post_ops_binary_rhs_arg_vec, const void *dst_orig, const size_t g_oc_offset) { @@ -348,6 +370,7 @@ void gemm_bf16_convolution_fwd_t::pp_ker_t::operator()( args.acc_stride_in_bytes = acc_stride_in_elements * sizeof(acc_data_t); args.spatial_length = sp_len; args.oc_work = oc_len; + args.oc_offset = (start_oc + g_offset) * sizeof(float); args.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; args.dst_orig = dst_orig; @@ -365,6 +388,8 @@ status_t gemm_bf16_convolution_fwd_t::execute_forward_nspc( = binary_injector::prepare_binary_args( this->pd()->attr()->post_ops_, ctx); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + auto scratchpad = ctx.get_scratchpad_grantor(); const conv_gemm_conf_t &jcp = pd()->jcp_; @@ -386,7 +411,7 @@ status_t gemm_bf16_convolution_fwd_t::execute_forward_nspc( parallel(jcp.nthr, [&](const int ithr, const int nthr) { status_t st_thr = execute_forward_thr_nspc(ithr, nthr, src_base, wei_base, bia_base, dst_base, scratchpad, - post_ops_binary_rhs_arg_vec.data()); + post_ops_binary_rhs_arg_vec.data(), MB); if (st_thr != status::success) st = st_thr; }); @@ -398,7 +423,7 @@ status_t gemm_bf16_convolution_fwd_t::execute_forward_thr_nspc( const int ithr, const int nthr, const src_data_t *src_base, const wei_data_t *wei_base, const float *bia_base, dst_data_t *dst_base, const memory_tracking::grantor_t &scratchpad, - const void *post_ops_binary_rhs_arg_vec) const { + const void *post_ops_binary_rhs_arg_vec, int MB) const { const conv_gemm_conf_t &jcp = pd()->jcp_; // Src Format: mb-spatial-groups-input_channels @@ -435,9 +460,9 @@ status_t gemm_bf16_convolution_fwd_t::execute_forward_thr_nspc( const dim_t nb_oh = div_up(jcp.oh, jcp.oh_block); const dim_t nb_ow = div_up(jcp.ow, jcp.ow_block); - const dim_t work_amount = jcp.ngroups * jcp.mb * nb_oh * nb_ow; + const dim_t work_amount = jcp.ngroups * MB * nb_oh * nb_ow; balance211(work_amount, nthr, ithr, start, end); - nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow); + nd_iterator_init(start, n, MB, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow); if (jcp.im2col_sz && is_problem_3d) { // jit_gemm_convolution_utils::im2col_dt_3d() requires external @@ -508,13 +533,13 @@ status_t gemm_bf16_convolution_fwd_t::execute_forward_thr_nspc( (*pp_ker_)(dst_arr, acc_needed ? acc_arr : (float *)dst_arr, - bia_arr, sum_scale, jcp.oc, + bia_arr, sum_scale, jcp.oc, g * jcp.oc, post_ops_binary_rhs_arg_vec, dst_base, g * jcp.oc); }); } } - nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow); + nd_iterator_step(n, MB, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow); } return status::success; } @@ -529,6 +554,8 @@ status_t gemm_bf16_convolution_fwd_t::execute_forward_ncsp( = binary_injector::prepare_binary_args( this->pd()->attr()->post_ops_, ctx); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + bool is_bf16_dst = dst_data_type == data_type::bf16; auto col = ctx.get_scratchpad_grantor().template get( @@ -540,6 +567,12 @@ status_t gemm_bf16_convolution_fwd_t::execute_forward_ncsp( const conv_gemm_conf_t &jcp = this->pd()->jcp_; + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + + src += src_d.off_l(0); + dst += dst_d.off_l(0); + float *bias = nullptr; if (jcp.with_bias) { if (pd()->desc()->bias_desc.data_type == data_type::bf16) { @@ -565,7 +598,7 @@ status_t gemm_bf16_convolution_fwd_t::execute_forward_ncsp( const dim_t LDB = weights_oc_size; const dim_t work_amount - = (size_t)jcp.ngroups * jcp.mb * jcp.od * jcp.os_nb_block; + = (size_t)jcp.ngroups * MB * jcp.od * jcp.os_nb_block; const bool is_problem_3d = pd()->ndims() == 5; std::atomic st(status::success); @@ -610,8 +643,7 @@ status_t gemm_bf16_convolution_fwd_t::execute_forward_ncsp( if (this->pd()->is_postprocess_required() && ic + ic_block >= jcp.ic) { size_t acc_str = LDC; size_t dst_str = M; - float *bias_ptr = bias ? bias + groups * jcp.oc + oc : nullptr; - (*pp_ker_)(dst_local, acc, bias_ptr, sum_scale, dst_str, acc_str, m, + (*pp_ker_)(dst_local, acc, bias, groups * jcp.oc, oc, sum_scale, dst_str, acc_str, m, oc_block, post_ops_binary_rhs_arg_vec.data(), dst, groups * jcp.oc + oc); } @@ -633,7 +665,7 @@ status_t gemm_bf16_convolution_fwd_t::execute_forward_ncsp( balance2D(nthr, ithr, work_amount, start, end, jcp.oc, oc_start, oc_end, dim_t(jcp.nthr_oc)); - nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, od, jcp.od, nb_os, + nd_iterator_init(start, g, jcp.ngroups, n, MB, od, jcp.od, nb_os, jcp.os_nb_block); for (dim_t iwork = start; iwork < end; ++iwork) { for_(dim_t oc = (dim_t)oc_start; oc < (dim_t)oc_end; @@ -660,7 +692,7 @@ status_t gemm_bf16_convolution_fwd_t::execute_forward_ncsp( inner_ker(ic, oc, g, od, nb_os, _src, _weights, _col, _dst_im, _acc, ic_block, oc_block); } - nd_iterator_step(g, jcp.ngroups, n, jcp.mb, od, jcp.od, nb_os, + nd_iterator_step(g, jcp.ngroups, n, MB, od, jcp.od, nb_os, jcp.os_nb_block); } }); @@ -676,13 +708,18 @@ status_t gemm_bf16_convolution_bwd_data_t:: auto wei_base = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); auto diff_src_base = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC); + const auto post_ops_binary_rhs_arg_vec + = x64::binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); + + auto MB = CTX_IN_BATCH(DNNL_ARG_DIFF_DST); + auto scratchpad = ctx.get_scratchpad_grantor(); const conv_gemm_conf_t &jcp = pd()->jcp_; std::atomic st(status::success); parallel(jcp.nthr, [&](const int ithr, const int nthr) { status_t st_thr = execute_backward_data_thr_nspc( - ithr, nthr, diff_src_base, wei_base, diff_dst_base, scratchpad); + ithr, nthr, diff_src_base, wei_base, diff_dst_base, scratchpad, MB, post_ops_binary_rhs_arg_vec); if (st_thr != status::success) st = st_thr; }); @@ -694,7 +731,8 @@ status_t gemm_bf16_convolution_bwd_data_t< diff_src_data_type>::execute_backward_data_thr_nspc(const int ithr, const int nthr, diff_src_data_t *diff_src_base, const wei_data_t *wei_base, const diff_dst_data_t *diff_dst_base, - const memory_tracking::grantor_t &scratchpad) const { + const memory_tracking::grantor_t &scratchpad, int MB, + const std::vector& post_ops_binary_rhs_arg_vec) const { const conv_gemm_conf_t &jcp = pd()->jcp_; @@ -713,7 +751,9 @@ status_t gemm_bf16_convolution_bwd_data_t< const size_t diff_src_os_stride = jcp.ngroups * jcp.ic; // threads share work across mini-batch and groups - const dim_t work_amount = jcp.ngroups * jcp.mb; + const dim_t work_amount = jcp.ngroups * MB; + + const auto& p = pd()->attr()->post_ops_; acc_data_t *__restrict col = scratchpad.get(key_conv_gemm_col) + (ptrdiff_t)ithr * jcp.im2col_sz; @@ -724,7 +764,7 @@ status_t gemm_bf16_convolution_bwd_data_t< dim_t start = 0, end = 0; balance211(work_amount, nthr, ithr, start, end); - nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups); + nd_iterator_init(start, n, MB, g, jcp.ngroups); for (dim_t iwork = start; iwork < end; ++iwork) { const diff_dst_data_t *__restrict diff_dst = diff_dst_base @@ -747,6 +787,30 @@ status_t gemm_bf16_convolution_bwd_data_t< if (jcp.im2col_sz) jit_gemm_convolution_utils::col2im_dt(jcp, col, acc); + if (p.len() > 0) { + std::size_t post_ops_data_idx = 0; + int depthwise_inj_idx = 0; + for (int i = 0; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + auto depthwise_base = reinterpret_cast(post_ops_binary_rhs_arg_vec[post_ops_data_idx]); + auto depthwise_weights = depthwise_base + post_op.depthwise.offset[post_op.depthwise.scales]; + auto depthwise_bias = depthwise_base + post_op.depthwise.offset[post_op.depthwise.shifts]; + + parallel_nd(static_cast(jcp.is) * jcp.id, [&](size_t is) { + diff_src_data_t*__restrict diff_src_arr + = diff_src + is * diff_src_os_stride; + for (int ic = 0; ic < jcp.ic; ic++) { + diff_src_arr[ic] = depthwise_injectors[depthwise_inj_idx]->compute_scalar(diff_src_arr[ic], + depthwise_weights + g * jcp.ic + ic, depthwise_bias + g * jcp.ic + ic); + } + }); + post_ops_data_idx++; + depthwise_inj_idx++; + } + } + } + const bool is_diff_src_bf16 = diff_src_data_type == data_type::bf16; if (is_diff_src_bf16 && jcp.ngroups == 1 && jcp.nthr != 1) { @@ -777,7 +841,7 @@ status_t gemm_bf16_convolution_bwd_data_t< diff_src_loc[ic] = acc_loc[ic]; }); } - nd_iterator_step(n, jcp.mb, g, jcp.ngroups); + nd_iterator_step(n, MB, g, jcp.ngroups); } return status::success; } @@ -789,6 +853,11 @@ status_t gemm_bf16_convolution_bwd_data_t:: auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); auto diff_src = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC); + const auto post_ops_binary_rhs_arg_vec + = x64::binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); + + auto MB = CTX_IN_BATCH(DNNL_ARG_DIFF_DST); + auto col = ctx.get_scratchpad_grantor().template get( key_conv_gemm_col); acc_data_t *acc_base = diff_src_data_type == data_type::bf16 @@ -807,9 +876,11 @@ status_t gemm_bf16_convolution_bwd_data_t:: const dim_t K = jcp.oc; const dim_t N = jcp.ic * jcp.ks; - const dim_t work_amount = (size_t)jcp.ngroups * jcp.mb; + const dim_t work_amount = (size_t)jcp.ngroups * MB; const bool is_problem_3d = pd()->ndims() == 5; + const auto& p = pd()->attr()->post_ops_; + std::atomic st(status::success); parallel(jcp.nthr, [&](const int ithr, const int nthr) { @@ -818,7 +889,7 @@ status_t gemm_bf16_convolution_bwd_data_t:: dim_t g {0}, n {0}; dim_t start = 0, end = 0; balance211(work_amount, nthr, ithr, start, end); - nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb); + nd_iterator_init(start, g, jcp.ngroups, n, MB); for (dim_t iwork = start; iwork < end; ++iwork) { diff_src_data_t *diff_src_local @@ -863,13 +934,39 @@ status_t gemm_bf16_convolution_bwd_data_t:: od, os_nb * jcp.os_block, os_block); } } + + if (p.len() > 0) { + std::size_t post_ops_data_idx = 0; + int depthwise_inj_idx = 0; + for (int i = 0; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + auto depthwise_base = reinterpret_cast(post_ops_binary_rhs_arg_vec[post_ops_data_idx]); + auto depthwise_weights = depthwise_base + post_op.depthwise.offset[post_op.depthwise.scales]; + auto depthwise_bias = depthwise_base + post_op.depthwise.offset[post_op.depthwise.shifts]; + + parallel_nd(jcp.ic, [&](const int ic) { + for (int id = 0; id < jcp.id; ++id) { + acc_data_t *d_ = acc + ic * jcp.id * jcp.is + id * jcp.is; + for (int iS = 0; iS < jcp.is; ++iS) { + d_[iS] = depthwise_injectors[depthwise_inj_idx]->compute_scalar(d_[iS], + depthwise_weights + g * jcp.ic + ic, depthwise_bias + g * jcp.ic + ic); + } + } + }); + post_ops_data_idx++; + depthwise_inj_idx++; + } + } + } + if (diff_src_data_type == data_type::bf16) { size_t spatial_size = (size_t)jcp.ih * jcp.iw * jcp.id; store_bfloat16_in_parallel((bfloat16_t *)diff_src_local, (const float *)acc, jcp.ic, spatial_size, jcp.nthr == 1); } - nd_iterator_step(g, jcp.ngroups, n, jcp.mb); + nd_iterator_step(g, jcp.ngroups, n, MB); } }); diff --git a/src/cpu/x64/gemm_bf16_convolution.hpp b/src/cpu/x64/gemm_bf16_convolution.hpp index 6fc05befc78..1e1c2c26d6f 100644 --- a/src/cpu/x64/gemm_bf16_convolution.hpp +++ b/src/cpu/x64/gemm_bf16_convolution.hpp @@ -28,6 +28,8 @@ #include "cpu/x64/cpu_reducer.hpp" #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" +#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" +#include "cpu/ref_depthwise_injector.hpp" namespace dnnl { namespace impl { @@ -55,17 +57,8 @@ struct gemm_bf16_convolution_fwd_t : public primitive_t { && !has_zero_dim_memory() && attr()->has_default_values( primitive_attr_t::skip_mask_t::post_ops, - dst_data_type); - { - using namespace x64::injector; - static constexpr bool sum_at_pos_0_only = true; - static constexpr bool sum_requires_scale_one = true; - static constexpr bool sum_requires_zp_zero = true; - const auto dst_md = memory_desc_wrapper(dst_md_); - ok &= post_ops_ok({avx512_core, {binary, eltwise, sum}, - attr()->post_ops_, &dst_md, sum_at_pos_0_only, - sum_requires_scale_one, sum_requires_zp_zero}); - } + dst_data_type) + && post_ops_ok(); if (!ok) return status::unimplemented; auto scratchpad = scratchpad_registry().registrar(); @@ -87,6 +80,29 @@ struct gemm_bf16_convolution_fwd_t : public primitive_t { } conv_gemm_conf_t jcp_; + + protected: + virtual bool post_ops_ok() const { + auto const &po = this->attr()->post_ops_; + auto all_post_ops_supported = [&]() { + bool ok = true; + + for (int i = 0; i < po.len(); i++) { + ok = ok && utils::one_of(po.entry_[i].kind, primitive_kind::sum, primitive_kind::eltwise, primitive_kind::depthwise); + } + return ok; + }; + + auto contain = [&](dnnl::impl::primitive_kind_t kind) { return po.find(kind) != -1; }; + auto position = [&](dnnl::impl::primitive_kind_t kind) { return po.find(kind); }; + auto count = [&](dnnl::impl::primitive_kind_t kind) { return po.count(kind); }; + + return all_post_ops_supported() && + count(primitive_kind::sum) <= 1 && + IMPLICATION(contain(primitive_kind::sum), position(primitive_kind::sum) == 0); + + return false; + } }; gemm_bf16_convolution_fwd_t(const pd_t *apd) @@ -124,7 +140,7 @@ struct gemm_bf16_convolution_fwd_t : public primitive_t { const src_data_t *src_base, const wei_data_t *wei_base, const float *bia_base, dst_data_t *dst_base, const memory_tracking::grantor_t &scratchpad, - const void *post_ops_binary_rhs_arg_vec) const; + const void *post_ops_binary_rhs_arg_vec, int MB) const; const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } @@ -133,12 +149,19 @@ struct gemm_bf16_convolution_fwd_t : public primitive_t { DECLARE_CPU_JIT_AUX_FUNCTIONS(gemm_bf16_convolution_fwd_t::pp_kernel); pp_ker_t(const pd_t *pd); + ~pp_ker_t() { + for (auto inj : jit_eltwise_injectors_) + delete inj; + jit_eltwise_injectors_.clear(); + } + void operator()(dst_data_t *dst, const acc_data_t *acc, - const acc_data_t *bias, float sum_scale, size_t oc_work, + const acc_data_t *bias, float sum_scale, size_t oc_work, size_t g_offset, const void *post_ops_binary_rhs_arg_vec, const void *dst_orig, const size_t g_oc_offset); void operator()(dst_data_t *dst, const acc_data_t *acc, - const acc_data_t *bias, float sum_scale, size_t dst_str, + const acc_data_t *bias, + size_t g_offset, size_t start_oc, float sum_scale, size_t dst_str, size_t acc_str, size_t sp_len, size_t oc, const void *post_ops_binary_rhs_arg_vec, const void *dst_orig, const size_t g_oc_offset); @@ -153,6 +176,7 @@ struct gemm_bf16_convolution_fwd_t : public primitive_t { size_t acc_stride_in_bytes; size_t spatial_length; size_t oc_work; + size_t oc_offset; size_t g_oc_offset; const void *post_ops_binary_rhs_arg_vec; @@ -177,10 +201,16 @@ struct gemm_bf16_convolution_fwd_t : public primitive_t { Xbyak::Reg64 reg_dst_str = r13; Xbyak::Reg64 reg_acc_str = r14; + using Vmm = typename cpu_isa_traits::Vmm; + Xbyak::Reg64 reg_oc_offset = r10; + Xbyak::Reg64 reg_dw = r9; + Xbyak::Reg64 reg_post_ops_data = reg_bias; + Xbyak::Opmask kmask = k7; + Xbyak::Reg64 reserved_eltwise_gpr = r10; Xbyak::Opmask reserved_eltwise_maskr = k2; - Xbyak::Zmm vreg_sum_scale, vreg_bias; + Xbyak::Zmm vreg_sum_scale, vreg_bias, vreg_dw; Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(27); Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(28); @@ -194,14 +224,15 @@ struct gemm_bf16_convolution_fwd_t : public primitive_t { constexpr static int stack_space_needed = reg64_size; const conv_gemm_conf_t &jcp_; + post_ops_t post_ops_; const bool do_sum_; int max_data_reg_idx_, max_unroll_, compute_reg_step_; int data_reg_base_idx_; size_t vlen_; cpu_isa_t isa_; std::unique_ptr bf16_emu_; - std::unique_ptr> - postops_injector_; + const primitive_attr_t* attr_; + nstl::vector*> jit_eltwise_injectors_; void apply_postops(const bool apply_mask, const size_t out_offset, const int vmm_idx); @@ -253,7 +284,7 @@ struct gemm_bf16_convolution_bwd_data_t : public primitive_t { && set_default_alg_kind(alg_kind::convolution_direct) && expect_data_types(diff_src_data_type, data_type::bf16, data_type::undef, data_type::bf16, data_type::f32) - && !has_zero_dim_memory() && attr()->has_default_values(); + && !has_zero_dim_memory() && is_supported_post_ops(); if (!ok) return status::unimplemented; auto scratchpad = scratchpad_registry().registrar(); @@ -263,9 +294,42 @@ struct gemm_bf16_convolution_bwd_data_t : public primitive_t { } conv_gemm_conf_t jcp_; + + + protected: + virtual bool is_supported_post_ops() const { + const auto &p = this->attr()->post_ops_; + if (p.len() > 1) + return false; + + auto all_post_ops_supported = [&]() { + bool ok = true; + + for (int i = 0; i < p.len(); i++) { + ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::depthwise); + } + return ok; + }; + + return all_post_ops_supported(); + } }; - gemm_bf16_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {} + gemm_bf16_convolution_bwd_data_t(const pd_t* apd) : primitive_t(apd) { + const auto& post_ops = pd()->attr()->post_ops_; + for (int i = 0; i < post_ops.len(); i++) { + auto& post_op = post_ops.entry_[i]; + if (post_op.is_depthwise()) { + depthwise_injectors.push_back(new ref_depthwise_scalar_fwd_t(post_op.depthwise.alg)); + } + } + } + + ~gemm_bf16_convolution_bwd_data_t() { + for (auto inj : depthwise_injectors) + delete inj; + depthwise_injectors.clear(); + } typedef typename prec_traits::type diff_dst_data_t; typedef typename prec_traits::type acc_data_t; @@ -284,9 +348,12 @@ struct gemm_bf16_convolution_bwd_data_t : public primitive_t { status_t execute_backward_data_thr_nspc(const int ithr, const int nthr, diff_src_data_t *diff_src_base, const wei_data_t *wei_base, const diff_dst_data_t *diff_dst_base, - const memory_tracking::grantor_t &scratchpad) const; + const memory_tracking::grantor_t &scratchpad, int MB, + const std::vector& post_ops_binary_rhs_arg_vec) const; const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + + nstl::vector depthwise_injectors; }; template diff --git a/src/cpu/x64/gemm_bf16_inner_product.cpp b/src/cpu/x64/gemm_bf16_inner_product.cpp index 39d15b6bb31..900b1c2adff 100644 --- a/src/cpu/x64/gemm_bf16_inner_product.cpp +++ b/src/cpu/x64/gemm_bf16_inner_product.cpp @@ -50,8 +50,10 @@ status_t gemm_bf16_inner_product_fwd_t::execute_forward( = binary_injector_utils::prepare_binary_args( this->pd()->attr()->post_ops_, ctx); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const dim_t M = pd()->OC(); - const dim_t N = pd()->MB(); + const dim_t N = MB; const dim_t K = pd()->IC_total_padded(); const auto &wmd = *pd()->weights_md(); @@ -98,8 +100,10 @@ gemm_bf16_inner_product_bwd_data_t::execute_backward_data( auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); auto diff_src = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC); + auto MB = CTX_IN_BATCH(DNNL_ARG_DIFF_DST); + const dim_t M = pd()->IC_total_padded(); - const dim_t N = pd()->MB(); + const dim_t N = MB; const dim_t K = pd()->OC(); const auto &wmd = *pd()->weights_md(); diff --git a/src/cpu/x64/injectors/jit_uni_binary_injector.cpp b/src/cpu/x64/injectors/jit_uni_binary_injector.cpp index 2b8de5fcc23..1585097a450 100644 --- a/src/cpu/x64/injectors/jit_uni_binary_injector.cpp +++ b/src/cpu/x64/injectors/jit_uni_binary_injector.cpp @@ -224,12 +224,14 @@ rhs_arg_static_params_t::rhs_arg_static_params_t( const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers, bool preserve_vmm_helper, std::size_t abi_param_offset, const memory_desc_wrapper &dst_d, std::size_t tail_size, - const Xbyak::Opmask &tail_opmask, bool use_exact_tail_scalar_bcast) + const Xbyak::Opmask &tail_opmask, bool use_exact_tail_scalar_bcast, std::size_t rhs_prelu_helper_vmm_idx) : rhs_arg_static_params_t(rhs_dt_helper_vmm_idx, rhs_addr_reg, rhs_helper_reg, preserve_gpr_helpers, preserve_vmm_helper, abi_param_offset, 0, dst_d, tail_size, tail_opmask, use_exact_tail_scalar_bcast, rhs_helper_reg, true /*is_opmask_set*/, - false /*is_dst_orig_set*/) {} + false /*is_dst_orig_set*/) { + this->rhs_prelu_helper_vmm_idx = rhs_prelu_helper_vmm_idx; +} rhs_arg_static_params_t::rhs_arg_static_params_t( std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg, @@ -237,12 +239,14 @@ rhs_arg_static_params_t::rhs_arg_static_params_t( bool preserve_vmm_helper, std::size_t abi_param_offset, std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d, std::size_t tail_size, const Xbyak::Opmask &tail_opmask, - bool use_exact_tail_scalar_bcast) + bool use_exact_tail_scalar_bcast, std::size_t rhs_prelu_helper_vmm_idx) : rhs_arg_static_params_t(rhs_dt_helper_vmm_idx, rhs_addr_reg, rhs_helper_reg, preserve_gpr_helpers, preserve_vmm_helper, abi_param_offset, dst_orig_offset, dst_d, tail_size, tail_opmask, use_exact_tail_scalar_bcast, rhs_helper_reg, true /*is_opmask_set*/, - true /*is_dst_orig_set*/) {} + true /*is_dst_orig_set*/) { + this->rhs_prelu_helper_vmm_idx = rhs_prelu_helper_vmm_idx; +} rhs_arg_static_params_t::rhs_arg_static_params_t( std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg, @@ -250,12 +254,14 @@ rhs_arg_static_params_t::rhs_arg_static_params_t( bool preserve_vmm_helper, std::size_t abi_param_offset, const memory_desc_wrapper &dst_d, std::size_t tail_size, const Xbyak::Opmask &tail_opmask, const Xbyak::Reg64 ®_tail_size, - bool use_exact_tail_scalar_bcast) + bool use_exact_tail_scalar_bcast, std::size_t rhs_prelu_helper_vmm_idx) : rhs_arg_static_params_t(rhs_dt_helper_vmm_idx, rhs_addr_reg, rhs_helper_reg, preserve_gpr_helpers, preserve_vmm_helper, abi_param_offset, 0, dst_d, tail_size, tail_opmask, use_exact_tail_scalar_bcast, reg_tail_size, true /*is_opmask_set*/, - false /*is_dst_orig_set*/) {} + false /*is_dst_orig_set*/) { + this->rhs_prelu_helper_vmm_idx = rhs_prelu_helper_vmm_idx; +} rhs_arg_static_params_t::rhs_arg_static_params_t( std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg, @@ -263,12 +269,14 @@ rhs_arg_static_params_t::rhs_arg_static_params_t( bool preserve_vmm_helper, std::size_t abi_param_offset, std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d, std::size_t tail_size, const Xbyak::Opmask &tail_opmask, - const Xbyak::Reg64 ®_tail_size, bool use_exact_tail_scalar_bcast) + const Xbyak::Reg64 ®_tail_size, bool use_exact_tail_scalar_bcast, std::size_t rhs_prelu_helper_vmm_idx) : rhs_arg_static_params_t(rhs_dt_helper_vmm_idx, rhs_addr_reg, rhs_helper_reg, preserve_gpr_helpers, preserve_vmm_helper, abi_param_offset, dst_orig_offset, dst_d, tail_size, tail_opmask, use_exact_tail_scalar_bcast, reg_tail_size, true /*is_opmask_set*/, - true /*is_dst_orig_set*/) {} + true /*is_dst_orig_set*/) { + this->rhs_prelu_helper_vmm_idx = rhs_prelu_helper_vmm_idx; +} rhs_arg_static_params_t::rhs_arg_static_params_t( std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg, @@ -1335,7 +1343,7 @@ void jit_uni_binary_injector_t::inject_binary( = rhs_arg_data_type != data_type::f32 || (scalar_f32 && !is_avx512_) || with_tail_not_fusable_to_binary_op || !binary_op_with_unaligned_mem_operand_allowed_ - || (cmp_op && !is_avx512_); + || ((cmp_op || alg == alg_kind::binary_prelu) && !is_avx512_); if (process_rhs_arg_using_tmp_vmm) { @@ -2110,6 +2118,23 @@ jit_uni_binary_injector_t::execute_cmp_binary(const Vmm &dst, pop_opmask(host_, cmp_mask); } +template +template +typename std::enable_if::value + || std::is_same::value>::type +jit_uni_binary_injector_t::execute_prelu_binary(const Vmm &dst, const Vmm &lhs, const T &rhs) const { + const auto &cmp_mask = rhs_arg_static_params_.tail_opmask; + const Xbyak::Zmm zmm_aux0 + = Xbyak::Zmm(rhs_arg_static_params_.rhs_prelu_helper_vmm_idx); + + push_opmask(host_, cmp_mask); + host_->uni_vpxor(zmm_aux0, zmm_aux0, zmm_aux0); + host_->vcmpps(cmp_mask, lhs, zmm_aux0, jit_generator::_cmp_lt_os); + host_->uni_vmulps(dst | cmp_mask, lhs, rhs); + pop_opmask(host_, cmp_mask); +} + + // SSE4.1., AVX and AVX2 implementation template template @@ -2129,6 +2154,23 @@ jit_uni_binary_injector_t::execute_cmp_binary(const Vmm &dst, host_->uni_vminps(dst, dst, vreg_one); } +// todo: [antonvor] check sse41 path +template +template +typename std::enable_if::value + || std::is_same::value)>::type +jit_uni_binary_injector_t::execute_prelu_binary(const Vmm &dst, + const Vmm &lhs, const T &rhs) const { + const Vmm vmm_aux0 = Vmm(rhs_arg_static_params_.rhs_prelu_helper_vmm_idx); + + push_vmm(host_, vmm_aux0); + host_->uni_vmulps(rhs, rhs, lhs); + host_->vpxor(vmm_aux0, vmm_aux0, vmm_aux0); + host_->vcmpltps(vmm_aux0, lhs, vmm_aux0); + host_->uni_vblendvps(dst, lhs, rhs, vmm_aux0); + pop_vmm(host_, vmm_aux0); +} + template template void jit_uni_binary_injector_t::execute_binary(alg_kind_t binary_alg, @@ -2158,6 +2200,9 @@ void jit_uni_binary_injector_t::execute_binary(alg_kind_t binary_alg, case alg_kind::binary_ne: execute_cmp_binary(dst, lhs, rhs, jit_generator::_cmp_neq_uq); break; + case alg_kind::binary_prelu: + execute_prelu_binary(dst, lhs, rhs); + break; default: assert(!"unsupported algorithm"); } } diff --git a/src/cpu/x64/injectors/jit_uni_binary_injector.hpp b/src/cpu/x64/injectors/jit_uni_binary_injector.hpp index 06b0c9dd8c4..d089c4d678f 100644 --- a/src/cpu/x64/injectors/jit_uni_binary_injector.hpp +++ b/src/cpu/x64/injectors/jit_uni_binary_injector.hpp @@ -119,21 +119,21 @@ struct rhs_arg_static_params_t { const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers, bool preserve_vmm_helper, std::size_t abi_param_offset, const memory_desc_wrapper &dst_d, std::size_t tail_size, - const Xbyak::Opmask &tail_opmask, bool use_exact_tail_scalar_bcast); + const Xbyak::Opmask &tail_opmask, bool use_exact_tail_scalar_bcast, std::size_t rhs_prelu_helper_vmm_idx = 0); rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg, const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers, bool preserve_vmm_helper, std::size_t abi_param_offset, std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d, std::size_t tail_size, const Xbyak::Opmask &tail_opmask, - bool use_exact_tail_scalar_bcast); + bool use_exact_tail_scalar_bcast, std::size_t rhs_prelu_helper_vmm_idx = 0); rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg, const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers, bool preserve_vmm_helper, std::size_t abi_param_offset, const memory_desc_wrapper &dst_d, std::size_t tail_size, const Xbyak::Opmask &tail_opmask, const Xbyak::Reg64 ®_tail_size, - bool use_exact_tail_scalar_bcast); + bool use_exact_tail_scalar_bcast, std::size_t rhs_prelu_helper_vmm_idx = 0); rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg, const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers, @@ -141,7 +141,7 @@ struct rhs_arg_static_params_t { std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d, std::size_t tail_size, const Xbyak::Opmask &tail_opmask, const Xbyak::Reg64 ®_tail_size, - bool use_exact_tail_scalar_bcast); + bool use_exact_tail_scalar_bcast, std::size_t rhs_prelu_helper_vmm_idx = 0); bool is_opmask_set() const noexcept { return is_opmask_set_; } bool is_dst_orig_set() const noexcept { return is_dst_orig_set_; } @@ -160,6 +160,8 @@ struct rhs_arg_static_params_t { Xbyak::Reg64 reg_tail_size; bool is_tail; + mutable std::size_t rhs_prelu_helper_vmm_idx; + private: rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg, @@ -444,11 +446,19 @@ class jit_uni_binary_injector_t { execute_cmp_binary(const Vmm &dst, const Vmm &lhs, const T &rhs, const unsigned int cmp_predicate) const; template + typename std::enable_if::value + || std::is_same::value>::type + execute_prelu_binary(const Vmm &dst, const Vmm &lhs, const T &rhs) const; + template typename std::enable_if::value || std::is_same::value)>::type execute_cmp_binary(const Vmm &dst, const Vmm &lhs, const T &rhs, const unsigned int cmp_predicate) const; template + typename std::enable_if::value + || std::is_same::value)>::type + execute_prelu_binary(const Vmm &dst, const Vmm &lhs, const T &rhs) const; + template void execute_binary(alg_kind_t binary_alg, const Vmm &dst, const Vmm &lhs, const T &rhs) const; /* diff --git a/src/cpu/x64/injectors/jit_uni_depthwise_injector.cpp b/src/cpu/x64/injectors/jit_uni_depthwise_injector.cpp new file mode 100644 index 00000000000..bb0e34ee147 --- /dev/null +++ b/src/cpu/x64/injectors/jit_uni_depthwise_injector.cpp @@ -0,0 +1,268 @@ +/******************************************************************************* +* Copyright 2019-2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "common/c_types_map.hpp" +#include "common/dnnl_thread.hpp" +#include "common/nstl.hpp" +#include "common/utils.hpp" +#include "cpu/x64/injectors/injector_utils.hpp" + +#include "cpu/x64/injectors/jit_uni_depthwise_injector.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +template +int jit_uni_depthwise_injector_f32::aux_vecs_count(alg_kind_t depthwise_alg, bool is_broadcast) { + switch (depthwise_alg) { + case alg_kind::depthwise_scale_shift: return isa == sse41 || is_broadcast ? 1 : 0; + case alg_kind::depthwise_prelu: return 2; + default: assert(!"unsupported depthwise algorithm"); + } + + return 0; +} + +template +void jit_uni_depthwise_injector_f32::injector_preamble(size_t start_idx, size_t end_idx, bool is_broadcast) { + preserved_vecs_count = 0; + vecs_to_preserve = (size_t)jit_uni_depthwise_injector_f32::aux_vecs_count(depthwise_alg, is_broadcast); + + for (size_t i = 0; i < vecs_count; i++) { + if (preserved_vecs_count >= vecs_to_preserve) + break; + + if (i < start_idx || i >= end_idx) { + preserved_vec_idxs[preserved_vecs_count] = i; + preserved_vecs_count++; + } + } + + start_idx_tail = start_idx; + size_t preserved_vecs_count_tail = vecs_to_preserve - preserved_vecs_count; + for (size_t i = 0; i < preserved_vecs_count_tail; i++) { + preserved_vec_idxs[preserved_vecs_count] = start_idx + i; + preserved_vecs_count++; + start_idx_tail = start_idx + i + 1; + } + + h->sub(h->rsp, preserved_vecs_count * vlen); + for (size_t i = 0; i < preserved_vecs_count; ++i) + h->uni_vmovups(h->ptr[h->rsp + i * vlen], Vmm(preserved_vec_idxs[i])); + + assign_regs(); +} + +template +void jit_uni_depthwise_injector_f32::injector_preamble_tail(size_t start_idx, size_t end_idx) { + size_t tail_vecs_to_preserve = start_idx_tail - start_idx; + int idx_off = (vecs_to_preserve - tail_vecs_to_preserve); + + if (tail_vecs_to_preserve > 0) { + h->add(h->rsp, idx_off * vlen); + for (size_t i = 0; i < tail_vecs_to_preserve; ++i) + h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]), h->ptr[h->rsp + i * vlen]); + + for (size_t i = 0; i < tail_vecs_to_preserve; ++i) { + preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve; + } + + for (size_t i = 0; i < tail_vecs_to_preserve; ++i) + h->uni_vmovups(h->ptr[h->rsp + i * vlen], Vmm(preserved_vec_idxs[idx_off + i])); + h->sub(h->rsp, idx_off * vlen); + + assign_regs(); + } +} + +template +void jit_uni_depthwise_injector_f32::injector_postamble() { + for (size_t i = 0; i < preserved_vecs_count; ++i) + h->uni_vmovups(Vmm(preserved_vec_idxs[i]), h->ptr[h->rsp + i * vlen]); + h->add(h->rsp, preserved_vecs_count * vlen); +} + +template +void jit_uni_depthwise_injector_f32::assign_regs() { + vmm_mask = Vmm(preserved_vec_idxs[0]); + vmm_aux0 = Vmm(preserved_vec_idxs[1]); +} + +template +void jit_uni_depthwise_injector_f32::scale_shift_compute_vector(const Vmm &vmm_src, + const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias, bool is_broadcast, int offset) { + size_t weights_off = post_op_.depthwise.offset[post_op_.depthwise.scales] * sizeof(float); + size_t bias_off = post_op_.depthwise.offset[post_op_.depthwise.shifts] * sizeof(float); + + if (isa == sse41) { + if (is_broadcast) + h->uni_vbroadcastss(vmm_mask, h->ptr[p_weights + weights_off]); + else + h->movups(vmm_mask, h->ptr[p_weights + offset + weights_off]); + h->mulps(vmm_src, vmm_mask); + if (is_broadcast) + h->uni_vbroadcastss(vmm_mask, h->ptr[p_bias + bias_off]); + else + h->movups(vmm_mask, h->ptr[p_bias + offset + bias_off]); + h->addps(vmm_src, vmm_mask); + } else { + if (is_broadcast) { + h->uni_vbroadcastss(vmm_mask, h->ptr[p_weights + weights_off]); + h->uni_vmulps(vmm_src, vmm_src, vmm_mask); + h->uni_vbroadcastss(vmm_mask, h->ptr[p_bias + bias_off]); + h->uni_vaddps(vmm_src, vmm_src, vmm_mask); + } else { + h->uni_vmulps(vmm_src, vmm_src, h->ptr[p_weights + offset + weights_off]); + h->uni_vaddps(vmm_src, vmm_src, h->ptr[p_bias + offset + bias_off]); + } + }; +} + +template +void jit_uni_depthwise_injector_f32::prelu_compute_vector(const Vmm &vmm_src, + const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias, bool is_broadcast, int offset) { + const unsigned char _cmp_gt_os = 6; + const unsigned char _cmp_lt_os = 1; + size_t weights_off = post_op_.depthwise.offset[post_op_.depthwise.scales] * sizeof(float); + + if (isa == sse41) { + h->pxor(vmm_mask, vmm_mask); + h->cmpps(vmm_mask, vmm_src, _cmp_gt_os); + if (is_broadcast) + h->uni_vbroadcastss(vmm_aux0, h->ptr[p_weights + weights_off]); + else + h->movups(vmm_aux0, h->ptr[p_weights + offset + weights_off]); + h->mulps(vmm_aux0, vmm_src); + h->blendvps(vmm_src, vmm_aux0); + } else if (isa == avx2) { + if (is_broadcast) { + h->uni_vbroadcastss(vmm_mask, h->ptr[p_weights + weights_off]); + h->vmulps(vmm_aux0, vmm_src, vmm_mask); + } else + h->vmulps(vmm_aux0, vmm_src, h->ptr[p_weights + offset + weights_off]); + h->vxorps(vmm_mask, vmm_mask, vmm_mask); + h->vcmpgtps(vmm_mask, vmm_src, vmm_mask); + h->vblendvps(vmm_src, vmm_aux0, vmm_src, vmm_mask); + } else if (isa == avx512_common || isa == avx512_core) { + h->vxorpd(vmm_mask, vmm_mask, vmm_mask); + h->vmovups(vmm_aux0, vmm_src); + h->vcmpps(k_mask, vmm_src, vmm_mask, _cmp_lt_os); + if (is_broadcast) { + h->uni_vbroadcastss(vmm_mask, h->ptr[p_weights + weights_off]); + h->vmulps(vmm_src | k_mask, vmm_aux0, vmm_mask); + } else + h->vmulps(vmm_src | k_mask, vmm_aux0, h->ptr[p_weights + offset + weights_off]); + } +} + +template +void jit_uni_depthwise_injector_f32::compute_body(size_t start_idx, size_t end_idx, + const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias, bool is_broadcast) { + for (size_t idx = start_idx; idx < end_idx; idx++) { + switch (depthwise_alg) { + case alg_kind::depthwise_scale_shift: + scale_shift_compute_vector(Vmm(idx), p_weights, p_bias, is_broadcast); break; + case alg_kind::depthwise_prelu: + prelu_compute_vector(Vmm(idx), p_weights, p_bias, is_broadcast); break; + default: assert(!"unsupported depthwise algorithm"); + } + } +} + +template +void jit_uni_depthwise_injector_f32::compute_vector_range(int start_idx, int end_idx, + const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias, bool is_broadcast) { + injector_preamble(start_idx, end_idx, is_broadcast); + compute_body(start_idx_tail, end_idx, p_weights, p_bias, is_broadcast); + injector_preamble_tail(start_idx, end_idx); + compute_body(start_idx, start_idx_tail, p_weights, p_bias, is_broadcast); + injector_postamble(); +} + +template +void jit_uni_depthwise_injector_f32::init_ptrs(const Xbyak::RegExp& ptr_data, + const Xbyak::Reg64& reg_d_weights, const Xbyak::Reg64& reg_d_bias, + const Xbyak::Operand& ch_off, bool is_broadcast) { + h->mov(reg_d_weights, h->ptr[ptr_data]); + if (post_op_.depthwise.alg == alg_kind::depthwise_scale_shift) + h->mov(reg_d_bias, h->ptr[ptr_data]); + + if (!is_broadcast) { + h->add(reg_d_weights, ch_off); + if (post_op_.depthwise.alg == alg_kind::depthwise_scale_shift) + h->add(reg_d_bias, ch_off); + } +} + +template +static void push_vmm(jit_generator *host, const Vmm &vmm) { + host->sub(host->rsp, injector_utils::vmm_size_t::bytes); + host->uni_vmovups(host->ptr[host->rsp], vmm); +} + +template +static void pop_vmm(jit_generator *host, const Vmm &vmm) { + host->uni_vmovups(vmm, host->ptr[host->rsp]); + host->add(host->rsp, injector_utils::vmm_size_t::bytes); +} + +template +void jit_uni_depthwise_injector_f32::compute(int start_idx, int end_idx, + int vmm_d_weights_idx, int vmm_d_bias_idx, + const Xbyak::Reg64& reg_d_weights, const Xbyak::Reg64& reg_d_bias, + bool is_broadcast, int offset, bool need_to_preserve) { + vmm_mask = Vmm(vmm_d_weights_idx); + vmm_aux0 = Vmm(vmm_d_bias_idx); + + if (need_to_preserve) { + preserved_vecs_count = aux_vecs_count(depthwise_alg, is_broadcast); + if (preserved_vecs_count > 0) + push_vmm(h, vmm_mask); + if (preserved_vecs_count > 1) + push_vmm(h, vmm_aux0); + } + + for (int idx = start_idx; idx < end_idx; idx++) { + switch (depthwise_alg) { + case alg_kind::depthwise_scale_shift: + scale_shift_compute_vector(Vmm(idx), reg_d_weights, reg_d_bias, is_broadcast, offset); break; + case alg_kind::depthwise_prelu: + prelu_compute_vector(Vmm(idx), reg_d_weights, reg_d_bias, is_broadcast, offset); break; + default: assert(!"unsupported depthwise algorithm"); + } + } + + if (need_to_preserve) { + if (preserved_vecs_count > 1) + pop_vmm(h, vmm_aux0); + if (preserved_vecs_count > 1) + pop_vmm(h, vmm_mask); + } +} + +template struct jit_uni_depthwise_injector_f32; +template struct jit_uni_depthwise_injector_f32; +template struct jit_uni_depthwise_injector_f32; +template struct jit_uni_depthwise_injector_f32; +template struct jit_uni_depthwise_injector_f32; +template struct jit_uni_depthwise_injector_f32; + +} // namespace x64 +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/x64/injectors/jit_uni_depthwise_injector.hpp b/src/cpu/x64/injectors/jit_uni_depthwise_injector.hpp new file mode 100644 index 00000000000..4dece960a79 --- /dev/null +++ b/src/cpu/x64/injectors/jit_uni_depthwise_injector.hpp @@ -0,0 +1,142 @@ +/******************************************************************************* +* Copyright 2019-2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_X64_JIT_UNI_DEPTHWISE_INJECTOR_HPP +#define CPU_X64_JIT_UNI_DEPTHWISE_INJECTOR_HPP + +#include + +#include "../../../common/c_types_map.hpp" +#include "../../../common/primitive_attr.hpp" +#include "../../../common/type_helpers.hpp" +#include "../../../common/utils.hpp" + +#include "../jit_generator.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +namespace depthwise_injector { + +struct static_params_t { + static_params_t(int vmm_d_weights_idx = 0, int vmm_d_bias_idx = 0, + Xbyak::Reg64 reg_d_weights = Xbyak::Reg64(0), Xbyak::Reg64 reg_d_bias = Xbyak::Reg64(0)) : + vmm_d_weights_idx(vmm_d_weights_idx), vmm_d_bias_idx(vmm_d_bias_idx), reg_d_weights(reg_d_weights), reg_d_bias(reg_d_bias) {} + + int vmm_d_weights_idx; + int vmm_d_bias_idx; + Xbyak::Reg64 reg_d_weights; + Xbyak::Reg64 reg_d_bias; +}; + +struct dynamic_params_t { + dynamic_params_t(int vmm_d_weights_idx = 0, int vmm_d_bias_idx = 0, + Xbyak::Reg64 reg_d_weights = Xbyak::Reg64(0), Xbyak::Reg64 reg_d_bias = Xbyak::Reg64(0), + Xbyak::Reg64 reg_init_off = Xbyak::Reg64(0), const std::map vmm_idx_off = {}, + Xbyak::Reg64 reg_post_ops_data = Xbyak::Reg64(0), int base_post_ops_data_offset = 0) : + vmm_d_weights_idx(vmm_d_weights_idx), vmm_d_bias_idx(vmm_d_bias_idx), reg_d_weights(reg_d_weights), reg_d_bias(reg_d_bias), + reg_init_off(reg_init_off), reg_init_off_addr(0), vmm_idx_off(vmm_idx_off), useAddr(false), + reg_post_ops_data(reg_post_ops_data), base_post_ops_data_offset(base_post_ops_data_offset) {} + + dynamic_params_t(int vmm_d_weights_idx, int vmm_d_bias_idx, + Xbyak::Reg64 reg_d_weights, Xbyak::Reg64 reg_d_bias, + Xbyak::Address reg_init_off, const std::map vmm_idx_off, + Xbyak::Reg64 reg_post_ops_data = Xbyak::Reg64(0), int base_post_ops_data_offset = 0) : + vmm_d_weights_idx(vmm_d_weights_idx), vmm_d_bias_idx(vmm_d_bias_idx), reg_d_weights(reg_d_weights), reg_d_bias(reg_d_bias), + reg_init_off(0), reg_init_off_addr(reg_init_off), vmm_idx_off(vmm_idx_off), useAddr(true), + reg_post_ops_data(reg_post_ops_data), base_post_ops_data_offset(base_post_ops_data_offset) {} + + int vmm_d_weights_idx; + int vmm_d_bias_idx; + Xbyak::Reg64 reg_d_weights; + Xbyak::Reg64 reg_d_bias; + Xbyak::Reg64 reg_init_off; + Xbyak::Address reg_init_off_addr; + std::map vmm_idx_off; + bool useAddr; + Xbyak::Reg64 reg_post_ops_data; + int base_post_ops_data_offset; +}; + +} // quantization_injector + +template +struct jit_uni_depthwise_injector_f32 { + using Vmm = typename utils::conditional3::type; + + jit_uni_depthwise_injector_f32(jit_generator* host, dnnl_post_ops::entry_t post_op, Xbyak::Opmask k_mask_ = Xbyak::Opmask(1)) + : h(host), post_op_(post_op), k_mask(k_mask_) { + depthwise_alg = post_op.depthwise.alg; + assert(utils::one_of(depthwise_alg, alg_kind::depthwise_scale_shift, alg_kind::depthwise_prelu)); + } + + void compute_vector_range(int start_idx, int end_idx, const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias, bool is_broadcast = false); + + void init_ptrs(const Xbyak::RegExp& ptr_data, + const Xbyak::Reg64& reg_d_weights, const Xbyak::Reg64& reg_d_bias, + const Xbyak::Operand& ch_off, bool is_broadcast); + + void compute(int start_idx, int end_idx, + int vmm_d_weights_idx, int vmm_d_bias_idx, + const Xbyak::Reg64& reg_d_weights, const Xbyak::Reg64& reg_d_bias, + bool is_broadcast = false, int offset = 0, bool need_to_preserve = false); + + static constexpr size_t memoryStep() { + return sizeof(float*); + } + +private: + jit_generator* h; + + size_t vlen = cpu_isa_traits::vlen; + + alg_kind_t depthwise_alg; + + mutable Vmm vmm_mask; + mutable Vmm vmm_aux0; + + dnnl_post_ops::entry_t post_op_; + + Xbyak::Opmask k_mask; + + const static size_t preserved_vecs_max = 5; + size_t vecs_to_preserve = 0; + size_t vecs_count = isa == avx512_common ? 32 : 16; + size_t preserved_vecs_count = 0; + size_t preserved_vec_idxs[preserved_vecs_max] = {0}; + size_t start_idx_tail = 0; + + int aux_vecs_count(alg_kind_t elt_alg, bool is_broadcast); + + void compute_body(size_t start_idx, size_t end_idx, const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias, bool is_broadcast = false); + void injector_preamble(size_t start_idx, size_t end_idx, bool is_broadcast = false); + void injector_preamble_tail(size_t start_idx, size_t end_idx); + void injector_postamble(); + void assign_regs(); + + void scale_shift_compute_vector(const Vmm &vmm_src, const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias, bool is_broadcast = false, int offset = 0); + void prelu_compute_vector(const Vmm &vmm_src, const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias, bool is_broadcast = false, int offset = 0); +}; + +} // namespace x64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/cpu/x64/injectors/jit_uni_eltwise_injector.cpp b/src/cpu/x64/injectors/jit_uni_eltwise_injector.cpp index 85d3c851a68..2b41fbd82c6 100644 --- a/src/cpu/x64/injectors/jit_uni_eltwise_injector.cpp +++ b/src/cpu/x64/injectors/jit_uni_eltwise_injector.cpp @@ -40,6 +40,7 @@ bool is_alg_supported(alg_kind_t alg) { eltwise_logsigmoid, eltwise_mish, eltwise_exp, eltwise_gelu_tanh, eltwise_hardswish, eltwise_swish, eltwise_log, eltwise_clip, eltwise_clip_v2, eltwise_pow, eltwise_gelu_erf, eltwise_round, + eltwise_hsigmoid, eltwise_round_half_away_from_zero, eltwise_round_half_to_even, eltwise_relu_use_dst_for_bwd, eltwise_tanh_use_dst_for_bwd, eltwise_elu_use_dst_for_bwd, eltwise_sqrt_use_dst_for_bwd, eltwise_logistic_use_dst_for_bwd, eltwise_exp_use_dst_for_bwd, @@ -769,9 +770,9 @@ void jit_uni_eltwise_injector_f32::soft_relu_compute_vector_fwd( h->uni_vaddps(vmm_src, vmm_src, vmm_aux1); h->uni_vaddps(vmm_src, vmm_src, vmm_aux0); - // get vmm_mask = src > max logf - // y = (x < max log f) ? soft_relu(x) : x - compute_cmp_mask(vmm_aux2, table_val(exp_ln_flt_max_f), _cmp_gt_os); + // get vmm_mask = src > 20.f + // y = (x < 20.f) ? soft_relu(x) : x + compute_cmp_mask(vmm_aux2, table_val(soft_relu_twenty), _cmp_gt_os); blend_with_mask(vmm_src, vmm_aux2); } @@ -1545,6 +1546,49 @@ void jit_uni_eltwise_injector_f32::round_compute_vector_fwd( h->uni_vroundps(vmm_src, vmm_src, _op_mxcsr); } +template +void jit_uni_eltwise_injector_f32::hsigmoid_compute_vector_fwd( + const Vmm &vmm_src) { + // x + 3 + h->uni_vaddps(vmm_src, vmm_src, table_val(hsigmoid, 0)); + // relu6(x + 3) + h->uni_vmaxps(vmm_src, vmm_src, table_val(zero)); + h->uni_vminps(vmm_src, vmm_src, table_val(hsigmoid, 1)); + // relu6(x + 3) / 6 + h->uni_vmulps(vmm_src, vmm_src, table_val(hsigmoid, 2)); +} + +template +void jit_uni_eltwise_injector_f32::round_half_to_even_compute_vector_fwd( + const Vmm &vmm_src) { + h->uni_vroundps(vmm_src, vmm_src, _op_near); +} + +template +void jit_uni_eltwise_injector_f32::round_half_away_from_zero_compute_vector_fwd( + const Vmm &vmm_src) { + // create a mask of negative numbers for later returning sign + compute_cmp_mask(vmm_src, table_val(zero), _cmp_lt_os); + + // round half away from zero for positive numbers + h->uni_vandps(vmm_src, vmm_src, table_val(positive_mask)); + h->uni_vaddps(vmm_src, vmm_src, table_val(half)); + h->uni_vroundps(vmm_src, vmm_src, _op_floor); + + // return a sign for negative numbers using the mask + if (isa == sse41) { + h->movups(vmm_aux1, vmm_src); + h->mulps(vmm_aux1, table_val(minus_one)); + h->blendvps(vmm_src, vmm_aux1); + } else if (isa == avx2) { + h->vmulps(vmm_aux1, vmm_src, table_val(minus_one)); + h->vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask); + } else if (isa == avx512_common) { + h->vmulps(vmm_aux1, vmm_src, table_val(minus_one)); + h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1); + } +} + template size_t jit_uni_eltwise_injector_f32::aux_vecs_count() { using namespace alg_kind; @@ -1579,6 +1623,9 @@ size_t jit_uni_eltwise_injector_f32::aux_vecs_count() { case eltwise_gelu_erf: return 5; case eltwise_round: return 0; case eltwise_hardswish: return 1; + case eltwise_hsigmoid: return 0; + case eltwise_round_half_to_even: return 0; + case eltwise_round_half_away_from_zero: return 2; default: assert(!"unsupported eltwise algorithm"); } } else { @@ -1674,6 +1721,9 @@ void jit_uni_eltwise_injector_f32::compute_body( case eltwise_hardswish: hardswish_compute_vector_fwd(Vmm(idx)); break; + case eltwise_hsigmoid: hsigmoid_compute_vector_fwd(Vmm(idx)); break; + case eltwise_round_half_to_even: round_half_to_even_compute_vector_fwd(Vmm(idx)); break; + case eltwise_round_half_away_from_zero: round_half_away_from_zero_compute_vector_fwd(Vmm(idx)); break; default: assert(!"unsupported eltwise algorithm"); } } else { @@ -2078,6 +2128,7 @@ void jit_uni_eltwise_injector_f32::register_table_entries() { static const table_t soft_relu_consts { {soft_relu_one_twenty_six, {0x42fc0000, true}}, {soft_relu_mantissa_sign_mask, {0x807fffff, true}}, + {soft_relu_twenty, {0x41a00000, true}}, }; // soft_relu ln(1 + x) polynomial approximation @@ -2236,6 +2287,13 @@ void jit_uni_eltwise_injector_f32::register_table_entries() { static const table_t hardswish_consts {{three, {0x40400000, true}}, {six, {0x40c00000, true}}, {minus_three, {0xc0400000, true}}}; + // hsigmoid(x) polynomial approximation + static const table_t hsigmoid_values { + {hsigmoid, {0x40400000, true}}, // 3 + {hsigmoid, {0x40C00000, true}}, // 6 + {hsigmoid, {0x3e2aaaaa, true}}, // 1 / 6 + }; + // This object takes care about which constants and polynomials to include. struct need_t { need_t(alg_kind_t alg) { @@ -2257,6 +2315,7 @@ void jit_uni_eltwise_injector_f32::register_table_entries() { case eltwise_tanh_use_dst_for_bwd: case eltwise_tanh: tanh_ = true; break; case eltwise_hardswish: hardswish_ = true; break; + case eltwise_hsigmoid: hsigmoid_ = true; break; default: break; } } @@ -2269,6 +2328,7 @@ void jit_uni_eltwise_injector_f32::register_table_entries() { bool gelu_erf_ = false; bool log_ = false; bool hardswish_ = false; + bool hsigmoid_ = false; bool exp() const { return exp_ || soft_relu_ || gelu_erf_ || mish_; } bool mish() const { return mish_; } @@ -2278,6 +2338,7 @@ void jit_uni_eltwise_injector_f32::register_table_entries() { bool gelu_erf() const { return gelu_erf_; } bool log() const { return log_; } bool hardswish() const { return hardswish_; } + bool hsigmoid() const { return hsigmoid_; } }; need_t need(alg_); @@ -2314,6 +2375,7 @@ void jit_uni_eltwise_injector_f32::register_table_entries() { if (need.log()) push_entries_of(log_polynomial); if (need.log()) push_entries_of(log_predefined_values); if (need.hardswish()) push_entries_of(hardswish_consts); + if (need.hsigmoid()) push_entries_of(hsigmoid_values); // Now that we registered the entries, we set the offsets. No // entries should be registered after this point. This allows to diff --git a/src/cpu/x64/injectors/jit_uni_eltwise_injector.hpp b/src/cpu/x64/injectors/jit_uni_eltwise_injector.hpp index 29c7e265bb5..76e2533d70e 100644 --- a/src/cpu/x64/injectors/jit_uni_eltwise_injector.hpp +++ b/src/cpu/x64/injectors/jit_uni_eltwise_injector.hpp @@ -147,7 +147,8 @@ struct jit_uni_eltwise_injector_f32 { _cmp_ge_os = jit_generator::_cmp_nlt_us, _cmp_gt_os = jit_generator::_cmp_nle_us, _op_floor = jit_generator::_op_floor, - _op_mxcsr = jit_generator::_op_mxcsr + _op_mxcsr = jit_generator::_op_mxcsr, + _op_near = jit_generator::_op_near }; static constexpr bool has_avx512() { @@ -214,6 +215,9 @@ struct jit_uni_eltwise_injector_f32 { void gelu_erf_compute_vector_fwd(const Vmm &vmm_src); void round_compute_vector_fwd(const Vmm &vmm_src); void hardswish_compute_vector_fwd(const Vmm &vmm_src); + void hsigmoid_compute_vector_fwd(const Vmm &vmm_src); + void round_half_to_even_compute_vector_fwd(const Vmm &vmm_src); + void round_half_away_from_zero_compute_vector_fwd(const Vmm &vmm_src); void exp_compute_vector_bwd(const Vmm &vmm_src); void relu_compute_vector_bwd(const Vmm &vmm_src); @@ -268,6 +272,7 @@ struct jit_uni_eltwise_injector_f32 { tanh_pol_table, // table of polynomial coefficients soft_relu_one_twenty_six, // 126.f soft_relu_mantissa_sign_mask, // mask for mantissa bits and sign + soft_relu_twenty, // 20.f soft_relu_pol, // see correspondent table for float values gelu_tanh_fitting_const, // 0.044715f gelu_tanh_fitting_const_times_three, // 0.134145f @@ -284,6 +289,7 @@ struct jit_uni_eltwise_injector_f32 { log_five_bit_offset, // 5 bits off (31 = 2^5 - 1) log_pol, // see correspondent table for float values log_predefined_vals, // see correspondent table for float values + hsigmoid, // hsigmoid undef_key, }; diff --git a/src/cpu/x64/injectors/jit_uni_postops_injector.cpp b/src/cpu/x64/injectors/jit_uni_postops_injector.cpp index 346e6801a54..6b24b7fb24e 100644 --- a/src/cpu/x64/injectors/jit_uni_postops_injector.cpp +++ b/src/cpu/x64/injectors/jit_uni_postops_injector.cpp @@ -44,11 +44,45 @@ bool is_supported(const post_ops_ok_args_t &post_ops_ok_args) { return true; } +template +jit_uni_postops_injector_t::jit_uni_postops_injector_t( + jit_generator *host, const post_ops_t &post_ops, + const eltwise_injector::static_params_t &eltwise_static_params, + const quantization_injector::static_params_t &quantization_static_params) + : post_ops_(post_ops) + , host_(host) + , binary_injector_(nullptr) { + + const auto &esp = eltwise_static_params; + const auto &qsp = quantization_static_params; + + for (const auto &post_op : post_ops.entry_) { + if (post_op.is_eltwise()) { + alg_to_eltwise_injector_.emplace(post_op.eltwise.alg, + jit_uni_eltwise_injector_f32(host_, post_op.eltwise, + esp.save_state, esp.p_table, esp.k_mask, esp.is_fwd, + esp.use_dst)); + } else if (post_op.is_depthwise()) { + depthwise_injectors.emplace_back(new jit_uni_depthwise_injector_f32( + host, + post_op + )); + } else if (post_op.is_quantization()) { + quantization_injectors.emplace_back(new jit_uni_quantization_injector_f32( + host, + post_op, + Vmm(qsp.vmm_d_weights_idx), Vmm(qsp.vmm_d_bias_idx), qsp.reg_d_weights, qsp.reg_d_bias + )); + } + } +} + template jit_uni_postops_injector_t::jit_uni_postops_injector_t( jit_generator *host, const post_ops_t &post_ops, const binary_injector::static_params_t &binary_static_params, const eltwise_injector::static_params_t &eltwise_static_params, + const quantization_injector::static_params_t &quantization_static_params, const lambda_jit_injectors_t &lambda_jit_injectors) : post_ops_(post_ops) , host_(host) @@ -56,6 +90,7 @@ jit_uni_postops_injector_t::jit_uni_postops_injector_t( , lambda_jit_injectors_(lambda_jit_injectors) { const auto &esp = eltwise_static_params; + const auto &qsp = quantization_static_params; bool is_binary = false; bool is_eltwise = false; @@ -68,6 +103,17 @@ jit_uni_postops_injector_t::jit_uni_postops_injector_t( esp.k_mask, esp.is_fwd, esp.use_dst)); } else if (post_op.is_binary()) { is_binary = true; + } else if (post_op.is_depthwise()) { + depthwise_injectors.emplace_back(new jit_uni_depthwise_injector_f32( + host, + post_op + )); + } else if (post_op.is_quantization()) { + quantization_injectors.emplace_back(new jit_uni_quantization_injector_f32( + host, + post_op, + Vmm(qsp.vmm_d_weights_idx), Vmm(qsp.vmm_d_bias_idx), qsp.reg_d_weights, qsp.reg_d_bias + )); } } @@ -90,7 +136,8 @@ jit_uni_postops_injector_t::jit_uni_postops_injector_t( jit_generator *host, const post_ops_t &post_ops, const binary_injector::static_params_t &binary_static_params) : jit_uni_postops_injector_t(host, post_ops, binary_static_params, - eltwise_injector::static_params_t(), lambda_jit_injectors_t()) {} + eltwise_injector::static_params_t(), quantization_injector::static_params_t(), + lambda_jit_injectors_t()) {} template jit_uni_postops_injector_t::jit_uni_postops_injector_t( @@ -98,7 +145,8 @@ jit_uni_postops_injector_t::jit_uni_postops_injector_t( const binary_injector::static_params_t &binary_static_params, const lambda_jit_injectors_t &lambda_jit_injectors) : jit_uni_postops_injector_t(host, post_ops, binary_static_params, - eltwise_injector::static_params_t(), lambda_jit_injectors) {} + eltwise_injector::static_params_t(), quantization_injector::static_params_t(), + lambda_jit_injectors) {} template jit_uni_postops_injector_t::jit_uni_postops_injector_t( @@ -106,7 +154,26 @@ jit_uni_postops_injector_t::jit_uni_postops_injector_t( const binary_injector::static_params_t &binary_static_params, const eltwise_injector::static_params_t &eltwise_static_params) : jit_uni_postops_injector_t(host, post_ops, binary_static_params, - eltwise_static_params, lambda_jit_injectors_t()) {} + eltwise_static_params, + quantization_injector::static_params_t(), lambda_jit_injectors_t()) {} + +template +jit_uni_postops_injector_t::jit_uni_postops_injector_t(jit_generator *host, + const post_ops_t &post_ops, + const binary_injector::static_params_t &binary_static_params, + const quantization_injector::static_params_t &quantization_static_params) + : jit_uni_postops_injector_t(host, post_ops, binary_static_params, + eltwise_injector::static_params_t(), + quantization_static_params, lambda_jit_injectors_t()) {} + +template +jit_uni_postops_injector_t::jit_uni_postops_injector_t(jit_generator *host, + const post_ops_t &post_ops, + const binary_injector::static_params_t &binary_static_params, + const eltwise_injector::static_params_t &eltwise_static_params, + const quantization_injector::static_params_t &quantization_static_params) + : jit_uni_postops_injector_t(host, post_ops, binary_static_params, + eltwise_static_params, quantization_static_params, lambda_jit_injectors_t()) {} template void jit_uni_postops_injector_t::compute_vector_range( @@ -119,6 +186,19 @@ void jit_uni_postops_injector_t::compute_vector_range( compute_vector_range(vmm_idxs, rhs_arg_params); } +template +void jit_uni_postops_injector_t::compute_vector_range( + size_t start_idx, size_t end_idx, + const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params, + const depthwise_injector::dynamic_params_t &ddp, + const quantization_injector::dynamic_params_t &qdp) { + + injector_utils::vmm_index_set_t vmm_idxs; + for (size_t i = start_idx; i < end_idx; i++) + vmm_idxs.emplace(i); + compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); +} + template void jit_uni_postops_injector_t::compute_vector_range( size_t start_idx, size_t end_idx) { @@ -129,10 +209,17 @@ void jit_uni_postops_injector_t::compute_vector_range( template void jit_uni_postops_injector_t::compute_vector_range( const injector_utils::vmm_index_set_t &vmm_idxs, - const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params) { + const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params, + const depthwise_injector::dynamic_params_t &ddp, + const quantization_injector::dynamic_params_t &qdp, bool is_broadcast) { std::size_t rhs_arg_idx = 0; - for (const auto &post_op : post_ops_.entry_) { + std::size_t quantization_inj_idx = 0; + std::size_t depthwise_inj_idx = 0; + std::size_t post_ops_data_offset = 0; + for (int i = 0; i < post_ops_.len(); i++) { + const auto &post_op = post_ops_.entry_[i]; + if (post_op.is_eltwise()) { alg_to_eltwise_injector_.at(post_op.eltwise.alg) .compute_vector_range(vmm_idxs); @@ -140,6 +227,85 @@ void jit_uni_postops_injector_t::compute_vector_range( binary_injector_->compute_vector_range( vmm_idxs, rhs_arg_idx, post_op, rhs_arg_params); ++rhs_arg_idx; + } else if (post_op.is_depthwise()) { + const Xbyak::RegExp depthwise_arg_base = ddp.reg_post_ops_data + ddp.base_post_ops_data_offset + post_ops_data_offset; + if (ddp.useAddr) + depthwise_injectors[depthwise_inj_idx]->init_ptrs(depthwise_arg_base, ddp.reg_d_weights, ddp.reg_d_bias, ddp.reg_init_off_addr, false); + else + depthwise_injectors[depthwise_inj_idx]->init_ptrs(depthwise_arg_base, ddp.reg_d_weights, ddp.reg_d_bias, ddp.reg_init_off, false); + + bool need_to_preserve = false; + if (post_op.depthwise.alg == dnnl_depthwise_prelu && isa == sse41) + need_to_preserve = true; + + for (auto vmm_idx : vmm_idxs) { + depthwise_injectors[depthwise_inj_idx]->compute(vmm_idx, vmm_idx + 1, + need_to_preserve ? 0 : ddp.vmm_d_weights_idx, ddp.vmm_d_bias_idx, + ddp.reg_d_weights, ddp.reg_d_bias, + is_broadcast, ddp.vmm_idx_off.at(vmm_idx), need_to_preserve); + } + + post_ops_data_offset += depthwise_injectors[depthwise_inj_idx]->memoryStep(); + ++rhs_arg_idx; + depthwise_inj_idx++; + } else if (post_op.is_quantization()) { + std::vector>> vecOfVmmIdxsSets; + + std::multimap offsetVmmIdxMap; + for (auto vmm_idx : vmm_idxs) { + offsetVmmIdxMap.insert({qdp.vmm_idx_off.at(vmm_idx), vmm_idx}); + } + + auto externalIt = offsetVmmIdxMap.begin(); + while (externalIt != offsetVmmIdxMap.end()) { + auto internalIt = externalIt; + auto endInternalIt = offsetVmmIdxMap.upper_bound(externalIt->first); + + std::set vmmIndexesToProcess; + while (internalIt != endInternalIt) { + vmmIndexesToProcess.insert(internalIt->second); + internalIt++; + } + vecOfVmmIdxsSets.push_back({externalIt->first, vmmIndexesToProcess}); + + externalIt = endInternalIt; + } + + bool do_dequantization = post_op.quantization.alg == alg_kind::quantization_quantize_dequantize; + bool do_rounding = do_dequantization || qdp.dst_dt == dnnl_f32 || i != post_ops_.len() - 1; + + const Xbyak::RegExp quant_arg_base = qdp.reg_post_ops_data + qdp.base_post_ops_data_offset + post_ops_data_offset; + if (qdp.useAddr) + quantization_injectors[quantization_inj_idx]->init_crop_ptrs(quant_arg_base, qdp.reg_oc_off_addr); + else + quantization_injectors[quantization_inj_idx]->init_crop_ptrs(quant_arg_base, qdp.reg_oc_off); + + for (auto &IdxSetPair : vecOfVmmIdxsSets) { + quantization_injectors[quantization_inj_idx]->compute_crop(IdxSetPair.second, IdxSetPair.first, false, is_broadcast); + } + + if (qdp.useAddr) + quantization_injectors[quantization_inj_idx]->init_input_scale_shift_ptrs(quant_arg_base, qdp.reg_oc_off_addr); + else + quantization_injectors[quantization_inj_idx]->init_input_scale_shift_ptrs(quant_arg_base, qdp.reg_oc_off); + + for (auto &IdxSetPair : vecOfVmmIdxsSets) { + quantization_injectors[quantization_inj_idx]->compute_input_scale_shift(IdxSetPair.second, IdxSetPair.first, do_rounding, + false, is_broadcast); + } + + if (qdp.useAddr) + quantization_injectors[quantization_inj_idx]->init_output_scale_shift_ptrs(quant_arg_base, qdp.reg_oc_off_addr); + else + quantization_injectors[quantization_inj_idx]->init_output_scale_shift_ptrs(quant_arg_base, qdp.reg_oc_off); + + for (auto &IdxSetPair : vecOfVmmIdxsSets) { + quantization_injectors[quantization_inj_idx]->compute_output_scale_shift(IdxSetPair.second, IdxSetPair.first, false, is_broadcast); + } + + post_ops_data_offset += quantization_injectors[quantization_inj_idx]->memoryStep(); + ++rhs_arg_idx; + quantization_inj_idx++; } else { const auto lam = lambda_jit_injectors_.find(post_op.kind); if (lam != lambda_jit_injectors_.end()) lam->second(); @@ -152,6 +318,13 @@ void jit_uni_postops_injector_t::compute_vector_range( compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t()); } +template +void jit_uni_postops_injector_t::compute_vector_range( + const injector_utils::vmm_index_set_t &vmm_idxs, + const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params) { + compute_vector_range(vmm_idxs, rhs_arg_params, depthwise_injector::dynamic_params_t(), quantization_injector::dynamic_params_t()); +} + template void jit_uni_postops_injector_t::prepare_table(bool gen_table) { for (auto &alg_elt_inject : alg_to_eltwise_injector_) @@ -169,12 +342,54 @@ void jit_uni_postops_injector_t::compute_vector(size_t idx) { compute_vector_range({idx}); } +template +void jit_uni_postops_injector_t::compute_vector(size_t idx, + const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params, + const depthwise_injector::dynamic_params_t &ddp, + const quantization_injector::dynamic_params_t &qdp) { + compute_vector_range({idx}, rhs_arg_params, ddp, qdp); +} + +template +void jit_uni_postops_injector_t::compute_vector(size_t idx, + const depthwise_injector::dynamic_params_t &ddp, + const quantization_injector::dynamic_params_t &qdp, bool is_broadcast) { + compute_vector_range({idx}, binary_injector::rhs_arg_dynamic_params_t(), ddp, qdp, is_broadcast); +} + template void jit_uni_postops_injector_t::set_lambda_injector( dnnl_primitive_kind_t kind, const std::function &jit_injector) { lambda_jit_injectors_[kind] = jit_injector; } +template +void jit_uni_postops_injector_t::push_post_ops_data_on_stack(const Xbyak::Reg64& post_ops_data_reg, std::size_t post_ops_data_offset, + const Xbyak::Reg64& aux_reg0, const Xbyak::Reg64& aux_reg1) { + for (int i = 0; i < post_ops_.len(); i++) { + if (post_ops_.entry_[i].is_depthwise() || post_ops_.entry_[i].is_quantization()) { + post_ops_pointers_count++; + } + } + + if (post_ops_pointers_count != 0) { + host_->sub(host_->rsp, post_ops_pointers_count * sizeof(float *)); + + host_->mov(aux_reg0, host_->ptr[post_ops_data_reg + post_ops_data_offset]); + for (size_t i = 0; i < post_ops_pointers_count; i++) { + host_->mov(aux_reg1, host_->ptr[aux_reg0 + i * sizeof(float *)]); + host_->mov(host_->ptr[host_->rsp + i * sizeof(float *)], aux_reg1); + } + } +} + +template +void jit_uni_postops_injector_t::reset_stack_pointer() { + if (post_ops_pointers_count != 0) { + host_->add(host_->rsp, post_ops_pointers_count * sizeof(float *)); + } +} + post_ops_ok_args_t::post_ops_ok_args_t(const cpu_isa_t isa, const std::vector &accepted_post_op_types, const post_ops_t &post_ops, const memory_desc_wrapper *dst_d, @@ -229,6 +444,8 @@ bool post_ops_ok(const post_ops_ok_args_t &post_ops_ok_args) { enabled_bcast_strategy); } break; + case depthwise: if (entry.is_depthwise()) return true; break; + case quantization: if (entry.is_quantization()) return true; break; default: assert(false && "Unhandled post_op type"); } } diff --git a/src/cpu/x64/injectors/jit_uni_postops_injector.hpp b/src/cpu/x64/injectors/jit_uni_postops_injector.hpp index a41d6c42867..b4d7ee49c9f 100644 --- a/src/cpu/x64/injectors/jit_uni_postops_injector.hpp +++ b/src/cpu/x64/injectors/jit_uni_postops_injector.hpp @@ -27,6 +27,8 @@ #include "cpu/x64/injectors/injector_utils.hpp" #include "cpu/x64/injectors/jit_uni_binary_injector.hpp" #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" +#include "cpu/x64/injectors/jit_uni_depthwise_injector.hpp" +#include "cpu/x64/injectors/jit_uni_quantization_injector.hpp" #include "cpu/x64/jit_generator.hpp" #include @@ -79,9 +81,20 @@ class jit_uni_postops_injector_t { jit_uni_postops_injector_t(jit_generator *host, const post_ops_t &post_ops, const binary_injector::static_params_t &binary_static_params, const eltwise_injector::static_params_t &eltwise_static_params); + jit_uni_postops_injector_t(jit_generator *host, const post_ops_t &post_ops, + const binary_injector::static_params_t &binary_static_params, + const quantization_injector::static_params_t &quantization_static_params); + jit_uni_postops_injector_t(jit_generator *host, const post_ops_t &post_ops, + const eltwise_injector::static_params_t &eltwise_static_params, + const quantization_injector::static_params_t &quantization_static_params); + jit_uni_postops_injector_t(jit_generator *host, const post_ops_t &post_ops, + const binary_injector::static_params_t &binary_static_params, + const eltwise_injector::static_params_t &eltwise_static_params, + const quantization_injector::static_params_t &quantization_static_params); jit_uni_postops_injector_t(jit_generator *host, const post_ops_t &post_ops, const binary_injector::static_params_t &binary_static_params, const eltwise_injector::static_params_t &eltwise_static_params, + const quantization_injector::static_params_t &quantization_static_params, const lambda_jit_injectors_t &lambda_jit_injectors); /* @@ -90,9 +103,20 @@ class jit_uni_postops_injector_t { * * @rhs_arg_params: see jit_uni_binary_injector description */ + void compute_vector_range(const injector_utils::vmm_index_set_t &vmm_idxs, + const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params, + const depthwise_injector::dynamic_params_t &ddp, + const quantization_injector::dynamic_params_t &qdp, bool is_broadcast = false); + void compute_vector_range(const injector_utils::vmm_index_set_t &vmm_idxs, const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params); + void compute_vector_range( + size_t start_idx, size_t end_idx, + const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params, + const depthwise_injector::dynamic_params_t &ddp, + const quantization_injector::dynamic_params_t &qdp); + void compute_vector_range(const injector_utils::vmm_index_set_t &vmm_idxs); /* @@ -115,6 +139,13 @@ class jit_uni_postops_injector_t { void compute_vector(size_t idx, const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params); void compute_vector(size_t idx); + void compute_vector(size_t idx, + const depthwise_injector::dynamic_params_t &ddp, + const quantization_injector::dynamic_params_t &qdp, bool is_broadcast = false); + void compute_vector(size_t idx, + const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params, + const depthwise_injector::dynamic_params_t &ddp, + const quantization_injector::dynamic_params_t &qdp); /* * Thin wrapper for eltwise injector specific function @@ -123,6 +154,10 @@ class jit_uni_postops_injector_t { void set_lambda_injector(lambda_jit_injectors_t::key_type, const lambda_jit_injectors_t::mapped_type &jit_injector); + void push_post_ops_data_on_stack(const Xbyak::Reg64& post_ops_data_reg, std::size_t post_ops_data_offset, + const Xbyak::Reg64& aux_reg0, const Xbyak::Reg64& aux_reg1); + void reset_stack_pointer(); + private: post_ops_t post_ops_; jit_generator *host_; @@ -131,9 +166,12 @@ class jit_uni_postops_injector_t { std::unique_ptr> binary_injector_; lambda_jit_injectors_t lambda_jit_injectors_; + nstl::vector>> depthwise_injectors; + nstl::vector>> quantization_injectors; + std::size_t post_ops_pointers_count = 0; }; -enum post_op_type { sum = 0, eltwise, binary }; +enum post_op_type { sum = 0, eltwise, binary, depthwise, quantization }; struct post_ops_ok_args_t { post_ops_ok_args_t(const cpu_isa_t isa, diff --git a/src/cpu/x64/injectors/jit_uni_quantization_injector.cpp b/src/cpu/x64/injectors/jit_uni_quantization_injector.cpp new file mode 100644 index 00000000000..73eb75dd90e --- /dev/null +++ b/src/cpu/x64/injectors/jit_uni_quantization_injector.cpp @@ -0,0 +1,300 @@ +/******************************************************************************* +* Copyright 2019-2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "common/c_types_map.hpp" +#include "common/dnnl_thread.hpp" +#include "common/nstl.hpp" +#include "common/utils.hpp" + +#include "cpu/x64/injectors/jit_uni_quantization_injector.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +template +void jit_uni_quantization_injector_f32::init_crop_ptrs(const Xbyak::RegExp& ptr_begin, const Xbyak::Operand& ch_off) { + h->mov(reg_d_weights_, h->ptr[ptr_begin]); + h->mov(reg_d_bias_, h->ptr[ptr_begin]); + + if (post_op_.quantization.per_channel[post_op_.quantization.crop_low] && !post_op_.quantization.all_default[post_op_.quantization.crop_low]) + h->add(reg_d_weights_, ch_off); + if (post_op_.quantization.per_channel[post_op_.quantization.crop_high] && !post_op_.quantization.all_default[post_op_.quantization.crop_high]) + h->add(reg_d_bias_, ch_off); +} + +template +void jit_uni_quantization_injector_f32::compute_crop_impl(const std::set& vmmIdxs, int offset, bool is_scalar, bool is_broadcast) { + size_t weights_off = post_op_.quantization.offset[post_op_.quantization.crop_low] * sizeof(float); + size_t bias_off = post_op_.quantization.offset[post_op_.quantization.crop_high] * sizeof(float); + + if (is_scalar) { + if (!post_op_.quantization.per_channel[post_op_.quantization.crop_low]) + h->uni_vmovss(xmm_d_weights_, h->ptr[reg_d_weights_ + weights_off]); + else if (post_op_.quantization.all_default[post_op_.quantization.crop_low]) + h->uni_vpxor(vmm_d_weights_, vmm_d_weights_, vmm_d_weights_); + else + h->uni_vmovss(xmm_d_weights_, h->ptr[reg_d_weights_ + offset + weights_off]); + } else { + if (!post_op_.quantization.per_channel[post_op_.quantization.crop_low]) + h->uni_vbroadcastss(vmm_d_weights_, h->ptr[reg_d_weights_ + weights_off]); + else if (post_op_.quantization.all_default[post_op_.quantization.crop_low]) + h->uni_vpxor(vmm_d_weights_, vmm_d_weights_, vmm_d_weights_); + else if (is_broadcast) + h->uni_vbroadcastss(vmm_d_weights_, h->ptr[reg_d_weights_ + offset + weights_off]); + else + h->uni_vmovups(vmm_d_weights_, h->ptr[reg_d_weights_ + offset + weights_off]); + } + + if (vmm_d_weights_.getIdx() == vmm_d_bias_.getIdx()) { + for (auto vmmIdx : vmmIdxs) { + Vmm vmm_dst = Vmm(vmmIdx); + h->uni_vmaxps(vmm_dst, vmm_dst, vmm_d_weights_); + } + } + + if (is_scalar) { + if (!post_op_.quantization.per_channel[post_op_.quantization.crop_high]) + h->uni_vmovss(xmm_d_bias_, h->ptr[reg_d_bias_ + bias_off]); + else if (post_op_.quantization.all_default[post_op_.quantization.crop_high]) + h->uni_vpxor(vmm_d_bias_, vmm_d_bias_, vmm_d_bias_); + else + h->uni_vmovss(xmm_d_bias_, h->ptr[reg_d_bias_ + offset + bias_off]); + } else { + if (!post_op_.quantization.per_channel[post_op_.quantization.crop_high]) + h->uni_vbroadcastss(vmm_d_bias_, h->ptr[reg_d_bias_ + bias_off]); + else if (post_op_.quantization.all_default[post_op_.quantization.crop_high]) + h->uni_vpxor(vmm_d_bias_, vmm_d_bias_, vmm_d_bias_); + else if (is_broadcast) + h->uni_vbroadcastss(vmm_d_bias_, h->ptr[reg_d_bias_ + offset + bias_off]); + else + h->uni_vmovups(vmm_d_bias_, h->ptr[reg_d_bias_ + offset + bias_off]); + } + + for (auto vmmIdx : vmmIdxs) { + Vmm vmm_dst = Vmm(vmmIdx); + + if (vmm_d_weights_.getIdx() != vmm_d_bias_.getIdx()) + h->uni_vmaxps(vmm_dst, vmm_dst, vmm_d_weights_); + + h->uni_vminps(vmm_dst, vmm_dst, vmm_d_bias_); + } +} + +template +void jit_uni_quantization_injector_f32::compute_crop(const std::set& vmmIdxs, int offset, bool is_scalar, bool is_broadcast) { + compute_crop_impl(vmmIdxs, offset, is_scalar, is_broadcast); +} + +template +void jit_uni_quantization_injector_f32::compute_crop(int start_idx, int end_idx, int offset, bool is_scalar, bool is_broadcast) { + std::set vmmIdxs; + for (int i = start_idx; i < end_idx; i++) { + vmmIdxs.insert(i); + } + + compute_crop_impl(vmmIdxs, offset, is_scalar, is_broadcast); +} + +template +void jit_uni_quantization_injector_f32::init_input_scale_shift_ptrs(const Xbyak::RegExp& ptr_begin, const Xbyak::Operand& ch_off) { + h->mov(reg_d_weights_, h->ptr[ptr_begin]); + h->mov(reg_d_bias_, h->ptr[ptr_begin]); + + if (post_op_.quantization.per_channel[post_op_.quantization.inp_scale]) + h->add(reg_d_weights_, ch_off); + if (post_op_.quantization.per_channel[post_op_.quantization.inp_shift] && !post_op_.quantization.all_default[post_op_.quantization.inp_shift]) + h->add(reg_d_bias_, ch_off); +} + +template +void jit_uni_quantization_injector_f32::compute_input_scale_shift_impl( + const std::set& vmmIdxs, int offset, bool do_rounding, bool is_scalar, bool is_broadcast) { + size_t weights_off = post_op_.quantization.offset[post_op_.quantization.inp_scale] * sizeof(float); + size_t bias_off = post_op_.quantization.offset[post_op_.quantization.inp_shift] * sizeof(float); + + if (is_scalar) { + if (!post_op_.quantization.per_channel[post_op_.quantization.inp_scale]) + h->movss(xmm_d_weights_, h->ptr[reg_d_weights_ + weights_off]); + else + h->movss(xmm_d_weights_, h->ptr[reg_d_weights_ + offset + weights_off]); + } else { + if (!post_op_.quantization.per_channel[post_op_.quantization.inp_scale]) + h->uni_vbroadcastss(vmm_d_weights_, h->ptr[reg_d_weights_ + weights_off]); + else if (is_broadcast) + h->uni_vbroadcastss(vmm_d_weights_, h->ptr[reg_d_weights_ + offset + weights_off]); + else + h->uni_vmovups(vmm_d_weights_, h->ptr[reg_d_weights_ + offset + weights_off]); + } + + if (vmm_d_weights_.getIdx() == vmm_d_bias_.getIdx()) { + for (auto vmmIdx : vmmIdxs) { + Vmm vmm_dst = Vmm(vmmIdx); + + h->uni_vmulps(vmm_dst, vmm_dst, vmm_d_weights_); + } + } + + if (is_scalar) { + if (!post_op_.quantization.per_channel[post_op_.quantization.inp_shift]) + h->movss(xmm_d_bias_, h->ptr[reg_d_bias_ + bias_off]); + else if (post_op_.quantization.all_default[post_op_.quantization.inp_shift]) + h->uni_vpxor(vmm_d_bias_, vmm_d_bias_, vmm_d_bias_); + else + h->movss(xmm_d_bias_, h->ptr[reg_d_bias_ + offset + bias_off]); + } else { + if (!post_op_.quantization.per_channel[post_op_.quantization.inp_shift]) + h->uni_vbroadcastss(vmm_d_bias_, h->ptr[reg_d_bias_ + bias_off]); + else if (post_op_.quantization.all_default[post_op_.quantization.inp_shift]) + h->uni_vpxor(vmm_d_bias_, vmm_d_bias_, vmm_d_bias_); + else if (is_broadcast) + h->uni_vbroadcastss(vmm_d_bias_, h->ptr[reg_d_bias_ + offset + bias_off]); + else + h->uni_vmovups(vmm_d_bias_, h->ptr[reg_d_bias_ + offset + bias_off]); + } + + for (auto vmmIdx : vmmIdxs) { + Vmm vmm_dst = Vmm(vmmIdx); + + if (vmm_d_weights_.getIdx() == vmm_d_bias_.getIdx()) + h->uni_vaddps(vmm_dst, vmm_dst, vmm_d_bias_); + else + h->uni_vfmadd213ps(vmm_dst, vmm_d_weights_, vmm_d_bias_); + + if (do_rounding) + h->uni_vroundps(vmm_dst, vmm_dst, 0); + } +} + +template +void jit_uni_quantization_injector_f32::compute_input_scale_shift(int start_idx, int end_idx, int offset, bool do_rounding, bool is_scalar, bool is_broadcast) { + std::set vmmIdxs; + for (int i = start_idx; i < end_idx; i++) { + vmmIdxs.insert(i); + } + + compute_input_scale_shift_impl(vmmIdxs, offset, do_rounding, is_scalar, is_broadcast); +} + +template +void jit_uni_quantization_injector_f32::compute_input_scale_shift(const std::set& vmmIdxs, int offset, bool do_rounding, bool is_scalar, bool is_broadcast) { + compute_input_scale_shift_impl(vmmIdxs, offset, do_rounding, is_scalar, is_broadcast); +} + +template +void jit_uni_quantization_injector_f32::init_output_scale_shift_ptrs(const Xbyak::RegExp& ptr_begin, const Xbyak::Operand& ch_off) { + if (!do_dequantization) + return; + + h->mov(reg_d_weights_, h->ptr[ptr_begin]); + h->mov(reg_d_bias_, h->ptr[ptr_begin]); + + if (post_op_.quantization.per_channel[post_op_.quantization.output_scale]) + h->add(reg_d_weights_, ch_off); + if (post_op_.quantization.per_channel[post_op_.quantization.output_shift] && !post_op_.quantization.all_default[post_op_.quantization.output_shift]) + h->add(reg_d_bias_, ch_off); +} + +template +void jit_uni_quantization_injector_f32::compute_output_scale_shift_impl(const std::set& vmmIdxs, int offset, bool is_scalar, bool is_broadcast) { + size_t weights_off = post_op_.quantization.offset[post_op_.quantization.output_scale] * sizeof(float); + size_t bias_off = post_op_.quantization.offset[post_op_.quantization.output_shift] * sizeof(float); + + if (!do_dequantization) + return; + + if (is_scalar) { + if (!post_op_.quantization.per_channel[post_op_.quantization.output_scale]) + h->movss(xmm_d_weights_, h->ptr[reg_d_weights_ + weights_off]); + else + h->movss(xmm_d_weights_, h->ptr[reg_d_weights_ + offset + weights_off]); + } else { + if (!post_op_.quantization.per_channel[post_op_.quantization.output_scale]) + h->uni_vbroadcastss(vmm_d_weights_, h->ptr[reg_d_weights_ + weights_off]); + else if (is_broadcast) + h->uni_vbroadcastss(vmm_d_weights_, h->ptr[reg_d_weights_ + offset + weights_off]); + else + h->uni_vmovups(vmm_d_weights_, h->ptr[reg_d_weights_ + offset + weights_off]); + } + + if (vmm_d_weights_.getIdx() == vmm_d_bias_.getIdx()) { + for (auto &vmmIdx : vmmIdxs) { + Vmm vmm_dst = Vmm(vmmIdx); + + h->uni_vmulps(vmm_dst, vmm_dst, vmm_d_weights_); + } + } + + if (is_scalar) { + if (!post_op_.quantization.per_channel[post_op_.quantization.output_shift]) + h->movss(xmm_d_bias_, h->ptr[reg_d_bias_ + bias_off]); + else if (post_op_.quantization.all_default[post_op_.quantization.output_shift]) + h->uni_vpxor(vmm_d_bias_, vmm_d_bias_, vmm_d_bias_); + else + h->movss(xmm_d_bias_, h->ptr[reg_d_bias_ + offset + bias_off]); + } else { + if (!post_op_.quantization.per_channel[post_op_.quantization.output_shift]) + h->uni_vbroadcastss(vmm_d_bias_, h->ptr[reg_d_bias_ + bias_off]); + else if (post_op_.quantization.all_default[post_op_.quantization.output_shift]) + h->uni_vpxor(vmm_d_bias_, vmm_d_bias_, vmm_d_bias_); + else if (is_broadcast) + h->uni_vbroadcastss(vmm_d_bias_, h->ptr[reg_d_bias_ + offset + bias_off]); + else + h->uni_vmovups(vmm_d_bias_, h->ptr[reg_d_bias_ + offset + bias_off]); + } + + for (auto &vmmIdx : vmmIdxs) { + Vmm vmm_dst = Vmm(vmmIdx); + + if (vmm_d_weights_.getIdx() == vmm_d_bias_.getIdx()) + h->uni_vaddps(vmm_dst, vmm_dst, vmm_d_bias_); + else + h->uni_vfmadd213ps(vmm_dst, vmm_d_weights_, vmm_d_bias_); + } +} + +template +void jit_uni_quantization_injector_f32::compute_output_scale_shift(int start_idx, int end_idx, int offset, bool is_scalar, bool is_broadcast) { + std::set vmmIdxs; + for (int i = start_idx; i < end_idx; i++) { + vmmIdxs.insert(i); + } + + compute_output_scale_shift_impl(vmmIdxs, offset, is_scalar, is_broadcast); +} + +template +void jit_uni_quantization_injector_f32::compute_output_scale_shift(const std::set& vmmIdxs, int offset, bool is_scalar, bool is_broadcast) { + compute_output_scale_shift_impl(vmmIdxs, offset, is_scalar, is_broadcast); +} + +template struct jit_uni_quantization_injector_f32; +template struct jit_uni_quantization_injector_f32; +template struct jit_uni_quantization_injector_f32; +template struct jit_uni_quantization_injector_f32; +template struct jit_uni_quantization_injector_f32; +template struct jit_uni_quantization_injector_f32; +template struct jit_uni_quantization_injector_f32; +template struct jit_uni_quantization_injector_f32; +template struct jit_uni_quantization_injector_f32; +template struct jit_uni_quantization_injector_f32; +template struct jit_uni_quantization_injector_f32; + +} // namespace x64 +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/x64/injectors/jit_uni_quantization_injector.hpp b/src/cpu/x64/injectors/jit_uni_quantization_injector.hpp new file mode 100644 index 00000000000..c3fcb789a7f --- /dev/null +++ b/src/cpu/x64/injectors/jit_uni_quantization_injector.hpp @@ -0,0 +1,134 @@ +/******************************************************************************* +* Copyright 2019-2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_X64_JIT_UNI_QUANTIZATION_INJECTOR_HPP +#define CPU_X64_JIT_UNI_QUANTIZATION_INJECTOR_HPP + +#include + +#include "common/c_types_map.hpp" +#include "common/primitive_attr.hpp" +#include "common/type_helpers.hpp" +#include "common/utils.hpp" +#include + +#include "cpu/x64/jit_generator.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +namespace quantization_injector { + +struct static_params_t { + static_params_t(int vmm_d_weights_idx = 0, int vmm_d_bias_idx = 0, + Xbyak::Reg64 reg_d_weights = Xbyak::Reg64(0), Xbyak::Reg64 reg_d_bias = Xbyak::Reg64(0)) : + vmm_d_weights_idx(vmm_d_weights_idx), vmm_d_bias_idx(vmm_d_bias_idx), reg_d_weights(reg_d_weights), reg_d_bias(reg_d_bias) {} + + int vmm_d_weights_idx; + int vmm_d_bias_idx; + Xbyak::Reg64 reg_d_weights; + Xbyak::Reg64 reg_d_bias; +}; + +struct dynamic_params_t { + dynamic_params_t() : + reg_oc_off(Xbyak::Reg64(0)), reg_oc_off_addr(0), vmm_idx_off({}), dst_dt(dnnl_f32), useAddr(false) { + } + + dynamic_params_t(Xbyak::Reg64 reg_oc_off, const std::map& vmm_idx_off, data_type_t dst_dt, + Xbyak::Reg64 reg_post_ops_data = Xbyak::Reg64(0), int base_post_ops_data_offset = 0) : + reg_oc_off(reg_oc_off), reg_oc_off_addr(0), vmm_idx_off(vmm_idx_off), dst_dt(dst_dt), useAddr(false), + reg_post_ops_data(reg_post_ops_data), base_post_ops_data_offset(base_post_ops_data_offset) { + } + + dynamic_params_t(Xbyak::Address reg_oc_off, const std::map& vmm_idx_off, data_type_t dst_dt, + Xbyak::Reg64 reg_post_ops_data = Xbyak::Reg64(0), int base_post_ops_data_offset = 0) : + reg_oc_off(0), reg_oc_off_addr(reg_oc_off), vmm_idx_off(vmm_idx_off), dst_dt(dst_dt), useAddr(true), + reg_post_ops_data(reg_post_ops_data), base_post_ops_data_offset(base_post_ops_data_offset) { + } + + Xbyak::Reg64 reg_oc_off; + Xbyak::Address reg_oc_off_addr; + std::map vmm_idx_off; + data_type_t dst_dt; + bool useAddr; + Xbyak::Reg64 reg_post_ops_data; + int base_post_ops_data_offset; +}; + +} // quantization_injector + +template ::Vmm> +struct jit_uni_quantization_injector_f32 { + jit_uni_quantization_injector_f32(jit_generator* host, dnnl_post_ops::entry_t post_op, + Vmm vmm_d_weights, Vmm vmm_d_bias, Xbyak::Reg64 reg_d_weights, Xbyak::Reg64 reg_d_bias) + : h(host), post_op_(post_op), vmm_d_weights_(vmm_d_weights), vmm_d_bias_(vmm_d_bias), reg_d_weights_(reg_d_weights), reg_d_bias_(reg_d_bias) { + assert(post_op.is_quantization()); + assert(utils::one_of(post_op.quantization.alg, alg_kind::quantization_quantize, alg_kind::quantization_quantize_dequantize)); + + do_dequantization = post_op_.quantization.alg == alg_kind::quantization_quantize_dequantize; + + xmm_d_weights_ = Xbyak::Xmm(vmm_d_weights.getIdx()); + xmm_d_bias_ = Xbyak::Xmm(vmm_d_bias.getIdx()); + } + + void init_crop_ptrs(const Xbyak::RegExp& ptr_begin, const Xbyak::Operand& ch_off); + void init_input_scale_shift_ptrs(const Xbyak::RegExp& ptr_begin, const Xbyak::Operand& ch_off); + void init_output_scale_shift_ptrs(const Xbyak::RegExp& ptr_begin, const Xbyak::Operand& ch_off); + + void compute_crop(int start_idx, int end_idx, int offset, bool is_scalar = false, bool is_broadcast = false); + void compute_input_scale_shift(int start_idx, int end_idx, int offset, bool do_rounding, bool is_scalar = false, bool is_broadcast = false); + void compute_output_scale_shift(int start_idx, int end_idx, int offset, bool is_scalar = false, bool is_broadcast = false); + + void compute_crop(const std::set& vmmIdxs, int offset, bool is_scalar, bool is_broadcast); + void compute_input_scale_shift(const std::set& vmmIdxs, int offset, bool do_rounding, bool is_scalar = false, bool is_broadcast = false); + void compute_output_scale_shift(const std::set& vmmIdxs, int offset, bool is_scalar = false, bool is_broadcast = false); + + // in bytes + static constexpr size_t memoryStep() { + return sizeof(float*); + } + +private: + void compute_crop_impl(const std::set& vmmIdxs, int offset, bool is_scalar, bool is_broadcast); + void compute_input_scale_shift_impl(const std::set& vmmIdxs, int offset, bool do_rounding, bool is_scalar = false, bool is_broadcast = false); + void compute_output_scale_shift_impl(const std::set& vmmIdxs, int offset, bool is_scalar = false, bool is_broadcast = false); + + jit_generator* h; + + size_t vlen = cpu_isa_traits::vlen; + + dnnl_post_ops::entry_t post_op_; + + Vmm vmm_d_weights_; + Vmm vmm_d_bias_; + Xbyak::Xmm xmm_d_weights_; + Xbyak::Xmm xmm_d_bias_; + + Xbyak::Reg64 reg_d_weights_; + Xbyak::Reg64 reg_d_bias_; + + bool do_dequantization; +}; + +} // namespace x64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.cpp b/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.cpp index 85244ee8cd0..f6ed8ad0a5e 100644 --- a/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.cpp +++ b/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.cpp @@ -50,7 +50,7 @@ jit_avx2_1x1_conv_kernel_f32::jit_avx2_1x1_conv_kernel_f32( : jit_generator(nullptr, MAX_CODE_SIZE, true, avx2) , jcp(ajcp) , attr_(attr) { - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; @@ -64,10 +64,12 @@ jit_avx2_1x1_conv_kernel_f32::jit_avx2_1x1_conv_kernel_f32( memory_desc_wrapper(dst_md), tail_size, use_exact_tail_scalar_bcast}; static_params_t static_params {this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params + {ymm_d_weights.getIdx(), ymm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique< injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } } @@ -148,13 +150,22 @@ void iterate(const int load_loop_blk, const int ur, const F &f) { void jit_avx2_1x1_conv_kernel_f32::apply_postops( const int load_loop_blk, const int ur, const int load_dim_tail) { - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { assert(ur * load_loop_blk < 14); Label store_nopost_ops; test(reg_reduce_pos_flag, FLAG_REDUCE_LAST); jz(store_nopost_ops, T_NEAR); + std::map vmm_idx_off; + iterate(load_loop_blk, ur, load_dim_tail, + [&](const bool, const int i, const int j) { + vmm_idx_off.insert({vreg_accum_idx(load_loop_blk, i, j), i * jcp.oc_block * sizeof(float)}); + }); + depthwise_injector::dynamic_params_t ddp {ymm_d_weights.getIdx(), ymm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + reg_oc_off, vmm_idx_off, this->rsp}; + quantization_injector::dynamic_params_t qdp {reg_oc_off, vmm_idx_off, jcp.dst_dt, this->rsp}; + injector_utils::vmm_index_set_t vmm_idxs; if (jcp.with_binary) { binary_injector::rhs_arg_dynamic_params_t rhs_arg_params, @@ -199,14 +210,14 @@ void jit_avx2_1x1_conv_kernel_f32::apply_postops( jmp(postops_done, T_NEAR); L(postops_no_tail); } - postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); L(postops_done); } else { iterate(load_loop_blk, ur, load_dim_tail, [&](const bool, const int i, const int j) { vmm_idxs.emplace(vreg_accum_idx(load_loop_blk, i, j)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t(), ddp, qdp); } L(store_nopost_ops); } @@ -567,6 +578,9 @@ void jit_avx2_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk) { void jit_avx2_1x1_conv_kernel_f32::generate() { preamble(); + if (postops_injector_) + postops_injector_->push_post_ops_data_on_stack(this->param1, GET_OFF(post_ops_binary_rhs_arg_vec), reg_bcast_data, reg_load_data); + if (jcp.with_binary || (jcp.with_bias && jcp.prop_kind == backward_weights)) sub(rsp, stack_space_needed); @@ -594,6 +608,7 @@ void jit_avx2_1x1_conv_kernel_f32::generate() { mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); if (jcp.prop_kind == backward_weights) mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); + mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]); auto generate_load_loop_body = [=](int load_loop_blk) { generate_bcast_loop(load_loop_blk); @@ -624,6 +639,7 @@ void jit_avx2_1x1_conv_kernel_f32::generate() { default: assert(!"invalid prop_kind"); } sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + add(reg_oc_off, load_loop_blk * jcp.oc_block * sizeof(float)); }; Label load_loop_blk_8; @@ -674,6 +690,9 @@ void jit_avx2_1x1_conv_kernel_f32::generate() { if (jcp.with_binary || (jcp.with_bias && jcp.prop_kind == backward_weights)) add(rsp, stack_space_needed); + if (postops_injector_) + postops_injector_->reset_stack_pointer(); + postamble(); if (jcp.with_eltwise) postops_injector_->prepare_table(); @@ -747,6 +766,9 @@ status_t jit_avx2_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp, = post_ops.find(primitive_kind::binary, 0, dw_conv_ind); jcp.with_binary = binary_ind != -1; + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise, 0, dw_conv_ind) != -1; + jcp.with_quantization = post_ops.find(primitive_kind::quantization, 0, dw_conv_ind) != -1; + if (dw_conv_ind >= 0) { // dw_conv and post_ops after it are handled externally, so skip them jcp.post_ops.entry_.assign(post_ops.entry_.cbegin(), @@ -757,8 +779,8 @@ status_t jit_avx2_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp, const auto dat_tag_nxc = utils::pick(ndims - 3, nwc, nhwc, ndhwc); const auto dat_tag_nCx8c = utils::pick(ndims - 3, nCw8c, nChw8c, nCdhw8c); - jcp.src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c); - jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c); + jcp.src_tag = src_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_nCx8c); + jcp.dst_tag = dst_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_nCx8c); const bool is_data_layout_nxc = utils::everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag); const auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c; @@ -779,14 +801,14 @@ status_t jit_avx2_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp, jcp.ic = rnd_up(jcp.ic, simd_w); } - if (jcp.with_eltwise || jcp.with_binary) + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) if (jcp.isa < avx2) return status::unimplemented; using namespace injector; static constexpr bool sum_at_pos_0_only = true; static constexpr bool sum_requires_scale_one = true; static constexpr bool sum_requires_zp_zero = true; - const bool post_ops_ok_ = post_ops_ok({avx2, {eltwise, binary, sum}, + const bool post_ops_ok_ = post_ops_ok({avx2, {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero}); if (!post_ops_ok_) return status::unimplemented; diff --git a/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.hpp b/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.hpp index b6a15d49a99..856ec301048 100644 --- a/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.hpp +++ b/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.hpp @@ -84,6 +84,12 @@ struct jit_avx2_1x1_conv_kernel_f32 : public jit_generator { constexpr static int reg_abi_param1_backup = 2 * reg64_size_; constexpr static int stack_space_needed = 3 * reg64_size_; + reg64_t reg_oc_off = load_loop_iter; + reg64_t reg_d_weights = aux_reg_bcast_data; + reg64_t reg_d_bias = reduce_loop_iter; // todo: [AV] check, conflict with out_off_oprnd (r15) + ymm_t ymm_d_weights = Xbyak::Ymm(14); + ymm_t ymm_d_bias = Xbyak::Ymm(15); + ymm_t vreg_bcast = ymm_t(15); ymm_t vtmp = ymm_t(14); diff --git a/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32_old.cpp b/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32_old.cpp new file mode 100644 index 00000000000..e3dd8623877 --- /dev/null +++ b/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32_old.cpp @@ -0,0 +1,838 @@ +/******************************************************************************* +* Copyright 2016-2020 Intel Corporation +* Copyright 2018 YANDEX LLC +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* [todo] antonvor: + * This file contains the old plugin behavior in order to fix performance + * problems after upgrading to OneDNN v1.6. This kernel is executed only on + * machines with avx2 instruction set support and in the case of a fused + * convolution. Remove after problems are fixed. +*/ + +#include + +#include "common/c_types_map.hpp" +#include "common/memory.hpp" +#include "common/memory_tracking.hpp" +#include "common/nstl.hpp" +#include "common/type_helpers.hpp" +#include "common/utils.hpp" + +#include "cpu/x64/jit_avx2_1x1_conv_kernel_f32_old.hpp" +#include "cpu/x64/jit_uni_1x1_conv_utils.hpp" + +#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +using namespace dnnl::impl::prop_kind; +using namespace dnnl::impl::format_tag; +using namespace dnnl::impl::utils; + +using namespace Xbyak; + +void jit_avx2_1x1_conv_kernel_f32_old::generate_bcast_loop(int load_loop_blk) { + mov(aux1_reg_bcast_data, reg_bcast_data); + mov(aux_reg_output_data, reg_output_data); + mov(bcast_loop_iter, reg_bcast_loop_work); + + Label bcast_loop, bcast_loop_tail; + + cmp(bcast_loop_iter, jcp.ur); + jl(bcast_loop_tail, T_NEAR); + + L(bcast_loop); { + assert(jcp.bcast_block % jcp.ur == 0); + int num_substeps = jcp.bcast_block / jcp.ur; + assert(num_substeps > 0 && num_substeps < 10); + for (int i = 0; i < num_substeps; i++) { + generate_reduce_loop(load_loop_blk, jcp.ur); + if (i < num_substeps - 1) { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_substep); + } else { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step + - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_step + - (num_substeps - 1) * jcp.bcast_loop_output_substep); + } + } + sub(bcast_loop_iter, jcp.bcast_block); + cmp(bcast_loop_iter, jcp.bcast_block); + jge(bcast_loop, T_NEAR); + } + + L(bcast_loop_tail); + if (jcp.ur_tail) { + Label bcast_loop_tail_out; + cmp(bcast_loop_iter, 0); + jz(bcast_loop_tail_out, T_NEAR); + generate_reduce_loop(load_loop_blk, jcp.ur_tail); + L(bcast_loop_tail_out); + } +} + +void jit_avx2_1x1_conv_kernel_f32_old::generate_reduce_loop( + int load_loop_blk, int ur) { + + auto vreg_load = [=](int i) { + return Ymm(ur * load_loop_blk + i); + }; + + auto vreg_accum = [=](int i, int j) { + return Ymm(j + i*ur); + }; + + auto bias_ptr = [=](int i) { + return ptr[reg_bias_data + sizeof(float) * jcp.oc_block * i]; + }; + + auto bcast_ptr = [=](int u, int j) { + assert(j < jcp.ur); + assert(u <= jcp.reduce_loop_unroll); + size_t offt; + if (one_of(jcp.prop_kind, + forward_training, forward_inference, backward_data)) + { + assert(jcp.reduce_loop_unroll == (jcp.prop_kind == backward_data) + ? jcp.oc_block : jcp.ic_block); + auto height = (jcp.prop_kind == backward_data) ? jcp.os : jcp.is; + offt = (u == jcp.reduce_loop_unroll) + ? (height + j) * jcp.reduce_loop_unroll + : j * jcp.reduce_loop_unroll + u; + } else + offt = u * jcp.ic_block + j; + return ptr[aux_reg_bcast_data + sizeof(float) * offt]; + }; + + auto load_ptr = [=](int u, int i) { + size_t offt; + size_t u0 = u % jcp.reduce_loop_unroll; + size_t u1 = u / jcp.reduce_loop_unroll; + switch (jcp.prop_kind) { + case backward_data: + offt = (i * jcp.oc_block + u0) * jcp.ic_block; + break; + case backward_weights: + offt = (i * jcp.os + u0) * jcp.oc_block; + break; + default: + offt = (i * jcp.ic + u0) * jcp.oc_block; + } + return ptr[aux_reg_load_data + + u1 * jcp.reduce_loop_load_step + sizeof(float) * offt]; + }; + + auto output_ptr = [=](int i, int j) { + switch (jcp.prop_kind) { + case backward_data: + return ptr[aux_reg_output_data + + (i * jcp.is + j) * jcp.ic_block * sizeof(float)]; + case backward_weights: + return ptr[aux_reg_output_data + + (i ? reg_output_stride * i : 0) // TODO: Xbyak should allow 0 scale + + sizeof(float) * jcp.oc_block * j]; + default: + if (jcp.with_dw_conv) { + return ptr[aux_reg_output_data + + (i * jcp_dw.kh * jcp.ow + j) * jcp.oc_block * sizeof(float)]; + } else { + return ptr[aux_reg_output_data + + (i * jcp.os + j) * jcp.oc_block * sizeof(float)]; + } + } + }; + + auto init = [=]() { + Label init_done, init_zero; + + if (jcp.with_bias && one_of(jcp.prop_kind, forward_training, + forward_inference)) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jz(init_zero); + + for (int i = 0; i < load_loop_blk; i++) + for (int j = 0; j < ur; ++j) + vmovups(vreg_accum(i, j), bias_ptr(i)); + jmp(init_done); + } + + L(init_zero); + for (int i = 0; i < load_loop_blk; ++i) + for (int j = 0; j < ur; ++j) { + auto r = vreg_accum(i, j); + vxorps(r, r, r); + } + + L(init_done); + for (int i = 0; i < load_loop_blk; ++i) + vmovups(vreg_load(i), load_ptr(0, i)); + vbroadcastss(vreg_bcast, bcast_ptr(0, 0)); + }; + + auto store = [=]() { + Label store_noadd; + + if (!jcp.with_sum) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jnz(store_noadd, T_NEAR); + } + + for (int j = 0; j < ur; ++j) + for (int i = 0; i < load_loop_blk; ++i) { + auto r = vreg_accum(i, j); + vaddps(r, r, output_ptr(i, j)); + } + + L(store_noadd); + + Label store_norelu; + test(reg_reduce_pos_flag, FLAG_REDUCE_LAST); + jz(store_norelu, T_NEAR); + + int eltwise_inj_idx = 0; + int depthwise_inj_idx = 0; + int quantization_inj_idx = 0; + std::size_t post_ops_data_offset = 0; + const auto &p = attr_.post_ops_; + + int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len(); + for (int i = 0; i < end_idx; i++) { + auto& post_op = p.entry_[i]; + if (post_op.is_eltwise()) { + eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0, ur * load_loop_blk); + eltwise_inj_idx++; + } else if (post_op.is_depthwise()) { + mov(reg_d_weights, ptr[this->rsp + post_ops_data_offset]); + add(reg_d_weights, reg_oc_off); + + for (int j = 0; j < load_loop_blk; ++j) { + int start_idx = vreg_accum(j, 0).getIdx(); + int end_idx = start_idx + ur; + + depthwise_injectors[depthwise_inj_idx]->compute_vector_range( + start_idx, end_idx, reg_d_weights, reg_d_weights); + + add(reg_d_weights, jcp.oc_block * sizeof(float)); + } + + post_ops_data_offset += depthwise_injectors[depthwise_inj_idx]->memoryStep(); + depthwise_inj_idx++; + } else if (post_op.is_quantization()) { + const Xbyak::RegExp quant_arg_base = this->rsp + post_ops_data_offset; + quantization_injectors[quantization_inj_idx]->init_crop_ptrs(quant_arg_base, reg_oc_off); + for (int ii = 0; ii < load_loop_blk; ii++) { + int s_idx = vreg_accum(ii, 0).getIdx(); + quantization_injectors[quantization_inj_idx]->compute_crop(s_idx, s_idx + ur, ii * jcp.oc_block * sizeof(float)); + } + + quantization_injectors[quantization_inj_idx]->init_input_scale_shift_ptrs(quant_arg_base, reg_oc_off); + for (int ii = 0; ii < load_loop_blk; ii++) { + int s_idx = vreg_accum(ii, 0).getIdx(); + quantization_injectors[quantization_inj_idx]->compute_input_scale_shift(s_idx, s_idx + ur, ii * jcp.oc_block * sizeof(float), true); + } + + quantization_injectors[quantization_inj_idx]->init_output_scale_shift_ptrs(quant_arg_base, reg_oc_off); + for (int ii = 0; ii < load_loop_blk; ii++) { + int s_idx = vreg_accum(ii, 0).getIdx(); + quantization_injectors[quantization_inj_idx]->compute_output_scale_shift(s_idx, s_idx + ur, ii * jcp.oc_block * sizeof(float)); + } + + post_ops_data_offset += quantization_injectors[quantization_inj_idx]->memoryStep(); + quantization_inj_idx++; + } + } + + L(store_norelu); + + for (int j = 0; j < ur; ++j) + for (int i = 0; i < load_loop_blk; ++i) { + vmovups(output_ptr(i, j), vreg_accum(i, j)); + } + }; + + auto fma_block = [=](bool last_block) { + for (int u = 0; u < jcp.reduce_loop_unroll; ++u) { + for (int j = 0; j < ur; ++j) { + for (int i = 0; i < load_loop_blk; ++i) { + if (mayiuse(avx2)) + vfmadd231ps(vreg_accum(i, j), vreg_load(i), vreg_bcast); + else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support + vmulps(vtmp, vreg_bcast, vreg_load(i)); + vaddps(vreg_accum(i, j), vreg_accum(i, j), vtmp); + } + if (j == ur - 1 && !(last_block + && u == jcp.reduce_loop_unroll - 1)) + vmovups(vreg_load(i), load_ptr(u + 1, i)); + } + if (j < ur - 1) + vbroadcastss(vreg_bcast, bcast_ptr(u, j + 1)); + } + if (!last_block || u < jcp.reduce_loop_unroll - 1) + vbroadcastss(vreg_bcast, bcast_ptr(u + 1, 0)); + } + }; + + Label reduce_loop, reduce_loop_tail; + + mov(aux_reg_load_data, reg_load_data); + mov(aux_reg_bcast_data, aux1_reg_bcast_data); + + init(); + + mov(reduce_loop_iter, reg_reduce_loop_work); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jle(reduce_loop_tail, T_NEAR); + + L(reduce_loop); { + fma_block(false); + add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jg(reduce_loop, T_NEAR); + } + + L(reduce_loop_tail); + fma_block(true); + + store(); +} + +void jit_avx2_1x1_conv_kernel_f32_old::generate_diff_bias_loop(int load_loop_blk) { + if (!jcp.with_bias || jcp.prop_kind != backward_weights) return; + + Label diff_bias_loop, diff_bias_loop_out, diff_bias_init_out; + Label diff_bias_load; + + auto diff_bias_ptr = [=](int i) { + return ptr[reg_diff_bias_data + i * jcp.oc_block * sizeof(float)]; + }; + + auto load_ptr = [=](int u, int i) { + return ptr[aux_reg_load_data + + (i * jcp.os + u) * jcp.oc_block * sizeof(float)]; + }; + + auto diff_bias_reg = [=](int i) { return Ymm(i); }; + + mov(reg_diff_bias_data, ptr[rsp + reg_diff_bias_data_stack_offt]); + cmp(reg_diff_bias_data, 0); + je(diff_bias_loop_out, T_NEAR); + + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jz(diff_bias_load, T_NEAR); + + for (int i = 0; i < load_loop_blk; ++i) { + auto r = diff_bias_reg(i); + vxorps(r, r, r); + } + jmp(diff_bias_init_out, T_NEAR); + + L(diff_bias_load); + for (int i = 0; i < load_loop_blk; ++i) + vmovups(diff_bias_reg(i), diff_bias_ptr(i)); + + L(diff_bias_init_out); + mov(aux_reg_load_data, reg_load_data); + mov(reduce_loop_iter, reg_reduce_loop_work); + L(diff_bias_loop); { + for(int u = 0; u < jcp.reduce_loop_unroll; ++u) + for (int i = 0; i < load_loop_blk; ++i) + vaddps(diff_bias_reg(i), diff_bias_reg(i), load_ptr(u, i)); + assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jnz(diff_bias_loop, T_NEAR); + } + + for (int i = 0; i < load_loop_blk; i++) + vmovups(diff_bias_ptr(i), diff_bias_reg(i)); + add(reg_diff_bias_data, load_loop_blk * jcp.oc_block * sizeof(float)); + mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); + + L(diff_bias_loop_out); +} + +void jit_avx2_1x1_conv_kernel_f32_old::generate() { + const auto &p = attr_.post_ops_; + int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len(); + for (int i = 0; i < end_idx; i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_eltwise()) { + eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32( + this, + post_op.eltwise + )); + } else if (post_op.is_depthwise()) { + depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32( + this, + post_op + )); + } else if (post_op.is_quantization()) { + quantization_injectors.push_back(new jit_uni_quantization_injector_f32( + this, + post_op, + ymm_d_weights, ymm_d_bias, reg_d_weights, reg_d_bias + )); + } + } + + preamble(); + + std::size_t post_ops_pointers_count = 0; + for (int i = 0; i < p.len(); i++) { + if (p.entry_[i].is_depthwise() || p.entry_[i].is_quantization()) { + post_ops_pointers_count++; + } + } + + if (post_ops_pointers_count != 0) { + sub(rsp, post_ops_pointers_count * sizeof(float *)); + + auto aux_reg0 = reg_bcast_data; + auto aux_reg1 = reg_load_data; + + mov(aux_reg0, ptr[this->param1 + GET_OFF(post_ops_binary_rhs_arg_vec)]); + for (size_t i = 0; i < post_ops_pointers_count; i++) { + mov(aux_reg1, ptr[aux_reg0 + i * sizeof(float *)]); + mov(ptr[rsp + i * sizeof(float *)], aux_reg1); + } + } + + mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); + mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); + mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); + if (jcp.with_bias) { + if (jcp.prop_kind == backward_weights) { + sub(rsp, stack_space_needed); + mov(reg_diff_bias_data, ptr[param1 + GET_OFF(bias_data)]); + mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); + } else + mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); + } + + mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); + mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); + mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); + mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); + if (jcp.prop_kind == backward_weights) + mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); + mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]); + + auto generate_load_loop_body = [=] (int load_loop_blk) { + generate_bcast_loop(load_loop_blk); + add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); + switch (jcp.prop_kind) { + case forward_training: + case forward_inference: + add(reg_bias_data, load_loop_blk * jcp.oc_block * sizeof(float)); + if (jcp.with_dw_conv) + add(reg_output_data, + load_loop_blk * jcp.ow * jcp.oc_block * sizeof(float)); + else + add(reg_output_data, + load_loop_blk * jcp.os * jcp.oc_block * sizeof(float)); + break; + case backward_data: + add(reg_output_data, + load_loop_blk * jcp.is * jcp.ic_block * sizeof(float)); + break; + case backward_weights: + for (int i = 0; i < load_loop_blk; i++) + add(reg_output_data, reg_output_stride); + break; + default: + assert(!"invalid prop_kind"); + } + sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + add(reg_oc_off, load_loop_blk * jcp.oc_block * sizeof(float)); + }; + + Label load_loop_blk_8; + Label load_loop_blk_16; + Label load_loop_blk_24; + Label load_loop_blk_end; + + cmp(reg_load_loop_work, 8); + jle(load_loop_blk_8, T_NEAR); + + cmp(reg_load_loop_work, 32); + je(load_loop_blk_16, T_NEAR); + + cmp(reg_load_loop_work, 16); + jle(load_loop_blk_16, T_NEAR); + + L(load_loop_blk_24); { + generate_diff_bias_loop(3); + generate_load_loop_body(3); + cmp(reg_load_loop_work, 32); + je(load_loop_blk_16); + cmp(reg_load_loop_work, 24); + jge(load_loop_blk_24); + } + + cmp(reg_load_loop_work, 8); + jle(load_loop_blk_8, T_NEAR); + + L(load_loop_blk_16); { + generate_diff_bias_loop(2); + generate_load_loop_body(2); + cmp(reg_load_loop_work, 16); + jge(load_loop_blk_16); + } + + L(load_loop_blk_8); { + cmp(reg_load_loop_work, 0); + je(load_loop_blk_end, T_NEAR); + generate_diff_bias_loop(1); + generate_load_loop_body(1); + } + + L(load_loop_blk_end); + + if (jcp.with_bias && jcp.prop_kind == backward_weights) + add(rsp, 8); + + if (post_ops_pointers_count != 0) { + add(rsp, post_ops_pointers_count * sizeof(float *)); + } + + postamble(); + + for (auto& inj : eltwise_injectors) + inj->prepare_table(); +} + +bool jit_avx2_1x1_conv_kernel_f32_old::post_ops_ok( + jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + int dw_conv_idx = p.find(primitive_kind::convolution); + bool with_dw_conv = dw_conv_idx != -1; + + auto all_post_ops_supported = [&]() { + bool ok = true; + + int end_idx = with_dw_conv ? dw_conv_idx : p.len(); + for (int i = 0; i < end_idx; i++) { + ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::sum, primitive_kind::eltwise, primitive_kind::depthwise, + primitive_kind::quantization); + } + return ok; + }; + auto contain = [&](dnnl::impl::primitive_kind_t kind) { return p.find(kind, 0, dw_conv_idx) != -1; }; + auto position = [&](dnnl::impl::primitive_kind_t kind) { return p.find(kind, 0, dw_conv_idx); }; + auto count = [&](dnnl::impl::primitive_kind_t kind) { return p.count(kind, 0, dw_conv_idx); }; + + return all_post_ops_supported() && + count(primitive_kind::sum) <= 1 && + IMPLICATION(contain(primitive_kind::sum), position(primitive_kind::sum) == 0) && + IMPLICATION(with_dw_conv, !contain(primitive_kind::sum)); +} + +status_t jit_avx2_1x1_conv_kernel_f32_old::init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr) +{ + if (!mayiuse(avx)) return status::unimplemented; + + // TODO (Roma): this code is duplicated from the generic kernel; maybe the + // configuration struct could do some stuff below + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + const int ndims = src_d.ndims(); + + jcp.prop_kind = cd.prop_kind; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2]; + jcp.ow = dst_d.dims()[ndims - 1]; + + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0]; + jcp.l_pad = cd.padding[0][ndims - 3]; + + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0]; + jcp.stride_w = cd.strides[ndims - 3]; + + const auto dat_tag_nxc = utils::pick(ndims - 3, nwc, nhwc, ndhwc); + const auto dat_tag_nCx8c = utils::pick(ndims - 3, nCw8c, nChw8c, nCdhw8c); + jcp.src_tag = src_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_nCx8c); + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + jcp.src_dt = cd.src_desc.data_type; + jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; + jcp.dst_dt = cd.dst_desc.data_type; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + + int dw_conv_ind = p.find(primitive_kind::convolution); + jcp.with_dw_conv = dw_conv_ind != -1; + + if (jcp.with_dw_conv && !mayiuse(avx2)) + return status::unimplemented; + + if (jcp.with_dw_conv) { + // dw_conv and post_ops after it are handled externally, so skip them + jcp.post_ops.entry_.assign(p.entry_.cbegin(), + p.entry_.cbegin() + dw_conv_ind); + + jcp.dw_conv_oh = jcp.oh; + jcp.dw_conv_ow = jcp.ow; + jcp.oh = p.entry_[dw_conv_ind].depthwise_conv_old.in_h; + jcp.ow = p.entry_[dw_conv_ind].depthwise_conv_old.in_w; + + jcp.dw_conv_dst_dt = jcp.dst_dt; + jcp.dst_dt = p.entry_[dw_conv_ind].depthwise_conv_old.in_dt; + } + + if (!mayiuse(avx2)) { + for (int i = 0; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_eltwise()) { + if (post_op.eltwise.alg != alg_kind::eltwise_relu) + return status::unimplemented; + } else if (post_op.is_depthwise() || post_op.is_quantization()) { + return status::unimplemented; + } + } + } + + jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1; + + jcp.os = jcp.oh * jcp.ow; + jcp.is = jcp.ih * jcp.iw; + + const int is_bwd_d = jcp.prop_kind == backward_data; + format_tag_t wei_tag = with_groups + ? utils::pick(2 * ndims - 6 + is_bwd_d, gOIw8i8o, gOIw8o8i, gOIhw8i8o, + gOIhw8o8i) + : utils::pick(2 * ndims - 6 + is_bwd_d, OIw8i8o, OIw8o8i, OIhw8i8o, + OIhw8o8i); + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + const int simd_w = 8; + + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + + jcp.dst_tag = dat_tag_nCx8c; + const bool is_data_layout_nxc + = utils::everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag); + const auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c; + bool args_ok = true && jcp.ngroups == 1 && jcp.src_tag == dat_tag + && jcp.wei_tag == wei_tag && jcp.dst_tag == dat_tag; + + if (!args_ok) return status::unimplemented; + + args_ok = true + && jcp.ih == jcp.oh && jcp.iw == jcp.ow + && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0 + && jcp.t_pad == 0 && jcp.l_pad == 0 + && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides + && jcp.kh == 1 && jcp.kw == 1; + if (!args_ok) return status::unimplemented; + + // TODO: remove this restriction + // optimized 1x1 bwd_w does not support Intel AVX + if (jcp.prop_kind == backward_weights && !mayiuse(avx2)) + return status::unimplemented; + + jcp.ic_block = jcp.oc_block = simd_w; + + jcp.ur = mayiuse(avx2) ? 4 : 3; // Intel AVX support + + int load_blocking{ 0 }; + int load_blocking_max{ 0 }; + int bcast_blocking{ 0 }; + int bcast_blocking_max{ 0 }; + int reduce_blocking{ 0 }; + + if (one_of(jcp.prop_kind, forward_training, forward_inference)) { + jcp.reduce_dim = jcp.ic; + jcp.reduce_block = jcp.ic_block; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.with_dw_conv ? jcp.iw : jcp.is; + jcp.bcast_block = jcp.ur; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.is * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float); + + jcp.bcast_loop_output_step = jcp.ur * jcp.oc_block * sizeof(float); + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_block * sizeof(float); + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_load_step = jcp.ic * jcp.oc_block * sizeof(float); + jcp.load_loop_iter_step = jcp.oc_block; + + load_blocking = jcp.with_dw_conv ? nstl::min(3 * jcp.load_block, jcp.oc) : 120; // assumes the kernel is jcp.ur x 3 + load_blocking_max = jcp.with_dw_conv ? nstl::min(3 * jcp.load_block, jcp.oc) : 144; + bcast_blocking = 128; // affects load balancing across threads + bcast_blocking_max = 192; + reduce_blocking = 128; // affects L1$ utilization + } else if (jcp.prop_kind == backward_data) { + jcp.reduce_dim = jcp.oc; + jcp.reduce_block = jcp.oc_block; + + jcp.load_dim = jcp.ic; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.os; + jcp.bcast_block = jcp.ur; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.os * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.ic * sizeof(float); + + jcp.bcast_loop_output_step = jcp.ur * jcp.ic_block * sizeof(float); + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.oc_block * sizeof(float); + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_load_step = jcp.oc_block * jcp.ic_block * sizeof(float); + jcp.load_loop_iter_step = jcp.ic_block; + + load_blocking = 96; // assumes the kernel is jcp.ur x 3 + load_blocking_max = 144; + bcast_blocking = 128; // affects load balancing across threads + bcast_blocking_max = 196; + reduce_blocking = 64; // affects L1$ utilization + } else if (jcp.prop_kind == backward_weights) { + jcp.reduce_dim = jcp.os; + jcp.reduce_block = 1; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.ic; + jcp.bcast_block = jcp.ic_block; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.ic_block * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float); + + jcp.bcast_loop_output_step = jcp.oc_block * jcp.ic_block * sizeof(float); + jcp.bcast_loop_output_substep = jcp.oc_block * jcp.ur * sizeof(float); + jcp.bcast_loop_bcast_step = jcp.ic_block * jcp.is * sizeof(float); + jcp.bcast_loop_bcast_substep = jcp.ur * sizeof(float); + + jcp.load_loop_load_step = jcp.oc_block * jcp.os * sizeof(float); + jcp.load_loop_iter_step = jcp.oc_block; + + /* --- */ + + load_blocking = div_up(jcp.load_dim, jcp.load_block); + while (true) { + if (load_blocking <= 32) break; + else if (load_blocking % 2 == 0) load_blocking /= 2; + else if (load_blocking % 3 == 0) load_blocking /= 3; + else break; + } + load_blocking *= jcp.load_block; + load_blocking_max = load_blocking; + assert(jcp.load_dim % load_blocking == 0); + + bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); + while (true) { + if (bcast_blocking <= 9) break; + else if (bcast_blocking % 2 == 0) bcast_blocking /= 2; + else if (bcast_blocking % 3 == 0) bcast_blocking /= 3; + else break; + } + bcast_blocking *= jcp.bcast_block; + bcast_blocking_max = bcast_blocking; + assert(jcp.bcast_dim % bcast_blocking == 0); + + reduce_blocking = 128; // affects L1$ utilization + } else + return status::unimplemented; + + assert(load_blocking); + assert(load_blocking_max); + assert(bcast_blocking); + assert(bcast_blocking_max); + assert(reduce_blocking); + + assert(jcp.bcast_block % jcp.ur == 0); + jcp.ur_tail = jcp.bcast_dim % jcp.ur; + + jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; + jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; + jcp.nb_load_blocking = load_blocking / jcp.load_block; + jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block; + jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block; + + jcp.nb_bcast = jcp.with_dw_conv ? jcp.ih : div_up(jcp.bcast_dim, jcp.bcast_block); + jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); + jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + return status::success; +} + +void jit_avx2_1x1_conv_kernel_f32_old::init_scratchpad( + memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw) { + using namespace dnnl::impl::memory_tracking::names; + + if (jcp.with_bias && jcp.prop_kind != backward_data + && (jcp.oc != jcp.oc_without_padding // blocked format + || (jcp.prop_kind == backward_weights // nxc format + && jcp.oc % jcp.oc_block != 0))) { + const size_t nelems_padded_bias + = jcp.ngroups * rnd_up(jcp.oc, jcp.oc_block); + scratchpad.book(key_conv_padded_bias, nelems_padded_bias); + } + + if (jcp.with_dw_conv) { + const int nthreads = dnnl_get_max_threads(); + size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * (jcp.oc / jcp.oc_block); + scratchpad.book(key_dw_conv_buffer, dw_conv_buffer_size_ * nthreads); + + if (jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_dw_conv_padded_bias, jcp.oc); + } +} + +} // namespace x64 +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32_old.hpp b/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32_old.hpp new file mode 100644 index 00000000000..9fdc7084806 --- /dev/null +++ b/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32_old.hpp @@ -0,0 +1,130 @@ +/******************************************************************************* +* Copyright 2016-2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* [todo] antonvor: + * This file contains the old plugin behavior in order to fix performance + * problems after upgrading to OneDNN v1.6. This kernel is executed only on + * machines with avx2 instruction set support and in the case of a fused + * convolution. Remove after problems are fixed. +*/ + +#ifndef CPU_X64_JIT_AVX2_1X1_CONV_KERNEL_F32_OLD_HPP +#define CPU_X64_JIT_AVX2_1X1_CONV_KERNEL_F32_OLD_HPP + +#include "common/c_types_map.hpp" +#include "common/memory.hpp" +#include "common/memory_tracking.hpp" + +#include "cpu/x64/jit_generator.hpp" +#include "cpu/x64/jit_primitive_conf.hpp" +#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" +#include "cpu/x64/injectors/jit_uni_depthwise_injector.hpp" +#include "cpu/x64/injectors/jit_uni_quantization_injector.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +struct jit_avx2_1x1_conv_kernel_f32_old : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_1x1_conv_kernel_f32_old) + + jit_avx2_1x1_conv_kernel_f32_old( + const jit_1x1_conv_conf_t &ajcp, jit_conv_conf_t ajcp_dw, const primitive_attr_t &attr) + : jcp(ajcp), jcp_dw(ajcp_dw), attr_(attr) {} + + ~jit_avx2_1x1_conv_kernel_f32_old() { + for (auto inj : eltwise_injectors) + delete inj; + eltwise_injectors.clear(); + + for (auto inj : depthwise_injectors) + delete inj; + depthwise_injectors.clear(); + + for (auto inj : quantization_injectors) + delete inj; + quantization_injectors.clear(); + } + + static bool post_ops_ok( + jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr); + + static status_t init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, const primitive_attr_t &attr); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw = jit_conv_conf_t()); + + jit_1x1_conv_conf_t jcp; + jit_conv_conf_t jcp_dw; + const primitive_attr_t &attr_; + +private: + using reg64_t = const Xbyak::Reg64; + using ymm_t = const Xbyak::Ymm; + + reg64_t reg_bcast_data = rax; + reg64_t reg_load_data = rsi; + reg64_t reg_output_data = rbx; + reg64_t aux_reg_bcast_data = rdx; + reg64_t aux1_reg_bcast_data = abi_not_param1; + reg64_t aux_reg_output_data = rbp; + reg64_t reg_load_loop_work = r9; + reg64_t reg_bcast_loop_work = r10; + reg64_t reg_reduce_loop_work = r11; + reg64_t load_loop_iter = r13; + reg64_t aux_reg_load_data = load_loop_iter; + reg64_t bcast_loop_iter = r14; + reg64_t reduce_loop_iter = r15; + reg64_t imm_addr64 = reduce_loop_iter; + reg64_t reg_reduce_pos_flag = r8; + reg64_t reg_output_stride = r12; + reg64_t reg_bias_data = r12; + reg64_t reg_diff_bias_data = bcast_loop_iter; + + reg64_t reg_oc_off = abi_param1; + reg64_t reg_d_weights = aux_reg_bcast_data; + reg64_t reg_d_bias = reduce_loop_iter; + + int reg_diff_bias_data_stack_offt = 0; + int stack_space_needed = 8; + + ymm_t vreg_bcast = ymm_t(15); + ymm_t vtmp = ymm_t(14); + + ymm_t ymm_d_weights = Xbyak::Ymm(14); + ymm_t ymm_d_bias = Xbyak::Ymm(15); + + void generate_bcast_loop(int load_loop_blk); + void generate_reduce_loop(int load_loop_blk, int ur); + void generate_diff_bias_loop(int load_loop_blk); + + nstl::vector*> eltwise_injectors; + nstl::vector*> depthwise_injectors; + nstl::vector*> quantization_injectors; + + void generate() override; +}; + +} // namespace x64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/cpu/x64/jit_avx2_1x1_convolution.cpp b/src/cpu/x64/jit_avx2_1x1_convolution.cpp index 88acfd978b3..cb2dc9cf102 100644 --- a/src/cpu/x64/jit_avx2_1x1_convolution.cpp +++ b/src/cpu/x64/jit_avx2_1x1_convolution.cpp @@ -55,6 +55,8 @@ void jit_avx2_1x1_convolution_fwd_t::execute_forward( pd()->jcp_.post_ops.entry_.size() + 1) : std::vector {}; + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + auto scratchpad = ctx.get_scratchpad_grantor(); const auto &jcp = kernel_->jcp; @@ -72,7 +74,7 @@ void jit_avx2_1x1_convolution_fwd_t::execute_forward( parallel(jcp.nthr, [&](const int ithr, const int nthr) { execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw, dst, scratchpad, post_ops_binary_rhs_arg_vec.data(), - post_ops_binary_rhs_arg_vec_dw.data()); + post_ops_binary_rhs_arg_vec_dw.data(), MB); }); if (pd()->wants_zero_pad_dst()) ctx.zero_pad_output(DNNL_ARG_DST); @@ -83,7 +85,7 @@ void jit_avx2_1x1_convolution_fwd_t::execute_forward_thr(const int ithr, const data_t *bias, const data_t *weights_dw, const data_t *bias_dw, data_t *dst, const memory_tracking::grantor_t &scratchpad, const void *post_ops_binary_rhs_arg_vec, - const void *post_ops_binary_rhs_arg_vec_dw) const { + const void *post_ops_binary_rhs_arg_vec_dw, int MB) const { const memory_desc_wrapper src_d(pd()->src_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); @@ -139,7 +141,7 @@ void jit_avx2_1x1_convolution_fwd_t::execute_forward_thr(const int ithr, int &bcast_step, int &od, int &oh, int &ow, int &id, int &ih, int &iw) { int osb {0}; - nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, nb_bcast); + nd_iterator_init(iwork, n, MB, g, jcp.ngroups, osb, nb_bcast); bcast_step = step( nb_bcast_blocking, nb_bcast - osb, nb_bcast_blocking_max); @@ -217,6 +219,8 @@ void jit_avx2_1x1_convolution_fwd_t::execute_forward_thr(const int ithr, p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; p.dst_orig = dst; + p.oc_off = oc_off_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block) * sizeof(float); + (*kernel_)(&p); }; @@ -291,6 +295,8 @@ void jit_avx2_1x1_convolution_fwd_t::execute_forward_thr(const int ithr, = post_ops_binary_rhs_arg_vec_dw; par_conv_dw.dst_orig = dst; + par_conv_dw.oc_off = ch * jcp_dw->ch_block * sizeof(float); + (*dw_jit_ker)(&par_conv_dw); for (int i = 0; i < jcp_dw->kh; ++i) @@ -314,7 +320,7 @@ void jit_avx2_1x1_convolution_fwd_t::execute_forward_thr(const int ithr, addrs.resize(jcp_dw->kh); int bcast_start {0}, bcast_end {0}, ocb_start {0}, ocb_end {0}; - balance2D(nthr, ithr, jcp.mb * jcp.ngroups * jcp_dw->oh, bcast_start, + balance2D(nthr, ithr, MB * jcp.ngroups * jcp_dw->oh, bcast_start, bcast_end, nb_oc, ocb_start, ocb_end, 1); while (ocb_start < ocb_end) { @@ -325,7 +331,7 @@ void jit_avx2_1x1_convolution_fwd_t::execute_forward_thr(const int ithr, auto bcast_iter = bcast_start; while (bcast_iter < bcast_end) { int n, g, oh_dw; - nd_iterator_init(bcast_iter, n, jcp.mb, g, jcp.ngroups, oh_dw, + nd_iterator_init(bcast_iter, n, MB, g, jcp.ngroups, oh_dw, jcp_dw->oh); if (oh_dw == 0) oh_1x1 = 0; // Reset over mb boundary const int oh_1x1_range @@ -356,7 +362,7 @@ void jit_avx2_1x1_convolution_fwd_t::execute_forward_thr(const int ithr, conv_dw(); } else { int start {0}, end {0}; - const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + const int work_amount = MB * jcp.ngroups * jcp.nb_bcast; balance211(work_amount, nthr, ithr, start, end); conv_1x1(start, end, 0, jcp.nb_load); } @@ -370,6 +376,11 @@ void jit_avx2_1x1_convolution_bwd_data_t::execute_backward_data( auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC); + const auto post_ops_binary_rhs_arg_vec + = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); + + auto MB = CTX_IN_BATCH(DNNL_ARG_DIFF_DST); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); @@ -392,7 +403,7 @@ void jit_avx2_1x1_convolution_bwd_data_t::execute_backward_data( const int os_block = jcp.bcast_block; const int nb_oc_blocking = jcp.nb_reduce_blocking; - const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + const int work_amount = MB * jcp.ngroups * jcp.nb_bcast; auto step = [](int default_step, int remaining, int tail_step) { assert(default_step <= tail_step); @@ -419,7 +430,7 @@ void jit_avx2_1x1_convolution_bwd_data_t::execute_backward_data( for (int iwork = start; iwork < end; iwork += bcast_step) { int n {0}, g {0}, osb {0}; nd_iterator_init( - iwork, n, jcp.mb, g, jcp.ngroups, osb, jcp.nb_bcast); + iwork, n, MB, g, jcp.ngroups, osb, jcp.nb_bcast); bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, jcp.nb_bcast_blocking_max); @@ -468,11 +479,15 @@ void jit_avx2_1x1_convolution_bwd_data_t::execute_backward_data( ? weights_d.blk_off(g, ocb, icb) : weights_d.blk_off(ocb, icb)]; - p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0; + p.first_last_flag = 0 | (ocb == 0 ? FLAG_REDUCE_FIRST : 0) + | (ocb + jcp.nb_reduce_blocking >= jcp.nb_reduce ? FLAG_REDUCE_LAST : 0); p.reduce_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, nb_oc_blocking * jcp.oc_block); + p.oc_off = ic_off_idx * (is_dsrc_layout_nxc ? 1 : jcp.ic_block) * sizeof(float); + p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); + (*kernel_)(&p); } diff --git a/src/cpu/x64/jit_avx2_1x1_convolution.hpp b/src/cpu/x64/jit_avx2_1x1_convolution.hpp index e2f48eb504f..1da796429c3 100644 --- a/src/cpu/x64/jit_avx2_1x1_convolution.hpp +++ b/src/cpu/x64/jit_avx2_1x1_convolution.hpp @@ -74,7 +74,11 @@ struct jit_avx2_1x1_convolution_fwd_t : public primitive_t { CHECK(jit_avx2_1x1_conv_kernel_f32::init_conf( jcp_, *conv_d, *src_d, *weights_md(), *dst_md(), *attr())); - if (jcp_.with_dw_conv) CHECK(depthwise_po_init(engine)); + if (jcp_.with_dw_conv) { + // todo: [antonvor] enable when new behavior of dw convolution fusing from oneDNN 1.6 will be supported + return status::unimplemented; + CHECK(depthwise_po_init(engine)); + } auto scratchpad = scratchpad_registry().registrar(); jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_); @@ -278,12 +282,12 @@ struct jit_avx2_1x1_convolution_fwd_t : public primitive_t { if (isa == avx2) { CHECK(safe_ptr_assign(kernel_dw_avx2, new dw_conv_kernel_t( - *(pd()->jcp_dw_), *pd()->dst_md(0)))); + *(pd()->jcp_dw_), *pd()->dst_md(0), *pd()->dw_conv_pd_->attr()))); CHECK(kernel_dw_avx2->create_kernel()); } else { CHECK(safe_ptr_assign(kernel_dw_sse41, new dw_conv_kernel_t( - *(pd()->jcp_dw_), *pd()->dst_md(0)))); + *(pd()->jcp_dw_), *pd()->dst_md(0), *pd()->dw_conv_pd_->attr()))); CHECK(kernel_dw_sse41->create_kernel()); } } @@ -305,7 +309,7 @@ struct jit_avx2_1x1_convolution_fwd_t : public primitive_t { const data_t *bias_dw, data_t *dst, const memory_tracking::grantor_t &scratchpad, const void *post_ops_binary_rhs_arg_vec, - const void *post_ops_binary_rhs_arg_vec_dw) const; + const void *post_ops_binary_rhs_arg_vec_dw, int MB) const; const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } std::unique_ptr kernel_; @@ -334,7 +338,7 @@ struct jit_avx2_1x1_convolution_bwd_data_t : public primitive_t { && set_default_alg_kind(alg_kind::convolution_direct) && expect_data_types(data_type::f32, data_type::f32, data_type::undef, data_type::f32, data_type::f32) - && attr()->has_default_values() && !has_zero_dim_memory() + && is_supported_post_ops() && !has_zero_dim_memory() && set_default_formats(); if (!ok) return status::unimplemented; @@ -369,6 +373,23 @@ struct jit_avx2_1x1_convolution_bwd_data_t : public primitive_t { return set_default_formats_common(dat_tag, wei_tag, dat_tag); } + + virtual bool is_supported_post_ops() const { + const auto &p = this->attr()->post_ops_; + if (p.len() > 1) + return false; + + auto all_post_ops_supported = [&]() { + bool ok = true; + + for (int i = 0; i < p.len(); i++) { + ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::depthwise); + } + return ok; + }; + + return all_post_ops_supported(); + } }; template diff --git a/src/cpu/x64/jit_avx2_1x1_convolution_with_dw_conv.cpp b/src/cpu/x64/jit_avx2_1x1_convolution_with_dw_conv.cpp new file mode 100644 index 00000000000..197e1c8e7a9 --- /dev/null +++ b/src/cpu/x64/jit_avx2_1x1_convolution_with_dw_conv.cpp @@ -0,0 +1,251 @@ +/******************************************************************************* +* Copyright 2016-2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* [todo] antonvor: + * This file contains the old plugin behavior in order to fix performance + * problems after upgrading to OneDNN v1.6. This kernel is executed only on + * machines with avx2 instruction set support and in the case of a fused + * convolution. Remove after problems are fixed. +*/ + +#include "common/c_types_map.hpp" +#include "common/dnnl_thread.hpp" +#include "common/type_helpers.hpp" +#include "common/utils.hpp" + +#include "cpu/x64/jit_generator.hpp" + +#include "cpu/x64/jit_avx2_1x1_convolution_with_dw_conv.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +using namespace dnnl::impl::status; +using namespace dnnl::impl::memory_tracking::names; +using namespace dnnl::impl::utils; + +#define data_blk_off(f, n, c, d, h, w) \ +((ndims == 3) ? (f).blk_off(n, c, w) \ + : ((ndims == 4) ? (f).blk_off(n, c, h, w) \ + : (f).blk_off(n, c, d, h, w))) +/* convolution forward */ + +void jit_avx2_1x1_convolution_with_dw_conv_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, DNNL_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); + + auto weights_dw = CTX_IN_MEM( + const data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS); + auto bias_dw = CTX_IN_MEM( + const data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS); + + const auto &jcp = kernel_old_->jcp; + const auto &jcp_dw = kernel_dw_->jcp; + + const auto post_ops_binary_rhs_arg_vec + = binary_injector::prepare_binary_args(jcp.post_ops, ctx); + const auto post_ops_binary_rhs_arg_vec_dw = binary_injector::prepare_binary_args(jcp_dw.post_ops, ctx, jcp.post_ops.entry_.size() + 1); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + auto scratchpad = ctx.get_scratchpad_grantor(); + + auto rtus_space = pd()->rtus_.reduce_src_ + ? scratchpad.get(key_conv_rtus_space) + : nullptr; + + const int MB = pd()->MB(); + + int ocb_work = jcp.with_dw_conv ? utils::div_up(jcp.nb_load, jcp.nb_load_blocking) : 1; + const int work_amount = MB * jcp.ngroups * ocb_work * jcp.nb_bcast; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + auto ker = [&](const int ithr, const int nthr) { + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + + auto compute_block_1x1 = [&](float* ws_p, int n, int g, int oh, int ow, int ih, int iw, int os, int os_block, int bcast_step, int ocb, int load_step, + int num_rows) { + auto rp = rtus_driver_t::call_params_t(); + auto p = jit_1x1_conv_call_s(); + + for (int h = 0; h < num_rows; h++) { + ih = nstl::max((oh + h) * jcp.stride_h - jcp.t_pad, 0); + + if ((oh + h) < 0 || (oh + h) >= jcp.ih) { + for (int chb = ocb; chb < ocb + load_step; chb++) { + memset(ws_p + (((oh + h) + 1) % jcp_dw.kh) * jcp.ow * jcp.oc_block + + (chb - ocb) * jcp_dw.kh * jcp.ow * jcp.oc_block, 0, jcp.ow * jcp.oc_block * sizeof(float)); + } + } else { + const int _ocb = g * jcp.nb_load + ocb; + + rp.iw_start = iw; + p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block); + + rp.os = p.bcast_dim; + p.load_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, load_step * jcp.oc_block); + + p.output_data = &ws_p[(((oh + h) + 1) % jcp_dw.kh) * jcp.ow * jcp.oc_block]; + + p.bias_data = &bias[_ocb * jcp.oc_block]; + + for (int icb = 0; icb < jcp.nb_reduce; icb += jcp.nb_reduce_blocking) { + p.first_last_flag = 0 + | (icb == 0 ? FLAG_REDUCE_FIRST : 0) + | (icb + jcp.nb_reduce_blocking >= jcp.nb_reduce + ? FLAG_REDUCE_LAST : 0); + + p.reduce_dim = this_block_size(icb * jcp.ic_block, jcp.ic, + jcp.nb_reduce_blocking * jcp.ic_block); + rp.icb = p.reduce_dim / jcp.reduce_block; + + p.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + + const int _icb = g * jcp.nb_reduce + icb; + if (pd()->rtus_.reduce_src_) { + rp.ws = rtus_space + + ithr * pd()->rtus_.space_per_thread_ + + _icb * jcp.is * jcp.ic_block; + + if (ocb == 0) { + rp.src = src + src_d.blk_off(n, _icb, ih, iw); + (*rtus_driver_)(&rp); + } + + p.bcast_data = rp.ws; + } else { + p.bcast_data = src + src_d.blk_off(n, _icb, ih, iw); + } + + p.oc_off = _ocb * jcp.oc_block * sizeof(float); + p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); + + (*kernel_old_)(&p); + } + } + } + }; + + auto compute_row_dw = [&](const float* ws_p, int n, int ocb, int load_step, int dst_idx) { + + for (int chb = ocb; chb < ocb + load_step; chb++) { + auto par_conv_dw = jit_conv_call_s(); + + par_conv_dw.src_row0 = &ws_p[(((dst_idx+1) - 1) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block + + (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block]; + par_conv_dw.src_row1 = &ws_p[(((dst_idx+1) - 0) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block + + (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block]; + par_conv_dw.src_row2 = &ws_p[(((dst_idx+1) + 1) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block + + (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block]; + + par_conv_dw.dst = &dst[n*jcp_dw.oc*jcp_dw.oh*jcp_dw.ow + chb*jcp_dw.ch_block*jcp_dw.oh*jcp_dw.ow + + dst_idx/jcp_dw.stride_h*jcp_dw.ow*jcp_dw.ch_block]; + + par_conv_dw.kh_padding = jcp_dw.kh; + par_conv_dw.filt = &weights_dw[chb * jcp_dw.kh * jcp_dw.kw * jcp_dw.ch_block]; + par_conv_dw.bias = &bias_dw[chb * jcp_dw.ch_block]; + par_conv_dw.ur_w = (size_t)(jcp_dw.ow); + par_conv_dw.oc_work = nstl::min((chb + 1) * jcp_dw.ch_block, (int)jcp_dw.oc) - chb*jcp_dw.ch_block; + par_conv_dw.oc_off = chb * jcp_dw.ch_block * sizeof(float); + par_conv_dw.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec_dw.data(); + + (*kernel_dw_)(&par_conv_dw); + } + }; + + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + + int start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + auto dw_conv_buffer = scratchpad.get(key_dw_conv_buffer); + size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * (jcp.oc / jcp.oc_block); + auto pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_; + + const int os_block = jcp.iw; + + int iwork = start; + while (iwork < end) { + int n{0}, g{0}, ocbb{0}, osb{0}; + nd_iterator_init(iwork, n, MB, g, jcp.ngroups, ocbb, ocb_work, osb, + jcp.nb_bcast); + int bcast_step = 1; + + const int os = osb * os_block; + const int oh = os / jcp.ow; + const int ow = os % jcp.ow; + + const int ih = nstl::max(oh * jcp.stride_h - jcp.t_pad, 0); + const int iw = nstl::max(ow * jcp.stride_w - jcp.l_pad, 0); + + int ocb = ocbb * jcp.nb_load_blocking; + + const int load_step = step(jcp.nb_load_blocking, + jcp.nb_load - ocb, jcp.nb_load_blocking_max); + + if (iwork == start || oh == 0) { + bcast_step = nstl::min(1, end - iwork); + compute_block_1x1(pbuf, n, g, oh - 1, ow, ih, iw, os, os_block, bcast_step, ocb, load_step, bcast_step + 2); + } else { + bcast_step = nstl::min(1, end - iwork); + compute_block_1x1(pbuf, n, g, oh + 1, ow, ih, iw, os, os_block, bcast_step, ocb, load_step, bcast_step); + } + + if ((oh % jcp_dw.stride_h == 0)) { + compute_row_dw(pbuf, n, ocb, load_step, oh); + } + + iwork += bcast_step; + } + }; + + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad.get(key_conv_padded_bias); + utils::array_copy(padded_bias, bias, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bias = padded_bias; + + auto dw_padded_bias = scratchpad.get(key_dw_conv_padded_bias); + utils::array_copy(dw_padded_bias, bias_dw, jcp.oc_without_padding); + utils::array_set(dw_padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bias_dw = dw_padded_bias; + } + + parallel(0, ker); + + if (pd()->wants_zero_pad_dst()) ctx.zero_pad_output(DNNL_ARG_DST); +} + +} // namespace x64 +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/x64/jit_avx2_1x1_convolution_with_dw_conv.hpp b/src/cpu/x64/jit_avx2_1x1_convolution_with_dw_conv.hpp new file mode 100644 index 00000000000..3e9426252ae --- /dev/null +++ b/src/cpu/x64/jit_avx2_1x1_convolution_with_dw_conv.hpp @@ -0,0 +1,194 @@ +/******************************************************************************* +* Copyright 2016-2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* [todo] antonvor: + * This file contains the old plugin behavior in order to fix performance + * problems after upgrading to OneDNN v1.6. This kernel is executed only on + * machines with avx2 instruction set support and in the case of a fused + * convolution. Remove after problems are fixed. +*/ + +#ifndef CPU_X64_JIT_AVX2_1X1_CONVOLUTION_WITH_DW_CONV_HPP +#define CPU_X64_JIT_AVX2_1X1_CONVOLUTION_WITH_DW_CONV_HPP + +#include "common/c_types_map.hpp" +#include "common/dnnl_thread.hpp" +#include "common/memory_tracking.hpp" +#include "common/primitive.hpp" +#include "common/primitive_hashing.hpp" +#include "common/utils.hpp" + +#include "cpu/cpu_convolution_pd.hpp" +#include "cpu/dw_convolution_utils.hpp" +#include "cpu/platform.hpp" + +#include "cpu/x64/cpu_reducer.hpp" +#include "cpu/x64/jit_avx2_1x1_conv_kernel_f32.hpp" +#include "cpu/x64/jit_uni_1x1_conv_utils.hpp" +#include "cpu/x64/jit_uni_dw_convolution.hpp" + +#include "cpu/x64/jit_avx2_1x1_conv_kernel_f32_old.hpp" +#include "cpu/x64/jit_uni_dw_conv_row_f32.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +struct jit_avx2_1x1_convolution_with_dw_conv_fwd_t : public primitive_t { + // TODO: (Roma) Code duplication duplication! Remove with templates + // (maybe...)! + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) + , jcp_(), jcp_dw_(), rtus_() {} + + pd_t(const pd_t &other) : cpu_convolution_fwd_pd_t(other) { + if (copy(other) != status::success) is_initialized_ = false; + } + + DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1_with_dw_conv:", avx2, ""), + jit_avx2_1x1_convolution_with_dw_conv_fwd_t); + + status_t init(engine_t *engine) { + using namespace prop_kind; + assert(engine->kind() == engine_kind::cpu); + bool ok = true + && this->set_default_formats() + && utils::one_of(this->desc()->prop_kind, forward_training, + forward_inference) + && utils::one_of(this->desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_direct) + && !this->has_zero_dim_memory() + && utils::everyone_is(data_type::f32, + this->desc()->src_desc.data_type, + this->desc()->weights_desc.data_type, + this->desc()->dst_desc.data_type) + && IMPLICATION(this->with_bias(), + data_type::f32 == this->desc()->bias_desc.data_type); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = this->desc(); + const memory_desc_t *src_d = src_md(); + rtus_prepare(this, conv_d, src_d, dst_md(), weights_md()); + + status_t sts_1x1 = jit_avx2_1x1_conv_kernel_f32_old::init_conf( + jcp_, *conv_d, *src_d, *weights_md(), *dst_md(), *attr()); + if (sts_1x1 != status::success) return sts_1x1; + + if (jcp_.with_dw_conv) { + status_t sts_dw = jit_uni_dw_conv_row_f32::init_conf(jcp_, jcp_dw_, *this->attr()); + if (sts_dw != status::success) return sts_dw; + } else { + return status::unimplemented; + } + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx2_1x1_conv_kernel_f32_old::init_scratchpad(scratchpad, jcp_, jcp_dw_); + + rtus_prepare_space_info(this, scratchpad, dnnl_get_max_threads()); + + return status::success; + } + + const memory_desc_t *dst_md(int index = 0) const override { + return &dst_md_; + } + + const memory_desc_t *arg_md(int index = 0) const override { + return convolution_fwd_pd_t::arg_md(index); + } + + arg_usage_t arg_usage(int arg) const override { + return convolution_fwd_pd_t::arg_usage(arg); + } + + jit_1x1_conv_conf_t jcp_; + jit_conv_conf_t jcp_dw_; + reduce_to_unit_stride_t rtus_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o, gOIdhw8i8o) + : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o, OIdhw8i8o); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + + status_t copy(const pd_t &other) { + jcp_ = other.jcp_; + rtus_ = other.rtus_; + jcp_dw_ = other.jcp_dw_; + + return status::success; + } + }; + + template + friend status_t init_rtus_driver(conv_t *self); + + jit_avx2_1x1_convolution_with_dw_conv_fwd_t(const pd_t *apd) : primitive_t(apd), + kernel_old_(nullptr), rtus_driver_(nullptr) { + kernel_old_ = new jit_avx2_1x1_conv_kernel_f32_old(pd()->jcp_, pd()->jcp_dw_, *pd()->attr()); + init_rtus_driver(this); + + if (pd()->jcp_.with_dw_conv) { + kernel_dw_ = new jit_uni_dw_conv_row_f32(pd()->jcp_dw_, *pd()->attr(), pd()->jcp_dw_.ch_block); + } + } + + status_t init(engine_t *engine) override { + CHECK(kernel_old_->create_kernel()); + if (kernel_dw_) + CHECK(kernel_dw_->create_kernel()); + return status::success; + } + + ~jit_avx2_1x1_convolution_with_dw_conv_fwd_t() { + delete kernel_old_; + delete rtus_driver_; + delete kernel_dw_; + } + + typedef typename prec_traits::type data_t; + + status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + + jit_avx2_1x1_conv_kernel_f32_old *kernel_old_; + jit_uni_dw_conv_row_f32 *kernel_dw_; + rtus_driver_t *rtus_driver_; + +}; + +} // namespace x64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/cpu/x64/jit_avx2_conv_kernel_f32.cpp b/src/cpu/x64/jit_avx2_conv_kernel_f32.cpp index dcc56f1b582..00e974129fa 100644 --- a/src/cpu/x64/jit_avx2_conv_kernel_f32.cpp +++ b/src/cpu/x64/jit_avx2_conv_kernel_f32.cpp @@ -47,7 +47,7 @@ jit_avx2_conv_fwd_kernel_f32::jit_avx2_conv_fwd_kernel_f32( : jit_generator(nullptr, MAX_CODE_SIZE, true, avx2) , jcp(ajcp) , attr_(attr) { - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; @@ -61,10 +61,12 @@ jit_avx2_conv_fwd_kernel_f32::jit_avx2_conv_fwd_kernel_f32( memory_desc_wrapper(dst_md), tail_size, use_exact_tail_scalar_bcast}; static_params_t static_params {this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params + {ymm_d_weights.getIdx(), ymm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique< injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } } @@ -201,11 +203,21 @@ void iterate(const int load_loop_blk, const int ur, const F &f) { void jit_avx2_conv_fwd_kernel_f32::apply_postops( const int oc_blocks, const int ur_w, const int oc_tail) { - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { Label regular_store; test(reg_ci_flag, FLAG_IC_LAST); je(regular_store, T_NEAR); + std::map vmm_idx_off; + iterate(oc_blocks, ur_w, [&](const bool, const int i, const int j) { + vmm_idx_off.insert({get_ymm_idx(ur_w, i, j), i * jcp.oc_block * sizeof(float)}); + }); + depthwise_injector::dynamic_params_t ddp {ymm_d_weights.getIdx(), ymm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off, + this->rsp, base_post_ops_data_offset}; + quantization_injector::dynamic_params_t qdp {ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off, jcp.dst_dt, + this->rsp, base_post_ops_data_offset}; + injector_utils::vmm_index_set_t vmm_idxs; if (jcp.with_binary) { binary_injector::rhs_arg_dynamic_params_t rhs_arg_params, @@ -237,14 +249,14 @@ void jit_avx2_conv_fwd_kernel_f32::apply_postops( jmp(postops_done, T_NEAR); L(postops_no_tail); } - postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); L(postops_done); } else { iterate(oc_blocks, ur_w, [&](const bool, const int i, const int j) { vmm_idxs.emplace(get_ymm_idx(ur_w, i, j)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t(), ddp, qdp); } L(regular_store); } @@ -258,6 +270,7 @@ void jit_avx2_conv_fwd_kernel_f32::width_blk_step( if (oc_tail) { push(reg_oc_blocks); + base_post_ops_data_offset += reg64_size; mov(reg_oc_flag, ptr[param1 + GET_OFF(oc_flag)]); } @@ -478,7 +491,10 @@ void jit_avx2_conv_fwd_kernel_f32::width_blk_step( L(store_done); } - if (oc_tail) pop(reg_oc_blocks); + if (oc_tail) { + pop(reg_oc_blocks); + base_post_ops_data_offset -= reg64_size; + } } inline void jit_avx2_conv_fwd_kernel_f32::solve_common(int oc_blocks) { @@ -533,6 +549,9 @@ inline void jit_avx2_conv_fwd_kernel_f32::solve_common(int oc_blocks) { void jit_avx2_conv_fwd_kernel_f32::generate() { this->preamble(); + if (postops_injector_) + postops_injector_->push_post_ops_data_on_stack(this->param1, GET_OFF(post_ops_binary_rhs_arg_vec), reg_input, reg_output); + mov(reg_input, ptr[this->param1 + GET_OFF(src)]); mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); @@ -569,6 +588,9 @@ void jit_avx2_conv_fwd_kernel_f32::generate() { solve_common(nb_oc_tail); } + if (postops_injector_) + postops_injector_->reset_stack_pointer(); + this->postamble(); if (jcp.with_eltwise) postops_injector_->prepare_table(); @@ -626,10 +648,10 @@ status_t jit_avx2_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh); jcp.back_pad = calculate_end_padding( jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd); - bool kernel_outside_src = false || ext_kw <= jcp.l_pad - || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad - || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad; - if (kernel_outside_src) return status::unimplemented; +// bool kernel_outside_src = false || ext_kw <= jcp.l_pad +// || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad +// || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad; +// if (kernel_outside_src) return status::unimplemented; const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc); const auto dat_tag_ncx = pick(ndims - 3, ncw, nchw, ncdhw); @@ -641,9 +663,9 @@ status_t jit_avx2_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, : pick(ndims - 3, Owi8o, Ohwi8o, Odhwi8o); jcp.src_tag - = src_d.matches_one_of_tag(dat_tag_ncx, dat_tag_nxc, dat_tag_nCx8c); + = src_d.mb_stride_relaxed_match(dat_tag_ncx, dat_tag_nxc, dat_tag_nCx8c); jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag_OIxio, wei_tag_Oxio); - jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c); + jcp.dst_tag = dst_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_nCx8c); jcp.typesize_in = types::data_type_size(src_d.data_type()); jcp.typesize_out = types::data_type_size(dst_d.data_type()); @@ -665,6 +687,8 @@ status_t jit_avx2_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, jcp.with_eltwise = eltwise_ind != -1; const int binary_ind = post_ops.find(primitive_kind::binary); jcp.with_binary = binary_ind != -1; + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise) != -1; + jcp.with_quantization = post_ops.find(primitive_kind::quantization) != -1; jcp.post_ops = post_ops; @@ -685,14 +709,14 @@ status_t jit_avx2_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, if (mimo) jcp.ic = rnd_up(jcp.ic, simd_w); } - if (jcp.with_eltwise || jcp.with_binary) + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) if (!mayiuse(avx2)) return status::unimplemented; using namespace injector; static constexpr bool sum_at_pos_0_only = true; static constexpr bool sum_requires_scale_one = true; static constexpr bool sum_requires_zp_zero = true; - const bool post_ops_ok_ = post_ops_ok({avx2, {eltwise, binary, sum}, + const bool post_ops_ok_ = post_ops_ok({avx2, {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero}); if (!post_ops_ok_) return status::unimplemented; @@ -839,6 +863,7 @@ void jit_avx2_conv_bwd_data_kernel_f32::compute_loop( } if (oc_tail) { + base_post_ops_data_offset += reg64_size; push(reg_long_offt); mov(reg_reduce_work, ptr[param1 + GET_OFF(reduce_work)]); } @@ -858,6 +883,7 @@ void jit_avx2_conv_bwd_data_kernel_f32::compute_loop( if (jcp.ndims == 5) { assert(jcp.nb_oc_blocking == 1); + base_post_ops_data_offset += reg64_size; push(oi_iter); mov(reg_ki, ptr[this->param1 + GET_OFF(kd_padding)]); @@ -949,6 +975,7 @@ void jit_avx2_conv_bwd_data_kernel_f32::compute_loop( L(skip_kd_loop); pop(oi_iter); + base_post_ops_data_offset -= reg64_size; } if (one_of(jcp.ndims, 3, 4)) { @@ -967,11 +994,14 @@ void jit_avx2_conv_bwd_data_kernel_f32::compute_loop( mov(reg_channel, ptr[param1 + GET_OFF(channel)]); } - if (oc_tail) pop(reg_long_offt); + if (oc_tail) { + pop(reg_long_offt); + base_post_ops_data_offset -= reg64_size; + } auto load_store_dsrc = [=](bool is_tail) { mov(reg_channel, ptr[param1 + GET_OFF(channel)]); - Label no_update_label; + Label no_update_label, skip_post_ops; cmp(reg_channel, 0); je(no_update_label, T_NEAR); @@ -987,8 +1017,37 @@ void jit_avx2_conv_bwd_data_kernel_f32::compute_loop( vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), Ymm(15)); } + jmp(skip_post_ops, T_NEAR); + L(no_update_label); + const auto &p = attr_.post_ops_; + std::size_t post_ops_data_offset = 0; + int depthwise_inj_idx = 0; + for (int i = 0; i < p.len(); i++) { + auto& post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + base_post_ops_data_offset += reg64_size; + push(reg_d_weights); + + mov(reg_d_weights, ptr[this->rsp + base_post_ops_data_offset + post_ops_data_offset]); + add(reg_d_weights, ptr[this->param1 + GET_OFF(ic_off)]); + + for (int ii = 0; ii < nb_ic_block; ii++) { + depthwise_injectors[depthwise_inj_idx]->compute_vector_range( + ur_w * ii, ur_w * ii + ur_w, reg_d_weights, reg_d_weights); + + add(reg_d_weights, jcp.ic_block * sizeof(float)); + } + pop(reg_d_weights); + base_post_ops_data_offset -= reg64_size; + + post_ops_data_offset += depthwise_injectors[depthwise_inj_idx]->memoryStep(); + depthwise_inj_idx++; + } + } + L(skip_post_ops); + for (int ii = 0; ii < nb_ic_block; ii++) for (int jj = 0; jj < ur_w; jj++) { if (is_tail && ii == nb_ic_block - 1) @@ -1020,8 +1079,39 @@ void jit_avx2_conv_bwd_data_kernel_f32::compute_loop( } void jit_avx2_conv_bwd_data_kernel_f32::generate() { + const auto &p = attr_.post_ops_; + for (int i = 0; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32( + this, + post_op + )); + } + } + preamble(); + std::size_t post_ops_pointers_count = 0; + for (int i = 0; i < p.len(); i++) { + if (p.entry_[i].is_depthwise() || p.entry_[i].is_quantization()) { + post_ops_pointers_count++; + } + } + + if (post_ops_pointers_count != 0) { + sub(rsp, post_ops_pointers_count * sizeof(float *)); + + auto aux_reg0 = reg_dsrc; + auto aux_reg1 = reg_ddst; + + mov(aux_reg0, ptr[this->param1 + GET_OFF(post_ops_binary_rhs_arg_vec)]); + for (size_t i = 0; i < post_ops_pointers_count; i++) { + mov(aux_reg1, ptr[aux_reg0 + i * sizeof(float *)]); + mov(ptr[rsp + i * sizeof(float *)], aux_reg1); + } + } + mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]); mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]); mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); @@ -1081,13 +1171,35 @@ void jit_avx2_conv_bwd_data_kernel_f32::generate() { if (jcp.ur_w_tail != 0) compute_loop(jcp.ur_w_tail, 0, r_overflow); } + if (post_ops_pointers_count != 0) { + add(rsp, post_ops_pointers_count * sizeof(float *)); + } + this->postamble(); } +bool jit_avx2_conv_bwd_data_kernel_f32::post_ops_ok(const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + if (p.len() > 1) + return false; + + auto all_post_ops_supported = [&]() { + bool ok = true; + + for (int i = 0; i < p.len(); i++) { + ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::depthwise); + } + return ok; + }; + + return all_post_ops_supported(); +} + status_t jit_avx2_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp, const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &diff_dst_d) { + const memory_desc_wrapper &diff_dst_d, + const primitive_attr_t &attr) { if (!mayiuse(avx2)) return status::unimplemented; jcp.nthr = dnnl_get_max_threads(); @@ -1134,6 +1246,20 @@ status_t jit_avx2_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp, const int simd_w = 8; + if (!post_ops_ok(attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + if (!mayiuse(avx2)) { + for (int i = 0; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + return status::unimplemented; + } + } + } + jcp.post_ops = p; + /* derivatives */ jcp.idp = jcp.id + 2 * jcp.f_pad; jcp.ihp = jcp.ih + 2 * jcp.t_pad; @@ -1147,8 +1273,8 @@ status_t jit_avx2_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp, ? pick(ndims - 3, gOIw8o8i, gOIhw8o8i, gOIdhw8o8i) : pick(ndims - 3, OIw8o8i, OIhw8o8i, OIdhw8o8i); - jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c); - jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c); + jcp.src_tag = diff_src_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_nCx8c); + jcp.dst_tag = diff_dst_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_nCx8c); jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); jcp.typesize_in = types::data_type_size(diff_src_d.data_type()); diff --git a/src/cpu/x64/jit_avx2_conv_kernel_f32.hpp b/src/cpu/x64/jit_avx2_conv_kernel_f32.hpp index 8cc89dad00c..7078ae38a56 100644 --- a/src/cpu/x64/jit_avx2_conv_kernel_f32.hpp +++ b/src/cpu/x64/jit_avx2_conv_kernel_f32.hpp @@ -80,6 +80,14 @@ struct jit_avx2_conv_fwd_kernel_f32 : public jit_generator { Xbyak::Ymm ytmp = Xbyak::Ymm(14); + reg64_t reg_d_weights = imm_addr64; + reg64_t reg_d_bias = ki_iter; + + Xbyak::Ymm ymm_d_weights = Xbyak::Ymm(14); + Xbyak::Ymm ymm_d_bias = Xbyak::Ymm(15); + int base_post_ops_data_offset = 0; + constexpr static int reg64_size = 8; + inline void oh_step_unroll_kw( int ur_w, int pad_l, int pad_r, int oc_blocks); inline void oh_step_nopad(int ur_w, int pad_l, int pad_r, int oc_blocks); @@ -143,17 +151,26 @@ struct jit_avx2_conv_fwd_kernel_f32 : public jit_generator { struct jit_avx2_conv_bwd_data_kernel_f32 : public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_bwd_data_kernel_f32) - jit_avx2_conv_bwd_data_kernel_f32(const jit_conv_conf_t &ajcp) - : jcp(ajcp) {} + jit_avx2_conv_bwd_data_kernel_f32(const jit_conv_conf_t &ajcp, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr) {} + + ~jit_avx2_conv_bwd_data_kernel_f32() { + for (auto inj : depthwise_injectors) + delete inj; + depthwise_injectors.clear(); + } + static bool post_ops_ok(const primitive_attr_t &attr); static status_t init_conf(jit_conv_conf_t &jcp, const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &diff_dst_d); + const memory_desc_wrapper &diff_dst_d, + const primitive_attr_t &attr); static void init_scratchpad(memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp); jit_conv_conf_t jcp; + const primitive_attr_t &attr_; private: using reg64_t = const Xbyak::Reg64; @@ -180,6 +197,13 @@ struct jit_avx2_conv_bwd_data_kernel_f32 : public jit_generator { reg64_t reg_reduce_work = reg_long_offt; Xbyak::Reg32 reg_ci_flag = r13d; // used for nxc tails + reg64_t reg_d_weights = r15; + reg64_t reg_d_bias = rbp; + int base_post_ops_data_offset = 0; + constexpr static int reg64_size = 8; + + nstl::vector*> depthwise_injectors; + inline void compute_loop(int ur_w, int l_overflow, int r_overflow); void generate() override; diff --git a/src/cpu/x64/jit_avx2_convolution.cpp b/src/cpu/x64/jit_avx2_convolution.cpp index 895a63b5e52..8a90a4818ff 100644 --- a/src/cpu/x64/jit_avx2_convolution.cpp +++ b/src/cpu/x64/jit_avx2_convolution.cpp @@ -53,6 +53,8 @@ void jit_avx2_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { const auto post_ops_binary_rhs_arg_vec = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const memory_desc_wrapper src_d(pd()->src_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); @@ -60,7 +62,7 @@ void jit_avx2_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { const size_t ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking); const size_t work_amount - = jcp.mb * jcp.ngroups * ocb_work * jcp.od * jcp.oh; + = MB * jcp.ngroups * ocb_work * jcp.od * jcp.oh; auto ker = [&](const int ithr, const int nthr) { size_t start {0}, end {0}; @@ -84,7 +86,7 @@ void jit_avx2_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { if (icb_step_rem < jcp.nb_ic_blocking_max) icb_step = icb_step_rem; size_t n {0}, g {0}, ocbb {0}, oh {0}, od {0}; - nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, + nd_iterator_init(start, n, MB, g, jcp.ngroups, ocbb, ocb_work, od, jcp.od, oh, jcp.oh); for (size_t iwork = start; iwork < end; ++iwork) { int ocb = ocbb * jcp.nb_oc_blocking; @@ -139,7 +141,7 @@ void jit_avx2_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { par_conv.flags |= FLAG_IC_FIRST; } - if ((jcp.with_eltwise || jcp.with_binary) + if ((jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) && icb + 1 == jcp.nb_ic) par_conv.flags |= FLAG_IC_LAST; @@ -166,10 +168,11 @@ void jit_avx2_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { par_conv.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); par_conv.dst_orig = dst; + par_conv.oc_off = _oc * oc_bias_scale * sizeof(float); (*kernel_)(&par_conv); } - nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, od, + nd_iterator_step(n, MB, g, jcp.ngroups, ocbb, ocb_work, od, jcp.od, oh, jcp.oh); } icbb += icb_step; @@ -195,6 +198,10 @@ void jit_avx2_convolution_bwd_data_t::execute_backward_data( auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC); + const auto post_ops_binary_rhs_arg_vec + = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); + + auto MB = CTX_IN_BATCH(DNNL_ARG_DIFF_DST); const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); @@ -205,7 +212,7 @@ void jit_avx2_convolution_bwd_data_t::execute_backward_data( int icb_work = jcp.nb_ic / jcp.nb_ic_blocking; int ih_block_size = jcp.ih; int num_ih_blocks = utils::div_up(jcp.ih, ih_block_size); - size_t work_amount = jcp.mb * jcp.ngroups * icb_work * num_ih_blocks; + size_t work_amount = MB * jcp.ngroups * icb_work * num_ih_blocks; const auto data_size = sizeof(data_t); const auto L2 = platform::get_per_core_cache_size(2) / data_size; @@ -244,7 +251,7 @@ void jit_avx2_convolution_bwd_data_t::execute_backward_data( balance211(work_amount, nthr, ithr, start, end); size_t n {0}, g {0}, icbb {0}, ihb {0}; - nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, icbb, icb_work, ihb, + nd_iterator_init(start, n, MB, g, jcp.ngroups, icbb, icb_work, ihb, num_ih_blocks); for (size_t iwork = start; iwork < end; ++iwork) { for_(int oc = 0; oc < jcp.nb_oc; oc += jcp.nb_oc_blocking) @@ -336,10 +343,14 @@ void jit_avx2_convolution_bwd_data_t::execute_backward_data( par_conv.flags |= FLAG_IC_LAST; } + par_conv.ic_off = (g * jcp.nb_ic + jcp.nb_ic_blocking * icbb) * jcp.ic_block * sizeof(float); + par_conv.post_ops_binary_rhs_arg_vec + = post_ops_binary_rhs_arg_vec.data(); + (*kernel_)(&par_conv); } } - nd_iterator_step(n, jcp.mb, g, jcp.ngroups, icbb, icb_work, ihb, + nd_iterator_step(n, MB, g, jcp.ngroups, icbb, icb_work, ihb, num_ih_blocks); } }; diff --git a/src/cpu/x64/jit_avx2_convolution.hpp b/src/cpu/x64/jit_avx2_convolution.hpp index 2fe616f300a..08ed3015484 100644 --- a/src/cpu/x64/jit_avx2_convolution.hpp +++ b/src/cpu/x64/jit_avx2_convolution.hpp @@ -121,12 +121,12 @@ struct jit_avx2_convolution_bwd_data_t : public primitive_t { && set_default_alg_kind(alg_kind::convolution_direct) && expect_data_types(data_type::f32, data_type::f32, data_type::undef, data_type::f32, data_type::f32) - && attr()->has_default_values() && !has_zero_dim_memory() + && !has_zero_dim_memory() && set_default_formats(); if (!ok) return status::unimplemented; status_t status = jit_avx2_conv_bwd_data_kernel_f32::init_conf(jcp_, - *desc(), *diff_src_md(), *weights_md(), *diff_dst_md()); + *desc(), *diff_src_md(), *weights_md(), *diff_dst_md(), *attr()); if (status != status::success) return status; auto scratchpad = scratchpad_registry().registrar(); @@ -157,7 +157,7 @@ struct jit_avx2_convolution_bwd_data_t : public primitive_t { status_t init(engine_t *engine) override { CHECK(safe_ptr_assign( - kernel_, new jit_avx2_conv_bwd_data_kernel_f32(pd()->jcp_))); + kernel_, new jit_avx2_conv_bwd_data_kernel_f32(pd()->jcp_, *pd()->attr()))); return kernel_->create_kernel(); } diff --git a/src/cpu/x64/jit_avx512_common_1x1_conv_kernel.cpp b/src/cpu/x64/jit_avx512_common_1x1_conv_kernel.cpp index 340abc75873..279bb6de084 100644 --- a/src/cpu/x64/jit_avx512_common_1x1_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_common_1x1_conv_kernel.cpp @@ -51,7 +51,7 @@ jit_avx512_common_1x1_conv_kernel::jit_avx512_common_1x1_conv_kernel( const jit_1x1_conv_conf_t &ajcp, const primitive_attr_t &attr, const memory_desc_t &dst_md) : jcp(ajcp), attr_(attr) { - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; @@ -66,10 +66,12 @@ jit_avx512_common_1x1_conv_kernel::jit_avx512_common_1x1_conv_kernel( use_exact_tail_scalar_bcast}; const static_params_t static_params { this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params + {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique< injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } } @@ -215,6 +217,16 @@ static void iterate(const int load_loop_blk, const int ur, const F &fun) { void jit_avx512_common_1x1_conv_kernel::apply_postops( const bool is_out_layout_nxc, const int load_loop_blk, const int ur) { + std::map vmm_idx_off; + iterate(load_loop_blk, ur, + [&](const bool, const int i_load, const int i_ur) { + vmm_idx_off.insert({vreg_accum_idx(load_loop_blk, i_load, i_ur), i_load * jcp.load_block * sizeof(float)}); + }); + + depthwise_injector::dynamic_params_t ddp {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + reg_oc_off, vmm_idx_off, this->rsp, base_post_ops_data_offset}; + quantization_injector::dynamic_params_t qdp {reg_oc_off, vmm_idx_off, jcp.dst_dt, this->rsp, base_post_ops_data_offset}; + injector_utils::vmm_index_set_t vmm_idxs; if (jcp.with_binary) { binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; @@ -235,14 +247,14 @@ void jit_avx512_common_1x1_conv_kernel::apply_postops( mov(abi_param1, ptr[rsp + reg_abi_param1_backup]); - postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); } else { iterate(load_loop_blk, ur, [&](const bool, const int i_load, const int i_ur) { vmm_idxs.emplace( vreg_accum_idx(load_loop_blk, i_load, i_ur)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t(), ddp, qdp); } } @@ -362,7 +374,7 @@ void jit_avx512_common_1x1_conv_kernel::reduce_loop( L(store_noadd); - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { Label store_nopostops; test(reg_reduce_pos_flag, FLAG_REDUCE_LAST); jz(store_nopostops, T_NEAR); @@ -591,7 +603,12 @@ void jit_avx512_common_1x1_conv_kernel::reduce_loop( void jit_avx512_common_1x1_conv_kernel::generate() { preamble(); + if (postops_injector_) + postops_injector_->push_post_ops_data_on_stack(this->param1, GET_OFF(post_ops_binary_rhs_arg_vec), reg_bcast_data, reg_load_data); + sub(rsp, stack_space_needed); + base_post_ops_data_offset += stack_space_needed; + if (jcp.with_binary) { const auto zeroed_reg = r15; xor_(zeroed_reg, zeroed_reg); @@ -612,6 +629,7 @@ void jit_avx512_common_1x1_conv_kernel::generate() { mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); if (jcp.prop_kind == backward_weights) mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); + mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]); const int load_dim_tail = (one_of(jcp.prop_kind, forward_training, forward_inference) @@ -636,6 +654,7 @@ void jit_avx512_common_1x1_conv_kernel::generate() { } bcast_loop(load_loop_blk); add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); + add(reg_oc_off, load_loop_blk * jcp.oc_block * sizeof(float)); switch (jcp.prop_kind) { case forward_training: case forward_inference: @@ -731,6 +750,10 @@ void jit_avx512_common_1x1_conv_kernel::generate() { L(load_loop_blk[num_ur_cases]); add(rsp, stack_space_needed); + base_post_ops_data_offset -= stack_space_needed; + + if (postops_injector_) + postops_injector_->reset_stack_pointer(); postamble(); @@ -804,6 +827,8 @@ status_t jit_avx512_common_1x1_conv_kernel::init_conf(jit_1x1_conv_conf_t &jcp, const int binary_ind = post_ops.find(primitive_kind::binary, 0, dw_conv_ind); jcp.with_binary = binary_ind != -1; + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise, 0, dw_conv_ind) != -1; + jcp.with_quantization = post_ops.find(primitive_kind::quantization, 0, dw_conv_ind) != -1; if (dw_conv_ind >= 0) { // dw_conv and post_ops after it are handled externally, so skip them @@ -815,8 +840,8 @@ status_t jit_avx512_common_1x1_conv_kernel::init_conf(jit_1x1_conv_conf_t &jcp, const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc); const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); - jcp.src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c); - jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c); + jcp.src_tag = src_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_nCx16c); + jcp.dst_tag = dst_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_nCx16c); bool is_data_layout_nxc = utils::everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag); if (mayiuse(avx512_mic) && is_data_layout_nxc) return status::unimplemented; @@ -835,7 +860,7 @@ status_t jit_avx512_common_1x1_conv_kernel::init_conf(jit_1x1_conv_conf_t &jcp, static constexpr bool sum_requires_scale_one = true; static constexpr bool sum_requires_zp_zero = true; const bool post_ops_ok_ = post_ops_ok({avx512_common, - {eltwise, binary, sum}, jcp.post_ops, &dst_d, sum_at_pos_0_only, + {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero}); if (!post_ops_ok_) return status::unimplemented; diff --git a/src/cpu/x64/jit_avx512_common_1x1_conv_kernel.hpp b/src/cpu/x64/jit_avx512_common_1x1_conv_kernel.hpp index 05fc00110cf..a8eaa16d9af 100644 --- a/src/cpu/x64/jit_avx512_common_1x1_conv_kernel.hpp +++ b/src/cpu/x64/jit_avx512_common_1x1_conv_kernel.hpp @@ -67,7 +67,7 @@ struct jit_avx512_common_1x1_conv_kernel : public jit_generator { reg64_t reg_load_loop_work = rsi; reg64_t reg_reduce_loop_work = r11; reg64_t reg_bcast_loop_iter = rdx; - reg64_t reduce_loop_iter = abi_param1; + reg64_t reduce_loop_iter = r13; reg64_t reg_reduce_pos_flag = rax; reg64_t reg_output_stride = r13; reg64_t reg_bias_data = r12; @@ -85,6 +85,14 @@ struct jit_avx512_common_1x1_conv_kernel : public jit_generator { constexpr static int reg_abi_param1_backup = 2 * reg64_size_; constexpr static int stack_space_needed = 3 * reg64_size_; + reg64_t reg_oc_off = abi_param1; + reg64_t reg_d_weights = imm_addr64; + reg64_t reg_d_bias = r13; + int base_post_ops_data_offset = 0; + + Xbyak::Zmm zmm_d_weights = Xbyak::Zmm(31); + Xbyak::Zmm zmm_d_bias = Xbyak::Zmm(30); + void bcast_loop(int load_loop_blk); void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound); diff --git a/src/cpu/x64/jit_avx512_common_1x1_convolution.cpp b/src/cpu/x64/jit_avx512_common_1x1_convolution.cpp index b47bac16237..d376ced4052 100644 --- a/src/cpu/x64/jit_avx512_common_1x1_convolution.cpp +++ b/src/cpu/x64/jit_avx512_common_1x1_convolution.cpp @@ -58,6 +58,8 @@ void jit_avx512_common_1x1_convolution_fwd_tjcp_.post_ops.entry_.size() + 1) : std::vector {}; + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + auto scratchpad = ctx.get_scratchpad_grantor(); if (pd()->wants_padded_bias()) { @@ -72,7 +74,7 @@ void jit_avx512_common_1x1_convolution_fwd_twants_zero_pad_dst()) ctx.zero_pad_output(DNNL_ARG_DST); @@ -86,7 +88,7 @@ void jit_avx512_common_1x1_convolution_fwd_tsrc_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); @@ -143,7 +145,7 @@ void jit_avx512_common_1x1_convolution_fwd_tjcp_; + const auto post_ops_binary_rhs_arg_vec + = binary_injector::prepare_binary_args(jcp.post_ops, ctx); + + auto MB = CTX_IN_BATCH(DNNL_ARG_DIFF_DST); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); - const auto &jcp = kernel_->jcp; auto rtus_space = pd()->rtus_.reduce_src_ ? ctx.get_scratchpad_grantor().template get( key_conv_rtus_space) @@ -470,7 +480,7 @@ void jit_avx512_common_1x1_convolution_bwd_data_t= jcp.nb_reduce ? FLAG_REDUCE_LAST : 0); p.reduce_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, nb_oc_blocking_step * jcp.oc_block); + p.oc_off = ic_off_idx * (is_dsrc_layout_nxc ? 1 : jcp.ic_block) * sizeof(float); + p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); + (*kernel_)(&p); } if (pd()->rtus_.reduce_src_) (*rtus_driver_)(&rp); diff --git a/src/cpu/x64/jit_avx512_common_1x1_convolution.hpp b/src/cpu/x64/jit_avx512_common_1x1_convolution.hpp index d5f30e47c50..829852fbc21 100644 --- a/src/cpu/x64/jit_avx512_common_1x1_convolution.hpp +++ b/src/cpu/x64/jit_avx512_common_1x1_convolution.hpp @@ -249,7 +249,7 @@ struct jit_avx512_common_1x1_convolution_fwd_t : public primitive_t { if (pd()->jcp_.with_dw_conv) { CHECK(safe_ptr_assign(kernel_dw_, new dw_conv_kernel_t( - pd()->dw_conv_pd_->jcp_, *pd()->dst_md(0)))); + pd()->dw_conv_pd_->jcp_, *pd()->dst_md(0), *pd()->dw_conv_pd_->attr()))); CHECK(kernel_dw_->create_kernel()); } @@ -270,7 +270,7 @@ struct jit_avx512_common_1x1_convolution_fwd_t : public primitive_t { const dst_data_t *bias_dw, dst_data_t *dst, const memory_tracking::grantor_t &scratchpad, const void *post_ops_binary_rhs_arg_vec, - const void *post_ops_binary_rhs_arg_vec_dw) const; + const void *post_ops_binary_rhs_arg_vec_dw, int MB) const; const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } std::unique_ptr kernel_; @@ -301,8 +301,9 @@ struct jit_avx512_common_1x1_convolution_bwd_data_t : public primitive_t { && set_default_alg_kind(alg_kind::convolution_direct) && expect_data_types(diff_src_type, wei_type, data_type::undef, diff_dst_type, data_type::undef) - && attr()->has_default_values() && !has_zero_dim_memory() - && set_default_formats(); + && !has_zero_dim_memory() + && set_default_formats() + && is_supported_post_ops(); if (!ok) return status::unimplemented; const convolution_desc_t *conv_d = desc(); @@ -338,6 +339,23 @@ struct jit_avx512_common_1x1_convolution_bwd_data_t : public primitive_t { return set_default_formats_common(dat_tag, wei_tag, dat_tag); } + + bool is_supported_post_ops() const { + const auto &p = this->attr()->post_ops_; + if (p.len() > 1) + return false; + + auto all_post_ops_supported = [&]() { + bool ok = true; + + for (int i = 0; i < p.len(); i++) { + ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::depthwise); + } + return ok; + }; + + return all_post_ops_supported(); + } }; template diff --git a/src/cpu/x64/jit_avx512_common_conv_kernel.cpp b/src/cpu/x64/jit_avx512_common_conv_kernel.cpp index d362548c460..84c9695a8a2 100644 --- a/src/cpu/x64/jit_avx512_common_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_common_conv_kernel.cpp @@ -74,10 +74,7 @@ inline status_t init_tag(format_tag_t &tag, memory_desc_t &md, } inline bool is_1stconv(const jit_conv_conf_t &jcp) { - if (mayiuse(avx512_core)) - return (jcp.ic < 16 && jcp.ngroups == 1); - else - return one_of(jcp.ic, 1, 3); + return one_of(jcp.ic, 1, 2, 3); } inline bool is_ow_threading_on(const jit_conv_conf_t &jcp) { @@ -99,7 +96,7 @@ _jit_avx512_common_conv_fwd_kernel::_jit_avx512_common_conv_fwd_kernel( const jit_conv_conf_t &ajcp, const primitive_attr_t &attr, const memory_desc_t &dst_md) : jcp(ajcp), attr_(attr) { - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; @@ -114,10 +111,12 @@ _jit_avx512_common_conv_fwd_kernel::_jit_avx512_common_conv_fwd_kernel( use_exact_tail_scalar_bcast}; const binary_injector::static_params_t static_params { this->param1, rhs_args_static_params}; + quantization_injector::static_params_t quantization_static_params + {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique< injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } } @@ -152,6 +151,17 @@ static void iterate(const int nb_oc_blocking, const int ur_w, const F &fun) { template void _jit_avx512_common_conv_fwd_kernel::apply_postops(int ur_w) { + std::map vmm_idx_off; + iterate(jcp.nb_oc_blocking, ur_w, + [&](const bool, const int i_load, const int i_ur) { + vmm_idx_off.insert({vmm_out_idx(i_ur, i_load), i_load * jcp.oc_block * sizeof(float)}); + }); + depthwise_injector::dynamic_params_t ddp {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off, + this->rsp, base_post_ops_data_offset}; + quantization_injector::dynamic_params_t qdp {ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off, jcp.dst_dt, + this->rsp, base_post_ops_data_offset}; + injector_utils::vmm_index_set_t vmm_idxs; if (jcp.with_binary) { binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; @@ -172,13 +182,13 @@ void _jit_avx512_common_conv_fwd_kernel::apply_postops(int ur_w) { } }); - postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); } else { iterate(jcp.nb_oc_blocking, ur_w, [&](const bool, const int i_load, const int i_ur) { vmm_idxs.emplace(vmm_out_idx(i_ur, i_load)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t(), ddp, qdp); } } @@ -258,7 +268,7 @@ void _jit_avx512_common_conv_fwd_kernel::store_output(int ur_w) { L(post_ops_label); - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { auto _jmp = [&](const Label &l) { return mayiuse(avx512_mic) ? jl(l, T_NEAR) : jz(l, T_NEAR); }; @@ -1016,7 +1026,10 @@ void _jit_avx512_common_conv_fwd_kernel::compute_loop_fma_core( template void _jit_avx512_common_conv_fwd_kernel::compute_loop( int ur_w, int pad_l, int pad_r) { - if (jcp.ndims == 5) push(reg_oi); + if (jcp.ndims == 5) { + push(reg_oi); + base_post_ops_data_offset += reg64_size; + } prepare_output(ur_w); @@ -1080,7 +1093,10 @@ void _jit_avx512_common_conv_fwd_kernel::compute_loop( L(skip_compute_loop); store_output(ur_w); - if (jcp.ndims == 5) pop(reg_oi); + if (jcp.ndims == 5) { + pop(reg_oi); + base_post_ops_data_offset -= reg64_size; + } } template @@ -1104,6 +1120,10 @@ void _jit_avx512_common_conv_fwd_kernel::generate() { * (is_dst_layout_nxc() ? jcp.ngroups * jcp.oc : jcp.oc_block); preamble(); + + if (postops_injector_) + postops_injector_->push_post_ops_data_on_stack(this->param1, GET_OFF(post_ops_binary_rhs_arg_vec), reg_inp, reg_out); + mov(reg_inp, ptr[param1 + GET_OFF(src)]); mov(reg_out, ptr[param1 + GET_OFF(dst)]); mov(reg_ker, ptr[param1 + GET_OFF(filt)]); @@ -1159,16 +1179,16 @@ void _jit_avx512_common_conv_fwd_kernel::generate() { compute_loop(ur_w_tail, 0, r_pad); } } else { - xor_(reg_oi, reg_oi); if (l_pad > 0) { + n_oi--; add(reg_inp_prf, inp_shift_pad); add(reg_out_prf, out_shift); compute_loop(ur_w, l_pad, 0); add(reg_inp, inp_shift_pad); add(reg_out, out_shift); - inc(reg_oi); } - if ((l_pad <= 0 && n_oi > 0) || (l_pad > 0 && n_oi > 1)) { + if (n_oi > 0) { + xor_(reg_oi, reg_oi); Label ow_loop_label; L(ow_loop_label); { @@ -1323,6 +1343,10 @@ void _jit_avx512_common_conv_fwd_kernel::generate() { } L(end_label); } + + if (postops_injector_) + postops_injector_->reset_stack_pointer(); + postamble(); if (jcp.with_eltwise) postops_injector_->prepare_table(); @@ -1385,19 +1409,19 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(jit_conv_conf_t &jcp, jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh); jcp.back_pad = calculate_end_padding( jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd); - bool kernel_outside_src = false || ext_kw <= jcp.l_pad - || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad - || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad; - if (kernel_outside_src) return status::unimplemented; +// bool kernel_outside_src = false || ext_kw <= jcp.l_pad +// || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad +// || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad; +// if (kernel_outside_src) return status::unimplemented; const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc); const auto dat_tag_ncx = pick(ndims - 3, ncw, nchw, ncdhw); const auto dat_tag_nCx4c = pick(ndims - 3, nCw4c, nChw4c, nCdhw4c); const auto dat_tag_nCx8c = pick(ndims - 3, nCw8c, nChw8c, nCdhw8c); const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); - auto curr_src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c, + auto curr_src_tag = src_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c, dat_tag_ncx); - auto curr_dst_tag = dst_d.matches_one_of_tag( + auto curr_dst_tag = dst_d.mb_stride_relaxed_match( dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c); bool is_data_layout_nxc = utils::everyone_is(dat_tag_nxc, curr_src_tag, curr_dst_tag); @@ -1491,6 +1515,8 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(jit_conv_conf_t &jcp, } const int binary_ind = post_ops.find(primitive_kind::binary); jcp.with_binary = binary_ind != -1; + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise) != -1; + jcp.with_quantization = post_ops.find(primitive_kind::quantization) != -1; jcp.post_ops = post_ops; @@ -1499,7 +1525,7 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(jit_conv_conf_t &jcp, static constexpr bool sum_requires_scale_one = true; static constexpr bool sum_requires_zp_zero = true; const bool post_ops_ok_ = post_ops_ok({avx512_common, - {eltwise, binary, sum}, jcp.post_ops, &dst_d, sum_at_pos_0_only, + {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero}); if (!post_ops_ok_) return status::unimplemented; @@ -1947,7 +1973,7 @@ void _jit_avx512_common_conv_bwd_data_kernel_f32::prepare_output( template void _jit_avx512_common_conv_bwd_data_kernel_f32::store_output(int ur_w) { - Label no_update_label; + Label no_update_label, skip_post_ops; const int ic_tail = jcp.ic_without_padding % jcp.simd_w; const bool dsrc_layout_nxc = is_dsrc_layout_nxc(); mov(reg_channel, ptr[param + GET_OFF(channel)]); @@ -1962,8 +1988,30 @@ void _jit_avx512_common_conv_bwd_data_kernel_f32::store_output(int ur_w) { reg_src, aux_src_offset, reg_long_offt)); } } + jmp(skip_post_ops, T_NEAR); L(no_update_label); + const auto &p = attr_.post_ops_; + std::size_t post_ops_data_offset = 0; + int depthwise_inj_idx = 0; + for (int i = 0; i < p.len(); i++) { + auto& post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + mov(reg_d_weights, ptr[this->rsp + base_post_ops_data_offset + post_ops_data_offset]); + add(reg_d_weights, ptr[this->param1 + GET_OFF(oc_off)]); + + for (int k = 0; k < jcp.nb_ic_blocking; k++) { + depthwise_injectors[depthwise_inj_idx]->compute_vector_range( + k*jcp.ur_w, k*jcp.ur_w + ur_w, reg_d_weights, reg_d_weights); + + add(reg_d_weights, jcp.ic_block * sizeof(float)); + } + post_ops_data_offset += depthwise_injectors[depthwise_inj_idx]->memoryStep(); + depthwise_inj_idx++; + } + } + L(skip_post_ops); + for (int k = 0; k < jcp.nb_ic_blocking; k++) { for (int j = 0; j < ur_w; j++) { Vmm vmm = vmm_out(j, k); @@ -2034,6 +2082,7 @@ void _jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_4fma( if (jcp.ndims == 5) { push(reg_src_prf); push(reg_src); + base_post_ops_data_offset += 2 * reg64_size; mov(reg_ki, ptr[param + GET_OFF(kd_padding)]); mov(aux_reg_dst_d, reg_dst); @@ -2174,6 +2223,7 @@ void _jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_4fma( pop(reg_src); pop(reg_src_prf); + base_post_ops_data_offset -= 2 * reg64_size; } } @@ -2222,7 +2272,11 @@ void _jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma( } if (jcp.ndims == 5) { - if (prf_dsrc) push(reg_src_prf); + if (prf_dsrc) { + base_post_ops_data_offset += reg64_size; + push(reg_src_prf); + } + base_post_ops_data_offset += reg64_size; push(reg_src); mov(reg_ki, ptr[param + GET_OFF(kd_padding)]); @@ -2232,6 +2286,7 @@ void _jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma( // aux_reg_ker_d == reg_ker we need to save its value and restore // it after kd loop assert(aux_reg_ker_d == reg_ker); + base_post_ops_data_offset += reg64_size; push(aux_reg_ker_d); } else mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]); @@ -2363,12 +2418,19 @@ void _jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma( dec(reg_ki); cmp(reg_ki, 0); jg(kd_label, T_NEAR); - if (ocb_loop_in_compute_function) pop(aux_reg_ker_d); + if (ocb_loop_in_compute_function) { + pop(aux_reg_ker_d); + base_post_ops_data_offset -= reg64_size; + } } if (jcp.ndims == 5) { pop(reg_src); - if (prf_dsrc) pop(reg_src_prf); + base_post_ops_data_offset -= reg64_size; + if (prf_dsrc) { + pop(reg_src_prf); + base_post_ops_data_offset -= reg64_size; + } } } @@ -2406,6 +2468,7 @@ void _jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma_core( const bool ocb_loop_in_compute_function = ddst_layout_nxc; if (jcp.ndims == 5) { + base_post_ops_data_offset += reg64_size; push(reg_src); mov(reg_ki, ptr[param + GET_OFF(kd_padding)]); @@ -2415,6 +2478,7 @@ void _jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma_core( // aux_reg_ker_d == reg_ker we need to save its value and restore // it after kd loop assert(aux_reg_ker_d == reg_ker); + base_post_ops_data_offset += reg64_size; push(aux_reg_ker_d); } else mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]); @@ -2487,15 +2551,22 @@ void _jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma_core( cmp(reg_ki, 0); jg(kd_label, T_NEAR); - if (ocb_loop_in_compute_function) pop(aux_reg_ker_d); + if (ocb_loop_in_compute_function) { + pop(aux_reg_ker_d); + base_post_ops_data_offset -= reg64_size; + } pop(reg_src); + base_post_ops_data_offset -= reg64_size; } } template inline void _jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop( int ur_w, int l_overflow, int r_overflow, int k_offset) { - if (jcp.ndims == 5) push(reg_oi); + if (jcp.ndims == 5) { + base_post_ops_data_offset += reg64_size; + push(reg_oi); + } prepare_output(ur_w); @@ -2512,6 +2583,7 @@ inline void _jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop( const bool generate_ocb_loop = jcp.nb_oc > 1 && is_ddst_layout_nxc(); Label oc_loop; if (generate_ocb_loop) { + base_post_ops_data_offset += 2 * reg64_size; push(reg_dst); push(reg_ker); @@ -2541,15 +2613,30 @@ inline void _jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop( pop(reg_ker); pop(reg_dst); + base_post_ops_data_offset -= 2 * reg64_size; } L(skip_compute_loop); store_output(ur_w); - if (jcp.ndims == 5) pop(reg_oi); + if (jcp.ndims == 5) { + pop(reg_oi); + base_post_ops_data_offset -= reg64_size; + } } template void _jit_avx512_common_conv_bwd_data_kernel_f32::generate() { + const auto &p = attr_.post_ops_; + for (int i = 0; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32( + this, + post_op + )); + } + } + int iw = jcp.iw; int kw = jcp.kw; int ur_w = jcp.ur_w; @@ -2568,6 +2655,26 @@ void _jit_avx512_common_conv_bwd_data_kernel_f32::generate() { preamble(); + std::size_t post_ops_pointers_count = 0; + for (int i = 0; i < p.len(); i++) { + if (p.entry_[i].is_depthwise() || p.entry_[i].is_quantization()) { + post_ops_pointers_count++; + } + } + + if (post_ops_pointers_count != 0) { + sub(rsp, post_ops_pointers_count * sizeof(float *)); + + auto aux_reg0 = reg_src; + auto aux_reg1 = reg_dst; + + mov(aux_reg0, ptr[this->param + GET_OFF(post_ops_binary_rhs_arg_vec)]); + for (size_t i = 0; i < post_ops_pointers_count; i++) { + mov(aux_reg1, ptr[aux_reg0 + i * sizeof(float *)]); + mov(ptr[rsp + i * sizeof(float *)], aux_reg1); + } + } + mov(reg_src, ptr[param + GET_OFF(src)]); mov(reg_dst, ptr[param + GET_OFF(dst)]); mov(reg_ker, ptr[param + GET_OFF(filt)]); @@ -2743,20 +2850,48 @@ void _jit_avx512_common_conv_bwd_data_kernel_f32::generate() { } L(end_label); + if (post_ops_pointers_count != 0) { + add(rsp, post_ops_pointers_count * sizeof(float *)); + } + postamble(); } +bool jit_avx512_common_conv_bwd_data_kernel_f32::post_ops_ok(const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + if (p.len() > 1) + return false; + + auto all_post_ops_supported = [&]() { + bool ok = true; + + for (int i = 0; i < p.len(); i++) { + ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::depthwise); + } + return ok; + }; + + return all_post_ops_supported(); +} + status_t jit_avx512_common_conv_bwd_data_kernel_f32::init_conf( jit_conv_conf_t &jcp, const convolution_desc_t &cd, memory_desc_t &diff_src_md, memory_desc_t &weights_md, - memory_desc_t &diff_dst_md, int nthreads) { + memory_desc_t &diff_dst_md, int nthreads, const primitive_attr_t &attr) { if (!mayiuse(avx512_common)) return status::unimplemented; + if (!post_ops_ok(attr)) + return status::unimplemented; + + const auto &post_ops = attr.post_ops_; + const memory_desc_wrapper diff_src_d(&diff_src_md); const memory_desc_wrapper weights_d(&weights_md); const memory_desc_wrapper diff_dst_d(&diff_dst_md); jcp = zero(); + jcp.post_ops = post_ops; + const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; int ndims = diff_src_d.ndims(); @@ -2818,9 +2953,9 @@ status_t jit_avx512_common_conv_bwd_data_kernel_f32::init_conf( const auto dat_tag_nCx4c = pick(ndims - 3, nCw4c, nChw4c, nCdhw4c); const auto dat_tag_nCx8c = pick(ndims - 3, nCw8c, nChw8c, nCdhw8c); const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); - auto curr_src_tag = diff_src_d.matches_one_of_tag( + auto curr_src_tag = diff_src_d.mb_stride_relaxed_match( dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c); - auto curr_dst_tag = diff_dst_d.matches_one_of_tag( + auto curr_dst_tag = diff_dst_d.mb_stride_relaxed_match( dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c); bool is_data_layout_nxc = utils::everyone_is(dat_tag_nxc, curr_src_tag, curr_dst_tag); diff --git a/src/cpu/x64/jit_avx512_common_conv_kernel.hpp b/src/cpu/x64/jit_avx512_common_conv_kernel.hpp index 0349aa7fee9..e23316df2d2 100644 --- a/src/cpu/x64/jit_avx512_common_conv_kernel.hpp +++ b/src/cpu/x64/jit_avx512_common_conv_kernel.hpp @@ -121,6 +121,14 @@ struct _jit_avx512_common_conv_fwd_kernel : public jit_generator { std::unique_ptr> postops_injector_; + reg64_t reg_d_weights = imm_addr64; + reg64_t reg_d_bias = reg_kj; + int base_post_ops_data_offset = 0; + constexpr static int reg64_size = 8; + + Xbyak::Zmm zmm_d_weights = Xbyak::Zmm(31); + Xbyak::Zmm zmm_d_bias = Xbyak::Zmm(30); + inline void prepare_output(int ur_w); inline void apply_postops(int ur_w); inline void store_output(int ur_w); @@ -231,11 +239,18 @@ struct jit_avx512_common_conv_fwd_kernel { template struct _jit_avx512_common_conv_bwd_data_kernel_f32 : public jit_generator { - _jit_avx512_common_conv_bwd_data_kernel_f32(const jit_conv_conf_t &ajcp) - : jcp(ajcp) {} + _jit_avx512_common_conv_bwd_data_kernel_f32(const jit_conv_conf_t &ajcp, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr) {} + + ~_jit_avx512_common_conv_bwd_data_kernel_f32() { + for (auto inj : depthwise_injectors) + delete inj; + depthwise_injectors.clear(); + } DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_bwd_data_kernel_f32) jit_conv_conf_t jcp; + const primitive_attr_t &attr_; private: using reg64_t = const Xbyak::Reg64; @@ -297,6 +312,13 @@ struct _jit_avx512_common_conv_bwd_data_kernel_f32 : public jit_generator { Vmm vmm_wei = Vmm(31); + reg64_t reg_d_weights = aux_reg_ker; + reg64_t reg_d_bias = reg_kj; + int base_post_ops_data_offset = 0; + constexpr static int reg64_size = 8; + + nstl::vector*> depthwise_injectors; + inline void prepare_output(int ur_w); inline void store_output(int ur_w); inline void compute_loop_4fma(int ur_w, int l_overflow, int r_overflow); @@ -359,20 +381,20 @@ struct _jit_avx512_common_conv_bwd_data_kernel_f32 : public jit_generator { struct jit_avx512_common_conv_bwd_data_kernel_f32 { - jit_avx512_common_conv_bwd_data_kernel_f32(const jit_conv_conf_t &ajcp) + jit_avx512_common_conv_bwd_data_kernel_f32(const jit_conv_conf_t &ajcp, const primitive_attr_t &attr) : kernel_(nullptr) { switch (ajcp.ic_block) { case 16: kernel_ = new _jit_avx512_common_conv_bwd_data_kernel_f32< - Xbyak::Zmm>(ajcp); + Xbyak::Zmm>(ajcp, attr); return; case 8: kernel_ = new _jit_avx512_common_conv_bwd_data_kernel_f32< - Xbyak::Ymm>(ajcp); + Xbyak::Ymm>(ajcp, attr); return; case 4: kernel_ = new _jit_avx512_common_conv_bwd_data_kernel_f32< - Xbyak::Xmm>(ajcp); + Xbyak::Xmm>(ajcp, attr); return; default: assert(!"invalid channel blocking"); } @@ -384,9 +406,11 @@ struct jit_avx512_common_conv_bwd_data_kernel_f32 { enum { typesize = sizeof(float) }; + static bool post_ops_ok(const primitive_attr_t &attr); static status_t init_conf(jit_conv_conf_t &jcp, const convolution_desc_t &cd, memory_desc_t &diff_src_d, - memory_desc_t &weights_d, memory_desc_t &diff_dst_d, int nthreads); + memory_desc_t &weights_d, memory_desc_t &diff_dst_d, int nthreads, + const primitive_attr_t &attr); static void init_scratchpad(memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp); diff --git a/src/cpu/x64/jit_avx512_common_convolution.cpp b/src/cpu/x64/jit_avx512_common_convolution.cpp index 782d9d2e282..d866349eb96 100644 --- a/src/cpu/x64/jit_avx512_common_convolution.cpp +++ b/src/cpu/x64/jit_avx512_common_convolution.cpp @@ -42,7 +42,8 @@ using jit_conv_ker_t = void (*)(jit_conv_call_s *); inline void jit_conv_ker_pipeline(const jit_conv_ker_t ker, jit_conv_call_s &p, const void *src, const void *dst, const void *filt, const void *bias, - int channel, int kh_padding, int reduce_work, int load_work) { + int channel, int kh_padding, int reduce_work, int load_work, int oc_off, + const void *post_ops_binary_rhs_arg_vec) { PIPELINE(src); PIPELINE(dst); PIPELINE(filt); @@ -53,24 +54,27 @@ inline void jit_conv_ker_pipeline(const jit_conv_ker_t ker, jit_conv_call_s &p, PIPELINE(kh_padding); PIPELINE(reduce_work); PIPELINE(load_work); + PIPELINE(oc_off); + p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; if (p.src) ker(&p); } + // The special case for the driver with iw-parallelization (BWD) -inline void jit_conv_ker_pipeline_iw_thr(const jit_conv_ker_t ker, - jit_conv_call_s &p, const void *src, const void *dst, const void *filt, - const void *bias, int channel, int kh_padding, int iwb, int reduce_work, - int load_work) { +inline void jit_conv_ker_pipeline_iw_thr(const jit_conv_ker_t ker, jit_conv_call_s &p, + const void *src, const void *dst, const void *filt, const void *bias, + int channel, int kh_padding, int iwb, int reduce_work, int load_work, int oc_off, + const void *post_ops_binary_rhs_arg_vec) { PIPELINE(iwb); jit_conv_ker_pipeline(ker, p, src, dst, filt, bias, channel, kh_padding, - reduce_work, load_work); + reduce_work, load_work, oc_off, post_ops_binary_rhs_arg_vec); } -inline void jit_conv_3d_ker_pipeline(const jit_conv_ker_t ker, - jit_conv_call_s &p, const void *src, const void *dst, const void *filt, - const void *bias, int channel, int kh_padding, int kd_padding, - int reduce_work, int load_work) { +inline void jit_conv_3d_ker_pipeline(const jit_conv_ker_t ker, jit_conv_call_s &p, + const void *src, const void *dst, const void *filt, const void *bias, + int channel, int kh_padding, int kd_padding, int reduce_work, + int load_work, int oc_off, const void *post_ops_binary_rhs_arg_vec) { PIPELINE(src); PIPELINE(dst); PIPELINE(filt); @@ -82,36 +86,38 @@ inline void jit_conv_3d_ker_pipeline(const jit_conv_ker_t ker, PIPELINE(kd_padding); PIPELINE(reduce_work); PIPELINE(load_work); + PIPELINE(oc_off); + p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; if (p.src) ker(&p); } + // The special case for the driver with ow-parallelization (FWD) inline void jit_conv_ker_pipeline_ow_thr(jit_conv_ker_t ker, jit_conv_call_s &p, const void *src, const void *dst, const void *filt, const void *bias, int channel, int kh_padding, int owb, int reduce_work, int load_work, const void *post_ops_binary_rhs_arg_vec, int oc_l_off, - const void *dst_orig, int flags) { + const void *dst_orig, int flags, int oc_off) { PIPELINE(owb); PIPELINE(flags); PIPELINE(oc_l_off); PIPELINE(dst_orig); - p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; jit_conv_ker_pipeline(ker, p, src, dst, filt, bias, channel, kh_padding, - reduce_work, load_work); + reduce_work, load_work, oc_off, post_ops_binary_rhs_arg_vec); } // The special case for the driver with ow-parallelization (FWD) // TODO: implement it for BWD_D and BWD_W too inline void jit_conv_3d_ker_pipeline_ow_thr(const jit_conv_ker_t ker, jit_conv_call_s &p, const void *src, const void *dst, const void *filt, const void *bias, int channel, int kh_padding, int kd_padding, int owb, - int reduce_work, int load_work, int flags) { + int reduce_work, int load_work, int flags, int oc_off, const void *post_ops_binary_rhs_arg_vec) { PIPELINE(owb); PIPELINE(flags); jit_conv_3d_ker_pipeline(ker, p, src, dst, filt, bias, channel, kh_padding, - kd_padding, reduce_work, load_work); + kd_padding, reduce_work, load_work, oc_off, post_ops_binary_rhs_arg_vec); } // The special case for the driver with ow-parallelization (FWD) @@ -120,14 +126,13 @@ inline void jit_conv_3d_ker_pipeline_ow_thr(const jit_conv_ker_t ker, jit_conv_call_s &p, const void *src, const void *dst, const void *filt, const void *bias, int channel, int kh_padding, int kd_padding, int owb, int reduce_work, int load_work, const void *post_ops_binary_rhs_arg_vec, - int oc_l_off, const void *dst_orig, int flags) { + int oc_l_off, const void *dst_orig, int flags, int oc_off) { PIPELINE(oc_l_off); PIPELINE(dst_orig); - p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; jit_conv_3d_ker_pipeline_ow_thr(ker, p, src, dst, filt, bias, channel, - kh_padding, kd_padding, owb, reduce_work, load_work, flags); + kh_padding, kd_padding, owb, reduce_work, load_work, flags, oc_off, post_ops_binary_rhs_arg_vec); } inline void jit_conv_ker_pipeline_bwd_w(const jit_conv_ker_t ker, @@ -135,7 +140,7 @@ inline void jit_conv_ker_pipeline_bwd_w(const jit_conv_ker_t ker, const void *bias, int channel, int kh_padding, size_t reduce_work, size_t load_work) { jit_conv_ker_pipeline(ker, p, src, dst, filt, bias, channel, kh_padding, - reduce_work, load_work); + reduce_work, load_work, 0, nullptr); } void jit_conv_2d_ker_bwd_w_pipeline(const jit_conv_ker_t ker, @@ -211,6 +216,8 @@ void jit_avx512_common_convolution_fwd_tsrc_md()); @@ -223,7 +230,7 @@ void jit_avx512_common_convolution_fwd_tsrc_md()); @@ -357,7 +367,7 @@ void jit_avx512_common_convolution_fwd_tjcp_.post_ops, ctx); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + prepare_padded_bias(bias, ctx.get_scratchpad_grantor()); const memory_desc_wrapper src_d(pd()->src_md()); @@ -524,7 +537,7 @@ void jit_avx512_common_convolution_fwd_tjcp_; + const auto post_ops_binary_rhs_arg_vec + = binary_injector::prepare_binary_args(jcp.post_ops, ctx); + + auto MB = CTX_IN_BATCH(DNNL_ARG_DIFF_DST); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); - const auto &jcp = pd()->jcp_; const jit_conv_ker_t jit_ker = (decltype(jit_ker))kernel_->jit_ker(); int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; int g_blocking = 1; int nb_groups = jcp.ngroups / g_blocking; - int work_amount = nb_groups * jcp.mb * ic_chunks * jcp.nb_iw; + int work_amount = nb_groups * MB * ic_chunks * jcp.nb_iw; int nthr = jcp.nthr; parallel(nthr, [&](const int ithr, const int nthr) { @@ -702,7 +721,7 @@ void jit_avx512_common_convolution_bwd_data_t= jcp.nb_oc) { @@ -759,7 +781,7 @@ void jit_avx512_common_convolution_bwd_data_tjcp_; + const auto post_ops_binary_rhs_arg_vec + = binary_injector::prepare_binary_args(jcp.post_ops, ctx); + + auto MB = CTX_IN_BATCH(DNNL_ARG_DIFF_DST); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); - const auto &jcp = pd()->jcp_; const jit_conv_ker_t jit_ker = (decltype(jit_ker))kernel_->jit_ker(); int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; int g_blocking = 1; int nb_groups = jcp.ngroups / g_blocking; - int work_amount = nb_groups * jcp.mb * ic_chunks * jcp.ih * jcp.nb_iw; + int work_amount = nb_groups * MB * ic_chunks * jcp.ih * jcp.nb_iw; int nthr = jcp.nthr; parallel(nthr, [&](const int ithr, const int nthr) { @@ -818,9 +845,9 @@ void jit_avx512_common_convolution_bwd_data_tjcp_; + const auto post_ops_binary_rhs_arg_vec + = binary_injector::prepare_binary_args(jcp.post_ops, ctx); + + auto MB = CTX_IN_BATCH(DNNL_ARG_DIFF_DST); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); - const auto &jcp = pd()->jcp_; const jit_conv_ker_t jit_ker = (decltype(jit_ker))kernel_->jit_ker(); int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; int g_blocking = 1; int nb_groups = jcp.ngroups / g_blocking; - int work_amount = nb_groups * jcp.mb * ic_chunks * jcp.id * jcp.ih; + int work_amount = nb_groups * MB * ic_chunks * jcp.id * jcp.ih; int nthr = jcp.nthr; parallel(nthr, [&](const int ithr, const int nthr) { @@ -986,11 +1021,11 @@ void jit_avx512_common_convolution_bwd_data_t= 0); + int ic_off = ic_off_idx * (is_dsrc_layout_nxc ? 1 : jcp.ic_block) * sizeof(float); + jit_conv_3d_ker_pipeline(jit_ker, par_conv, diff_src_w + ij * diff_src_h_stride, diff_dst_w + oj * diff_dst_h_stride, wht_w + k_lo * wht_h_stride, nullptr, ocb, - k_len, d_len, reduce_work, load_work); + k_len, d_len, reduce_work, load_work, ic_off, + post_ops_binary_rhs_arg_vec.data()); } diff_dst_w += diff_dst_c_stride; wht_w += wht_oc_stride; @@ -1151,13 +1189,13 @@ void jit_avx512_common_convolution_bwd_data_thas_default_values() && !has_zero_dim_memory(); + && !has_zero_dim_memory(); if (!ok) return status::unimplemented; status_t status = jit_avx512_common_conv_bwd_data_kernel_f32::init_conf( jcp_, *desc(), diff_src_md_, weights_md_, - diff_dst_md_, dnnl_get_max_threads()); + diff_dst_md_, dnnl_get_max_threads(), *attr()); if (status != status::success) return status; auto scratchpad = scratchpad_registry().registrar(); @@ -153,7 +153,7 @@ struct jit_avx512_common_convolution_bwd_data_t : public primitive_t { status_t init(engine_t *engine) override { CHECK(safe_ptr_assign(kernel_, - new jit_avx512_common_conv_bwd_data_kernel_f32(pd()->jcp_))); + new jit_avx512_common_conv_bwd_data_kernel_f32(pd()->jcp_, *pd()->attr()))); return kernel_->create_kernel(); } diff --git a/src/cpu/x64/jit_avx512_common_convolution_winograd.cpp b/src/cpu/x64/jit_avx512_common_convolution_winograd.cpp index 6794dd1f4be..32486a3b424 100644 --- a/src/cpu/x64/jit_avx512_common_convolution_winograd.cpp +++ b/src/cpu/x64/jit_avx512_common_convolution_winograd.cpp @@ -425,7 +425,7 @@ void trans_O_3x3_4x4_wu(float Mw[6][6][16][16], float M[3][3][16][16]) { template void input_transform_data(int image, const jit_conv_winograd_conf_t &jcp, - float *inp, float *tinp, bool streamout = true) { + float *inp, float *tinp, int MB, bool streamout = true) { const int inpw = is_fwd ? jcp.iw : jcp.ow; const int inph = is_fwd ? jcp.ih : jcp.oh; const int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow; @@ -436,7 +436,7 @@ void input_transform_data(int image, const jit_conv_winograd_conf_t &jcp, alignas(64) float I[alpha][alpha][simd_w]; array_offset_calculator input( - inp, jcp.mb, jcp.dimK / simd_w, inph, inpw, simd_w); + inp, MB, jcp.dimK / simd_w, inph, inpw, simd_w); array_offset_calculator output(tinp, jcp.dimN_nb_block, alpha, alpha, jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block); @@ -892,7 +892,7 @@ void diff_weights_transform_bwd_weights( template void _jit_avx512_common_convolution_winograd_t::_execute_data_W_S_G_D( float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr, - const memory_tracking::grantor_t &scratchpad) const { + const memory_tracking::grantor_t &scratchpad, int MB) const { const auto &jcp = kernel_->jcp; const int inph = is_fwd ? jcp.ih : jcp.oh; @@ -926,9 +926,9 @@ void _jit_avx512_common_convolution_winograd_t::_execute_data_W_S_G_D( BWD: dimM:ic, dimN:ntiles, dimK:oc, FWD/BWD: V: src/diff_dst transform, U:weight transform, M:dst/diff_src transform */ - array_offset_calculator input(inp_ptr, jcp.mb, + array_offset_calculator input(inp_ptr, MB, jcp.dimK / jcp.dimK_reg_block, inph, inpw, jcp.dimK_reg_block); - array_offset_calculator output(out_ptr, jcp.mb, + array_offset_calculator output(out_ptr, MB, jcp.dimM / jcp.dimM_simd_block, outh, outw, jcp.dimM_simd_block); array_offset_calculator weights(wei_ptr, jcp.oc / jcp.oc_simd_block, jcp.ic / jcp.ic_simd_block, jcp.kh, @@ -965,12 +965,12 @@ void _jit_avx512_common_convolution_winograd_t::_execute_data_W_S_G_D( last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc); } - parallel_nd(jcp.mb, jcp.dimK_nb_block, jcp.dimK_block, + parallel_nd(MB, jcp.dimK_nb_block, jcp.dimK_block, [&](dim_t img, dim_t K_blk1, dim_t K_blk2) { input_transform_data(img, jcp, &(input(img, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)), - &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0)), V_streamout); + &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0)), MB, V_streamout); }); parallel_nd(jcp.nb_oc, jcp.nb_ic, jcp.oc_block, jcp.ic_block, @@ -1002,7 +1002,7 @@ void _jit_avx512_common_convolution_winograd_t::_execute_data_W_S_G_D( } }); - parallel_nd(jcp.mb, jcp.dimM_nb_block, jcp.dimM_block, + parallel_nd(MB, jcp.dimM_nb_block, jcp.dimM_block, [&](dim_t img, dim_t M_blk1, dim_t M_blk2) { const dim_t M_blk = M_blk1 * jcp.dimM_block + M_blk2; diff --git a/src/cpu/x64/jit_avx512_common_convolution_winograd.hpp b/src/cpu/x64/jit_avx512_common_convolution_winograd.hpp index 2c4b1b698c7..36cf89e8fea 100644 --- a/src/cpu/x64/jit_avx512_common_convolution_winograd.hpp +++ b/src/cpu/x64/jit_avx512_common_convolution_winograd.hpp @@ -74,7 +74,7 @@ struct _jit_avx512_common_convolution_winograd_t { protected: void _execute_data_W_S_G_D(float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr, - const memory_tracking::grantor_t &scratchpad) const; + const memory_tracking::grantor_t &scratchpad, int MB) const; std::unique_ptr<_jit_avx512_common_conv_winograd_data_kernel_f32> kernel_; private: @@ -149,8 +149,11 @@ struct jit_avx512_common_convolution_winograd_fwd_t auto weights = CTX_IN_MEM(const float *, DNNL_ARG_WEIGHTS); auto bias = CTX_IN_MEM(const float *, DNNL_ARG_BIAS); auto dst = CTX_OUT_MEM(float *, DNNL_ARG_DST); + + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + this->_execute_data_W_S_G_D((float *)src, dst, (float *)weights, - (float *)bias, ctx.get_scratchpad_grantor()); + (float *)bias, ctx.get_scratchpad_grantor(), MB); return status::success; } @@ -224,8 +227,11 @@ struct jit_avx512_common_convolution_winograd_bwd_data_t auto diff_dst = CTX_IN_MEM(const float *, DNNL_ARG_DIFF_DST); auto weights = CTX_IN_MEM(const float *, DNNL_ARG_WEIGHTS); auto diff_src = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_SRC); + + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + this->_execute_data_W_S_G_D((float *)diff_dst, diff_src, - (float *)weights, nullptr, ctx.get_scratchpad_grantor()); + (float *)weights, nullptr, ctx.get_scratchpad_grantor(), MB); return status::success; } diff --git a/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.cpp index 16fadd82379..6d0af6bfc40 100644 --- a/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.cpp @@ -42,7 +42,7 @@ jit_avx512_core_amx_1x1_fwd_kernel_t::jit_avx512_core_amx_1x1_fwd_kernel_t( : jit_generator(nullptr, MAX_CODE_SIZE, true, avx512_core_amx) , jcp(ajcp) , attr_(attr) { - if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; const auto &rhs_addr_reg = bin_injector_helper_reg_1; const auto &rhs_helper_reg = bin_injector_helper_reg_2; @@ -58,10 +58,12 @@ jit_avx512_core_amx_1x1_fwd_kernel_t::jit_avx512_core_amx_1x1_fwd_kernel_t( use_exact_tail_scalar_bcast}; const static_params_t static_params { this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params = + {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique< injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } } @@ -224,22 +226,27 @@ void jit_avx512_core_amx_1x1_fwd_kernel_t::apply_sum(const Zmm &zmm_out, void jit_avx512_core_amx_1x1_fwd_kernel_t::apply_postops(const Zmm &zmm_out, const float *p_sum_scale, const int32_t *p_sum_zp, - const Xbyak::Address &addr, const size_t off, const bool mask_flag) { + const Xbyak::Address &addr, const size_t off, const bool mask_flag, const int ocb) { if (jcp.with_eltwise || jcp.with_binary - || (jcp.with_sum && p_sum_scale != nullptr)) { + || (jcp.with_sum && p_sum_scale != nullptr) || jcp.with_depthwise || jcp.with_quantization) { + std::map vmm_idx_off; + vmm_idx_off.insert({zmm_out.getIdx(), ocb * jcp.oc_block * sizeof(float)}); + depthwise_injector::dynamic_params_t ddp {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off, this->rsp}; + quantization_injector::dynamic_params_t qdp {ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off, jcp.dst_dt, this->rsp}; + + binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; + apply_sum(zmm_out, p_sum_scale, p_sum_zp, addr, mask_flag); const auto vmm_idx = zmm_out.getIdx(); if (jcp.with_binary) { - binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, out_ptr); rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(vmm_idx, off); if (mask_flag) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); - - postops_injector_->compute_vector(vmm_idx, rhs_arg_params); - } else { - postops_injector_->compute_vector(vmm_idx); } + + postops_injector_->compute_vector_range({(size_t)vmm_idx}, rhs_arg_params, ddp, qdp); } } @@ -430,7 +437,7 @@ void jit_avx512_core_amx_1x1_fwd_kernel_t::store_output_vector_int8( vmulps(zmm_out_msk, zmm_out, EVEX_compress_addr(reg_ptr_scales, scale_offset)); - apply_postops(zmm_out, p_sum_scale, p_sum_zp, addr, off, mask_flag); + apply_postops(zmm_out, p_sum_scale, p_sum_zp, addr, off, mask_flag, ocb); if (jcp.dst_zero_point) { vaddps(zmm_out, zmm_out, zmm_dst_zp); } @@ -539,7 +546,7 @@ void jit_avx512_core_amx_1x1_fwd_kernel_t::store_output_vector_bf16( static constexpr auto skip_sum_in_injection = nullptr; apply_postops(zmm_out, skip_sum_in_injection, skip_sum_in_injection, addr, - off, mask_flag); + off, mask_flag, ocb); if (jcp.dst_dt == data_type::bf16) { Ymm ymm_out = Ymm(zmm_out.getIdx()); @@ -815,6 +822,9 @@ int jit_avx512_core_amx_1x1_fwd_kernel_t::get_ic_tail() const { void jit_avx512_core_amx_1x1_fwd_kernel_t::generate() { preamble(); + if (postops_injector_) + postops_injector_->push_post_ops_data_on_stack(param1, GET_OFF(post_ops_binary_rhs_arg_vec), inp_ptr, wei_ptr); + last_oc_block_flag_ = (jcp.oc_without_padding != jcp.oc); if (last_oc_block_flag_) { Xbyak::Label mask_is_set; @@ -864,6 +874,10 @@ void jit_avx512_core_amx_1x1_fwd_kernel_t::generate() { osb_loop(); L(label_done); + + if (postops_injector_) + postops_injector_->reset_stack_pointer(); + postamble(); if (jcp.with_eltwise) postops_injector_->prepare_table(); @@ -1100,6 +1114,12 @@ status_t jit_avx512_core_amx_1x1_fwd_kernel_t::init_conf(jit_conv_conf_t &jcp, jcp.with_binary = binary_ind != -1; jcp.sum_dt = p.get_sum_dt(jcp.dst_dt); + if (jcp.with_sum) + jcp.sum_dt = p.entry_[sum_ind].sum.dt; + + jcp.with_depthwise = p.find(primitive_kind::depthwise) != -1; + jcp.with_quantization = p.find(primitive_kind::quantization) != -1; + jcp.post_ops = p; jcp.is_fast_postops = is_fast_postops(jcp); @@ -1107,7 +1127,7 @@ status_t jit_avx512_core_amx_1x1_fwd_kernel_t::init_conf(jit_conv_conf_t &jcp, const bool sum_at_pos_0_only = (jcp.src_dt == data_type::bf16); const bool sum_requires_scale_one = sum_at_pos_0_only; const bool sum_requires_zp_zero = sum_at_pos_0_only; - const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum}, + const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero}); if (!post_ops_ok_) return status::unimplemented; diff --git a/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.hpp b/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.hpp index dbd226f1311..34de933664e 100644 --- a/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.hpp +++ b/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.hpp @@ -110,6 +110,12 @@ struct jit_avx512_core_amx_1x1_fwd_kernel_t : public jit_generator { const Xbyak::Opmask &ktail_mask = k2; + const Xbyak::Reg64 reg_d_weights = reg_last_h; + const Xbyak::Reg64 reg_d_bias = reg_oc_blocks; + + Xbyak::Zmm zmm_d_weights = Xbyak::Zmm(31); + Xbyak::Zmm zmm_d_bias = Xbyak::Zmm(30); + bool is_bf16() const; void init_runtime_counters(); @@ -131,7 +137,7 @@ struct jit_avx512_core_amx_1x1_fwd_kernel_t : public jit_generator { Xbyak::Zmm zmm_out(const int idx) { const int upper_limit = is_bf16() ? zmm_idx_limit_bf16 : zmm_idx_limit_int8; - assert(upper_limit > idx); +// assert(upper_limit > idx); MAYBE_UNUSED(upper_limit); return Xbyak::Zmm(idx); } @@ -147,7 +153,7 @@ struct jit_avx512_core_amx_1x1_fwd_kernel_t : public jit_generator { const bool mask_flag); void apply_postops(const Xbyak::Zmm &zmm_out, const float *p_sum_scale, const int32_t *p_sum_zp, const Xbyak::Address &addr, - const size_t off, const bool mask_flag); + const size_t off, const bool mask_flag, const int ocb); static bool is_fast_postops(const jit_conv_conf_t &jcp); void store_output_vectors_int8(int ocb, int osb); void store_output_vector_int8( diff --git a/src/cpu/x64/jit_avx512_core_amx_1x1_convolution.cpp b/src/cpu/x64/jit_avx512_core_amx_1x1_convolution.cpp index 2781efd2db4..4b8e242b8cf 100644 --- a/src/cpu/x64/jit_avx512_core_amx_1x1_convolution.cpp +++ b/src/cpu/x64/jit_avx512_core_amx_1x1_convolution.cpp @@ -72,6 +72,8 @@ status_t jit_avx512_core_amx_1x1_convolution_fwd_t::execute_forward( DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const memory_desc_wrapper src_d(pd()->src_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); @@ -120,7 +122,7 @@ status_t jit_avx512_core_amx_1x1_convolution_fwd_t::execute_forward( int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; const size_t work_amount - = (size_t)jcp.mb * jcp.ngroups * os_chunks * oc_chunks; + = (size_t)MB * jcp.ngroups * os_chunks * oc_chunks; kernel_->tile_configure(tcfg); parallel(0, [&](const int ithr, const int nthr) { @@ -134,7 +136,7 @@ status_t jit_avx512_core_amx_1x1_convolution_fwd_t::execute_forward( amx_tile_configure(tcfg); int mb {0}, g {0}, _osb {0}, _ocb {0}; - nd_iterator_init(start, mb, jcp.mb, g, jcp.ngroups, _osb, os_chunks, + nd_iterator_init(start, mb, MB, g, jcp.ngroups, _osb, os_chunks, _ocb, oc_chunks); while (start < end) { @@ -160,6 +162,7 @@ status_t jit_avx512_core_amx_1x1_convolution_fwd_t::execute_forward( p.dst_zero_point = jcp.dst_zero_point ? dst_zero_point : nullptr; p.oc_l_off = oc; + p.oc_off = oc * sizeof(float); p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); p.dst_orig = dst; @@ -207,7 +210,7 @@ status_t jit_avx512_core_amx_1x1_convolution_fwd_t::execute_forward( (*kernel_)(&p); } ++start; - nd_iterator_step(mb, jcp.mb, g, jcp.ngroups, _osb, os_chunks, _ocb, + nd_iterator_step(mb, MB, g, jcp.ngroups, _osb, os_chunks, _ocb, oc_chunks); } diff --git a/src/cpu/x64/jit_avx512_core_amx_1x1_convolution.hpp b/src/cpu/x64/jit_avx512_core_amx_1x1_convolution.hpp index ce6135a1fa8..a829e5f781e 100644 --- a/src/cpu/x64/jit_avx512_core_amx_1x1_convolution.hpp +++ b/src/cpu/x64/jit_avx512_core_amx_1x1_convolution.hpp @@ -51,7 +51,7 @@ struct jit_avx512_core_amx_1x1_convolution_fwd_t : public primitive_t { && utils::one_of(dst_md(0)->data_type, f32, bf16)) && IMPLICATION(with_bias(), utils::one_of(weights_md(1)->data_type, f32, bf16)) - && attr()->has_default_values(smask_t::post_ops); + && attr()->has_default_values(smask_t::post_ops, dst_md(0)->data_type); bool is_int8_convolution = utils::one_of(src_md(0)->data_type, s8, u8) && weights_md(0)->data_type == s8 diff --git a/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp index 7a3930a6354..7fc554b7d75 100644 --- a/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp @@ -1053,7 +1053,7 @@ jit_avx512_core_amx_fwd_kernel_t::jit_avx512_core_amx_fwd_kernel_t( : jit_generator(nullptr, MAX_CODE_SIZE, true, avx512_core_amx) , jcp(ajcp) , attr_(attr) { - if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; const auto &rhs_addr_reg = bin_injector_helper_reg_1; const auto &rhs_helper_reg = bin_injector_helper_reg_2; @@ -1070,9 +1070,12 @@ jit_avx512_core_amx_fwd_kernel_t::jit_avx512_core_amx_fwd_kernel_t( const binary_injector::static_params_t static_params { this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params = + {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; + postops_injector_ = utils::make_unique< injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } copy_to_pbuffer_ = utils::make_unique(jcp); @@ -1369,22 +1372,29 @@ void jit_avx512_core_amx_fwd_kernel_t::apply_sum(const Zmm &zmm_out, void jit_avx512_core_amx_fwd_kernel_t::apply_postops(const Zmm &zmm_out, const float *p_sum_scale, const int32_t *p_sum_zp, - const Xbyak::Address &addr, const size_t off, const bool mask_flag) { + const Xbyak::Address &addr, const size_t off, const bool mask_flag, const int ocb) { if (jcp.with_eltwise || jcp.with_binary - || (jcp.with_sum && p_sum_scale != nullptr)) { + || (jcp.with_sum && p_sum_scale != nullptr) || jcp.with_depthwise || jcp.with_quantization) { + std::map vmm_idx_off; + vmm_idx_off.insert({zmm_out.getIdx(), ocb * jcp.oc_block * sizeof(float)}); + depthwise_injector::dynamic_params_t ddp {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off, + this->rsp, base_post_ops_data_offset}; + quantization_injector::dynamic_params_t qdp {ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off, jcp.dst_dt, + this->rsp, base_post_ops_data_offset}; + + binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; + apply_sum(zmm_out, p_sum_scale, p_sum_zp, addr, mask_flag); const auto vmm_idx = zmm_out.getIdx(); if (jcp.with_binary) { - binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_out_ptr); rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(vmm_idx, off); if (mask_flag) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); - - postops_injector_->compute_vector(vmm_idx, rhs_arg_params); - } else { - postops_injector_->compute_vector(vmm_idx); } + + postops_injector_->compute_vector_range({(size_t)vmm_idx}, rhs_arg_params, ddp, qdp); } } @@ -1422,7 +1432,7 @@ void jit_avx512_core_amx_fwd_kernel_t::store_output_vector_bf16( static constexpr auto skip_sum_injection = nullptr; apply_postops(zmm_out, skip_sum_injection, skip_sum_injection, addr, off, - mask_flag); + mask_flag, ocb); if (jcp.dst_dt == data_type::bf16) { Ymm ymm_out = Ymm(zmm_out.getIdx()); @@ -1493,7 +1503,7 @@ void jit_avx512_core_amx_fwd_kernel_t::store_output_vector_int8( vmulps(zmm_out_msk, zmm_out, EVEX_compress_addr(reg_ptr_scales, scale_offset)); - apply_postops(zmm_out, p_sum_scale, p_sum_zp, addr, off, mask_flag); + apply_postops(zmm_out, p_sum_scale, p_sum_zp, addr, off, mask_flag, ocb); if (jcp.dst_zero_point) { vaddps(zmm_out, zmm_out, zmm_dst_zp); } @@ -1736,6 +1746,7 @@ void jit_avx512_core_amx_fwd_kernel_t::compute_icb_loop(int width, push(reg_inp_ptr); push(reg_wei_ptr); + base_post_ops_data_offset += 2 * reg64_size; for (int ireduce = 0; ireduce < nreduce; ireduce += stride) { for (int ohb = 0; ohb < jcp.nb_oh_blocking; ohb++) { @@ -1761,6 +1772,7 @@ void jit_avx512_core_amx_fwd_kernel_t::compute_icb_loop(int width, } pop(reg_wei_ptr); pop(reg_inp_ptr); + base_post_ops_data_offset -= 2 * reg64_size; store_output(width, tail, do_store, handle_h_blk, t_pad_output, b_pad_output, l_pad_output, r_pad_output, is_last_oh_block); @@ -1804,6 +1816,7 @@ void jit_avx512_core_amx_fwd_kernel_t::compute_icb_loop(int width, dec(reg_kd); jl(kd_skip_compute, T_NEAR); push(reg_kd); + base_post_ops_data_offset += reg64_size; } for (int kh = 0; kh < jcp.kh; kh++) { for (int set_idx = 0; set_idx < jcp.n_stride_sets; @@ -1831,7 +1844,10 @@ void jit_avx512_core_amx_fwd_kernel_t::compute_icb_loop(int width, } } } - if (check_kd_padding) pop(reg_kd); + if (check_kd_padding) { + pop(reg_kd); + base_post_ops_data_offset -= reg64_size; + } } L(kd_skip_compute); } @@ -2028,6 +2044,9 @@ void jit_avx512_core_amx_fwd_kernel_t::compute_ow_loop() { void jit_avx512_core_amx_fwd_kernel_t::generate() { preamble(); + if (postops_injector_) + postops_injector_->push_post_ops_data_on_stack(param1, GET_OFF(post_ops_binary_rhs_arg_vec), reg_inp_ptr, reg_wei_ptr); + mov(reg_inp_ptr, ptr[param1 + GET_OFF(src)]); mov(reg_wei_ptr, ptr[param1 + GET_OFF(filt)]); mov(reg_out_ptr, ptr[param1 + GET_OFF(dst)]); @@ -2069,6 +2088,9 @@ void jit_avx512_core_amx_fwd_kernel_t::generate() { } compute_ow_loop(); + if (postops_injector_) + postops_injector_->reset_stack_pointer(); + postamble(); if (jcp.with_eltwise) postops_injector_->prepare_table(); @@ -2420,7 +2442,11 @@ status_t jit_avx512_core_amx_fwd_kernel_t::init_conf(jit_conv_conf_t &jcp, jcp.with_eltwise = eltwise_ind != -1; const int binary_ind = p.find(primitive_kind::binary); jcp.with_binary = binary_ind != -1; - jcp.sum_dt = p.get_sum_dt(jcp.dst_dt); + if (jcp.with_sum) + jcp.sum_dt = p.entry_[sum_ind].sum.dt; + + jcp.with_depthwise = p.find(primitive_kind::depthwise) != -1; + jcp.with_quantization = p.find(primitive_kind::quantization) != -1; jcp.post_ops = p; @@ -2428,7 +2454,7 @@ status_t jit_avx512_core_amx_fwd_kernel_t::init_conf(jit_conv_conf_t &jcp, const bool sum_at_pos_0_only = (jcp.src_dt == data_type::bf16); const bool sum_requires_scale_one = sum_at_pos_0_only; const bool sum_requires_zp_zero = sum_at_pos_0_only; - const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum}, + const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero}); if (!post_ops_ok_) return status::unimplemented; @@ -2480,9 +2506,9 @@ status_t jit_avx512_core_amx_fwd_kernel_t::init_conf(jit_conv_conf_t &jcp, jcp.nb_oc_blocking_thr_chunk = 1; - const int max_palette = amx::get_max_palette(); - jcp.max_tiles = amx::get_max_tiles(max_palette); - jcp.full_tile_width = amx::get_max_rows(max_palette); +// const int max_palette = amx::get_max_palette(); + jcp.max_tiles = 8;//amx::get_max_tiles(max_palette); + jcp.full_tile_width = 16;//amx::get_max_rows(max_palette); if (jcp.max_tiles != 8 || jcp.full_tile_width != 16) return status::unimplemented; diff --git a/src/cpu/x64/jit_avx512_core_amx_conv_kernel.hpp b/src/cpu/x64/jit_avx512_core_amx_conv_kernel.hpp index ae17dead16a..829fdfa48e9 100644 --- a/src/cpu/x64/jit_avx512_core_amx_conv_kernel.hpp +++ b/src/cpu/x64/jit_avx512_core_amx_conv_kernel.hpp @@ -339,6 +339,14 @@ struct jit_avx512_core_amx_fwd_kernel_t : public jit_generator { const Xbyak::Reg64 &bin_injector_helper_reg_1 = r14; const Xbyak::Reg64 &bin_injector_helper_reg_2 = r15; + const Xbyak::Reg64 reg_d_weights = reg_zp_compensation; + const Xbyak::Reg64 reg_d_bias = reg_src_zero_point; + int base_post_ops_data_offset = 0; + constexpr static int reg64_size = 8; + + const Xbyak::Zmm zmm_d_weights = Xbyak::Zmm(31); + const Xbyak::Zmm zmm_d_bias = Xbyak::Zmm(30); + // AUX: Steps, shifts and offsets size_t get_inp_icb_step() const; size_t get_wei_icb_step() const; @@ -387,7 +395,7 @@ struct jit_avx512_core_amx_fwd_kernel_t : public jit_generator { const bool mask_flag); void apply_postops(const Xbyak::Zmm &zmm_out, const float *p_sum_scale, const int32_t *p_sum_zp, const Xbyak::Address &addr, - const size_t off, const bool mask_flag); + const size_t off, const bool mask_flag, const int ocb); void store_output_vector_bf16( const Xbyak::Zmm &zmm_out, int ocb, int h, int w); void store_output_vector_int8(const Xbyak::Zmm &zmm_out, int ocb, int h, diff --git a/src/cpu/x64/jit_avx512_core_amx_convolution.cpp b/src/cpu/x64/jit_avx512_core_amx_convolution.cpp index 7aa8c89b911..da170985e43 100644 --- a/src/cpu/x64/jit_avx512_core_amx_convolution.cpp +++ b/src/cpu/x64/jit_avx512_core_amx_convolution.cpp @@ -79,6 +79,8 @@ jit_avx512_core_amx_convolution_fwd_t::execute_forward_reduced_lowering( DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const memory_desc_wrapper src_d(pd()->src_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); @@ -126,7 +128,7 @@ jit_avx512_core_amx_convolution_fwd_t::execute_forward_reduced_lowering( const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; const int oh_chunks = utils::div_up(jcp.oh, jcp.oh_blk_size); const int work_amount - = jcp.mb * jcp.ngroups * oh_chunks * jcp.nb_ow * oc_chunks; + = MB * jcp.ngroups * oh_chunks * jcp.nb_ow * oc_chunks; const int zp_pbuff_size = jcp.zp_pbuff_size; // reorder weights from (g)Owhi16o to (g)OR16r16o4r, where r := whi @@ -224,7 +226,7 @@ jit_avx512_core_amx_convolution_fwd_t::execute_forward_reduced_lowering( int mb {0}, g {0}, ohc {0}, owb {0}, occ {0}; // need "inner" oh blocks w.r.t. ow blocks to allow pbuffer reuse - nd_iterator_init(start, mb, jcp.mb, g, jcp.ngroups, owb, jcp.nb_ow, ohc, + nd_iterator_init(start, mb, MB, g, jcp.ngroups, owb, jcp.nb_ow, ohc, oh_chunks, occ, oc_chunks); int last_copied_mb = -1; int last_copied_ohc = -1; @@ -387,6 +389,7 @@ jit_avx512_core_amx_convolution_fwd_t::execute_forward_reduced_lowering( p.oc_blocks = occ * jcp.nb_oc_blocking; p.oc_l_off = oc; + p.oc_off = oc * sizeof(float); p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); p.dst_orig = dst; @@ -406,7 +409,7 @@ jit_avx512_core_amx_convolution_fwd_t::execute_forward_reduced_lowering( last_copied_g = g; ++start; // need "inner" oh blocks w.r.t. ow blocks to allow pbuffer reuse - nd_iterator_step(mb, jcp.mb, g, jcp.ngroups, owb, jcp.nb_ow, ohc, + nd_iterator_step(mb, MB, g, jcp.ngroups, owb, jcp.nb_ow, ohc, oh_chunks, occ, oc_chunks); } @@ -424,6 +427,8 @@ status_t jit_avx512_core_amx_convolution_fwd_t::execute_forward( const auto post_ops_binary_rhs_arg_vec = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const memory_desc_wrapper src_d(pd()->src_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); @@ -488,7 +493,7 @@ status_t jit_avx512_core_amx_convolution_fwd_t::execute_forward( const int ngroups = jcp.ngroups; const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; const int oh_chunks = utils::div_up(jcp.oh, jcp.oh_blk_size); - const size_t work_amount = (size_t)jcp.mb * jcp.ngroups * jcp.od * oh_chunks + const size_t work_amount = (size_t)MB * jcp.ngroups * jcp.od * oh_chunks * jcp.nb_ow * oc_chunks; const int zp_pbuff_size = jcp.zp_pbuff_size; @@ -588,7 +593,7 @@ status_t jit_avx512_core_amx_convolution_fwd_t::execute_forward( const int owb_limit = jcp.nb_ow - jcp.r_pad_blk - jcp.no_pad_w_blk; int mb {0}, g {0}, odc {0}, ohc {0}, owb {0}, occ {0}; - nd_iterator_init(start, mb, jcp.mb, g, jcp.ngroups, odc, jcp.od, ohc, + nd_iterator_init(start, mb, MB, g, jcp.ngroups, odc, jcp.od, ohc, oh_chunks, owb, jcp.nb_ow, occ, oc_chunks); int last_copied_mb = -1; int last_copied_odc = -1; @@ -772,6 +777,7 @@ status_t jit_avx512_core_amx_convolution_fwd_t::execute_forward( p.oc_blocks = occ * jcp.nb_oc_blocking; p.oc_l_off = oc; + p.oc_off = oc * sizeof(float); p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); p.dst_orig = dst; @@ -791,7 +797,7 @@ status_t jit_avx512_core_amx_convolution_fwd_t::execute_forward( last_copied_owb = owb; last_copied_g = g; ++start; - nd_iterator_step(mb, jcp.mb, g, jcp.ngroups, odc, jcp.od, ohc, + nd_iterator_step(mb, MB, g, jcp.ngroups, odc, jcp.od, ohc, oh_chunks, owb, jcp.nb_ow, occ, oc_chunks); } diff --git a/src/cpu/x64/jit_avx512_core_amx_convolution.hpp b/src/cpu/x64/jit_avx512_core_amx_convolution.hpp index 2e39bb6779d..959402fff1e 100644 --- a/src/cpu/x64/jit_avx512_core_amx_convolution.hpp +++ b/src/cpu/x64/jit_avx512_core_amx_convolution.hpp @@ -54,7 +54,7 @@ struct jit_avx512_core_amx_convolution_fwd_t : public primitive_t { && utils::one_of(dst_md(0)->data_type, f32, bf16)) && IMPLICATION(with_bias(), utils::one_of(weights_md(1)->data_type, f32, bf16)) - && attr()->has_default_values(smask_t::post_ops); + && attr()->has_default_values(smask_t::post_ops, dst_md(0)->data_type); bool is_int8_convolution = utils::one_of(src_md(0)->data_type, s8, u8) && weights_md(0)->data_type == s8 diff --git a/src/cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.cpp index ead9014acb9..df935178a44 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.cpp @@ -47,7 +47,7 @@ jit_avx512_core_bf16_1x1_conv_kernel::jit_avx512_core_bf16_1x1_conv_kernel( : jit_generator(nullptr, ker_code_size, true, avx512_core_bf16) , jcp(ajcp) , attr_(attr) { - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; @@ -62,10 +62,12 @@ jit_avx512_core_bf16_1x1_conv_kernel::jit_avx512_core_bf16_1x1_conv_kernel( use_exact_tail_scalar_bcast}; const static_params_t static_params { this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params + {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique< injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } if (!isa_has_bf16(jcp.isa)) @@ -184,7 +186,17 @@ static void iterate(const int load_loop_blk, const int ur, const F &f) { void jit_avx512_core_bf16_1x1_conv_kernel::apply_postops( const int load_loop_blk, const int ur) { - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { + std::map vmm_idx_off; + iterate(load_loop_blk, ur, + [&](const bool, const int i_load, const int i_ur) { + vmm_idx_off.insert({vreg_accum_idx(load_loop_blk, i_load, i_ur), i_load * jcp.oc_block * sizeof(float)}); + }); + + depthwise_injector::dynamic_params_t ddp {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + reg_oc_off, vmm_idx_off, this->rsp, base_post_ops_data_offset}; + quantization_injector::dynamic_params_t qdp {reg_oc_off, vmm_idx_off, jcp.dst_dt, this->rsp, base_post_ops_data_offset}; + injector_utils::vmm_index_set_t vmm_idxs; if (jcp.with_binary) { binary_injector::rhs_arg_dynamic_params_t rhs_arg_params, @@ -231,7 +243,7 @@ void jit_avx512_core_bf16_1x1_conv_kernel::apply_postops( jmp(postops_done, T_NEAR); L(postops_no_tail); } - postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); L(postops_done); } else { @@ -240,7 +252,7 @@ void jit_avx512_core_bf16_1x1_conv_kernel::apply_postops( vmm_idxs.emplace( vreg_accum_idx(load_loop_blk, i_load, i_ur)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t(), ddp, qdp); } } } @@ -846,6 +858,8 @@ void jit_avx512_core_bf16_1x1_conv_kernel::reduce_loop( mov(aux_reg_bcast_data, aux1_reg_bcast_data); init(); + push(reg_oc_off); + mov(reduce_loop_iter, reg_reduce_loop_work); Label reduce_loop_exit; cmp(reduce_loop_iter, jcp.reduce_loop_unroll); @@ -867,6 +881,9 @@ void jit_avx512_core_bf16_1x1_conv_kernel::reduce_loop( fma_block(true); L(reduce_loop_exit); + + pop(reg_oc_off); + store(); } @@ -982,7 +999,12 @@ void jit_avx512_core_bf16_1x1_conv_kernel::compute_diff_bias( void jit_avx512_core_bf16_1x1_conv_kernel::generate() { preamble(); + if (postops_injector_) + postops_injector_->push_post_ops_data_on_stack(this->param1, GET_OFF(post_ops_binary_rhs_arg_vec), reg_bcast_data, reg_load_data); + sub(rsp, stack_space_needed); + base_post_ops_data_offset += stack_space_needed; + if (jcp.with_binary) { const auto zeroed_reg = r15; xor_(zeroed_reg, zeroed_reg); @@ -1038,6 +1060,7 @@ void jit_avx512_core_bf16_1x1_conv_kernel::generate() { mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); } + mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]); auto load_loop_body = [=](int load_loop_blk) { Label no_update_mask, update_mask_done; if (load_dim_tail) { @@ -1059,6 +1082,7 @@ void jit_avx512_core_bf16_1x1_conv_kernel::generate() { mov(reg_load_loop_work, ptr[rsp + reg_load_loop_work_off]); add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); + add(reg_oc_off, load_loop_blk * jcp.oc_block * sizeof(float)); switch (jcp.prop_kind) { case forward_training: case forward_inference: @@ -1159,6 +1183,10 @@ void jit_avx512_core_bf16_1x1_conv_kernel::generate() { L(load_loop_blk[num_ur_cases]); add(rsp, stack_space_needed); + base_post_ops_data_offset -= stack_space_needed; + + if (postops_injector_) + postops_injector_->reset_stack_pointer(); postamble(); @@ -1247,6 +1275,8 @@ status_t jit_avx512_core_bf16_1x1_conv_kernel::init_conf( const int binary_ind = post_ops.find(primitive_kind::binary, 0, dw_conv_ind); jcp.with_binary = binary_ind != -1; + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise, 0, dw_conv_ind) != -1; + jcp.with_quantization = post_ops.find(primitive_kind::quantization, 0, dw_conv_ind) != -1; if (dw_conv_ind >= 0) { // dw_conv and post_ops after it are handled externally, so skip them @@ -1260,7 +1290,7 @@ status_t jit_avx512_core_bf16_1x1_conv_kernel::init_conf( static constexpr bool sum_at_pos_0_only = true; static constexpr bool sum_requires_scale_one = true; static constexpr bool sum_requires_zp_zero = true; - const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum}, + const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero}); if (!post_ops_ok_) return status::unimplemented; @@ -1268,8 +1298,8 @@ status_t jit_avx512_core_bf16_1x1_conv_kernel::init_conf( using namespace format_tag; const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc); const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); - jcp.src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c); - jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c); + jcp.src_tag = src_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_nCx16c); + jcp.dst_tag = dst_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_nCx16c); bool is_data_layout_nxc = utils::everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag); auto required_dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; diff --git a/src/cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.hpp b/src/cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.hpp index 5edece3df90..c17f5149d25 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.hpp +++ b/src/cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.hpp @@ -111,6 +111,13 @@ struct jit_avx512_core_bf16_1x1_conv_kernel : public jit_generator { Xbyak::Opmask half_mask = Xbyak::Opmask(6); Xbyak::Opmask half_mask_hi = Xbyak::Opmask(5); Xbyak::Label dst_prm_table; + reg64_t reg_oc_off = abi_param1; + reg64_t reg_d_weights = imm_addr64; + reg64_t reg_d_bias = aux_reg_bcast_data; + int base_post_ops_data_offset = 0; + + Xbyak::Zmm zmm_d_weights = Xbyak::Zmm(31); + Xbyak::Zmm zmm_d_bias = Xbyak::Zmm(30); constexpr static int reg64_size_ = sizeof(int64_t); constexpr static int bcast_loop_work_offt = 0; diff --git a/src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.cpp b/src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.cpp index 002fdd61bd5..641e55f0872 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.cpp +++ b/src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.cpp @@ -80,6 +80,8 @@ void jit_avx512_core_bf16_1x1_convolution_fwd_t::execute_forward( pd()->jcp_.post_ops.entry_.size() + 1) : std::vector {}; + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + auto scratchpad = ctx.get_scratchpad_grantor(); const auto &jcp = kernel_->jcp; @@ -115,7 +117,7 @@ void jit_avx512_core_bf16_1x1_convolution_fwd_t::execute_forward( parallel(jcp.nthr, [&](const int ithr, const int nthr) { execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw, dst, scratchpad, post_ops_binary_rhs_arg_vec.data(), - post_ops_binary_rhs_arg_vec_dw.data()); + post_ops_binary_rhs_arg_vec_dw.data(), MB); }); if (pd()->wants_zero_pad_dst()) ctx.zero_pad_output(DNNL_ARG_DST); @@ -128,7 +130,7 @@ void jit_avx512_core_bf16_1x1_convolution_fwd_t::execute_forward_thr( const dw_wei_data_t *weights_dw, const float *bias_dw, const char *dst, const memory_tracking::grantor_t &scratchpad, const void *post_ops_binary_rhs_arg_vec, - const void *post_ops_binary_rhs_arg_vec_dw) const { + const void *post_ops_binary_rhs_arg_vec_dw, int MB) const { const memory_desc_wrapper src_d(pd()->src_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); @@ -174,6 +176,8 @@ void jit_avx512_core_bf16_1x1_convolution_fwd_t::execute_forward_thr( const int nb_buffer = jcp.nb_load_blocking; std::vector addrs; + auto start_off = dst_d.off_l(0); + auto step = [](int default_step, int remaining, int tail_step) { assert(default_step <= tail_step); return remaining < tail_step ? remaining : default_step; @@ -183,7 +187,7 @@ void jit_avx512_core_bf16_1x1_convolution_fwd_t::execute_forward_thr( int &bcast_step, int &od, int &oh, int &ow, int &id, int &ih, int &iw) { int osb {0}; - nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, nb_bcast); + nd_iterator_init(iwork, n, MB, g, jcp.ngroups, osb, nb_bcast); bcast_step = step( nb_bcast_blocking, nb_bcast - osb, nb_bcast_blocking_max); bcast_step = nstl::min(bcast_step, bcast_end - iwork); @@ -266,12 +270,13 @@ void jit_avx512_core_bf16_1x1_convolution_fwd_t::execute_forward_thr( : rnd_up((jcp.load_dim / grp_count), jcp.load_block); const size_t str_size = jcp.bcast_dim * max_load_per_thread; p.store_buffer = store_buffer + ithr * str_size - + data_blk_off(dst_d, 0, 0, od, oh, ow); + + data_blk_off(dst_d, 0, 0, od, oh, ow) - start_off; p.dst_l_off = dst_off; p.oc_l_off = oc_off_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block); p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; p.dst_orig = dst; + p.oc_off = oc_off_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block) * sizeof(float); (*kernel_)(&p); }; @@ -376,6 +381,8 @@ void jit_avx512_core_bf16_1x1_convolution_fwd_t::execute_forward_thr( = post_ops_binary_rhs_arg_vec_dw; par_conv_dw.dst_orig = dst; + par_conv_dw.oc_off = ch * jcp_dw->ch_block * sizeof(float); + (*kernel_dw_)(&par_conv_dw); for (int i = 0; i < jcp_dw->kh; ++i) @@ -397,7 +404,7 @@ void jit_avx512_core_bf16_1x1_convolution_fwd_t::execute_forward_thr( addrs.resize(jcp_dw->kh); int bcast_start {0}, bcast_end {0}, ocb_start, ocb_end; - balance2D(nthr, ithr, jcp.mb * jcp.ngroups * jcp_dw->oh, bcast_start, + balance2D(nthr, ithr, MB * jcp.ngroups * jcp_dw->oh, bcast_start, bcast_end, nb_oc, ocb_start, ocb_end, jcp.load_grp_count); while (ocb_start < ocb_end) { @@ -408,7 +415,7 @@ void jit_avx512_core_bf16_1x1_convolution_fwd_t::execute_forward_thr( auto bcast_iter = bcast_start; while (bcast_iter < bcast_end) { int n {0}, g {0}, oh_dw {0}; - nd_iterator_init(bcast_iter, n, jcp.mb, g, jcp.ngroups, oh_dw, + nd_iterator_init(bcast_iter, n, MB, g, jcp.ngroups, oh_dw, jcp_dw->oh); if (oh_dw == 0) oh_1x1 = 0; // Reset over mb boundary const int oh_1x1_range @@ -438,7 +445,7 @@ void jit_avx512_core_bf16_1x1_convolution_fwd_t::execute_forward_thr( if (jcp.with_dw_conv) { conv_dw(); } else { - const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + const int work_amount = MB * jcp.ngroups * jcp.nb_bcast; int bcast_start {0}, bcast_end {0}, ocb_start {0}, ocb_end {0}; balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, jcp.nb_load, ocb_start, ocb_end, jcp.load_grp_count); @@ -458,12 +465,18 @@ void jit_avx512_core_bf16_1x1_convolution_bwd_data_t< auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); auto diff_src = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC); + + const auto post_ops_binary_rhs_arg_vec + = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); + + auto MB = CTX_IN_BATCH(DNNL_ARG_DIFF_DST); + auto scratchpad = ctx.get_scratchpad_grantor(); const auto &jcp = kernel_->jcp; parallel(jcp.nthr, [&](const int ithr, const int nthr) { assert(nthr == jcp.nthr); execute_backward_data_thr( - ithr, nthr, diff_dst, weights, diff_src, scratchpad); + ithr, nthr, diff_dst, weights, diff_src, scratchpad, MB, post_ops_binary_rhs_arg_vec.data()); }); } @@ -472,7 +485,8 @@ void jit_avx512_core_bf16_1x1_convolution_bwd_data_t< diff_src_type>::execute_backward_data_thr(const int ithr, const int nthr, const diff_dst_data_t *diff_dst, const wei_data_t *weights, diff_src_data_t *diff_src, - const memory_tracking::grantor_t &scratchpad) const { + const memory_tracking::grantor_t &scratchpad, int MB, + const void *post_ops_binary_rhs_arg_vec) const { const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); @@ -489,7 +503,7 @@ void jit_avx512_core_bf16_1x1_convolution_bwd_data_t< const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[ndims - 4]; const int stride_w = pd()->desc()->strides[ndims - 3]; - const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + const int work_amount = MB * jcp.ngroups * jcp.nb_bcast; auto step = [](int default_step, int remaining, int tail_step) { assert(default_step <= tail_step); @@ -511,7 +525,7 @@ void jit_avx512_core_bf16_1x1_convolution_bwd_data_t< auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step, int &od, int &oh, int &ow, int &id, int &ih, int &iw) { int osb {0}; - nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, jcp.nb_bcast); + nd_iterator_init(iwork, n, MB, g, jcp.ngroups, osb, jcp.nb_bcast); bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, jcp.nb_bcast_blocking_max); bcast_step = nstl::min(bcast_step, bcast_end - iwork); @@ -585,6 +599,10 @@ void jit_avx512_core_bf16_1x1_convolution_bwd_data_t< const size_t str_size = jcp.bcast_dim * max_load_per_thread; p.store_buffer = store_buffer + ithr * str_size + data_blk_off(diff_src_d, 0, 0, id, ih, iw); + + p.oc_off = ic_off_idx * (is_dsrc_layout_nxc ? 1 : jcp.ic_block) * sizeof(float); + p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; + (*kernel_)(&p); if (pd()->rtus_.reduce_src_) (*rtus_driver_)(&rp); }; diff --git a/src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.hpp b/src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.hpp index 216d7cbde42..0b1dba2de8a 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.hpp +++ b/src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.hpp @@ -289,7 +289,7 @@ struct jit_avx512_core_bf16_1x1_convolution_fwd_t : public primitive_t { if (pd()->jcp_.with_dw_conv) { CHECK(safe_ptr_assign(kernel_dw_, - new dw_conv_kernel_t(*(pd()->jcp_dw_), *pd()->dst_md(0)))); + new dw_conv_kernel_t(*(pd()->jcp_dw_), *pd()->dst_md(0), *pd()->dw_conv_pd_->attr()))); CHECK(kernel_dw_->create_kernel()); } @@ -309,7 +309,7 @@ struct jit_avx512_core_bf16_1x1_convolution_fwd_t : public primitive_t { const dw_wei_data_t *weights_dw, const float *bias_dw, const char *dst, const memory_tracking::grantor_t &scratchpad, const void *post_ops_binary_rhs_arg_vec, - const void *post_ops_binary_rhs_arg_vec_dw) const; + const void *post_ops_binary_rhs_arg_vec_dw, int MB) const; const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } std::unique_ptr kernel_; @@ -336,8 +336,9 @@ struct jit_avx512_core_bf16_1x1_convolution_bwd_data_t : public primitive_t { && set_default_alg_kind(alg_kind::convolution_direct) && expect_data_types(diff_src_type, data_type::bf16, data_type::undef, data_type::bf16, data_type::undef) - && attr()->has_default_values() && !has_zero_dim_memory() - && set_default_formats(); + && !has_zero_dim_memory() + && set_default_formats() + && is_supported_post_ops(); if (!ok) return status::unimplemented; const convolution_desc_t *conv_d = desc(); @@ -373,6 +374,23 @@ struct jit_avx512_core_bf16_1x1_convolution_bwd_data_t : public primitive_t { return set_default_formats_common(dat_tag, wei_tag, dat_tag); } + + bool is_supported_post_ops() const { + const auto &p = this->attr()->post_ops_; + if (p.len() > 1) + return false; + + auto all_post_ops_supported = [&]() { + bool ok = true; + + for (int i = 0; i < p.len(); i++) { + ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::depthwise); + } + return ok; + }; + + return all_post_ops_supported(); + } }; template @@ -403,7 +421,8 @@ struct jit_avx512_core_bf16_1x1_convolution_bwd_data_t : public primitive_t { void execute_backward_data(const exec_ctx_t &ctx) const; void execute_backward_data_thr(const int, const int, const diff_dst_data_t *, const wei_data_t *, diff_src_data_t *, - const memory_tracking::grantor_t &scratchpad) const; + const memory_tracking::grantor_t &scratchpad, int MB, + const void *post_ops_binary_rhs_arg_vec) const; const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } std::unique_ptr kernel_; diff --git a/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.cpp index 59de422225e..59d181f4bbc 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.cpp @@ -90,7 +90,7 @@ inline bool is_1stconv(const jit_conv_conf_t &jcp) { * nstl::max(jcp.typesize_in, jcp.typesize_out) * jcp.id * jcp.ih * jcp.iw < INT_MAX; - return jcp.ic < 16 && jcp.ngroups == 1 && no_big_offt; + return one_of(jcp.ic, 1, 2, 3) && jcp.ngroups == 1 && no_big_offt; } } // namespace @@ -101,7 +101,7 @@ _jit_avx512_core_bf16_fwd_kernel::_jit_avx512_core_bf16_fwd_kernel( : jit_generator(nullptr, ker_code_size, true, avx512_core_bf16) , jcp(ajcp) , attr_(attr) { - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; @@ -119,10 +119,12 @@ _jit_avx512_core_bf16_fwd_kernel::_jit_avx512_core_bf16_fwd_kernel( use_exact_tail_scalar_bcast}; const static_params_t static_params { this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params + {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique< - injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + injector::jit_uni_postops_injector_t>( + this, jcp.post_ops, static_params, quantization_static_params); } if (!isa_has_bf16(jcp.isa)) bf16_emu_ = utils::make_unique(this, @@ -169,7 +171,18 @@ static void iterate(const int nb_oc_block, const int ur_w, const F &f) { template void _jit_avx512_core_bf16_fwd_kernel::apply_postops(int ur_w) { - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { + std::map vmm_idx_off; + iterate(jcp.nb_oc_blocking, ur_w, + [&](const bool, const int k, const int j) { + vmm_idx_off.insert({vmm_dst_idx(j, k), k * jcp.oc_block * sizeof(float)}); + }); + depthwise_injector::dynamic_params_t ddp {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off, + this->rsp, base_post_ops_data_offset}; + quantization_injector::dynamic_params_t qdp {ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off, jcp.dst_dt, + this->rsp, base_post_ops_data_offset}; + injector_utils::vmm_index_set_t vmm_idxs; if (jcp.with_binary) { binary_injector::rhs_arg_dynamic_params_t rhs_arg_params, @@ -206,7 +219,7 @@ void _jit_avx512_core_bf16_fwd_kernel::apply_postops(int ur_w) { jmp(postops_done, T_NEAR); L(postops_no_tail); } - postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); L(postops_done); } else { @@ -214,7 +227,7 @@ void _jit_avx512_core_bf16_fwd_kernel::apply_postops(int ur_w) { [&](const bool, const int k, const int j) { vmm_idxs.emplace(vmm_dst_idx(j, k)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t(), ddp, qdp); } } } @@ -589,7 +602,14 @@ void _jit_avx512_core_bf16_fwd_kernel::generate() { = get_src_offset(0, filter_w_to_src(0, 0, l_pad)); preamble(); - if (jcp.ndims == 5) sub(rsp, stack_space_needed_); + + if (postops_injector_) + postops_injector_->push_post_ops_data_on_stack(this->param1, GET_OFF(post_ops_binary_rhs_arg_vec), reg_src, reg_dst); + + if (jcp.ndims == 5) { + sub(rsp, stack_space_needed_); + base_post_ops_data_offset += stack_space_needed_; + } if (jcp.is_1stconv || jcp.ic_tail) { Xbyak::Reg64 reg_alt_mask = r8; @@ -801,7 +821,14 @@ void _jit_avx512_core_bf16_fwd_kernel::generate() { L(end_label); } - if (jcp.ndims == 5) add(rsp, stack_space_needed_); + if (jcp.ndims == 5) { + add(rsp, stack_space_needed_); + base_post_ops_data_offset -= stack_space_needed_; + } + + if (postops_injector_) + postops_injector_->reset_stack_pointer(); + postamble(); if (jcp.with_eltwise) postops_injector_->prepare_table(); @@ -881,19 +908,19 @@ status_t jit_avx512_core_bf16_fwd_kernel::init_conf(jit_conv_conf_t &jcp, jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh); jcp.back_pad = calculate_end_padding( jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd); - bool kernel_outside_src = false || ext_kw <= jcp.l_pad - || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad - || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad; - if (kernel_outside_src) return status::unimplemented; +// bool kernel_outside_src = false || ext_kw <= jcp.l_pad +// || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad +// || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad; +// if (kernel_outside_src) return status::unimplemented; const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc); const auto dat_tag_ncx = pick(ndims - 3, ncw, nchw, ncdhw); const auto dat_tag_nCx4c = pick(ndims - 3, nCw4c, nChw4c, nCdhw4c); const auto dat_tag_nCx8c = pick(ndims - 3, nCw8c, nChw8c, nCdhw8c); const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); - auto curr_src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c, + auto curr_src_tag = src_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c, dat_tag_ncx); - auto curr_dst_tag = dst_d.matches_one_of_tag( + auto curr_dst_tag = dst_d.mb_stride_relaxed_match( dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c); bool is_data_layout_nxc = utils::everyone_is(dat_tag_nxc, curr_src_tag, curr_dst_tag); @@ -996,6 +1023,8 @@ status_t jit_avx512_core_bf16_fwd_kernel::init_conf(jit_conv_conf_t &jcp, } const int binary_ind = post_ops.find(primitive_kind::binary); jcp.with_binary = binary_ind != -1; + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise) != -1; + jcp.with_quantization = post_ops.find(primitive_kind::quantization) != -1; jcp.ic_tail = is_data_layout_nxc ? jcp.ic % jcp.simd_w : 0; if (is_data_layout_nxc) @@ -1012,7 +1041,7 @@ status_t jit_avx512_core_bf16_fwd_kernel::init_conf(jit_conv_conf_t &jcp, static constexpr bool sum_at_pos_0_only = true; static constexpr bool sum_requires_scale_one = true; static constexpr bool sum_requires_zp_zero = true; - const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum}, + const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero}); if (!post_ops_ok_) return status::unimplemented; @@ -1097,6 +1126,27 @@ void _jit_avx512_core_bf16_bwd_data_kernel::store_output(int ur_w) { if (!isa_has_bf16(jcp.isa)) bf16_emu_->init_vcvtneps2bf16(); const int ic_tail = jcp.ic_tail; + int depthwise_inj_idx = 0; + std::size_t post_ops_data_offset = 0; + const auto& p = attr_.post_ops_; + for (int i = 0; i < p.len(); i++) { + auto& post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + mov(reg_d_weights, ptr[this->rsp + post_ops_data_offset]); + add(reg_d_weights, ptr[this->param1 + GET_OFF(oc_off)]); + + for (int k = 0; k < jcp.nb_ic_blocking; k++) { + depthwise_injectors[depthwise_inj_idx]->compute_vector_range( + k * jcp.ur_w, k * jcp.ur_w + ur_w, reg_d_weights, reg_d_weights); + + add(reg_d_weights, jcp.ic_block * sizeof(float)); + } + + post_ops_data_offset += depthwise_injectors[depthwise_inj_idx]->memoryStep(); + depthwise_inj_idx++; + } + } + if (jcp.dst_dt == data_type::f32) { for (int k = 0; k < jcp.nb_ic_blocking; k++) for (int j = 0; j < ur_w; j++) { @@ -1341,6 +1391,17 @@ void _jit_avx512_core_bf16_bwd_data_kernel::compute_loop( template void _jit_avx512_core_bf16_bwd_data_kernel::generate() { + const auto &p = attr_.post_ops_; + for (int i = 0; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32( + this, + post_op + )); + } + } + int iw = jcp.iw; int kw = jcp.kw; int ur_w = jcp.ur_w; @@ -1355,6 +1416,26 @@ void _jit_avx512_core_bf16_bwd_data_kernel::generate() { preamble(); + std::size_t post_ops_pointers_count = 0; + for (int i = 0; i < p.len(); i++) { + if (p.entry_[i].is_depthwise() || p.entry_[i].is_quantization()) { + post_ops_pointers_count++; + } + } + + if (post_ops_pointers_count != 0) { + sub(rsp, post_ops_pointers_count * sizeof(float *)); + + auto aux_reg0 = reg_src; + auto aux_reg1 = reg_dst; + + mov(aux_reg0, ptr[this->param + GET_OFF(post_ops_binary_rhs_arg_vec)]); + for (size_t i = 0; i < post_ops_pointers_count; i++) { + mov(aux_reg1, ptr[aux_reg0 + i * sizeof(float *)]); + mov(ptr[rsp + i * sizeof(float *)], aux_reg1); + } + } + if (jcp.simd_w == 4) { Reg32 reg_tail_32 = reg_oc.cvt32(); mov(reg_tail_32, (1 << jcp.simd_w) - 1); @@ -1524,12 +1605,33 @@ void _jit_avx512_core_bf16_bwd_data_kernel::generate() { if (ur_w_tail != 0) { compute_loop(ur_w_tail, 0, r_overflow); } L(end_label); + if (post_ops_pointers_count != 0) { + add(rsp, post_ops_pointers_count * sizeof(float *)); + } + postamble(); } +bool jit_avx512_core_bf16_bwd_data_kernel::post_ops_ok( + jit_conv_conf_t& jcp, const primitive_attr_t& attr) { + const auto& p = attr.post_ops_; + + auto all_post_ops_supported = [&]() { + bool ok = true; + + for (int i = 0; i < p.len(); i++) { + ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::depthwise); + } + return ok; + }; + + return all_post_ops_supported(); +} + status_t jit_avx512_core_bf16_bwd_data_kernel::init_conf(jit_conv_conf_t &jcp, const convolution_desc_t &cd, memory_desc_t &diff_src_md, - memory_desc_t &weights_md, memory_desc_t &diff_dst_md, int nthreads) { + memory_desc_t &weights_md, memory_desc_t &diff_dst_md, + const primitive_attr_t& attr, int nthreads) { const memory_desc_wrapper diff_src_d(&diff_src_md); const memory_desc_wrapper weights_d(&weights_md); @@ -1604,9 +1706,9 @@ status_t jit_avx512_core_bf16_bwd_data_kernel::init_conf(jit_conv_conf_t &jcp, const auto dat_tag_nCx4c = pick(ndims - 3, nCw4c, nChw4c, nCdhw4c); const auto dat_tag_nCx8c = pick(ndims - 3, nCw8c, nChw8c, nCdhw8c); const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); - auto curr_src_tag = diff_src_d.matches_one_of_tag( + auto curr_src_tag = diff_src_d.mb_stride_relaxed_match( dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c); - auto curr_dst_tag = diff_dst_d.matches_one_of_tag( + auto curr_dst_tag = diff_dst_d.mb_stride_relaxed_match( dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c); bool is_data_layout_nxc = utils::everyone_is(dat_tag_nxc, curr_src_tag, curr_dst_tag); @@ -1685,6 +1787,10 @@ status_t jit_avx512_core_bf16_bwd_data_kernel::init_conf(jit_conv_conf_t &jcp, && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; if (!args_ok) return status::unimplemented; + if (!post_ops_ok(jcp, attr)) return status::unimplemented; + + jcp.post_ops = attr.post_ops_; + jcp.nb_ic = utils::div_up(jcp.ic, jcp.ic_block); jcp.nb_oc = utils::div_up(jcp.oc, jcp.oc_block); diff --git a/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.hpp b/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.hpp index 82a109af720..67a04bca0a7 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.hpp +++ b/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.hpp @@ -127,10 +127,17 @@ struct _jit_avx512_core_bf16_fwd_kernel : public jit_generator { constexpr static int off_reg_ker_ = 8; constexpr static int stack_space_needed_ = 16; - std::unique_ptr> + std::unique_ptr> postops_injector_; std::unique_ptr bf16_emu_; + reg64_t reg_d_weights = r15; + reg64_t reg_d_bias = reg_kj; + int base_post_ops_data_offset = 0; + + Xbyak::Zmm zmm_d_weights = Xbyak::Zmm(31); + Xbyak::Zmm zmm_d_bias = Xbyak::Zmm(30); + inline void prepare_dst(int ur_w); void apply_postops(int ur_w); inline void store_dst(int ur_w); @@ -252,9 +259,10 @@ struct jit_avx512_core_bf16_fwd_kernel { template struct _jit_avx512_core_bf16_bwd_data_kernel : public jit_generator { - _jit_avx512_core_bf16_bwd_data_kernel(const jit_conv_conf_t &ajcp) + _jit_avx512_core_bf16_bwd_data_kernel(const jit_conv_conf_t &ajcp, const primitive_attr_t& attr) : jit_generator(nullptr, ker_code_size, true, avx512_core_bf16) , jcp(ajcp) + , attr_(attr) , bf16_emu_(nullptr) { if (!isa_has_bf16(jcp.isa)) bf16_emu_ = utils::make_unique(this, @@ -262,9 +270,16 @@ struct _jit_avx512_core_bf16_bwd_data_kernel : public jit_generator { bf16_emu_scratch, bf16_emu_reserv_4, bf16_emu_reserv_5); } + ~_jit_avx512_core_bf16_bwd_data_kernel() { + for (auto inj : depthwise_injectors) + delete inj; + depthwise_injectors.clear(); + } + DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_bf16_bwd_data_kernel_f32) const jit_conv_conf_t &jcp; + const primitive_attr_t& attr_; private: using Vmm_down_t = @@ -343,6 +358,11 @@ struct _jit_avx512_core_bf16_bwd_data_kernel : public jit_generator { Vmm vmm_wei = Vmm(31); std::unique_ptr bf16_emu_; + reg64_t reg_d_weights = r15; + reg64_t reg_d_bias = reg_kj; + + nstl::vector*> depthwise_injectors; + inline void prepare_output(int ur_w); inline void store_output(int ur_w); inline void compute_loop(int ur_w, int l_overflow, int r_overflow); @@ -422,20 +442,20 @@ struct _jit_avx512_core_bf16_bwd_data_kernel : public jit_generator { struct jit_avx512_core_bf16_bwd_data_kernel { - jit_avx512_core_bf16_bwd_data_kernel(const jit_conv_conf_t &ajcp) + jit_avx512_core_bf16_bwd_data_kernel(const jit_conv_conf_t &ajcp, const primitive_attr_t& attr) : kernel_(nullptr) { switch (ajcp.ic_block) { case 16: kernel_ = new _jit_avx512_core_bf16_bwd_data_kernel( - ajcp); + ajcp, attr); return; case 8: kernel_ = new _jit_avx512_core_bf16_bwd_data_kernel( - ajcp); + ajcp, attr); return; case 4: kernel_ = new _jit_avx512_core_bf16_bwd_data_kernel( - ajcp); + ajcp, attr); return; default: assert(!"invalid channel blocking"); } @@ -445,10 +465,12 @@ struct jit_avx512_core_bf16_bwd_data_kernel { ~jit_avx512_core_bf16_bwd_data_kernel() { delete kernel_; } + static bool post_ops_ok(jit_conv_conf_t& jcp, const primitive_attr_t& attr); + static status_t init_conf(jit_conv_conf_t &jcp, const convolution_desc_t &cd, memory_desc_t &diff_src_md, memory_desc_t &weights_md, memory_desc_t &diff_dst_md, - int nthreads); + const primitive_attr_t& attr, int nthreads); void operator()(const jit_conv_call_s *p) const { (*kernel_)(p); } const Xbyak::uint8 *jit_ker() const { return kernel_->jit_ker(); } diff --git a/src/cpu/x64/jit_avx512_core_bf16_convolution.cpp b/src/cpu/x64/jit_avx512_core_bf16_convolution.cpp index d5fa07110cc..5d3dffdde45 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_convolution.cpp +++ b/src/cpu/x64/jit_avx512_core_bf16_convolution.cpp @@ -61,6 +61,8 @@ void jit_avx512_core_bf16_convolution_fwd_t::execute_forward_1d( const auto post_ops_binary_rhs_arg_vec = binary_injector::prepare_binary_args(jcp.post_ops, ctx); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + prepare_padded_bias(bias, ctx.get_scratchpad_grantor()); const size_t bia_dt_size = pd()->jcp_.typesize_bia; @@ -75,7 +77,7 @@ void jit_avx512_core_bf16_convolution_fwd_t::execute_forward_1d( // TODO: experiment with g_blocking for perf fine tuning int g_blocking = 1; int nb_groups = jcp.ngroups / g_blocking; - dim_t work_amount = jcp.mb * nb_groups * oc_chunks * jcp.nb_ow; + dim_t work_amount = MB * nb_groups * oc_chunks * jcp.nb_ow; int nthr = jcp.aligned_threads ? jcp.aligned_threads : jcp.nthr; parallel(nthr, [&](const int ithr, const int nthr) { @@ -88,13 +90,13 @@ void jit_avx512_core_bf16_convolution_fwd_t::execute_forward_1d( if (jcp.loop_order == loop_cwgn) { int dummy {0}; nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg, - nb_groups, n, jcp.mb, dummy, 1); + nb_groups, n, MB, dummy, 1); } else if (jcp.loop_order == loop_gncw) { int dummy {0}; - nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, oc_chunks, + nd_iterator_init(start, gg, nb_groups, n, MB, occ, oc_chunks, owb, jcp.nb_ow, dummy, 1); } else if (jcp.loop_order == loop_nhwcg) { - nd_iterator_init(start, n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks, + nd_iterator_init(start, n, MB, owb, jcp.nb_ow, occ, oc_chunks, gg, nb_groups); } else assert(!"unsupported loop order"); @@ -132,19 +134,21 @@ void jit_avx512_core_bf16_convolution_fwd_t::execute_forward_1d( par_conv.dst_orig = dst; par_conv.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); + par_conv.oc_off = oc_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block) * sizeof(float); + (*kernel_)(&par_conv); if (jcp.loop_order == loop_cwgn) { int dummy {0}; nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, gg, - nb_groups, n, jcp.mb, dummy, 1); + nb_groups, n, MB, dummy, 1); } else if (jcp.loop_order == loop_gncw) { int dummy {0}; - nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, occ, + nd_iterator_jump(start, end, gg, nb_groups, n, MB, occ, oc_chunks, owb, jcp.nb_ow, dummy, 1); } else if (jcp.loop_order == loop_nhwcg) { ++start; - nd_iterator_step(n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks, gg, + nd_iterator_step(n, MB, owb, jcp.nb_ow, occ, oc_chunks, gg, nb_groups); } else assert(!"unsupported loop order"); @@ -162,6 +166,8 @@ void jit_avx512_core_bf16_convolution_fwd_t::execute_forward_2d( const auto post_ops_binary_rhs_arg_vec = binary_injector::prepare_binary_args(jcp.post_ops, ctx); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + prepare_padded_bias(bias, ctx.get_scratchpad_grantor()); const size_t bia_dt_size = pd()->jcp_.typesize_bia; @@ -176,7 +182,7 @@ void jit_avx512_core_bf16_convolution_fwd_t::execute_forward_2d( // TODO: experiment with g_blocking for perf fine tuning int g_blocking = 1; int nb_groups = jcp.ngroups / g_blocking; - dim_t work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow; + dim_t work_amount = MB * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow; int nthr = jcp.aligned_threads ? jcp.aligned_threads : jcp.nthr; parallel(nthr, [&](const int ithr, const int nthr) { @@ -184,20 +190,20 @@ void jit_avx512_core_bf16_convolution_fwd_t::execute_forward_2d( balance211(work_amount, nthr, ithr, start, end); auto par_conv = jit_conv_call_s(); - size_t src_h_stride = src_d.blk_off(0, 0, 1); - size_t dst_h_stride = dst_d.blk_off(0, 0, 1); + size_t src_h_stride = src_d.blk_off(0, 0, 1) - src_d.off_l(0); + size_t dst_h_stride = dst_d.blk_off(0, 0, 1) - dst_d.off_l(0); size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); int n {0}, gg {0}, occ {0}, oh_s {0}, owb {0}; if (jcp.loop_order == loop_cwgn) nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg, - nb_groups, n, jcp.mb, oh_s, jcp.oh); + nb_groups, n, MB, oh_s, jcp.oh); else if (jcp.loop_order == loop_gncw) - nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, oc_chunks, + nd_iterator_init(start, gg, nb_groups, n, MB, occ, oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); else if (jcp.loop_order == loop_nhwcg) - nd_iterator_init(start, n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, + nd_iterator_init(start, n, MB, oh_s, jcp.oh, owb, jcp.nb_ow, occ, oc_chunks, gg, nb_groups); else assert(!"unsupported loop order"); @@ -256,6 +262,8 @@ void jit_avx512_core_bf16_convolution_fwd_t::execute_forward_2d( par_conv.dst_orig = dst; par_conv.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); + par_conv.oc_off = oc_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block) * sizeof(float); + (*kernel_)(&par_conv); src_w += src_h_stride * jcp.stride_h; @@ -263,13 +271,13 @@ void jit_avx512_core_bf16_convolution_fwd_t::execute_forward_2d( } if (jcp.loop_order == loop_cwgn) nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, gg, - nb_groups, n, jcp.mb, oh_s, jcp.oh); + nb_groups, n, MB, oh_s, jcp.oh); else if (jcp.loop_order == loop_gncw) - nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, occ, + nd_iterator_jump(start, end, gg, nb_groups, n, MB, occ, oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); else if (jcp.loop_order == loop_nhwcg) { ++start; - nd_iterator_step(n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, occ, + nd_iterator_step(n, MB, oh_s, jcp.oh, owb, jcp.nb_ow, occ, oc_chunks, gg, nb_groups); } else assert(!"unsupported loop order"); @@ -287,6 +295,8 @@ void jit_avx512_core_bf16_convolution_fwd_t::execute_forward_3d( const auto post_ops_binary_rhs_arg_vec = binary_injector::prepare_binary_args(jcp.post_ops, ctx); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + prepare_padded_bias(bias, ctx.get_scratchpad_grantor()); const size_t bia_dt_size = pd()->jcp_.typesize_bia; @@ -302,7 +312,7 @@ void jit_avx512_core_bf16_convolution_fwd_t::execute_forward_3d( int g_blocking = 1; int nb_groups = jcp.ngroups / g_blocking; dim_t work_amount - = jcp.mb * nb_groups * oc_chunks * jcp.od * jcp.oh * jcp.nb_ow; + = MB * nb_groups * oc_chunks * jcp.od * jcp.oh * jcp.nb_ow; int nthr = jcp.aligned_threads ? jcp.aligned_threads : jcp.nthr; parallel(nthr, [&](const int ithr, const int nthr) { @@ -310,9 +320,9 @@ void jit_avx512_core_bf16_convolution_fwd_t::execute_forward_3d( balance211(work_amount, nthr, ithr, start, end); auto par_conv = jit_conv_call_s(); - size_t src_d_stride = src_d.blk_off(0, 0, 1); - size_t src_h_stride = src_d.blk_off(0, 0, 0, 1); - size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1); + size_t src_d_stride = src_d.blk_off(0, 0, 1) - src_d.off_l(0); + size_t src_h_stride = src_d.blk_off(0, 0, 0, 1) - src_d.off_l(0); + size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1) - dst_d.off_l(0); size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1); size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1); @@ -320,12 +330,12 @@ void jit_avx512_core_bf16_convolution_fwd_t::execute_forward_3d( if (jcp.loop_order == loop_cwgn) nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg, - nb_groups, n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh); + nb_groups, n, MB, od_s, jcp.od, oh_s, jcp.oh); else if (jcp.loop_order == loop_gncw) - nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, oc_chunks, + nd_iterator_init(start, gg, nb_groups, n, MB, occ, oc_chunks, owb, jcp.nb_ow, od_s, jcp.od, oh_s, jcp.oh); else if (jcp.loop_order == loop_nhwcg) - nd_iterator_init(start, n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh, owb, + nd_iterator_init(start, n, MB, od_s, jcp.od, oh_s, jcp.oh, owb, jcp.nb_ow, occ, oc_chunks, gg, nb_groups); else assert(!"unsupported loop order"); @@ -395,6 +405,8 @@ void jit_avx512_core_bf16_convolution_fwd_t::execute_forward_3d( par_conv.dst_orig = dst; par_conv.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); + par_conv.oc_off = oc_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block) * sizeof(float); + (*kernel_)(&par_conv); src_w += src_h_stride * jcp.stride_h; @@ -402,13 +414,13 @@ void jit_avx512_core_bf16_convolution_fwd_t::execute_forward_3d( } if (jcp.loop_order == loop_cwgn) nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, gg, - nb_groups, n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh); + nb_groups, n, MB, od_s, jcp.od, oh_s, jcp.oh); else if (jcp.loop_order == loop_gncw) - nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, occ, + nd_iterator_jump(start, end, gg, nb_groups, n, MB, occ, oc_chunks, owb, jcp.nb_ow, od_s, jcp.od, oh_s, jcp.oh); else if (jcp.loop_order == loop_nhwcg) { ++start; - nd_iterator_step(n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh, owb, + nd_iterator_step(n, MB, od_s, jcp.od, oh_s, jcp.oh, owb, jcp.nb_ow, occ, oc_chunks, gg, nb_groups); } else assert(!"unsupported loop order"); @@ -422,19 +434,23 @@ void jit_avx512_core_bf16_convolution_bwd_data_t ::execute_backward_data_3d( auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); auto diff_src = CTX_OUT_MEM(char *, DNNL_ARG_DIFF_SRC); + const auto &jcp = pd()->jcp_; + const auto post_ops_binary_rhs_arg_vec + = binary_injector::prepare_binary_args(jcp.post_ops, ctx); + + auto MB = CTX_IN_BATCH(DNNL_ARG_DIFF_DST); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); - const auto &jcp = pd()->jcp_; - parallel(jcp.nthr, [&](const int ithr, const int nthr) { int start {0}, end {0}; int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; // TODO: experiment with g_blocking for perf fine tuning int g_blocking = 1; int nb_groups = jcp.ngroups / g_blocking; - int work_amount = nb_groups * jcp.mb * ic_chunks * jcp.id * jcp.ih; + int work_amount = nb_groups * MB * ic_chunks * jcp.id * jcp.ih; balance211(work_amount, nthr, ithr, start, end); auto par_conv = jit_conv_call_s(); @@ -448,13 +464,13 @@ void jit_avx512_core_bf16_convolution_bwd_data_t ::execute_backward_data_3d( int n {0}, gg {0}, icc {0}, id_s {0}, ih_s {0}; if (jcp.loop_order == loop_cgn) - nd_iterator_init(start, icc, ic_chunks, gg, nb_groups, n, jcp.mb, + nd_iterator_init(start, icc, ic_chunks, gg, nb_groups, n, MB, id_s, jcp.id, ih_s, jcp.ih); else if (jcp.loop_order == loop_gnc) - nd_iterator_init(start, gg, nb_groups, n, jcp.mb, icc, ic_chunks, + nd_iterator_init(start, gg, nb_groups, n, MB, icc, ic_chunks, id_s, jcp.id, ih_s, jcp.ih); else if (jcp.loop_order == loop_nhwcg) - nd_iterator_init(start, n, jcp.mb, id_s, jcp.id, ih_s, jcp.ih, icc, + nd_iterator_init(start, n, MB, id_s, jcp.id, ih_s, jcp.ih, icc, ic_chunks, gg, nb_groups); else assert(!"unsupported loop order"); @@ -562,19 +578,22 @@ void jit_avx512_core_bf16_convolution_bwd_data_t ::execute_backward_data_3d( par_conv.filt = wht_w + kh_lo * wht_h_stride; par_conv.kh_padding = kh_len; par_conv.kd_padding = kd_len; + par_conv.oc_off = ic_idx * (is_dsrc_layout_nxc ? 1 : jcp.ic_block) * sizeof(float); + par_conv.post_ops_binary_rhs_arg_vec + = post_ops_binary_rhs_arg_vec.data(); (*kernel_)(&par_conv); } if (jcp.loop_order == loop_cgn) nd_iterator_jump(start, end, icc, ic_chunks, gg, nb_groups, n, - jcp.mb, id_s, jcp.id, ih_s, jcp.ih); + MB, id_s, jcp.id, ih_s, jcp.ih); else if (jcp.loop_order == loop_gnc) - nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, icc, + nd_iterator_jump(start, end, gg, nb_groups, n, MB, icc, ic_chunks, id_s, jcp.id, ih_s, jcp.ih); else if (jcp.loop_order == loop_nhwcg) { ++start; - nd_iterator_step(n, jcp.mb, id_s, jcp.id, ih_s, jcp.ih, icc, + nd_iterator_step(n, MB, id_s, jcp.id, ih_s, jcp.ih, icc, ic_chunks, gg, nb_groups); } else assert(!"unsupported loop order"); @@ -588,24 +607,28 @@ void jit_avx512_core_bf16_convolution_bwd_data_t ::execute_backward_data( auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); auto diff_src = CTX_OUT_MEM(char *, DNNL_ARG_DIFF_SRC); + const auto &jcp = pd()->jcp_; + const auto post_ops_binary_rhs_arg_vec + = binary_injector::prepare_binary_args(jcp.post_ops, ctx); + + auto MB = CTX_IN_BATCH(DNNL_ARG_DIFF_DST); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); - const auto &jcp = pd()->jcp_; - parallel(jcp.nthr, [&](const int ithr, const int nthr) { int start {0}, end {0}; int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; // TODO: experiment with g_blocking for perf fine tuning int g_blocking = 1; int nb_groups = jcp.ngroups / g_blocking; - int work_amount = nb_groups * jcp.mb * ic_chunks * jcp.ih * jcp.nb_iw; + int work_amount = nb_groups * MB * ic_chunks * jcp.ih * jcp.nb_iw; balance211(work_amount, nthr, ithr, start, end); auto par_conv = jit_conv_call_s(); - size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 1); - size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 1); + size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 1) - diff_src_d.off_l(0); + size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 1) - diff_dst_d.off_l(0); size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); bool is_fast_path = jcp.dilate_h == 0 && jcp.stride_h == 1; @@ -613,12 +636,12 @@ void jit_avx512_core_bf16_convolution_bwd_data_t ::execute_backward_data( int n {0}, gg {0}, icc {0}, ih_s {0}, iwb {0}; if (jcp.loop_order == loop_cwgn) nd_iterator_init(start, icc, ic_chunks, iwb, jcp.nb_iw, gg, - nb_groups, n, jcp.mb, ih_s, jcp.ih); + nb_groups, n, MB, ih_s, jcp.ih); else if (jcp.loop_order == loop_gncw) - nd_iterator_init(start, gg, nb_groups, n, jcp.mb, icc, ic_chunks, + nd_iterator_init(start, gg, nb_groups, n, MB, icc, ic_chunks, iwb, jcp.nb_iw, ih_s, jcp.ih); else if (jcp.loop_order == loop_nhwcg) - nd_iterator_init(start, n, jcp.mb, ih_s, jcp.ih, iwb, jcp.nb_iw, + nd_iterator_init(start, n, MB, ih_s, jcp.ih, iwb, jcp.nb_iw, icc, ic_chunks, gg, nb_groups); else assert(!"unsupported loop order"); @@ -699,19 +722,22 @@ void jit_avx512_core_bf16_convolution_bwd_data_t ::execute_backward_data( par_conv.filt = wht_w + k_lo * wht_h_stride; par_conv.kh_padding = k_len; par_conv.iwb = iwb; + par_conv.oc_off = ic_idx * (is_dsrc_layout_nxc ? 1 : jcp.ic_block) * sizeof(float); + par_conv.post_ops_binary_rhs_arg_vec + = post_ops_binary_rhs_arg_vec.data(); (*kernel_)(&par_conv); } if (jcp.loop_order == loop_cwgn) nd_iterator_jump(start, end, icc, ic_chunks, iwb, jcp.nb_iw, gg, - nb_groups, n, jcp.mb, ih_s, jcp.ih); + nb_groups, n, MB, ih_s, jcp.ih); else if (jcp.loop_order == loop_gncw) - nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, icc, + nd_iterator_jump(start, end, gg, nb_groups, n, MB, icc, ic_chunks, iwb, jcp.nb_iw, ih_s, jcp.ih); else if (jcp.loop_order == loop_nhwcg) { ++start; - nd_iterator_step(n, jcp.mb, ih_s, jcp.ih, iwb, jcp.nb_iw, icc, + nd_iterator_step(n, MB, ih_s, jcp.ih, iwb, jcp.nb_iw, icc, ic_chunks, gg, nb_groups); } else assert(!"unsupported loop order"); diff --git a/src/cpu/x64/jit_avx512_core_bf16_convolution.hpp b/src/cpu/x64/jit_avx512_core_bf16_convolution.hpp index 2735333fa3f..b63bfdfcb31 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_convolution.hpp +++ b/src/cpu/x64/jit_avx512_core_bf16_convolution.hpp @@ -130,11 +130,12 @@ struct jit_avx512_core_bf16_convolution_bwd_data_t : public primitive_t { || expect_data_types(data_type::bf16, data_type::bf16, data_type::undef, data_type::bf16, data_type::undef)) - && attr()->has_default_values() && !has_zero_dim_memory(); + && attr()->has_default_values(primitive_attr_t::skip_mask_t::post_ops) + && !has_zero_dim_memory(); if (!ok) return status::unimplemented; status_t status = jit_avx512_core_bf16_bwd_data_kernel::init_conf( - jcp_, *desc(), diff_src_md_, weights_md_, diff_dst_md_, + jcp_, *desc(), diff_src_md_, weights_md_, diff_dst_md_, *attr(), dnnl_get_max_threads()); return status; } @@ -150,7 +151,8 @@ struct jit_avx512_core_bf16_convolution_bwd_data_t : public primitive_t { status_t init(engine_t *engine) override { CHECK(safe_ptr_assign( - kernel_, new jit_avx512_core_bf16_bwd_data_kernel(pd()->jcp_))); + kernel_, new jit_avx512_core_bf16_bwd_data_kernel( + pd()->jcp_, *pd()->attr()))); return kernel_->create_kernel(); } diff --git a/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.cpp index a5880b6f967..bcc1f7be66e 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.cpp @@ -34,9 +34,9 @@ using namespace Xbyak; using namespace dnnl::impl::utils; jit_avx512_dw_conv_fwd_kernel_bf16::jit_avx512_dw_conv_fwd_kernel_bf16( - const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md) - : jcp(ajcp) { - if (jcp.with_eltwise || jcp.with_binary) { + const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md, const primitive_attr_t& attr) + : jcp(ajcp), attr_(attr) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; @@ -52,10 +52,12 @@ jit_avx512_dw_conv_fwd_kernel_bf16::jit_avx512_dw_conv_fwd_kernel_bf16( use_exact_tail_scalar_bcast}; const static_params_t static_params { this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params + {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique< injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } if (!isa_has_bf16(jcp.isa)) bf16_emu_ = utils::make_unique(this, @@ -202,7 +204,17 @@ static void iterate(const int ur_ch_blocks, const int ur_w, const F &f) { void jit_avx512_dw_conv_fwd_kernel_bf16::apply_postops( int ur_ch_blocks, int ur_w, bool last_ch_block_flag) { - if (this->jcp.with_eltwise || this->jcp.with_binary) { + if (this->jcp.with_eltwise || this->jcp.with_binary || this->jcp.with_depthwise || this->jcp.with_quantization) { + std::map vmm_idx_off; + iterate(ur_ch_blocks, ur_w, [&](int ch, int ow, int) { + vmm_idx_off.insert({get_acc_reg_idx(ch * ur_w + ow), ch * jcp.ch_block * sizeof(float)}); + }); + + depthwise_injector::dynamic_params_t ddp {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off, + this->rsp, base_post_ops_data_offset}; + quantization_injector::dynamic_params_t qdp {ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off, jcp.dst_dt, + this->rsp, base_post_ops_data_offset}; injector_utils::vmm_index_set_t vmm_idxs; if (jcp.with_binary) { @@ -244,20 +256,20 @@ void jit_avx512_dw_conv_fwd_kernel_bf16::apply_postops( jmp(postops_done, T_NEAR); L(postops_no_tail); postops_injector_->compute_vector_range( - vmm_idxs, rhs_arg_params); + vmm_idxs, rhs_arg_params, ddp, qdp); } else if (last_ch_block_flag) postops_injector_->compute_vector_range( - vmm_idxs, rhs_arg_params_tail); + vmm_idxs, rhs_arg_params_tail, ddp, qdp); else /* if (!last_ch_block_flag) */ postops_injector_->compute_vector_range( - vmm_idxs, rhs_arg_params); + vmm_idxs, rhs_arg_params, ddp, qdp); L(postops_done); } else { iterate(ur_ch_blocks, ur_w, [&](int ch, int ow, int) { vmm_idxs.emplace(get_acc_reg_idx(ch * ur_w + ow)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t(), ddp, qdp); } } } @@ -414,7 +426,11 @@ void jit_avx512_dw_conv_fwd_kernel_bf16::compute_loop( push(reg_kernel); push(reg_input); push(reg_output); - if (jcp.with_bias) push(reg_bias); + base_post_ops_data_offset += 3 * reg64_size; + if (jcp.with_bias) { + push(reg_bias); + base_post_ops_data_offset += reg64_size; + } if (nb_ch >= jcp.nb_ch_blocking) { if (nb_ch_blocking_tail) { @@ -442,10 +458,14 @@ void jit_avx512_dw_conv_fwd_kernel_bf16::compute_loop( compute(nb_ch_blocking_tail, masked_ch_block_tail); L(skip_ch_tail_label); } - if (jcp.with_bias) pop(reg_bias); + if (jcp.with_bias) { + pop(reg_bias); + base_post_ops_data_offset -= reg64_size; + } pop(reg_output); pop(reg_input); pop(reg_kernel); + base_post_ops_data_offset -= reg64_size; } else { compute(ur_ch_blocks, masked_ch_block_tail); @@ -524,6 +544,9 @@ void jit_avx512_dw_conv_fwd_kernel_bf16::loop_ow(int ur_ch_blocks) { void jit_avx512_dw_conv_fwd_kernel_bf16::generate() { this->preamble(); + if (postops_injector_) + postops_injector_->push_post_ops_data_on_stack(this->param1, GET_OFF(post_ops_binary_rhs_arg_vec), reg_input, reg_output); + assert(mayiuse(avx512_core)); if (jcp.is_fused_conv) { mov(reg_input_buffer_ptr, ptr[this->param1 + GET_OFF(src)]); @@ -605,6 +628,9 @@ void jit_avx512_dw_conv_fwd_kernel_bf16::generate() { L(exit_label); } + if (postops_injector_) + postops_injector_->reset_stack_pointer(); + postamble(); if (jcp.with_eltwise) postops_injector_->prepare_table(); @@ -701,6 +727,32 @@ inline void jit_avx512_dw_conv_bwd_data_kernel_bf16::apply_filter( L(iter_exit_label); } +void jit_avx512_dw_conv_bwd_data_kernel_bf16::apply_postprocess(int ur_ch_blocks, int ur_str_) { + const auto& p = attr_.post_ops_; + std::size_t post_ops_data_offset = 0; + int depthwise_inj_idx = 0; + for (int i = 0; i < p.len(); i++) { + auto& post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + mov(reg_d_weights, ptr[this->rsp + base_post_ops_data_offset + post_ops_data_offset]); + add(reg_d_weights, ptr[this->param1 + GET_OFF(ic_off)]); + + for (int ch = 0; ch < ur_ch_blocks; ch++) { + int start_idx = get_acc_reg(ur_str_ * ch).getIdx(); + int end_idx = get_acc_reg(ur_str_ * ch + ur_str_).getIdx(); + + depthwise_injectors[depthwise_inj_idx]->compute_vector_range( + start_idx, end_idx, reg_d_weights, reg_d_weights); + + add(reg_d_weights, jcp.ch_block * sizeof(float)); + } + + post_ops_data_offset += depthwise_injectors[depthwise_inj_idx]->memoryStep(); + depthwise_inj_idx++; + } + } +} + inline void jit_avx512_dw_conv_bwd_data_kernel_bf16::store_dsrc( int ur_ch_blocks, int ur_str_w, bool last_ch_block_flag) { int ch_blk = jcp.ch_block; @@ -758,6 +810,7 @@ inline void jit_avx512_dw_conv_bwd_data_kernel_bf16::ch_loop_body( load_ddst(ur_ch_blocks, unroll_w); apply_filter(ur_ch_blocks, unroll_w, is_last_ch); + apply_postprocess(ur_ch_blocks, unroll_w); store_dsrc(ur_ch_blocks, unroll_w, is_last_ch); }; @@ -776,6 +829,7 @@ inline void jit_avx512_dw_conv_bwd_data_kernel_bf16::ch_loop_body( const size_t data_ch_stride = (size_t)jcp.nb_ch_blocking * jcp.ch_block; mov(aux_reg_ch_blocks, reg_ch_blocks); + base_post_ops_data_offset += 3 * reg64_size; push(reg_dsrc); push(reg_ddst); push(reg_kernel); @@ -812,6 +866,7 @@ inline void jit_avx512_dw_conv_bwd_data_kernel_bf16::ch_loop_body( pop(reg_kernel); pop(reg_ddst); pop(reg_dsrc); + base_post_ops_data_offset -= 3 * reg64_size; } else { call_compute_body(ur_ch_blocks, unroll_w, jcp.ch_tail); @@ -849,7 +904,39 @@ inline void jit_avx512_dw_conv_bwd_data_kernel_bf16::unroll_width_body( void jit_avx512_dw_conv_bwd_data_kernel_bf16::generate() { assert(is_dsrc_layout_nxc() == is_ddst_layout_nxc()); + const auto& p = attr_.post_ops_; + for (int i = 0; i < p.len(); i++) { + auto& post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32( + this, + post_op + )); + } + } + preamble(); + + std::size_t post_ops_pointers_count = 0; + for (int i = 0; i < p.len(); i++) { + if (p.entry_[i].is_depthwise() || p.entry_[i].is_quantization()) { + post_ops_pointers_count++; + } + } + + if (post_ops_pointers_count != 0) { + sub(rsp, post_ops_pointers_count * sizeof(float *)); + + auto aux_reg0 = reg_dsrc; + auto aux_reg1 = reg_ddst; + + mov(aux_reg0, ptr[this->param1 + GET_OFF(post_ops_binary_rhs_arg_vec)]); + for (size_t i = 0; i < post_ops_pointers_count; i++) { + mov(aux_reg1, ptr[aux_reg0 + i * sizeof(float *)]); + mov(ptr[rsp + i * sizeof(float *)], aux_reg1); + } + } + mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]); mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]); mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); @@ -888,6 +975,11 @@ void jit_avx512_dw_conv_bwd_data_kernel_bf16::generate() { int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking; if (ch_blocks_tail) { ch_blocks_loop(ch_blocks_tail); } } + + if (post_ops_pointers_count != 0) { + add(rsp, post_ops_pointers_count * sizeof(float *)); + } + postamble(); } #undef GET_OFF diff --git a/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.hpp b/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.hpp index 7376d017f76..00be507e7e2 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.hpp +++ b/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.hpp @@ -25,6 +25,7 @@ #include "cpu/x64/jit_primitive_conf.hpp" #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" +#include "cpu/x64/injectors/jit_uni_depthwise_injector.hpp" namespace dnnl { namespace impl { @@ -35,9 +36,10 @@ struct jit_avx512_dw_conv_fwd_kernel_bf16 : public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_dw_conv_fwd_kernel_bf16) jit_avx512_dw_conv_fwd_kernel_bf16( - const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md); + const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md, const primitive_attr_t& attr); jit_conv_conf_t jcp; + const primitive_attr_t& attr_; private: using reg64_t = const Xbyak::Reg64; @@ -71,6 +73,14 @@ struct jit_avx512_dw_conv_fwd_kernel_bf16 : public jit_generator { mask_t ktail_mask = k_oc_tail_mask; mask_t k_ch_tail_mask_extended = Xbyak::Opmask(3); + reg64_t reg_d_weights = abi_not_param1; + reg64_t reg_d_bias = iter_kh; + int base_post_ops_data_offset = 0; + constexpr static int reg64_size = 8; + + Xbyak::Zmm zmm_d_weights = Xbyak::Zmm(31); + Xbyak::Zmm zmm_d_bias = Xbyak::Zmm(30); + Xbyak::Zmm zmm_ker_reg = Xbyak::Zmm(0); Xbyak::Zmm zmm_src_reg = Xbyak::Zmm(1); Xbyak::Zmm zmm_prev_dst = Xbyak::Zmm(31); @@ -130,8 +140,8 @@ struct jit_avx512_dw_conv_fwd_kernel_bf16 : public jit_generator { struct jit_avx512_dw_conv_bwd_data_kernel_bf16 : public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_dw_conv_bwd_data_kernel_bf16) - jit_avx512_dw_conv_bwd_data_kernel_bf16(const jit_conv_conf_t &ajcp) - : jcp(ajcp), bf16_emu_(nullptr) { + jit_avx512_dw_conv_bwd_data_kernel_bf16(const jit_conv_conf_t &ajcp, const primitive_attr_t& attr) + : jcp(ajcp), attr_(attr), bf16_emu_(nullptr) { if (!isa_has_bf16(jcp.isa)) bf16_emu_ = new bf16_emulation_t(this, bf16_emu_reserv_1, @@ -139,10 +149,18 @@ struct jit_avx512_dw_conv_bwd_data_kernel_bf16 : public jit_generator { bf16_emu_reserv_5, bf16_emu_reserv_6); } - ~jit_avx512_dw_conv_bwd_data_kernel_bf16() { delete bf16_emu_; } + ~jit_avx512_dw_conv_bwd_data_kernel_bf16() { + for (auto inj : depthwise_injectors) + delete inj; + depthwise_injectors.clear(); + + delete bf16_emu_; + } jit_conv_conf_t jcp; + const primitive_attr_t& attr_; + private: using reg64_t = const Xbyak::Reg64; @@ -177,6 +195,11 @@ struct jit_avx512_dw_conv_bwd_data_kernel_bf16 : public jit_generator { reg64_t reg_tmp = r15; Xbyak::Opmask k_ch_tail_mask = Xbyak::Opmask(1); + reg64_t reg_d_weights = r15; + reg64_t reg_d_bias = iter_kh; + int base_post_ops_data_offset = 0; + constexpr static int reg64_size = 8; + Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(26); Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(27); Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(28); @@ -186,10 +209,13 @@ struct jit_avx512_dw_conv_bwd_data_kernel_bf16 : public jit_generator { bf16_emulation_t *bf16_emu_; + nstl::vector*> depthwise_injectors; + inline void ch_loop_body(int ur_ch_blocks, int unroll_w); inline void unroll_width_body(int ur_ch_blocks); inline void load_ddst(int ur_ch_blocks, int ur_str_w); inline void apply_filter(int ur_ch_blocks, int ur_str_w, bool is_last_ch); + inline void apply_postprocess(int ur_ch_blocks, int ur_str_w); inline void store_dsrc(int ur_ch_blocks, int ur_str_w, bool is_last_ch); void generate() override; diff --git a/src/cpu/x64/jit_avx512_core_f32_wino_conv_2x3.cpp b/src/cpu/x64/jit_avx512_core_f32_wino_conv_2x3.cpp index 253bfd6d0dc..7bec6014291 100644 --- a/src/cpu/x64/jit_avx512_core_f32_wino_conv_2x3.cpp +++ b/src/cpu/x64/jit_avx512_core_f32_wino_conv_2x3.cpp @@ -840,7 +840,7 @@ jit_avx512_core_f32_wino_conv_2x3_fwd_t:: void jit_avx512_core_f32_wino_conv_2x3_fwd_t::execute_forward_mbN( const float *src, const float *wei, const float *bia, float *dst, - const memory_tracking::grantor_t &scratchpad) const { + const memory_tracking::grantor_t &scratchpad, int MB) const { const auto &jcp = kernel_->jcp; const auto &oscales = pd()->attr()->output_scales_; @@ -861,7 +861,7 @@ void jit_avx512_core_f32_wino_conv_2x3_fwd_t::execute_forward_mbN( auto ptr_V = scratchpad.get(key_wino_V); auto ptr_M = scratchpad.get(key_wino_M); - parallel_nd_ext(jcp.nthr, jcp.mb, div_up(jcp.oh, jcp.yb), + parallel_nd_ext(jcp.nthr, MB, div_up(jcp.oh, jcp.yb), div_up(jcp.ow, jcp.xb), [&](dim_t ithr, dim_t nthr, dim_t mb, dim_t tile_y_b, dim_t tile_x_b) { @@ -971,7 +971,7 @@ void jit_avx512_core_f32_wino_conv_2x3_fwd_t::execute_forward_mbN( void jit_avx512_core_f32_wino_conv_2x3_fwd_t::execute_forward_small_mb( const float *src, const float *wei, const float *bia, float *dst, - const memory_tracking::grantor_t &scratchpad) const { + const memory_tracking::grantor_t &scratchpad, int MB) const { const auto &jcp = kernel_->jcp; const auto &oscales = pd()->attr()->output_scales_; @@ -986,7 +986,7 @@ void jit_avx512_core_f32_wino_conv_2x3_fwd_t::execute_forward_small_mb( auto ptr_V = scratchpad.get(key_wino_V); auto ptr_M = scratchpad.get(key_wino_M); - for_(int mb = 0; mb < jcp.mb; mb++) + for_(int mb = 0; mb < MB; mb++) for_(int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) { /* transformation of input tensor to winograd domain */ diff --git a/src/cpu/x64/jit_avx512_core_f32_wino_conv_2x3.hpp b/src/cpu/x64/jit_avx512_core_f32_wino_conv_2x3.hpp index c0a63022054..7864c350256 100644 --- a/src/cpu/x64/jit_avx512_core_f32_wino_conv_2x3.hpp +++ b/src/cpu/x64/jit_avx512_core_f32_wino_conv_2x3.hpp @@ -117,12 +117,14 @@ struct jit_avx512_core_f32_wino_conv_2x3_fwd_t : public primitive_t { auto bia = CTX_IN_MEM(const float *, DNNL_ARG_BIAS); auto dst = CTX_OUT_MEM(float *, DNNL_ARG_DST); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + if (pd()->jcp_.small_mb) execute_forward_small_mb( - src, wei, bia, dst, ctx.get_scratchpad_grantor()); + src, wei, bia, dst, ctx.get_scratchpad_grantor(), MB); else execute_forward_mbN( - src, wei, bia, dst, ctx.get_scratchpad_grantor()); + src, wei, bia, dst, ctx.get_scratchpad_grantor(), MB); return status::success; } @@ -130,10 +132,10 @@ struct jit_avx512_core_f32_wino_conv_2x3_fwd_t : public primitive_t { private: void execute_forward_small_mb(const float *src, const float *wei, const float *bia, float *dst, - const memory_tracking::grantor_t &scratchpad) const; + const memory_tracking::grantor_t &scratchpad, int MB) const; void execute_forward_mbN(const float *src, const float *wei, const float *bia, float *dst, - const memory_tracking::grantor_t &scratchpad) const; + const memory_tracking::grantor_t &scratchpad, int MB) const; const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } std::unique_ptr kernel_; diff --git a/src/cpu/x64/jit_avx512_core_f32_wino_conv_4x3.cpp b/src/cpu/x64/jit_avx512_core_f32_wino_conv_4x3.cpp index 99de1ffb69e..98eead82855 100644 --- a/src/cpu/x64/jit_avx512_core_f32_wino_conv_4x3.cpp +++ b/src/cpu/x64/jit_avx512_core_f32_wino_conv_4x3.cpp @@ -218,7 +218,7 @@ void _jit_avx512_core_f32_wino_conv_4x3_t::input_transform_data( template void _jit_avx512_core_f32_wino_conv_4x3_t< is_fwd>::input_transform_tileblock_data(int tile_block, - const jit_conv_winograd_conf_t &jcp, float *inp, float *tinp) const { + const jit_conv_winograd_conf_t &jcp, float *inp, float *tinp, int MB) const { float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f}; float Iw[alpha][alpha][simd_w]; @@ -229,7 +229,7 @@ void _jit_avx512_core_f32_wino_conv_4x3_t< const int inpw = is_fwd ? jcp.iw : jcp.ow; array_offset_calculator input( - inp, jcp.mb, jcp.dimK / simd_w, inph, inpw, simd_w); + inp, MB, jcp.dimK / simd_w, inph, inpw, simd_w); array_offset_calculator output(tinp, alpha, alpha, jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block); @@ -271,7 +271,7 @@ void _jit_avx512_core_f32_wino_conv_4x3_t< template void _jit_avx512_core_f32_wino_conv_4x3_t::_execute_data_W_S_G_D( float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr, - const memory_tracking::grantor_t &scratchpad) const { + const memory_tracking::grantor_t &scratchpad, int MB) const { const auto &jcp = kernel_->jcp; const auto &p_ops = attr_->post_ops_; @@ -285,9 +285,9 @@ void _jit_avx512_core_f32_wino_conv_4x3_t::_execute_data_W_S_G_D( BWD: dimM:ic, dimN:ntiles, dimK:oc, FWD/BWD: V: src/diff_dst transform, U:weight transform, M:dst/diff_src transform */ - array_offset_calculator input(inp_ptr, jcp.mb, + array_offset_calculator input(inp_ptr, MB, jcp.dimK / jcp.dimK_reg_block, inph, inpw, jcp.dimK_reg_block); - array_offset_calculator output(out_ptr, jcp.mb, + array_offset_calculator output(out_ptr, MB, jcp.dimM / jcp.dimM_simd_block, outh, outw, jcp.dimM_simd_block); array_offset_calculator weights(wei_ptr, jcp.oc / jcp.oc_simd_block, jcp.ic / jcp.ic_simd_block, jcp.kh, @@ -323,7 +323,7 @@ void _jit_avx512_core_f32_wino_conv_4x3_t::_execute_data_W_S_G_D( last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc); } - parallel_nd(jcp.mb, jcp.dimK_nb_block, jcp.dimK_block, + parallel_nd(MB, jcp.dimK_nb_block, jcp.dimK_block, [&](dim_t img, dim_t K_blk1, dim_t K_blk2) { input_transform_data(img, jcp, &(input(img, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, @@ -361,7 +361,7 @@ void _jit_avx512_core_f32_wino_conv_4x3_t::_execute_data_W_S_G_D( K_blk1); }); - parallel_nd(jcp.mb, jcp.dimM_nb_block, + parallel_nd(MB, jcp.dimM_nb_block, (jcp.dimM_block * jcp.dimM_reg_block), [&](dim_t img, dim_t M_blk1, dim_t M_blk2) { const int M_blk @@ -380,7 +380,7 @@ void _jit_avx512_core_f32_wino_conv_4x3_t::_execute_data_W_S_G_D( template void _jit_avx512_core_f32_wino_conv_4x3_t::_execute_data_W_SGD( float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr, - const memory_tracking::grantor_t &scratchpad) const { + const memory_tracking::grantor_t &scratchpad, int MB) const { const auto &jcp = kernel_->jcp; const auto &p_ops = attr_->post_ops_; @@ -389,9 +389,9 @@ void _jit_avx512_core_f32_wino_conv_4x3_t::_execute_data_W_SGD( const int outh = is_fwd ? jcp.oh : jcp.ih; const int outw = is_fwd ? jcp.ow : jcp.iw; - array_offset_calculator input(inp_ptr, jcp.mb, + array_offset_calculator input(inp_ptr, MB, jcp.dimK / jcp.dimK_reg_block, inph, inpw, jcp.dimK_reg_block); - array_offset_calculator output(out_ptr, jcp.mb, + array_offset_calculator output(out_ptr, MB, jcp.dimM / jcp.dimM_simd_block, outh, outw, jcp.dimM_simd_block); array_offset_calculator weights(wei_ptr, jcp.oc / jcp.oc_simd_block, jcp.ic / jcp.ic_simd_block, jcp.kh, @@ -456,7 +456,7 @@ void _jit_avx512_core_f32_wino_conv_4x3_t::_execute_data_W_SGD( input_transform_tileblock_data(tile_block, jcp, &(input(0, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)), - &(V(ithr, 0, 0, 0, K_blk1, K_blk2, 0, 0))); + &(V(ithr, 0, 0, 0, K_blk1, K_blk2, 0, 0)), MB); } } diff --git a/src/cpu/x64/jit_avx512_core_f32_wino_conv_4x3.hpp b/src/cpu/x64/jit_avx512_core_f32_wino_conv_4x3.hpp index 335493f64d7..7a46caf78ac 100644 --- a/src/cpu/x64/jit_avx512_core_f32_wino_conv_4x3.hpp +++ b/src/cpu/x64/jit_avx512_core_f32_wino_conv_4x3.hpp @@ -92,7 +92,7 @@ struct _jit_avx512_core_f32_wino_conv_4x3_t { void input_transform_data(int image, const jit_conv_winograd_conf_t &jcp, float *inp, float *tinp) const; void input_transform_tileblock_data(int tile_block, - const jit_conv_winograd_conf_t &jcp, float *inp, float *tinp) const; + const jit_conv_winograd_conf_t &jcp, float *inp, float *tinp, int MB) const; void output_transform_data(int image, const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias) const; @@ -101,10 +101,10 @@ struct _jit_avx512_core_f32_wino_conv_4x3_t { float *toutp, float *outp, float *bias) const; void _execute_data_W_S_G_D(float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr, - const memory_tracking::grantor_t &scratchpad) const; + const memory_tracking::grantor_t &scratchpad, int MB) const; void _execute_data_W_SGD(float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr, - const memory_tracking::grantor_t &scratchpad) const; + const memory_tracking::grantor_t &scratchpad, int MB) const; std::unique_ptr<_jit_avx512_core_f32_wino_conv_4x3_data_kernel> kernel_; const primitive_attr_t *attr_; @@ -179,16 +179,18 @@ struct jit_avx512_core_f32_wino_conv_4x3_fwd_t auto bias = CTX_IN_MEM(const float *, DNNL_ARG_BIAS); auto dst = CTX_OUT_MEM(float *, DNNL_ARG_DST); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + auto scratchpad = ctx.get_scratchpad_grantor(); switch ((pd()->jcp_).sched_policy) { case WSCHED_DATA_W_S_G_D: this->_execute_data_W_S_G_D((float *)src, dst, (float *)weights, - (float *)bias, scratchpad); + (float *)bias, scratchpad, MB); break; case WSCHED_DATA_W_SGD: this->_execute_data_W_SGD((float *)src, dst, (float *)weights, - (float *)bias, scratchpad); + (float *)bias, scratchpad, MB); break; default: break; } @@ -264,17 +266,19 @@ struct jit_avx512_core_f32_wino_conv_4x3_bwd_data_t auto weights = CTX_IN_MEM(const float *, DNNL_ARG_WEIGHTS); auto diff_src = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_SRC); + auto MB = CTX_IN_BATCH(DNNL_ARG_DIFF_DST); + auto scratchpad = ctx.get_scratchpad_grantor(); switch ((pd()->jcp_).sched_policy) { case WSCHED_DATA_W_S_G_D: this->_execute_data_W_S_G_D((float *)diff_dst, diff_src, - (float *)weights, NULL, scratchpad); + (float *)weights, NULL, scratchpad, MB); break; case WSCHED_DATA_W_SGD: this->_execute_data_W_SGD((float *)diff_dst, diff_src, - (float *)weights, NULL, scratchpad); + (float *)weights, NULL, scratchpad, MB); break; default: break; diff --git a/src/cpu/x64/jit_avx512_core_fork_bf16_dw_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_fork_bf16_dw_conv_kernel.cpp new file mode 100644 index 00000000000..78cca9dcd21 --- /dev/null +++ b/src/cpu/x64/jit_avx512_core_fork_bf16_dw_conv_kernel.cpp @@ -0,0 +1,803 @@ +/******************************************************************************* +* Copyright 2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "common/c_types_map.hpp" +#include "common/nstl.hpp" +#include "common/type_helpers.hpp" +#include "common/utils.hpp" + +#include "jit_avx512_core_fork_bf16_dw_conv_kernel.hpp" + +#define GET_OFF(field) offsetof(jit_conv_call_s, field) + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +using namespace Xbyak; + +void jit_avx512_fork_dw_conv_fwd_kernel_bf16::load_src(int ur_ch_blocks, int ur_w, bool last_ch_block_flag) { + const auto dst_layout_nxc = is_dst_layout_nxc(); + const auto ch_blk = jcp.ch_block; + const auto ocb_stride = dst_layout_nxc ? ch_blk : jcp.oh * jcp.ow * ch_blk; + const auto ow_stride = dst_layout_nxc ? jcp.ngroups : ch_blk; + + for (int ch = 0; ch < ur_ch_blocks; ch++) { + const bool mask_flag = last_ch_block_flag && ch == ur_ch_blocks - 1; + for (int ow = 0; ow < ur_w; ow++) { + Zmm zmm_acc = get_acc_reg(ch * ur_w + ow); + const Zmm zmm_acc_msk + = mask_flag ? zmm_acc | ktail_mask | T_z : zmm_acc; + + if (this->jcp.with_bias) { + int b_off = ch * ch_blk; + uni_vmovups(zmm_acc_msk, vmmword[reg_bias + b_off * sizeof(float)]); + } else { + uni_vpxor(zmm_acc, zmm_acc, zmm_acc); + } + if (this->jcp.with_sum) { + int o_off = ch * ocb_stride + ow * ow_stride; + if (jcp.dst_dt == data_type::bf16) { + const Zmm zmm_prev_dst_msk = mask_flag + ? zmm_prev_dst | ktail_mask | T_z + : zmm_prev_dst; + vpmovzxwd(zmm_prev_dst_msk, + vmmword[reg_output + o_off * jcp.typesize_out]); + vpslld(zmm_prev_dst, zmm_prev_dst, 16); + vaddps(zmm_acc, zmm_prev_dst); + } else { + uni_vaddps(zmm_acc_msk, zmm_acc_msk, + vmmword[reg_output + o_off * jcp.typesize_out]); + } + } + } + } +} + +void jit_avx512_fork_dw_conv_fwd_kernel_bf16::apply_filter( + int ur_ch_blocks, int ur_w, bool last_ch_block_flag) { + int ch_block = jcp.ch_block; + int dilate_h = jcp.dilate_h + 1; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + + const auto src_layout_nxc = is_src_layout_nxc(); + const auto iw_stride = src_layout_nxc ? jcp.ngroups : ch_block; + const auto ih_stride = jcp.iw * iw_stride; + const auto icb_stride = src_layout_nxc + ? ch_block + : jcp.ih * jcp.iw * ch_block; + + Label iter_exit_label; + + cmp(reg_kh, 0); + je(iter_exit_label, T_NEAR); + cmp(reg_kw, 0); + je(iter_exit_label, T_NEAR); + + mov(iter_kh, reg_kh); + Label kh_label; + push(aux1_reg_kernel); + base_post_ops_data_offset += reg64_size; + L(kh_label); { + mov(iter_kw, reg_kw); + mov(aux1_reg_input, aux_reg_input); + mov(aux1_reg_kernel, aux_reg_kernel); + + Label kw_label; + L(kw_label); { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + const bool mask_flag = last_ch_block_flag && ch == ur_ch_blocks - 1; + int ker_off = ch * jcp.kh * jcp.kw * ch_block; + const Zmm zmm_ker_reg_msk = mask_flag + ? zmm_ker_reg | ktail_mask | T_z + : zmm_ker_reg; + vpmovzxwd(zmm_ker_reg_msk, + ptr[aux1_reg_kernel + ker_off * jcp.typesize_in]); + for (int ow = 0; ow < ur_w; ow++) { + const Zmm zmm_src_reg_msk = mask_flag + ? zmm_src_reg | ktail_mask | T_z + : zmm_src_reg; + Zmm zmm_acc = get_acc_reg(ch * ur_w + ow); + int inp_off = ch * icb_stride + + ow * stride_w * iw_stride; + /* zero-extend bf16 to packed 32-bit int */ + vpmovzxwd(zmm_src_reg_msk, + ptr[aux1_reg_input + inp_off * jcp.typesize_in]); + if (!isa_has_bf16(jcp.isa)) { + bf16_emu_->vdpbf16ps(zmm_acc, zmm_ker_reg, zmm_src_reg); + } else { + vdpbf16ps(zmm_acc, zmm_ker_reg, zmm_src_reg); + } + } + } + add(aux1_reg_kernel, ch_block * jcp.typesize_in); + add(aux1_reg_input, iw_stride * dilate_w * jcp.typesize_in); + + dec(iter_kw); + cmp(iter_kw, 0); + jg(kw_label, T_NEAR); + } + add(aux_reg_kernel, jcp.kw * ch_block * jcp.typesize_in); + add(aux_reg_input, ih_stride * dilate_h * jcp.typesize_in); + + dec(iter_kh); + cmp(iter_kh, 0); + jg(kh_label, T_NEAR); + pop(aux1_reg_kernel); + base_post_ops_data_offset -= reg64_size; + } + + L(iter_exit_label); +} + +void jit_avx512_fork_dw_conv_fwd_kernel_bf16::apply_filter_unrolled( + int ur_ch_blocks, int ur_w, bool last_ch_block_flag) { + int ch_blk = jcp.ch_block; + int dilate_h = jcp.dilate_h + 1; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + + const auto src_layout_nxc = is_src_layout_nxc(); + const auto iw_stride = src_layout_nxc ? jcp.ngroups : ch_blk; + const auto ih_stride = jcp.iw * iw_stride; + const auto icb_stride = src_layout_nxc + ? ch_blk + : jcp.ih * jcp.iw * ch_blk; + + Label iter_exit_label; + + cmp(reg_kh, 0); + je(iter_exit_label, T_NEAR); + + mov(iter_kh, reg_kh); + Label kh_label; + L(kh_label); { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + const bool mask_flag = last_ch_block_flag && ch == ur_ch_blocks - 1; + for (int kw = 0; kw < jcp.kw; kw++) { + int ker_off = ch * jcp.kh * jcp.kw * ch_blk + kw * ch_blk; + const Zmm zmm_ker_reg_msk = mask_flag + ? zmm_ker_reg | ktail_mask | T_z + : zmm_ker_reg; + + vpmovzxwd(zmm_ker_reg_msk, + ptr[aux_reg_kernel + ker_off * jcp.typesize_in]); + for (int ow = 0; ow < ur_w; ow++) { + const Zmm zmm_src_reg_msk = mask_flag + ? zmm_src_reg | ktail_mask | T_z + : zmm_src_reg; + Zmm zmm_acc = get_acc_reg(ch * ur_w + ow); + int inp_off = ch * icb_stride + + ow * stride_w * iw_stride + kw * dilate_w * iw_stride; + /* zero-extend bf16 to packed 32-bit int */ + vpmovzxwd(zmm_src_reg_msk, + ptr[aux_reg_input + inp_off * jcp.typesize_in]); + if (!isa_has_bf16(jcp.isa)) { + bf16_emu_->vdpbf16ps(zmm_acc, zmm_ker_reg, zmm_src_reg); + } else { + vdpbf16ps(zmm_acc, zmm_ker_reg, zmm_src_reg); + } + } + } + } + + add(aux_reg_kernel, jcp.kw * ch_blk * jcp.typesize_in); + add(aux_reg_input, ih_stride * dilate_h * jcp.typesize_in); + + dec(iter_kh); + cmp(iter_kh, 0); + jg(kh_label, T_NEAR); + } + + L(iter_exit_label); +} + +void jit_avx512_fork_dw_conv_fwd_kernel_bf16::apply_postprocess( + int ur_ch_blocks, int ur_w) { + int eltwise_inj_idx = 0; + int depthwise_inj_idx = 0; + std::size_t post_ops_data_offset = 0; + const auto& p = attr_.post_ops_; + + for (int i = 0; i < p.len(); i++) { + auto& post_op = p.entry_[i]; + if (post_op.is_eltwise()) { + int start_idx = get_acc_reg(0).getIdx(); + int end_idx = get_acc_reg(ur_w * ur_ch_blocks).getIdx(); + + eltwise_injectors[eltwise_inj_idx]->compute_vector_range(start_idx, end_idx); + eltwise_inj_idx++; + } else if (post_op.is_depthwise()) { + push(aux_reg_blocks_offset); + base_post_ops_data_offset += reg64_size; + add(aux_reg_blocks_offset, ptr[this->param1 + GET_OFF(oc_off)]); //add offset of processed blocks + + mov(reg_d_weights, ptr[this->rsp + base_post_ops_data_offset + post_ops_data_offset]); + add(reg_d_weights, aux_reg_blocks_offset); + + for (int ch = 0; ch < ur_ch_blocks; ch++) { + int start_idx = get_acc_reg(ur_w * ch).getIdx(); + int end_idx = get_acc_reg(ur_w * ch + ur_w).getIdx(); + + depthwise_injectors[depthwise_inj_idx]->compute_vector_range( + start_idx, end_idx, reg_d_weights, reg_d_weights); + + add(reg_d_weights, jcp.ch_block * sizeof(float)); + } + pop(aux_reg_blocks_offset); + base_post_ops_data_offset -= reg64_size; + + post_ops_data_offset += depthwise_injectors[depthwise_inj_idx]->memoryStep(); + depthwise_inj_idx++; + } + } +} + +void jit_avx512_fork_dw_conv_fwd_kernel_bf16::store_dst(int ur_ch_blocks, int ur_w, bool last_ch_block_flag) { + const auto dst_layout_nxc = is_dst_layout_nxc(); + const auto ch_blk = jcp.ch_block; + const auto ocb_stride = dst_layout_nxc ? ch_blk : jcp.oh * jcp.ow * ch_blk; + const auto ow_stride = dst_layout_nxc ? jcp.ngroups : ch_blk; + + if (jcp.dst_dt == data_type::bf16 && (!isa_has_bf16(jcp.isa))) + bf16_emu_->init_vcvtneps2bf16(); + + if (dst_layout_nxc && jcp.dst_dt == data_type::bf16 + && isa_has_bf16(jcp.isa)) { + for (int j = 0; j < ur_w; ++j) { + int n_2bf2ps = (ur_ch_blocks / 2) * 2; + int ch = 0; + for (; ch < n_2bf2ps; ch += 2) { + size_t aux_output_offset + = (size_t)ch * ocb_stride + j * ow_stride; + auto addr = ptr[reg_output + + aux_output_offset * jcp.typesize_out]; + auto zmm_dst = get_acc_reg(ch * ur_w + j); + vcvtne2ps2bf16( + zmm_dst, get_acc_reg((ch + 1) * ur_w + j), zmm_dst); + bool mask_flag = last_ch_block_flag && ch + 2 == ur_ch_blocks; + Zmm zmm_dst_msk = mask_flag ? zmm_dst | k_ch_tail_mask_extended + : zmm_dst; + vmovdqu16(addr, zmm_dst_msk); + } + /* Perform tail write for odd ch sizes */ + if (ch < ur_ch_blocks) { + size_t aux_output_offset + = (size_t) ch * ocb_stride + j * ow_stride; + auto addr = ptr[reg_output + + aux_output_offset * jcp.typesize_out]; + auto zmm_dst = get_acc_reg(ch * ur_w + j); + auto ymm_dst = Ymm(zmm_dst.getIdx()); + vcvtneps2bf16(ymm_dst, zmm_dst); + Ymm ymm_dst_msk = last_ch_block_flag ? ymm_dst | ktail_mask : ymm_dst; + vmovdqu16(addr, ymm_dst_msk); + } + } + } else { + // also used for case when dst_layout_nxc && dst.dt == f32 + if (jcp.dst_dt == data_type::f32) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + bool mask_flag = last_ch_block_flag && ch == ur_ch_blocks - 1; + for (int ow = 0; ow < ur_w; ow++) { + int o_off = ch * ocb_stride + ow * ow_stride; + Zmm zmm_dst = get_acc_reg(ch * ur_w + ow); + Zmm zmm_dst_msk = mask_flag ? zmm_dst | ktail_mask : zmm_dst; + vmovups(vmmword[reg_output + o_off * jcp.typesize_out], + zmm_dst_msk); + } + } + } else if (jcp.dst_dt == data_type::bf16) { + if (isa_has_bf16(jcp.isa)) { // !dst_layout_nxc() + assert(jcp.ngroups % jcp.ch_block == 0); + for (int ch = 0; ch < ur_ch_blocks; ch++) { + int n_2bf2ps = (ur_w / 2) * 2; + int j = 0; + for (; j < n_2bf2ps; j += 2) { + size_t aux_output_offset + = (size_t)ch * ocb_stride + j * ow_stride; + auto addr = ptr[reg_output + + aux_output_offset * jcp.typesize_out]; + auto zmm_dst = get_acc_reg(ch * ur_w + j); + vcvtne2ps2bf16(zmm_dst, get_acc_reg(ch * ur_w + j + 1), + get_acc_reg(ch * ur_w + j)); + vmovups(addr, zmm_dst); + } + /* Perform tail write for odd ur_w sizes */ + if (j < ur_w) { + size_t aux_output_offset + = (size_t)ch * ocb_stride + j * ow_stride; + auto addr = ptr[reg_output + + aux_output_offset * jcp.typesize_out]; + auto zmm_dst = get_acc_reg(ch * ur_w + j); + auto ymm_dst = Ymm(zmm_dst.getIdx()); + vcvtneps2bf16(ymm_dst, zmm_dst); + vmovups(addr, ymm_dst); + } + } + } else { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + bool mask_flag = last_ch_block_flag && ch == ur_ch_blocks - 1; + for (int ow = 0; ow < ur_w; ow++) { + int o_off = ch * ocb_stride + ow * ow_stride; + Zmm zmm_dst = get_acc_reg(ch * ur_w + ow); + + /* down-convert f32 output to bf16 */ + auto ymm_dst = Ymm(zmm_dst.getIdx()); + bf16_emu_->vcvtneps2bf16(ymm_dst, zmm_dst); + + Ymm ymm_dst_msk = mask_flag ? ymm_dst | ktail_mask : ymm_dst; + vmovdqu16(ptr[reg_output + o_off * jcp.typesize_out], ymm_dst_msk); + } + } + } + } else + assert(!"unsupported destination type"); + } +} + +void jit_avx512_fork_dw_conv_fwd_kernel_bf16::compute_loop(int ur_w, int ur_ch_blocks) { + const bool ch_loop = ur_ch_blocks > jcp.nb_ch_blocking; + // ch_loop currently happen only when data layout is nxc. The strides are + // calculated for this layout only. + const size_t wei_ch_stride = (size_t)jcp.nb_ch_blocking * jcp.kd * jcp.kh * jcp.kw + * jcp.ch_block * jcp.typesize_in; + const size_t inp_ch_stride + = (size_t)jcp.nb_ch_blocking * jcp.ch_block * jcp.typesize_in; + const size_t out_ch_stride + = (size_t)jcp.nb_ch_blocking * jcp.ch_block * jcp.typesize_out; + const size_t bias_stride + = (size_t)jcp.nb_ch_blocking * jcp.ch_block * sizeof(float); + + auto compute = [&](int ur_ch_blocks, bool last_ch_block_flag = false) { + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + + load_src(ur_ch_blocks, ur_w, last_ch_block_flag); + if (ur_w == 1) { + apply_filter(ur_ch_blocks, ur_w, last_ch_block_flag); + } else { + apply_filter_unrolled(ur_ch_blocks, ur_w, last_ch_block_flag); + } + apply_postprocess(ur_ch_blocks, ur_w); + store_dst(ur_ch_blocks, ur_w, last_ch_block_flag); + }; + + const bool masked_ch_block_tail = jcp.oc % jcp.ch_block != 0; + + xor_(aux_reg_blocks_offset, aux_reg_blocks_offset); + + if (ch_loop) { + Label ch_loop_label, ch_tail_label, skip_ch_tail_label; + const int nb_ch = jcp.oc / jcp.ch_block; + const int nb_ch_blocking_tail = jcp.nb_ch - utils::rnd_dn(nb_ch, jcp.nb_ch_blocking); + const int ch_step = jcp.nb_ch_blocking * jcp.ch_block; + + push(aux_reg_ch_blocks); + mov(aux_reg_ch_blocks, reg_ch_blocks); + push(reg_kernel); + push(reg_input); + push(reg_output); + base_post_ops_data_offset += 4 * reg64_size; + if (jcp.with_bias) { + push(reg_bias); + base_post_ops_data_offset += reg64_size; + } + + if (nb_ch >= jcp.nb_ch_blocking) { + if (nb_ch_blocking_tail) { + cmp(aux_reg_ch_blocks, ch_step); + jl(ch_tail_label, T_NEAR); + } + + L(ch_loop_label); + { + compute(jcp.nb_ch_blocking); + add(reg_kernel, wei_ch_stride); + add(reg_input, inp_ch_stride); + add(reg_output, out_ch_stride); + if (jcp.with_bias) add(reg_bias, bias_stride); + sub(aux_reg_ch_blocks, ch_step); + add(aux_reg_blocks_offset, ch_step * sizeof(float)); //add initial offset of processed blocks + cmp(aux_reg_ch_blocks, ch_step); + jge(ch_loop_label, T_NEAR); + } + } + + if (nb_ch_blocking_tail) { + // ch work range [1, jcp.nb_ch_blocking * ch_block) + L(ch_tail_label); + cmp(aux_reg_ch_blocks, 0); + jle(skip_ch_tail_label, T_NEAR); + compute(nb_ch_blocking_tail, masked_ch_block_tail); + L(skip_ch_tail_label); + } + + if (jcp.with_bias) { + pop(reg_bias); + base_post_ops_data_offset -= reg64_size; + } + pop(reg_output); + pop(reg_input); + pop(reg_kernel); + pop(aux_reg_ch_blocks); + base_post_ops_data_offset -= 4 * reg64_size; + + } else { + compute(ur_ch_blocks, masked_ch_block_tail); + } +} + +void jit_avx512_fork_dw_conv_fwd_kernel_bf16::loop_ow(int ur_ch_blocks) { + + Label unrolled_w_label; + Label tail_w_label; + Label exit_label; + + const auto src_layout_nxc = is_src_layout_nxc(); + const auto dat_c_stride = src_layout_nxc ? jcp.ngroups : jcp.ch_block; + + L(unrolled_w_label); { + int ur_w = jcp.ur_w; + + size_t inp_shift = (size_t)jcp.typesize_in * ur_w * jcp.stride_w * dat_c_stride; + size_t out_shift = (size_t)jcp.typesize_out * ur_w * dat_c_stride; + + cmp(reg_ur_w, ur_w); + jl(tail_w_label, T_NEAR); + + compute_loop(ur_w, ur_ch_blocks); + + add(reg_input, inp_shift); + add(reg_output, out_shift); + + sub(reg_ur_w, ur_w); + jmp(unrolled_w_label); + } + + L(tail_w_label); { + int ur_w = 1; + + size_t inp_shift = (size_t)jcp.typesize_in * ur_w * jcp.stride_w * dat_c_stride; + size_t out_shift = (size_t)jcp.typesize_out * ur_w * dat_c_stride; + + cmp(reg_ur_w, ur_w); + jl(exit_label, T_NEAR); + + compute_loop(ur_w, ur_ch_blocks); + + add(reg_input, inp_shift); + add(reg_output, out_shift); + + sub(reg_ur_w, ur_w); + jmp(tail_w_label); + } + + L(exit_label); +} + +void jit_avx512_fork_dw_conv_fwd_kernel_bf16::generate() { + const auto& p = attr_.post_ops_; + for (int i = 0; i < p.len(); i++) { + auto& post_op = p.entry_[i]; + if (post_op.is_eltwise()) { + eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32( + this, + post_op.eltwise + )); + } else if (post_op.is_depthwise()) { + depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32( + this, + post_op + )); + } + } + + this->preamble(); + + std::size_t post_ops_pointers_count = 0; + for (int i = 0; i < p.len(); i++) { + if (p.entry_[i].is_depthwise() || p.entry_[i].is_quantization()) { + post_ops_pointers_count++; + } + } + + if (post_ops_pointers_count != 0) { + sub(rsp, post_ops_pointers_count * sizeof(float *)); + + auto aux_reg0 = reg_input; + auto aux_reg1 = reg_output; + + mov(aux_reg0, ptr[this->param1 + GET_OFF(post_ops_binary_rhs_arg_vec)]); + for (size_t i = 0; i < post_ops_pointers_count; i++) { + mov(aux_reg1, ptr[aux_reg0 + i * sizeof(float *)]); + mov(ptr[rsp + i * sizeof(float *)], aux_reg1); + } + } + + mov(reg_input, ptr[this->param1 + GET_OFF(src)]); + mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + if (jcp.with_bias) + mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]); + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]); + mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(load_work)]); + mov(reg_ur_w, ptr[this->param1 + GET_OFF(ur_w)]); + + Label ch_blocks_tail_label; + Label exit_label; + + int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking; + const auto oc_tail = jcp.oc_without_padding % jcp.ch_block; + if (oc_tail != 0) { + // Note: is_src_layout_nxc() == true, otherwise channels are padded + // Prepare masks for tailing + const int oc_tail_shift + = jcp.ch_block - jcp.oc_without_padding % jcp.ch_block; + static constexpr auto zmm_16b_mask = ((1 << 16) - 1); + + // To account for special store optimization, where two oc_blocks are + // combined with one single write, extend the mask for 32 bits + // (i.e. 32 bfloat16 elements) + const bool need_extended_mask = jcp.dst_dt == data_type::bf16 + && isa_has_bf16(jcp.isa) && jcp.nb_ch_blocking > 1; + if (need_extended_mask) + kxnord(k_ch_tail_mask_extended, k_ch_tail_mask_extended, + k_ch_tail_mask_extended); + + Label done; + mov(reg_tail, ptr[this->param1 + GET_OFF(load_work)]); + cmp(reg_tail, jcp.nb_ch_blocking * jcp.ch_block); + je(done, T_NEAR); + Reg32 reg_tail_32 = reg_tail.cvt32(); + mov(reg_tail_32, zmm_16b_mask >> oc_tail_shift); + kmovw(k_oc_tail_mask, reg_tail_32); + if (need_extended_mask) { + auto zmm_32b_mask = (1 << (oc_tail + jcp.ch_block)) - 1; + mov(reg_tail_32, zmm_32b_mask); + kmovd(k_ch_tail_mask_extended, reg_tail_32); + } + L(done); + } + + if (is_src_layout_nxc()) { + loop_ow(jcp.nb_ch); + } else { + cmp(reg_ch_blocks, (jcp.nb_ch_blocking - 1) * jcp.ch_block); + jle(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR); + + loop_ow(jcp.nb_ch_blocking); // channel main loop + + if (ch_blocks_tail) { + jmp(exit_label, T_NEAR); + L(ch_blocks_tail_label); + + loop_ow(ch_blocks_tail); // channel tail loop + } + + L(exit_label); + } + + if (post_ops_pointers_count != 0) { + add(rsp, post_ops_pointers_count * sizeof(float *)); + } + + this->postamble(); + + for (auto& inj : eltwise_injectors) + inj->prepare_table(); +} + +inline void jit_avx512_fork_dw_conv_bwd_data_kernel_bf16::load_ddst( + int ur_ch_blocks, int ur_str_w) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int w = 0; w < ur_str_w; w++) { + Zmm zmm_acc = get_acc_reg(ch * ur_str_w + w); + uni_vpxor(zmm_acc, zmm_acc, zmm_acc); + } + } +} + +inline void jit_avx512_fork_dw_conv_bwd_data_kernel_bf16::apply_filter( + int ur_ch_blocks, int ur_str_w) { + int kw = jcp.kw; + int kh = jcp.kh; + int ow = jcp.ow; + int oh = jcp.oh; + + int ch_blk = jcp.ch_block; + int stride_h = jcp.stride_h; + int stride_w = jcp.stride_w; + + Label iter_exit_label; + + cmp(reg_kh, 0); + je(iter_exit_label, T_NEAR); + + cmp(reg_kw, 0); + je(iter_exit_label, T_NEAR); + + mov(iter_kh, reg_kh); + Label kh_label; + L(kh_label); { + mov(aux1_reg_ddst, aux_reg_ddst); + mov(aux1_reg_kernel, aux_reg_kernel); + + mov(iter_kw, reg_kw); + Label kw_label; + L(kw_label); { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + int ker_off = ch * kh * kw * ch_blk; + vpmovzxwd(zmm_ker_reg, + ptr[aux1_reg_kernel + ker_off * jcp.typesize_in]); + + for (int w = 0; w < ur_str_w; w++) { + Zmm zmm_acc = get_acc_reg(ch * ur_str_w + w); + int ddst_off = (ch * oh * ow + w) * ch_blk; + vpmovzxwd(zmm_dst_reg, + ptr[aux1_reg_ddst + ddst_off * jcp.typesize_in]); + + if (!isa_has_bf16(jcp.isa)) { + bf16_emu_->vdpbf16ps( + zmm_acc, zmm_dst_reg, zmm_ker_reg); + } else { + vdpbf16ps(zmm_acc, zmm_ker_reg, zmm_dst_reg); + } + } + } + + add(aux1_reg_kernel, ch_blk * stride_w * jcp.typesize_in); + sub(aux1_reg_ddst, ch_blk * jcp.typesize_in); + + sub(iter_kw, stride_w); + cmp(iter_kw, 0); + jg(kw_label, T_NEAR); + } + + add(aux_reg_kernel, kw * ch_blk * stride_h * jcp.typesize_in); + sub(aux_reg_ddst, ow * ch_blk * jcp.typesize_in); + + sub(iter_kh, stride_h); + cmp(iter_kh, 0); + jg(kh_label, T_NEAR); + } + + L(iter_exit_label); +} + +inline void jit_avx512_fork_dw_conv_bwd_data_kernel_bf16::store_dsrc( + int ur_ch_blocks, int ur_str_w) { + int ch_blk = jcp.ch_block; + int iw = jcp.iw; + int ih = jcp.ih; + int stride_w = jcp.stride_w; + + if (jcp.dsrc_dt == data_type::bf16 && (!isa_has_bf16(jcp.isa))) + bf16_emu_->init_vcvtneps2bf16(); + + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int w = 0; w < ur_str_w; w++) { + int dsrc_off = (ch * ih * iw + w * stride_w) * ch_blk; + auto zmm_dsrc = get_acc_reg(ch * ur_str_w + w); + + if (jcp.dsrc_dt == data_type::f32) { + uni_vmovups( + ptr[reg_dsrc + dsrc_off * jcp.typesize_out], zmm_dsrc); + } else if (jcp.dsrc_dt == data_type::bf16) { + auto ymm_dsrc = Ymm(zmm_dsrc.getIdx()); + if (isa_has_bf16(jcp.isa)) { + vcvtneps2bf16(ymm_dsrc, zmm_dsrc); + } else { + bf16_emu_->vcvtneps2bf16(ymm_dsrc, zmm_dsrc); + } + vmovups(ptr[reg_dsrc + dsrc_off * jcp.typesize_out], ymm_dsrc); + } + } + } + /* Note: current 'store_dsrc' is limited to storing 'ymm' output. This is + * because of the current implementation approach that calculates convolution as + * a strided backward-pass. To increase store throughput by writing 'zmm' + * registers, changes are needed in both JIT-kernel and Driver code. */ +} + +inline void jit_avx512_fork_dw_conv_bwd_data_kernel_bf16::loop_body( + int ur_ch_blocks) { + Label unrolled_w_label; + Label tail_w_label; + Label exit_label; + + L(unrolled_w_label); { + int ur_w = jcp.ur_w; + + cmp(reg_ur_str_w, ur_w); + jl(tail_w_label, T_NEAR); + + mov(aux_reg_ddst, reg_ddst); + mov(aux_reg_kernel, reg_kernel); + + load_ddst(ur_ch_blocks, ur_w); + apply_filter(ur_ch_blocks, ur_w); + store_dsrc(ur_ch_blocks, ur_w); + + add(reg_dsrc, jcp.typesize_out * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_ddst, jcp.typesize_in * ur_w * jcp.ch_block); + + sub(reg_ur_str_w, ur_w); + jmp(unrolled_w_label); + } + + L(tail_w_label); { + int ur_w = 1; + + cmp(reg_ur_str_w, ur_w); + jl(exit_label, T_NEAR); + + mov(aux_reg_ddst, reg_ddst); + mov(aux_reg_kernel, reg_kernel); + + load_ddst(ur_ch_blocks, ur_w); + apply_filter(ur_ch_blocks, ur_w); + store_dsrc(ur_ch_blocks, ur_w); + + add(reg_dsrc, jcp.typesize_out * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_ddst, jcp.typesize_in * ur_w * jcp.ch_block); + + sub(reg_ur_str_w, ur_w); + jmp(tail_w_label); + } + + L(exit_label); +} + +void jit_avx512_fork_dw_conv_bwd_data_kernel_bf16::generate() { + preamble(); + mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]); + mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]); + mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]); + mov(reg_ur_str_w, ptr[this->param1 + GET_OFF(ur_str_w)]); + + Label ch_blocks_tail_label; + Label exit_label; + + int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking; + + cmp(reg_ch_blocks, jcp.nb_ch_blocking); + jne(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR); + + loop_body(jcp.nb_ch_blocking); // channel main loop + + if (ch_blocks_tail) { + L(ch_blocks_tail_label); + + cmp(reg_ch_blocks, ch_blocks_tail); + jne(exit_label, T_NEAR); + + loop_body(ch_blocks_tail); // channel tail loop + } + + L(exit_label); + this->postamble(); +} + +} +} +} +} diff --git a/src/cpu/x64/jit_avx512_core_fork_bf16_dw_conv_kernel.hpp b/src/cpu/x64/jit_avx512_core_fork_bf16_dw_conv_kernel.hpp new file mode 100644 index 00000000000..fefc63e0dbd --- /dev/null +++ b/src/cpu/x64/jit_avx512_core_fork_bf16_dw_conv_kernel.hpp @@ -0,0 +1,208 @@ +/******************************************************************************* +* Copyright 2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_X64_JIT_AVX512_CORE_FORK_BF16_DW_CONV_KERNEL_HPP +#define CPU_X64_JIT_AVX512_CORE_FORK_BF16_DW_CONV_KERNEL_HPP + +#include "common/c_types_map.hpp" +#include "common/memory_tracking.hpp" + +#include "cpu/x64/jit_generator.hpp" +#include "cpu/x64/jit_primitive_conf.hpp" +#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" + +#include "cpu/x64/jit_avx512_core_bf16cvt.hpp" +#include "cpu/x64/injectors/jit_uni_depthwise_injector.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +struct jit_avx512_fork_dw_conv_fwd_kernel_bf16 : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_fork_dw_conv_fwd_kernel_bf16) + + jit_avx512_fork_dw_conv_fwd_kernel_bf16(const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md, const primitive_attr_t& attr) + : jcp(ajcp), attr_(attr), bf16_emu_(nullptr) { + if (!isa_has_bf16(jcp.isa)) + bf16_emu_ = new bf16_emulation_t(this, bf16_emu_reserv_1, + bf16_emu_reserv_2, bf16_emu_reserv_3, bf16_emu_reserv_4, + bf16_emu_reserv_5, bf16_emu_reserv_6); + } + + ~jit_avx512_fork_dw_conv_fwd_kernel_bf16() { + for (auto inj : eltwise_injectors) + delete inj; + eltwise_injectors.clear(); + + for (auto inj : depthwise_injectors) + delete inj; + depthwise_injectors.clear(); + + delete bf16_emu_; + } + + jit_conv_conf_t jcp; + const primitive_attr_t& attr_; + +private: + using reg64_t = const Xbyak::Reg64; + using mask_t = const Xbyak::Opmask; + const Xbyak::AddressFrame &vmmword = zword; + + const int acc_idx_start = 2; + inline int get_max_regs() { return isa_has_bf16(jcp.isa) ? 30 : 25; }; + + // dw convolution + reg64_t reg_input = r8; + reg64_t aux_reg_input = r9; + reg64_t aux1_reg_input = r10; + reg64_t reg_kernel = r11; + reg64_t aux_reg_kernel = r12; + reg64_t reg_ch_blocks = r13; + reg64_t reg_output = r14; + reg64_t reg_bias = r15; + reg64_t reg_kh = rax; + reg64_t reg_kw = rbx; + reg64_t iter_kh = rdx; + reg64_t iter_kw = rsi; + reg64_t reg_ur_w = rbp; + reg64_t reg_tail = abi_not_param1; + reg64_t aux1_reg_kernel = reg_ch_blocks; + reg64_t imm_addr64 = aux1_reg_input; + reg64_t reg_d_weights = imm_addr64; + reg64_t reg_d_bias = iter_kh; + int base_post_ops_data_offset = 0; + constexpr static int reg64_size = 8; + reg64_t aux_reg_ch_blocks = reg_ur_w; + reg64_t aux_reg_blocks_offset = reg_tail; + + mask_t k_oc_tail_mask = Xbyak::Opmask(2); + mask_t ktail_mask = k_oc_tail_mask; + mask_t k_ch_tail_mask_extended = Xbyak::Opmask(3); + + Xbyak::Zmm zmm_ker_reg = Xbyak::Zmm(0); + Xbyak::Zmm zmm_src_reg = Xbyak::Zmm(1); + Xbyak::Zmm zmm_prev_dst = Xbyak::Zmm(31); + + /* Registers used for bfloat16 emulation */ + Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(26); + Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(27); + Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(28); + reg64_t bf16_emu_reserv_4 = iter_kw; + Xbyak::Zmm bf16_emu_reserv_5 = Xbyak::Zmm(29); + Xbyak::Zmm bf16_emu_reserv_6 = Xbyak::Zmm(30); + + inline Xbyak::Zmm get_acc_reg(int idx) { + assert(idx + acc_idx_start <= get_max_regs()); + return Xbyak::Zmm(idx + acc_idx_start); + } + + inline bool is_src_layout_nxc() { + return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc, + format_tag::nwc); + } + + inline bool is_dst_layout_nxc() { + return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc, + format_tag::nwc); + } + + inline void load_src(int ur_ch_blocks, int ur_w, bool last_ch_block_flag); + inline void compute_loop(int ur_w, int ur_ch_blocks); + inline void apply_filter(int ur_ch_blocks, int ur_w, bool last_ch_block_flag); + inline void apply_filter_unrolled(int ur_ch_blocks, int ur_w, bool last_ch_block_flag); + inline void apply_postprocess(int ur_ch_blocks, int ur_w); + inline void store_dst(int ur_ch_blocks, int ur_w, bool last_ch_block_flag); + inline void loop_ow(int ur_ch_blocks); + + nstl::vector*> eltwise_injectors; + nstl::vector*> depthwise_injectors; + + bf16_emulation_t *bf16_emu_; + + void generate() override; +}; + +struct jit_avx512_fork_dw_conv_bwd_data_kernel_bf16 : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_fork_dw_conv_bwd_data_kernel_bf16) + + jit_avx512_fork_dw_conv_bwd_data_kernel_bf16(const jit_conv_conf_t &ajcp, const primitive_attr_t&) + : jcp(ajcp), bf16_emu_(nullptr) { + + if (!isa_has_bf16(jcp.isa)) + bf16_emu_ = new bf16_emulation_t(this, bf16_emu_reserv_1, + bf16_emu_reserv_2, bf16_emu_reserv_3, bf16_emu_reserv_4, + bf16_emu_reserv_5, bf16_emu_reserv_6); + } + + ~jit_avx512_fork_dw_conv_bwd_data_kernel_bf16() { delete bf16_emu_; } + + jit_conv_conf_t jcp; + +private: + using reg64_t = const Xbyak::Reg64; + + const int acc_idx_start = 2; + inline int get_max_regs() { return isa_has_bf16(jcp.isa) ? 30 : 25; }; + + Xbyak::Zmm zmm_ker_reg = Xbyak::Zmm(0); + Xbyak::Zmm zmm_dst_reg = Xbyak::Zmm(1); + + inline Xbyak::Zmm get_acc_reg(int idx) { + assert(idx + acc_idx_start <= get_max_regs()); + return Xbyak::Zmm(idx + acc_idx_start); + } + + reg64_t reg_ddst = rax; + reg64_t aux_reg_ddst = r8; + reg64_t aux1_reg_ddst = abi_not_param1; + reg64_t reg_kernel = rdx; + reg64_t aux_reg_kernel = r10; + reg64_t aux1_reg_kernel = rbp; + reg64_t reg_dsrc = rsi; + + reg64_t reg_ur_str_w = r9; + reg64_t reg_ch_blocks = rbx; + + reg64_t iter_kh = r11; + reg64_t iter_kw = r12; + reg64_t reg_kh = r13; + reg64_t reg_kw = r14; + + Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(26); + Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(27); + Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(28); + reg64_t bf16_emu_reserv_4 = iter_kw; + Xbyak::Zmm bf16_emu_reserv_5 = Xbyak::Zmm(29); + Xbyak::Zmm bf16_emu_reserv_6 = Xbyak::Zmm(30); + + bf16_emulation_t *bf16_emu_; + + inline void loop_body(int ur_ch_blocks); + inline void load_ddst(int ur_ch_blocks, int ur_str_w); + inline void apply_filter(int ur_ch_blocks, int ur_str_w); + inline void store_dsrc(int ur_ch_blocks, int ur_str_w); + + void generate() override; +}; + +} +} +} +} + +#endif diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp index d945dc685a4..b3b1b8b2e7c 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp @@ -15,6 +15,7 @@ *******************************************************************************/ #include +#include #include "common/c_types_map.hpp" #include "common/memory.hpp" @@ -49,7 +50,7 @@ _jit_avx512_core_x8s8s32x_1x1_conv_kernel:: const jit_1x1_conv_conf_t &ajcp, const primitive_attr_t &attr, const memory_desc_t &dst_md) : jcp(ajcp), attr_(attr), postops_injector_(nullptr) { - if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; @@ -67,10 +68,12 @@ _jit_avx512_core_x8s8s32x_1x1_conv_kernel:: use_exact_tail_scalar_bcast}; const static_params_t static_params { this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params + {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique< - injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + injector::jit_uni_postops_injector_t>( + this, jcp.post_ops, static_params, quantization_static_params); } } @@ -215,7 +218,17 @@ template void _jit_avx512_core_x8s8s32x_1x1_conv_kernel::apply_postops( const int load_loop_blk, const int ur, const bool mask_flag_in, const float *p_sum_scale, const int32_t *p_sum_zp) { - if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum || jcp.with_depthwise || jcp.with_quantization) { + std::map vmm_idx_off; + iterate(load_loop_blk, ur, + [&](const bool, const int i_load, const int i_ur) { + vmm_idx_off.insert({vreg_accum_idx(load_loop_blk, i_load, i_ur), i_load * jcp.load_block * sizeof(float)}); + }); + depthwise_injector::dynamic_params_t ddp {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + reg_oc_off, vmm_idx_off, + this->rsp, base_post_ops_data_offset}; + quantization_injector::dynamic_params_t qdp {reg_oc_off, vmm_idx_off, jcp.dst_dt, + this->rsp, base_post_ops_data_offset}; apply_sum(load_loop_blk, ur, mask_flag_in, p_sum_scale, p_sum_zp); @@ -264,7 +277,7 @@ void _jit_avx512_core_x8s8s32x_1x1_conv_kernel::apply_postops( jmp(postops_done, T_NEAR); L(postops_no_tail); } - postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); L(postops_done); } else { @@ -273,7 +286,7 @@ void _jit_avx512_core_x8s8s32x_1x1_conv_kernel::apply_postops( vmm_idxs.emplace( vreg_accum_idx(load_loop_blk, i_load, i_ur)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t(), ddp, qdp); } } } @@ -375,14 +388,14 @@ void _jit_avx512_core_x8s8s32x_1x1_conv_kernel::reduce_loop( auto vmm_bias = vmm_tmp; auto vmm_comp = vmm_bcast; if (jcp.with_bias) { - if (jcp.signed_input) + if (jcp.signed_input || jcp.with_input_zp) mov(reg_bias_data, EVEX_compress_addr(rsp, reg_bias_data_off)); cvt2ps(jcp.bia_dt, vmm_bias, bias_ptr(i_load), mask_flag); if (jcp.signed_input && jcp.ver != ver_vnni) vmulps(vmm_bias, vmm_bias, vmm_bias_alpha()); } - if (jcp.signed_input) { + if (jcp.signed_input || jcp.with_input_zp) { mov(reg_comp_data, EVEX_compress_addr(rsp, reg_comp_data_off)); cvt2ps(data_type::s32, vmm_comp, comp_ptr(i_load), mask_flag); } @@ -402,7 +415,7 @@ void _jit_avx512_core_x8s8s32x_1x1_conv_kernel::reduce_loop( for (int i_ur = 0; i_ur < ur; ++i_ur) { auto r = vreg_accum(load_loop_blk, i_load, i_ur); vcvtdq2ps(r, r); - if (jcp.signed_input) vaddps(r, r, vmm_comp); + if (jcp.signed_input || jcp.with_input_zp) vaddps(r, r, vmm_comp); if (jcp.src_zero_point) vaddps(r, r, vmm_zp); if (jcp.with_bias) vaddps(r, r, vmm_bias); @@ -510,6 +523,8 @@ void _jit_avx512_core_x8s8s32x_1x1_conv_kernel::reduce_loop( Label reduce_loop; Label reduce_loop_tail; + push(reg_oc_off); + mov(aux_reg_load_data, reg_load_data); mov(aux_reg_bcast_data, aux1_reg_bcast_data); @@ -535,6 +550,8 @@ void _jit_avx512_core_x8s8s32x_1x1_conv_kernel::reduce_loop( fma_block(false); } + pop(reg_oc_off); + if (jcp.oc_without_padding != jcp.oc) { Label end_store, common_store; mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data); @@ -564,9 +581,11 @@ void _jit_avx512_core_x8s8s32x_1x1_conv_kernel::reduce_loop( template void _jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate() { - preamble(); + if (postops_injector_) + postops_injector_->push_post_ops_data_on_stack(this->param1, GET_OFF(post_ops_binary_rhs_arg_vec), reg_load_data, reg_output_data); + const int simd_w = jcp.ic_block; xor_(reg_scratch, reg_scratch); Reg16 _t = reg_scratch.cvt16(); @@ -574,6 +593,8 @@ void _jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate() { vpbroadcastw(vmm_one, _t); sub(rsp, stack_space_needed); + base_post_ops_data_offset += stack_space_needed; + if (jcp.with_binary) { const auto zeroed_reg = r15; xor_(zeroed_reg, zeroed_reg); @@ -582,7 +603,7 @@ void _jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate() { } if (jcp.with_bias) mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); - if (jcp.signed_input) { + if (jcp.signed_input || jcp.with_input_zp) { mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data); mov(reg_comp_data, ptr[param1 + GET_OFF(compensation)]); mov(EVEX_compress_addr(rsp, reg_comp_data_off), reg_comp_data); @@ -611,6 +632,7 @@ void _jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate() { mov(EVEX_compress_addr(rsp, bcast_loop_work_off), reg_bcast_loop_work); mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); + mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]); const int load_dim_tail = (one_of(jcp.prop_kind, forward_training, forward_inference) @@ -644,11 +666,11 @@ void _jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate() { bcast_loop(load_loop_blk); add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); if (jcp.with_bias) { - if (jcp.signed_input) + if (jcp.signed_input || jcp.with_input_zp) mov(reg_bias_data, EVEX_compress_addr(rsp, reg_bias_data_off)); add(reg_bias_data, load_loop_blk * jcp.load_block * jcp.typesize_bia); - if (jcp.signed_input) + if (jcp.signed_input || jcp.with_input_zp) mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data); } if (jcp.with_binary) { @@ -658,7 +680,7 @@ void _jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate() { mov(EVEX_compress_addr(rsp, reg_binary_post_op_acc_off), reg_scratch); } - if (jcp.signed_input) { + if (jcp.signed_input || jcp.with_input_zp) { mov(reg_comp_data, EVEX_compress_addr(rsp, reg_comp_data_off)); add(reg_comp_data, load_loop_blk * jcp.load_block * sizeof(int32_t)); @@ -681,6 +703,7 @@ void _jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate() { mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off)); add(reg_output_data, load_loop_blk * jcp.load_block * jcp.typesize_out); sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + add(reg_oc_off, load_loop_blk * jcp.oc_block * sizeof(float)); }; Label load_loop_blk[7]; @@ -734,8 +757,12 @@ void _jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate() { } L(load_loop_blk[num_ur_cases]); + base_post_ops_data_offset -= stack_space_needed; add(rsp, stack_space_needed); + if (postops_injector_) + postops_injector_->reset_stack_pointer(); + postamble(); if (jcp.with_eltwise) postops_injector_->prepare_table(); @@ -803,6 +830,20 @@ status_t jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_conf( jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; jcp.signed_input = (src_d.data_type() == data_type::s8); + jcp.with_input_zp = !attr.input_zero_points_.has_default_values(); + jcp.with_weights_zp = !attr.weights_zero_points_.has_default_values(); + + if (jcp.with_input_zp) { + if (attr.input_zero_points_.count_ != 1 && attr.input_zero_points_.count_ != jcp.ic * jcp.ngroups) + return status::unimplemented; + + if (attr.output_compensations_.count_ != jcp.oc * jcp.ngroups) + return status::unimplemented; + } + + if (jcp.with_weights_zp) + return status::unimplemented; + dim_t output_spatial = jcp.od * jcp.oh * jcp.ow; dim_t input_spatial = jcp.id * jcp.ih * jcp.iw; @@ -828,6 +869,11 @@ status_t jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_conf( const int sum_ind = post_ops.find(primitive_kind::sum, 0, dw_conv_ind); jcp.with_sum = sum_ind != -1; + if (jcp.with_sum) + jcp.sum_dt = post_ops.entry_[sum_ind].sum.dt; + + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise, 0, dw_conv_ind) != -1; + jcp.with_quantization = post_ops.find(primitive_kind::quantization, 0, dw_conv_ind) != -1; if (dw_conv_ind >= 0) { // dw_conv and post_ops after it are handled externally, so skip them @@ -864,7 +910,7 @@ status_t jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_conf( static constexpr bool sum_at_pos_0_only = false; static constexpr bool sum_requires_scale_one = false; static constexpr bool sum_requires_zp_zero = false; - const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum}, + const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero}); if (!post_ops_ok_) return status::unimplemented; @@ -889,11 +935,10 @@ status_t jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_conf( memory_desc_t want_wei_md = weights_md; memory_desc_init_by_tag(want_wei_md, wei_tag); if (jcp.signed_input) { - want_wei_md.extra.flags = 0 | compensation_conv_s8s8 | scale_adjust; + want_wei_md.extra.flags = 0 | compensation_conv_s8s8; want_wei_md.extra.compensation_mask = (1 << 0) + (with_groups ? (1 << 1) : 0); - want_wei_md.extra.scale_adjust - = mayiuse(avx512_core_vnni) ? 1.f : 0.5f; + want_wei_md.extra.scale_adjust = 1.f; } if (jcp.src_zero_point) { want_wei_md.extra.flags |= compensation_conv_asymmetric_src; diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp index f176f176786..3b237756eb2 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp @@ -41,7 +41,7 @@ struct _jit_avx512_core_x8s8s32x_1x1_conv_kernel : public jit_generator { private: constexpr static int isa_simd_width_ = cpu_isa_traits::vlen / sizeof(float); - std::unique_ptr> + std::unique_ptr> postops_injector_; /* register mapping */ @@ -75,6 +75,16 @@ struct _jit_avx512_core_x8s8s32x_1x1_conv_kernel : public jit_generator { const Xbyak::Opmask k_load_dim_mask = Xbyak::Opmask(2); const Xbyak::Opmask k_load_dim_tail_mask = Xbyak::Opmask(3); const Xbyak::Opmask postops_mask = Xbyak::Opmask(4); + + const Xbyak::Reg64 reg_d_weights = aux_reg_bcast_data; + const Xbyak::Reg64 reg_d_bias = reduce_loop_iter; + const Xbyak::Reg64 reg_oc_off = aux_reg_load_data; + int base_post_ops_data_offset = 0; + + Xbyak::Zmm zmm_d_weights = Xbyak::Zmm(31); + Xbyak::Zmm zmm_d_bias = Xbyak::Zmm(30); + + const Xbyak::Opmask ktail_mask = k6; const Xbyak::Opmask vmask = k7; const Vmm vmm_tmp = Vmm(28); diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.cpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.cpp index 326bc552566..9848f60fe6b 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.cpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.cpp @@ -55,6 +55,10 @@ status_t jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t::execute_forward( DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); + DEFINE_OUTPUT_COMPENSATION_BUFFER(output_compensation, pd()->jcp_); + + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + auto scratchpad = ctx.get_scratchpad_grantor(); if (pd()->jcp_.signed_input && pd()->jcp_.ver != ver_vnni) { @@ -97,7 +101,7 @@ status_t jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t::execute_forward( execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw, dst, src_zero_point, dst_zero_point, scratchpad, post_ops_binary_rhs_arg_vec.data(), - post_ops_binary_rhs_arg_vec_dw.data()); + post_ops_binary_rhs_arg_vec_dw.data(), MB, output_compensation); }); return status::success; } @@ -108,7 +112,8 @@ void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t::execute_forward_thr( char *dst, const int32_t *src_zero_point, const int32_t *dst_zero_point, const memory_tracking::grantor_t &scratchpad, const void *post_ops_binary_rhs_arg_vec, - const void *post_ops_binary_rhs_arg_vec_dw) const { + const void *post_ops_binary_rhs_arg_vec_dw, int MB, + const int32_t *output_compensation) const { const memory_desc_wrapper src_d(pd()->src_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); @@ -129,7 +134,7 @@ void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t::execute_forward_thr( auto local_scales = scratchpad.get(key_conv_adjusted_scales); - const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + const int work_amount = MB * jcp.ngroups * jcp.nb_bcast; const bool is_2d = pd()->ndims() == 4; const bool is_3d = pd()->ndims() == 5; @@ -146,9 +151,8 @@ void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t::execute_forward_thr( auto offset = weights_d.size() - weights_d.additional_buffer_size(); char *w = const_cast(weights); - const int32_t *compensation = (jcp.signed_input) - ? reinterpret_cast(w + offset) - : nullptr; + const int32_t *compensation = (jcp.signed_input) ? reinterpret_cast(w + offset) : + (jcp.with_input_zp) ? output_compensation : nullptr; const int32_t *zp_compensation = jcp.src_zero_point ? reinterpret_cast(w + offset) + (jcp.signed_input ? jcp.ngroups * jcp.oc : 0) @@ -210,7 +214,7 @@ void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t::execute_forward_thr( int &bcast_step, int &od, int &oh, int &ow, int &id, int &ih, int &iw) { int osb {0}; - nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, nb_bcast); + nd_iterator_init(iwork, n, MB, g, jcp.ngroups, osb, nb_bcast); bcast_step = step( nb_bcast_blocking, nb_bcast - osb, nb_bcast_blocking_max); bcast_step = nstl::min(bcast_step, bcast_end - iwork); @@ -265,7 +269,7 @@ void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t::execute_forward_thr( : weights_d.blk_off(ocb, icb); p.load_data = weights + wei_offset; p.bias_data = &bias[_ocb * jcp.oc_block * bia_dt_size]; - p.compensation = (jcp.signed_input) ? &compensation[_ocb * jcp.oc_block] + p.compensation = (jcp.signed_input || jcp.with_input_zp) ? &compensation[_ocb * jcp.oc_block] : nullptr; p.zp_compensation = jcp.src_zero_point ? zp_compensation + _ocb * jcp.oc_block @@ -296,6 +300,7 @@ void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t::execute_forward_thr( p.oc_l_off = _ocb * jcp.oc_block; p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; p.dst_orig = dst; + p.oc_off = _ocb * jcp.oc_block * sizeof(float); (*kernel_)(&p); }; @@ -421,6 +426,7 @@ void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t::execute_forward_thr( par_conv_dw.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec_dw; par_conv_dw.dst_orig = dst; + par_conv_dw.oc_off = ocb * jcp_dw->ch_block * sizeof(float); (*kernel_dw_)(&par_conv_dw); @@ -440,7 +446,7 @@ void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t::execute_forward_thr( addrs.resize(jcp_dw->kh); int bcast_start {0}, bcast_end {0}, ocb_start, ocb_end; - balance2D(nthr, ithr, jcp.mb * jcp.ngroups * jcp_dw->oh, bcast_start, + balance2D(nthr, ithr, MB * jcp.ngroups * jcp_dw->oh, bcast_start, bcast_end, nb_oc, ocb_start, ocb_end, jcp.load_grp_count); while (ocb_start < ocb_end) { @@ -451,7 +457,7 @@ void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t::execute_forward_thr( auto bcast_iter = bcast_start; while (bcast_iter < bcast_end) { int n, g, oh_dw; - nd_iterator_init(bcast_iter, n, jcp.mb, g, jcp.ngroups, oh_dw, + nd_iterator_init(bcast_iter, n, MB, g, jcp.ngroups, oh_dw, jcp_dw->oh); if (oh_dw == 0) oh_1x1 = 0; // Reset over mb boundary const int oh_1x1_range diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.hpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.hpp index 49f866abad2..2988d4aa1e1 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.hpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.hpp @@ -71,7 +71,10 @@ struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t : public primitive_t { && desc()->accum_data_type == s32 && attr()->has_default_values(smask_t::oscale | smask_t::zero_points_runtime - | smask_t::post_ops | smask_t::sum_dt, + | smask_t::post_ops + | smask_t::sum_dt + | smask_t::input_zero_points + | smask_t::output_compensations, dst_md(0)->data_type) && attr()->post_ops_.check_sum_consistent_dt( dst_md(0)->data_type) @@ -294,7 +297,8 @@ struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t : public primitive_t { const int32_t *dst_zero_point, const memory_tracking::grantor_t &scratchpad, const void *post_ops_binary_rhs_arg_vec, - const void *post_ops_binary_rhs_arg_vec_dw) const; + const void *post_ops_binary_rhs_arg_vec_dw, int MB, + const int32_t *output_compensation) const; const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } std::unique_ptr kernel_; std::unique_ptr> rtus_driver_; diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.cpp index 9ec3db0a03d..2b6a2649d5d 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.cpp @@ -44,7 +44,7 @@ void pick_loop_order(jit_conv_conf_t &jcp, int nthr) { jcp.loop_order = loop_cwgn; if (jcp.ngroups > 1) { jcp.loop_order = loop_ngcw; - if (jcp.mb < nthr) + if (jcp.mb < nthr && jcp.ndims != 5) jcp.loop_order = jcp.ndims == 3 ? loop_nwcg : loop_nhwcg; } else if (jcp.mb >= nthr && jcp.ic_without_padding <= 16) { jcp.loop_order = loop_ngcw; @@ -57,7 +57,7 @@ _jit_avx512_core_x8s8s32x_fwd_kernel::_jit_avx512_core_x8s8s32x_fwd_kernel( const jit_conv_conf_t &ajcp, const primitive_attr_t &attr, const memory_desc_t &dst_md) : jcp(ajcp), attr_(attr), postops_injector_(nullptr) { - if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; @@ -76,9 +76,18 @@ _jit_avx512_core_x8s8s32x_fwd_kernel::_jit_avx512_core_x8s8s32x_fwd_kernel( const static_params_t static_params { this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params; + int max_ur_w = nstl::max(jcp.ur_w, jcp.ur_w_tail); + int nb_oc_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + int last_accum_idx = vmm_out(max_ur_w - 1, nb_oc_block - 1).getIdx(); + if (last_accum_idx >= 30) + quantization_static_params = {zmm_d_weights.getIdx(), zmm_d_weights.getIdx(), reg_d_weights, reg_d_bias}; + else + quantization_static_params = {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; + postops_injector_ = utils::make_unique< - injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + injector::jit_uni_postops_injector_t>( + this, jcp.post_ops, static_params, quantization_static_params); } } @@ -187,9 +196,20 @@ template void _jit_avx512_core_x8s8s32x_fwd_kernel::apply_postops(int ur_w, bool last_oc_block_flag, const int nb_oc_block, const int oc_block, const float *p_sum_scale, const int32_t *p_sum_zp) { - if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum || jcp.with_depthwise || jcp.with_quantization) { + std::map vmm_idx_off; + iterate(nb_oc_block, ur_w, + [&](const bool, const int k, const int j) { + vmm_idx_off.insert({vmm_out_idx(j, k), k * oc_block * sizeof(float)}); + }); + depthwise_injector::dynamic_params_t ddp {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off, + this->rsp, base_post_ops_data_offset}; + quantization_injector::dynamic_params_t qdp {ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off, jcp.dst_dt, + this->rsp, base_post_ops_data_offset}; + apply_sum(ur_w, last_oc_block_flag, nb_oc_block, oc_block, p_sum_scale, - p_sum_zp); + p_sum_zp); injector_utils::vmm_index_set_t vmm_idxs; if (jcp.with_binary) { @@ -213,13 +233,13 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::apply_postops(int ur_w, rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); }); - postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); } else { iterate(nb_oc_block, ur_w, [&](const bool, const int k, const int j) { vmm_idxs.emplace(vmm_out_idx(j, k)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t(), ddp, qdp); } } } @@ -233,7 +253,7 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::store_output( mov(reg_bias, ptr[param1 + GET_OFF(bias)]); mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); - if (jcp.signed_input) + if (jcp.signed_input || jcp.with_input_zp) mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]); if (jcp.src_zero_point) { @@ -269,7 +289,7 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::store_output( if (jcp.signed_input && jcp.ver != ver_vnni) /* bias *= 0.5 */ vmulps(vmm_bias, vmm_bias, vmm_bias_alpha()); } - if (jcp.signed_input) { + if (jcp.signed_input || jcp.with_input_zp) { int comp_offset = sizeof(int32_t) * k * oc_block; Vmm vmm_comp_ = vmm_mask(vmm_comp, mask_flag); vmovups(vmm_comp_, @@ -293,10 +313,9 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::store_output( /* add comp in s32 to avoid loss of precision when convert s32 to f32 in integer(2^24) TODO: do the same to bias */ - if (jcp.signed_input) vpaddd(vmm, vmm, vmm_comp); + if (jcp.signed_input || jcp.with_input_zp) vpaddd(vmm, vmm, vmm_comp); if (jcp.src_zero_point) vpaddd(vmm, vmm, vmm_zp); vcvtdq2ps(vmm, vmm); - if (jcp.with_bias) vaddps(vmm, vmm, vmm_bias); const Vmm vmm_k = vmm_mask(vmm, mask_flag); @@ -366,10 +385,11 @@ template <> void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker_dw(int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) { - const bool compute_kernel = IMPLICATION(h_padded, jcp.signed_input); + const bool compute_kernel = IMPLICATION(h_padded, jcp.signed_input || jcp.with_input_zp); if (jcp.src_zero_point) { push(aux_reg_ker_d); + base_post_ops_data_offset += reg64_size; mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]); } @@ -390,7 +410,7 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker_dw(int ur_w, }; auto kernel_offset = [=](int ci, int ki) { - return jcp.typesize_in * ((ci * jcp.kh * jcp.kw + ki) * jcp.ch_block); + return jcp.typesize_in * ((ci * jcp.kd * jcp.kh * jcp.kw + ki) * jcp.ch_block); }; auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) { @@ -420,9 +440,17 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker_dw(int ur_w, } } - if (jcp.signed_input) vmovups(zmm_shifted_zero, vmm_shift); + if (jcp.signed_input || jcp.with_input_zp) vmovups(zmm_shifted_zero, vmm_shift); for (int ci = 0; ci < jcp.nb_ch_blocking; ci++) { + if (jcp.with_input_zp && (h_padded || get_ow_start(0, pad_l) != 0 || get_ow_end(ur_w, jcp.kw-1, pad_r) != ur_w)) { + if (jcp.is_fast_depthwise) { + vbroadcasti32x4(zmm_shifted_zero, ptr[reg_input_zp + ci * jcp.ch_block]); + } else { + vpmovzxbd(zmm_shifted_zero, ptr[reg_input_zp + ci * jcp.ch_block]); + } + } + const bool mask_flag = last_ic_block_flag != no_last_block && ci == jcp.nb_ch_blocking - 1; if (jcp.is_resrc_depthwise && !h_padded) { @@ -460,14 +488,14 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker_dw(int ur_w, } if (h_padded) { - assert(jcp.signed_input); + assert(jcp.signed_input || jcp.with_input_zp); for (int oi = 0; oi < ur_w; oi++) compute(zmm_out(oi, ci), zmm_wei, zmm_shifted_zero); } else { const Zmm r_zmm_src = mask_flag ? zmm_src | ktail_mask : zmm_src; - int start_ = jcp.signed_input ? 0 : oi_start; - int end_ = jcp.signed_input ? ur_w : oi_end; + int start_ = (jcp.signed_input || jcp.with_input_zp) ? 0 : oi_start; + int end_ = (jcp.signed_input || jcp.with_input_zp) ? ur_w : oi_end; for (int oi = start_; oi < end_; oi++) { if (oi >= oi_start && oi < oi_end) { if (jcp.is_resrc_depthwise) { @@ -491,7 +519,7 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker_dw(int ur_w, } compute(zmm_out(oi, ci), zmm_wei, zmm_src); } else { - assert(jcp.signed_input); + assert(jcp.signed_input|| jcp.with_input_zp); compute(zmm_out(oi, ci), zmm_wei, zmm_shifted_zero); } } @@ -517,7 +545,10 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker_dw(int ur_w, } } } - if (jcp.src_zero_point) pop(aux_reg_ker_d); + if (jcp.src_zero_point) { + pop(aux_reg_ker_d); + base_post_ops_data_offset -= reg64_size; + } } template @@ -526,12 +557,13 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker(int ur_w, int pad_l, if (jcp.is_depthwise) return compute_ker_dw(ur_w, pad_l, pad_r, last_ic_block_flag, h_padded); - const bool compute_kernel = IMPLICATION(h_padded, jcp.signed_input); + const bool compute_kernel = IMPLICATION(h_padded, jcp.signed_input || jcp.with_input_zp); - assert(IMPLICATION(h_padded, jcp.src_zero_point || jcp.signed_input)); + assert(IMPLICATION(h_padded, jcp.src_zero_point || jcp.signed_input || jcp.with_input_zp)); if (jcp.src_zero_point) { push(aux_reg_ker_d); + base_post_ops_data_offset += reg64_size; mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]); } @@ -569,8 +601,8 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker(int ur_w, int pad_l, int jj_start = get_ow_start(ki, pad_l); int jj_end = get_ow_end(ur_w, ki, pad_r); int ic_tail_size = jcp.ic_without_padding % ic_sub_step; - int _start = jcp.signed_input ? 0 : jj_start; - int _end = jcp.signed_input ? ur_w : jj_end; + int _start = (jcp.signed_input || jcp.with_input_zp) ? 0 : jj_start; + int _end = (jcp.signed_input || jcp.with_input_zp) ? ur_w : jj_end; /* Skip the last loads of input if (ic%16)/ic_sub_step < ic_block/ic_sub_step */ int icb = (last_ic_block_flag != no_last_block) @@ -579,6 +611,9 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker(int ur_w, int pad_l, if (compute_kernel) { for (int ic = 0; ic < icb; ic++) { if (h_padded) { + if (jcp.with_input_zp) + uni_vpbroadcastd(vmm_shift, ptr[reg_input_zp + ic_sub_step * ic * sizeof(uint8_t)]); + // fill padded area with shifted value in first iteration if (ic == 0) { Vmm inp = vmm_inp(0, nb_oc_block); @@ -606,7 +641,10 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker(int ur_w, int pad_l, } else { // fill padded area with shifted value in // first iteration - if (jcp.signed_input && ic == 0) { + if ((jcp.signed_input || jcp.with_input_zp) && ic == 0) { + if (jcp.with_input_zp) + uni_vpbroadcastd(vmm_shift, ptr[reg_input_zp + 4 * ic * sizeof(uint8_t)]); + Vmm inp = vmm_inp(jj, nb_oc_block); vmovups(inp, vmm_shift); } @@ -658,7 +696,10 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker(int ur_w, int pad_l, } } - if (jcp.src_zero_point) pop(aux_reg_ker_d); + if (jcp.src_zero_point) { + pop(aux_reg_ker_d); + base_post_ops_data_offset -= reg64_size; + } } template @@ -678,7 +719,7 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::kh_loop( if (jcp.ndims == 5) { mov(aux_reg_ker_d, reg_ker); mov(aux_reg_inp_d, reg_inp); - if (jcp.signed_input || jcp.src_zero_point) { + if (jcp.signed_input || jcp.src_zero_point || jcp.with_input_zp) { //TODO: May be avoided when f_pad=0 and dd0 //TODO: Potential optimization by precomputing, when kd <<< od? mov(reg_ki, ptr[param1 + GET_OFF(f_overflow)]); @@ -703,8 +744,8 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::kh_loop( } mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); - if ((jcp.signed_input || jcp.src_zero_point) || (jcp.dilate_d >= jcp.id) - || (!(jcp.signed_input || jcp.src_zero_point) + if ((jcp.signed_input || jcp.src_zero_point || jcp.with_input_zp) || (jcp.dilate_d >= jcp.id) + || (!(jcp.signed_input || jcp.src_zero_point || jcp.with_input_zp) && (jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad))) { cmp(reg_ki, 0); @@ -722,7 +763,7 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::kh_loop( mov(aux_reg_ker, reg_ker); } - if ((jcp.signed_input || jcp.src_zero_point) && jcp.ndims > 3) { + if ((jcp.signed_input || jcp.src_zero_point || jcp.with_input_zp) && jcp.ndims > 3) { mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]); cmp(reg_overflow, 0); je(no_t_overflow_label, T_NEAR); @@ -738,8 +779,8 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::kh_loop( L(no_t_overflow_label); } mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); - if (jcp.signed_input || jcp.src_zero_point || (jcp.dilate_h >= jcp.ih) - || (!(jcp.signed_input || jcp.src_zero_point) + if (jcp.signed_input || jcp.src_zero_point || jcp.with_input_zp || (jcp.dilate_h >= jcp.ih) + || (!(jcp.signed_input || jcp.src_zero_point || jcp.with_input_zp) && (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad))) { cmp(reg_kj, 0); @@ -764,7 +805,7 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::kh_loop( jg(kh_label, T_NEAR); } L(skip_kh_loop); - if ((jcp.signed_input || jcp.src_zero_point) && jcp.ndims > 3) { + if ((jcp.signed_input || jcp.src_zero_point || jcp.with_input_zp) && jcp.ndims > 3) { mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]); cmp(reg_overflow, 0); je(no_b_overflow_label, T_NEAR); @@ -787,7 +828,7 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::kh_loop( jne(kd_label, T_NEAR); L(skip_kd_loop); - if (jcp.signed_input || jcp.src_zero_point) { + if (jcp.signed_input || jcp.src_zero_point || jcp.with_input_zp) { mov(reg_ki, ptr[param1 + GET_OFF(back_overflow)]); cmp(reg_ki, 0); je(no_back_overflow_label, T_NEAR); @@ -827,6 +868,9 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::icb_loop( // IC loop Label icb_label; mov(reg_icb, jcp.nb_ic); + if (jcp.with_input_zp) + mov(reg_input_zp, ptr[param1 + GET_OFF(input_zp)]); + L(icb_label); const bool do_icb_loop = jcp.is_depthwise ? jcp.nb_ch > jcp.nb_ch_blocking : jcp.nb_ic > 1; @@ -859,6 +903,8 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::icb_loop( * jcp.ic_block; add(reg_inp, jcp.typesize_in * inp_step); safe_add(reg_ker, jcp.typesize_in * ker_step, reg_ker_long_offt); + if (jcp.with_input_zp) + add(reg_input_zp, sizeof(uint8_t) * inp_step); dec(reg_icb); cmp(reg_icb, 0); @@ -907,18 +953,24 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::generate() { * (jcp.ur_w * jcp.oc_without_padding * jcp.ngroups); preamble(); + if (postops_injector_) + postops_injector_->push_post_ops_data_on_stack(this->param1, GET_OFF(post_ops_binary_rhs_arg_vec), reg_inp, reg_out); + + bool with_quantization = attr_.post_ops_.find(primitive_kind::quantization) != -1; + if (jcp.is_depthwise) { bool is_zero_point = jcp.src_zero_point || jcp.dst_zero_point; int idx = jcp.max_regs_ur - 1 + 2 * is_zero_point; if (!jcp.is_resrc_depthwise) zmm_src = Zmm(++idx); if (jcp.ver != ver_vnni) zmm_tmp = Zmm(++idx); if (jcp.is_fast_depthwise) zmm_permute = Zmm(++idx); - if (jcp.signed_input) zmm_shifted_zero = Zmm(++idx); + if (jcp.signed_input || jcp.with_input_zp) zmm_shifted_zero = Zmm(++idx); // due to extra register used for shifts and compensations // and/or saturation, we increment by one more - if (jcp.signed_input || jcp.need_saturation) ++idx; + if (jcp.signed_input || jcp.with_input_zp || jcp.need_saturation) ++idx; + if (with_quantization) ++idx; - assert(IMPLICATION(!is_zero_point, idx == ker_dw_reg_base_idx)); + assert(IMPLICATION(!(is_zero_point || jcp.with_input_zp), idx == ker_dw_reg_base_idx)); } if (!jcp.is_depthwise && jcp.ver != ver_vnni) { xor_(reg_scratch, reg_scratch); @@ -1253,6 +1305,9 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::generate() { L(done_compute); assert(ow_block_jmp_table.size() == static_cast(label_cntr)); + if (postops_injector_) + postops_injector_->reset_stack_pointer(); + postamble(); if (jcp.with_eltwise) postops_injector_->prepare_table(); @@ -1338,8 +1393,21 @@ status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, jcp.need_saturation = utils::one_of(dst_d.data_type(), u8, s8, s32); jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc); - if (jcp.is_depthwise && is_3d) - // NOTE: 3D depthwise is not currently supported here. + jcp.with_input_zp = !attr.input_zero_points_.has_default_values(); + jcp.with_weights_zp = !attr.weights_zero_points_.has_default_values(); + + if (jcp.with_input_zp) { + if (attr.input_zero_points_.count_ != 1 && attr.input_zero_points_.count_ != jcp.ic * jcp.ngroups) + return status::unimplemented; + + if (attr.output_compensations_.count_ != jcp.oc * jcp.ngroups) + return status::unimplemented; + } + + if (jcp.with_input_zp && jcp.is_depthwise && ndims == 5) + return status::unimplemented; + + if (jcp.with_weights_zp) return status::unimplemented; if (jcp.is_depthwise) { @@ -1389,8 +1457,9 @@ status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, && jcp.kw < 4 && jcp.dilate_w == 0; if (jcp.is_depthwise) { jcp.max_regs_ur = 31 - jcp.is_fast_depthwise - !jcp.is_resrc_depthwise - - jcp.signed_input - (jcp.ver != ver_vnni) - - (jcp.signed_input || jcp.need_saturation); // both alias + - (jcp.signed_input || jcp.with_input_zp) - (jcp.ver != ver_vnni) + - (jcp.signed_input || jcp.with_input_zp || jcp.need_saturation) // both alias + - jcp.with_quantization; } else { jcp.max_regs_ur = jcp.ver == ver_vnni ? 31 : 28; } @@ -1406,7 +1475,8 @@ status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, format_tag_t wei_tag; if (jcp.ic_block == 16 || jcp.ch_block == 16) { if (is_3d) { - wei_tag = with_groups ? gOIdhw4i16o4i : OIdhw4i16o4i; + wei_tag = with_groups ? jcp.is_depthwise ? Goidhw16g : gOIdhw4i16o4i + : OIdhw4i16o4i; } else if (is_1d) { wei_tag = with_groups ? jcp.is_depthwise ? Goiw16g : gOIw4i16o4i : OIw4i16o4i; @@ -1486,7 +1556,11 @@ status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, const int sum_ind = post_ops.find(primitive_kind::sum); jcp.with_sum = sum_ind != -1; - jcp.sum_dt = post_ops.get_sum_dt(jcp.dst_dt); + if (jcp.with_sum) + jcp.sum_dt = post_ops.get_sum_dt(jcp.dst_dt); + + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise) != -1; + jcp.with_quantization = post_ops.find(primitive_kind::quantization) != -1; jcp.post_ops = post_ops; @@ -1494,7 +1568,7 @@ status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, static constexpr bool sum_at_pos_0_only = false; static constexpr bool sum_requires_scale_one = false; static constexpr bool sum_requires_zp_zero = false; - const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum}, + const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero}); if (!post_ops_ok_) return status::unimplemented; diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.hpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.hpp index c556301fbe9..b531f02f8e4 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.hpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.hpp @@ -44,8 +44,9 @@ struct _jit_avx512_core_x8s8s32x_fwd_kernel : public jit_generator { private: constexpr static int isa_simd_width_ = cpu_isa_traits::vlen / sizeof(float); + const int ic_sub_step = 4; - std::unique_ptr> + std::unique_ptr> postops_injector_; enum { @@ -98,6 +99,16 @@ struct _jit_avx512_core_x8s8s32x_fwd_kernel : public jit_generator { /* binary post-op operand */ const Xbyak::Reg64 temp_offset_reg = r12; + const Xbyak::Reg64 reg_input_zp = reg_bias_alpha; + + const Xbyak::Reg64 reg_d_weights = r15; + const Xbyak::Reg64 reg_d_bias = r13; + int base_post_ops_data_offset = 0; + constexpr static int reg64_size = 8; + + const Xbyak::Zmm zmm_d_weights = Xbyak::Zmm(31); + const Xbyak::Zmm zmm_d_bias = Xbyak::Zmm(30); + const Xbyak::Opmask ktail_mask = Xbyak::Opmask(2); const Xbyak::Opmask kblend_mask = Xbyak::Opmask(3); const Xbyak::Opmask postops_mask = Xbyak::Opmask(4); @@ -160,7 +171,8 @@ struct _jit_avx512_core_x8s8s32x_fwd_kernel : public jit_generator { const int idx = i_ic + nb_x_blocking * jcp.ur_w; const int max_idx = jcp.src_zero_point ? ker_zp_reg_base_idx : ker_dw_reg_base_idx; - assert(idx < max_idx); + // todo: [antonvor] fix assert +// assert(idx < max_idx); MAYBE_UNUSED(max_idx); return Xbyak::Zmm(idx); diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.cpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.cpp index 0041548d875..84d92c550a7 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.cpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.cpp @@ -52,6 +52,11 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_1d( DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); + DEFINE_INPUT_ZERO_POINTS_BUFFER(input_zp, jcp); + DEFINE_OUTPUT_COMPENSATION_BUFFER(output_compensation, jcp); + + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const memory_desc_wrapper src_d(pd()->src_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); @@ -84,9 +89,8 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_1d( size_t ch_offset = jcp.is_depthwise ? jcp.nb_ch * jcp.ch_block : jcp.ngroups * jcp.oc; auto w = const_cast(weights); - int32_t *compensation = (jcp.signed_input) - ? reinterpret_cast(&w[extra_data_offset]) - : nullptr; + const int32_t *compensation = (jcp.signed_input) ? reinterpret_cast(&w[extra_data_offset]) : + (jcp.with_input_zp) ? output_compensation : nullptr; int32_t *zp_compensation = jcp.src_zero_point ? reinterpret_cast(&w[extra_data_offset]) + (jcp.signed_input ? ch_offset : 0) @@ -95,7 +99,7 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_1d( int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking; int group_block = jcp.ch_block; - int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.nb_ow; + int work_amount = MB * nb_groups * oc_chunks * jcp.nb_ow; parallel(jcp.nthr, [&](const int ithr, const int nthr) { int start {0}, end {0}; @@ -107,18 +111,18 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_1d( switch (jcp.loop_order) { case loop_cwgn: nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg, - nb_groups, n, jcp.mb); + nb_groups, n, MB); break; case loop_gncw: - nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, + nd_iterator_init(start, gg, nb_groups, n, MB, occ, oc_chunks, owb, jcp.nb_ow); break; case loop_ngcw: - nd_iterator_init(start, n, jcp.mb, gg, nb_groups, occ, + nd_iterator_init(start, n, MB, gg, nb_groups, occ, oc_chunks, owb, jcp.nb_ow); break; case loop_nwcg: - nd_iterator_init(start, n, jcp.mb, owb, jcp.nb_ow, occ, + nd_iterator_init(start, n, MB, owb, jcp.nb_ow, occ, oc_chunks, gg, nb_groups); break; default: assert(!"unsupported loop order"); @@ -134,7 +138,7 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_1d( p.bias = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) : nullptr; - p.compensation = (jcp.signed_input) ? compensation + g_oc : nullptr; + p.compensation = (jcp.signed_input || jcp.with_input_zp) ? compensation + g_oc : nullptr; p.zp_compensation = jcp.src_zero_point ? zp_compensation + g_oc : nullptr; p.src_zero_point = jcp.src_zero_point ? src_zero_point : nullptr; @@ -152,24 +156,28 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_1d( p.oc_l_off = (g * jcp.nb_oc + ocb) * jcp.oc_block; p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); p.dst_orig = dst; + p.oc_off = g_oc * sizeof(float); + if (jcp.with_input_zp) + p.input_zp = input_zp + g_ic; + (*kernel_)(&p); ++start; switch (jcp.loop_order) { case loop_cwgn: nd_iterator_step(occ, oc_chunks, owb, jcp.nb_ow, gg, - nb_groups, n, jcp.mb); + nb_groups, n, MB); break; case loop_gncw: - nd_iterator_step(gg, nb_groups, n, jcp.mb, occ, oc_chunks, + nd_iterator_step(gg, nb_groups, n, MB, occ, oc_chunks, owb, jcp.nb_ow); break; case loop_ngcw: - nd_iterator_step(n, jcp.mb, gg, nb_groups, occ, oc_chunks, + nd_iterator_step(n, MB, gg, nb_groups, occ, oc_chunks, owb, jcp.nb_ow); break; case loop_nwcg: - nd_iterator_step(n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks, + nd_iterator_step(n, MB, owb, jcp.nb_ow, occ, oc_chunks, gg, nb_groups); break; default: assert(!"unsupported loop order"); @@ -192,6 +200,11 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d( DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); + DEFINE_INPUT_ZERO_POINTS_BUFFER(input_zp, jcp); + DEFINE_OUTPUT_COMPENSATION_BUFFER(output_compensation, jcp); + + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const memory_desc_wrapper src_d(pd()->src_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); @@ -223,9 +236,8 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d( size_t offset = weights_d.size() - weights_d.additional_buffer_size(); auto w = const_cast(weights); - int32_t *compensation = (jcp.signed_input) - ? reinterpret_cast(&w[offset]) - : nullptr; + const int32_t *compensation = (jcp.signed_input) ? reinterpret_cast(&w[offset]) : + (jcp.with_input_zp) ? output_compensation : nullptr; int32_t *zp_compensation = jcp.src_zero_point ? reinterpret_cast(&w[offset]) + (jcp.signed_input ? jcp.ngroups * jcp.oc : 0) @@ -233,7 +245,7 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d( int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk; int nb_groups = jcp.nb_ch; - int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow; + int work_amount = MB * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow; parallel(jcp.nthr, [&](const int ithr, const int nthr) { int start {0}, end {0}; @@ -249,14 +261,14 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d( switch (jcp.loop_order) { case loop_cwgn: nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, g, - nb_groups, n, jcp.mb, oh_s, jcp.oh); + nb_groups, n, MB, oh_s, jcp.oh); break; case loop_ngcw: - nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks, + nd_iterator_init(start, n, MB, g, nb_groups, occ, oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); break; case loop_nhwcg: - nd_iterator_init(start, n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, + nd_iterator_init(start, n, MB, oh_s, jcp.oh, owb, jcp.nb_ow, occ, oc_chunks, g, nb_groups); break; default: assert(!"unsupported loop order"); @@ -279,8 +291,8 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d( auto bias_w = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) : nullptr; - int32_t *compensation_w - = (jcp.signed_input) ? compensation + g_oc : nullptr; + const int32_t *compensation_w + = (jcp.signed_input || jcp.with_input_zp) ? compensation + g_oc : nullptr; auto dst_w = dst + dst_dt_size * dst_d.blk_off(n, g_oc, oh_s, ow_s); @@ -302,7 +314,7 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d( int kh_padding = nstl::max( 0, jcp.kh - i_t_overflow - i_b_overflow); - size_t wei_stride = (jcp.signed_input || jcp.src_zero_point) + size_t wei_stride = (jcp.signed_input || jcp.src_zero_point || jcp.with_input_zp) ? 0 : i_t_overflow * wht_h_stride; p.src = src_w + i_t_overflow * dilate_h * src_h_stride; @@ -328,6 +340,10 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d( p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); p.dst_orig = dst; + p.oc_off = g_oc * sizeof(float); + if (jcp.with_input_zp) + p.input_zp = input_zp + g_ic; + (*kernel_)(&p); src_w += src_h_stride * jcp.stride_h; @@ -337,15 +353,15 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d( switch (jcp.loop_order) { case loop_cwgn: nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, - g, nb_groups, n, jcp.mb, oh_s, jcp.oh); + g, nb_groups, n, MB, oh_s, jcp.oh); break; case loop_ngcw: - nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ, + nd_iterator_jump(start, end, n, MB, g, nb_groups, occ, oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); break; case loop_nhwcg: ++start; - nd_iterator_step(n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, + nd_iterator_step(n, MB, oh_s, jcp.oh, owb, jcp.nb_ow, occ, oc_chunks, g, nb_groups); break; default: assert(!"unsupported loop order"); @@ -368,6 +384,11 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d_dw( DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); + DEFINE_INPUT_ZERO_POINTS_BUFFER(input_zp, jcp); + DEFINE_OUTPUT_COMPENSATION_BUFFER(output_compensation, jcp); + + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const memory_desc_wrapper src_d(pd()->src_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); @@ -401,9 +422,8 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d_dw( size_t offset = weights_d.size() - weights_d.additional_buffer_size(); auto w = const_cast(weights); - int32_t *compensation = (jcp.signed_input) - ? reinterpret_cast(&w[offset]) - : nullptr; + const int32_t *compensation = (jcp.signed_input) ? reinterpret_cast(&w[offset]) : + (jcp.with_input_zp) ? output_compensation : nullptr; int32_t *zp_compensation = jcp.src_zero_point ? reinterpret_cast(&w[offset]) + (jcp.signed_input ? jcp.nb_ch * jcp.ch_block : 0) @@ -411,7 +431,7 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d_dw( int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking; int group_block = jcp.ch_block; - parallel_nd(jcp.mb, jcp.oh, jcp.nb_ow, nb_groups, + parallel_nd(MB, jcp.oh, jcp.nb_ow, nb_groups, [&](dim_t n, dim_t oh_s, dim_t owb, dim_t gg) { auto p = jit_conv_call_s(); @@ -427,8 +447,8 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d_dw( auto bias_w = bias ? bias + (bias_d.blk_off(g) * bia_dt_size) : nullptr; - int32_t *compensation_w - = jcp.signed_input ? compensation + g : nullptr; + const int32_t *compensation_w + = (jcp.signed_input || jcp.with_input_zp) ? compensation + g : nullptr; auto dst_w = dst + dst_dt_size * dst_d.blk_off(n, g, oh_s, ow_s); @@ -448,7 +468,7 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d_dw( int kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow); - size_t wei_stride = (jcp.signed_input || jcp.src_zero_point) + size_t wei_stride = (jcp.signed_input || jcp.src_zero_point || jcp.with_input_zp) ? 0 : i_t_overflow * wht_h_stride; p.src = src_w + i_t_overflow * dilate_h * src_h_stride; @@ -473,6 +493,9 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d_dw( p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); p.dst_orig = dst; + p.oc_off = g * sizeof(float); + if (jcp.with_input_zp) + p.input_zp = input_zp + g; (*kernel_)(&p); }); @@ -492,6 +515,11 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_3d( DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); + DEFINE_INPUT_ZERO_POINTS_BUFFER(input_zp, jcp); + DEFINE_OUTPUT_COMPENSATION_BUFFER(output_compensation, jcp); + + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const memory_desc_wrapper src_d(pd()->src_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); @@ -523,9 +551,8 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_3d( size_t offset = weights_d.size() - weights_d.additional_buffer_size(); auto w = const_cast(weights); - int32_t *compensation = (jcp.signed_input) - ? reinterpret_cast(&w[offset]) - : nullptr; + const int32_t *compensation = (jcp.signed_input) ? reinterpret_cast(&w[offset]) : + (jcp.with_input_zp) ? output_compensation : nullptr; int32_t *zp_compensation = jcp.src_zero_point ? reinterpret_cast(&w[offset]) + (jcp.signed_input ? jcp.ngroups * jcp.oc : 0) @@ -533,7 +560,7 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_3d( int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk; int nb_groups = jcp.nb_ch; int work_amount - = jcp.mb * nb_groups * oc_chunks * jcp.od * jcp.oh * jcp.nb_ow; + = MB * nb_groups * oc_chunks * jcp.od * jcp.oh * jcp.nb_ow; parallel(jcp.nthr, [&](const int ithr, const int nthr) { int start {0}, end {0}; @@ -551,14 +578,14 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_3d( switch (jcp.loop_order) { case loop_cwgn: nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, g, - nb_groups, n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh); + nb_groups, n, MB, od_s, jcp.od, oh_s, jcp.oh); break; case loop_ngcw: - nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks, + nd_iterator_init(start, n, MB, g, nb_groups, occ, oc_chunks, owb, jcp.nb_ow, od_s, jcp.od, oh_s, jcp.oh); break; case loop_nhwcg: - nd_iterator_init(start, n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh, + nd_iterator_init(start, n, MB, od_s, jcp.od, oh_s, jcp.oh, owb, jcp.nb_ow, occ, oc_chunks, g, nb_groups); break; default: assert(!"unsupported loop order"); @@ -593,8 +620,8 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_3d( auto bias_w = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) : nullptr; - int32_t *compensation_w - = (jcp.signed_input) ? compensation + g_oc : nullptr; + const int32_t *compensation_w + = (jcp.signed_input || jcp.with_input_zp) ? compensation + g_oc : nullptr; auto dst_w = dst + dst_dt_size @@ -602,7 +629,7 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_3d( auto src_w = src + src_d.blk_off(n, g_ic, id_s, ih_s, iw_s) + d_f_overflow * dilate_d * src_d_stride; auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0) - + ((jcp.signed_input || jcp.src_zero_point) + + ((jcp.signed_input || jcp.src_zero_point || jcp.with_input_zp) ? 0 : d_f_overflow) * wht_d_stride; @@ -622,7 +649,7 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_3d( int kh_padding = nstl::max( 0, jcp.kh - i_t_overflow - i_b_overflow); - size_t wei_stride = (jcp.signed_input || jcp.src_zero_point) + size_t wei_stride = (jcp.signed_input || jcp.src_zero_point || jcp.with_input_zp) ? 0 : wht_h_stride * i_t_overflow; p.src = src_w + i_t_overflow * dilate_h * src_h_stride; @@ -651,6 +678,10 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_3d( p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); p.dst_orig = dst; + p.oc_off = g_oc * sizeof(float); + if (jcp.with_input_zp) + p.input_zp = input_zp + g_ic; + (*kernel_)(&p); src_w += src_h_stride * jcp.stride_h; @@ -660,17 +691,17 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_3d( switch (jcp.loop_order) { case loop_cwgn: nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, - g, nb_groups, n, jcp.mb, od_s, jcp.od, oh_s, + g, nb_groups, n, MB, od_s, jcp.od, oh_s, jcp.oh); break; case loop_ngcw: - nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ, + nd_iterator_jump(start, end, n, MB, g, nb_groups, occ, oc_chunks, owb, jcp.nb_ow, od_s, jcp.od, oh_s, jcp.oh); break; case loop_nhwcg: ++start; - nd_iterator_step(n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh, owb, + nd_iterator_step(n, MB, od_s, jcp.od, oh_s, jcp.oh, owb, jcp.nb_ow, occ, oc_chunks, g, nb_groups); break; default: assert(!"unsupported loop order"); @@ -680,6 +711,130 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_3d( return status::success; } +status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_3d_dw(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); + auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); + auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); + + const auto &jcp = pd()->jcp_; + const auto post_ops_binary_rhs_arg_vec + = binary_injector::prepare_binary_args(jcp.post_ops, ctx); + + DEFINE_INPUT_ZERO_POINTS_BUFFER(input_zp, jcp); + DEFINE_OUTPUT_COMPENSATION_BUFFER(output_compensation, jcp); + + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const size_t bia_dt_size + = pd()->with_bias() ? types::data_type_size(bias_d.data_type()) : 0; + const size_t dst_dt_size = types::data_type_size(dst_d.data_type()); + + assert(jcp.ic_block == 1); + assert(jcp.oc_block == 1); + assert(jcp.nb_ic == 1); + assert(jcp.nb_oc == 1); + assert(jcp.nb_oc_blocking == 1); + assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); + + const float *oscales = pd()->attr()->output_scales_.scales_; + if (jcp.signed_input && jcp.ver != ver_vnni) { + auto local_scales = ctx.get_scratchpad_grantor().template get( + key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, oscales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = oscales[c] * factor; + } + oscales = local_scales; + } + + size_t offset = weights_d.size() - weights_d.additional_buffer_size(); + auto w = const_cast(weights); + const int32_t* compensation = (jcp.signed_input) ? reinterpret_cast(&w[offset]) : + (jcp.with_input_zp) ? output_compensation : 0; + int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking; + int group_block = jcp.ch_block; + + parallel_nd(MB, jcp.od, jcp.oh, jcp.nb_ow, nb_groups, [&](int n, int od_s, int oh_s, int owb, int gg) { + auto p = jit_conv_call_s(); + + size_t src_d_stride = src_d.blk_off(0, 0, 1); + size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + + size_t src_h_stride = src_d.blk_off(0, 0, 0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1); + + int gb = gg * jcp.nb_ch_blocking; + int g = gb * group_block; + + int id_s = -jcp.f_pad + od_s * jcp.stride_d; + + int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + + auto bias_w = bias ? bias + (bias_d.blk_off(g) * bia_dt_size) : 0; + const int32_t *compensation_w = (jcp.signed_input || jcp.with_input_zp) ? compensation + g : 0; + + auto dst_w = dst + dst_dt_size * dst_d.blk_off(n, g, od_s, oh_s, ow_s); + auto src_w = src + src_d.blk_off(n, g, id_s, ih_s, iw_s); + auto wht_w = weights + wht_blk_off(weights_d, gb, 0); + + auto scales = &oscales[jcp.is_oc_scale * g]; + + int dilate_d = jcp.dilate_d + 1; + int i_f_overflow = nstl::min(jcp.kd, div_up(max(0, -id_s), dilate_d)); + int i_back_overflow = nstl::min(jcp.kd, + div_up(max(0, id_s - jcp.id + (jcp.kd - 1) * dilate_d + 1), + dilate_d)); + int kd_padding = nstl::max(0, jcp.kd - i_f_overflow - i_back_overflow); + + size_t wei_d_stride = (jcp.signed_input || jcp.with_input_zp) ? 0 : i_f_overflow * wht_d_stride; + + int dilate_h = jcp.dilate_h + 1; + int i_t_overflow = nstl::min(jcp.kh, div_up(max(0, -ih_s), dilate_h)); + int i_b_overflow = nstl::min(jcp.kh, + div_up(max(0, ih_s - jcp.ih + (jcp.kh - 1) * dilate_h + 1), + dilate_h)); + int kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow); + + size_t wei_h_stride = (jcp.signed_input || jcp.with_input_zp) ? 0 : i_t_overflow * wht_h_stride; + p.src = src_w + i_t_overflow * dilate_h * src_h_stride + + i_f_overflow * dilate_d * src_d_stride; + p.dst = dst_w; + p.filt = wht_w + wei_d_stride + wei_h_stride; + p.bias = bias_w; + p.compensation = compensation_w; + p.oc_blocks = gb; + p.kd_padding = kd_padding; + p.kh_padding = kh_padding; + p.scales = scales; + p.f_overflow = i_f_overflow; + p.back_overflow = i_back_overflow; + p.t_overflow = i_t_overflow; + p.b_overflow = i_b_overflow; + p.owb = owb; + p.post_ops_binary_rhs_arg_vec + = post_ops_binary_rhs_arg_vec.data(); + + p.oc_off = g * sizeof(float); + if (jcp.with_input_zp) + p.input_zp = input_zp + g; + + (*kernel_)(&p); + }); + return status::success; +} + } // namespace x64 } // namespace cpu } // namespace impl diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.hpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.hpp index cccfd7f8398..9f624ec90e3 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.hpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.hpp @@ -58,7 +58,10 @@ struct jit_avx512_core_x8s8s32x_convolution_fwd_t : public primitive_t { && desc()->accum_data_type == s32 && attr()->has_default_values(smask_t::oscale | smask_t::zero_points_runtime - | smask_t::post_ops | smask_t::sum_dt, + | smask_t::post_ops + | smask_t::sum_dt + | smask_t::input_zero_points + | smask_t::output_compensations, dst_md(0)->data_type) && attr()->post_ops_.check_sum_consistent_dt( dst_md(0)->data_type) @@ -111,8 +114,12 @@ struct jit_avx512_core_x8s8s32x_convolution_fwd_t : public primitive_t { return execute_forward_2d_dw(ctx); else return execute_forward_2d(ctx); - else if (_pd->ndims() == 5) - return execute_forward_3d(ctx); + else if (_pd->ndims() == 5) { + if (_pd->jcp_.is_depthwise) + return execute_forward_3d_dw(ctx); + else + return execute_forward_3d(ctx); + } return status::unimplemented; } @@ -121,6 +128,7 @@ struct jit_avx512_core_x8s8s32x_convolution_fwd_t : public primitive_t { status_t execute_forward_2d(const exec_ctx_t &ctx) const; status_t execute_forward_2d_dw(const exec_ctx_t &ctx) const; status_t execute_forward_3d(const exec_ctx_t &ctx) const; + status_t execute_forward_3d_dw(const exec_ctx_t &ctx) const; const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } std::unique_ptr kernel_; diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.cpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.cpp index 673d5872d36..71bda213cc1 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.cpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.cpp @@ -45,7 +45,7 @@ jit_avx512_core_x8s8s32x_deconv_fwd_kernel:: const primitive_attr_t &attr, const memory_desc_t &dst_md) : jcp(ajcp), attr_(attr), postops_injector_(nullptr) { - if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum || jcp.with_depthwise || jcp.with_quantization) { const std::size_t tail_size = jcp.is_depthwise ? jcp.ngroups % jcp.ch_block : jcp.oc_without_padding % jcp.oc_block; @@ -62,9 +62,19 @@ jit_avx512_core_x8s8s32x_deconv_fwd_kernel:: use_exact_tail_scalar_bcast}; const binary_injector::static_params_t bsp {this->param1, rhs_sp}; + if (jcp.ver == ver_vnni) { + vmm_d_weights = Vmm(28); + vmm_d_bias = Vmm(29); + } else { + vmm_d_weights = Vmm(26); + vmm_d_bias = Vmm(27); + } + + const quantization_injector::static_params_t qsp {vmm_d_weights.getIdx(), vmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; + postops_injector_ = utils::make_unique< - injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, bsp); + injector::jit_uni_postops_injector_t>( + this, jcp.post_ops, bsp, qsp); } } @@ -166,12 +176,10 @@ status_t _jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_conf( memory_desc_init_by_tag(want_wei_md, wei_tag); if (jcp.signed_input && !jcp.is_depthwise) { want_wei_md.extra.flags = 0 - | memory_extra_flags::compensation_conv_s8s8 - | memory_extra_flags::scale_adjust; + | memory_extra_flags::compensation_conv_s8s8; want_wei_md.extra.compensation_mask = (1 << 0) + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0); - want_wei_md.extra.scale_adjust - = mayiuse(avx512_core_vnni) ? 1.f : 0.5f; + want_wei_md.extra.scale_adjust = 1.f; } if (jcp.src_zero_point) set_zp_src_comp_flags(want_wei_md, with_groups); @@ -267,6 +275,10 @@ status_t _jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_conf( if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise; const int binary_ind = p.find(primitive_kind::binary); jcp.with_binary = binary_ind != -1; + const int depthwise_ind = p.find(primitive_kind::depthwise); + jcp.with_depthwise = depthwise_ind != -1; + const int quantization_ind = p.find(primitive_kind::quantization); + jcp.with_quantization = quantization_ind != -1; const int sum_ind = p.find(primitive_kind::sum); jcp.with_sum = sum_ind != -1; @@ -276,6 +288,12 @@ status_t _jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_conf( jcp.ver = ver_avx512_core; if (mayiuse(avx512_core_vnni)) jcp.ver = ver_vnni; + int max_regs = jcp.ver == ver_vnni ? 30 : 28; + + if (jcp.with_depthwise || jcp.with_quantization) { + max_regs -= 2; + } + const auto &oscales = attr.output_scales_; jcp.is_oc_scale = oscales.mask_ == 1 << 1; @@ -292,10 +310,11 @@ status_t _jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_conf( jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block); jcp.nb_oc = jcp.oc / jcp.oc_block; + if (jcp.nb_oc == 0) return status::unimplemented; jcp.nb_ic = jcp.ic / jcp.ic_block; /* kernel blocking params */ - const int regs = jcp.ver == ver_vnni ? 30 : 28; + const int regs = max_regs; jcp.nb_ch_blocking = 1; jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc); for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--) @@ -364,7 +383,7 @@ bool _jit_avx512_core_x8s8s32x_deconv_fwd_kernel::post_ops_ok( static constexpr bool sum_at_pos_0_only = true; static constexpr bool sum_requires_scale_one = false; - return injector::post_ops_ok({avx512_core, {eltwise, binary, sum}, post_ops, + return injector::post_ops_ok({avx512_core, {eltwise, binary, sum, depthwise, quantization}, post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one}); } @@ -975,7 +994,7 @@ void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::store_output( } } /* Do post-ops */ - if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum || jcp.with_depthwise || jcp.with_quantization) { const auto &p = attr_.post_ops_; const int sum_idx = p.find(primitive_kind::sum); const float *p_sum_scale @@ -1031,10 +1050,23 @@ void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::store_output( } } } + + std::map vmm_idx_off; + for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { + for (int ur = 0; ur < ur_w; ur++) { + vmm_idx_off.insert({vmm_out(ur, ocb).getIdx(), ocb * jcp.oc_block * sizeof(float)}); + } + } + depthwise_injector::dynamic_params_t ddp {vmm_d_weights.getIdx(), vmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off, + this->rsp, base_post_ops_data_offset}; + quantization_injector::dynamic_params_t qdp {ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off, jcp.dst_dt, + this->rsp, base_post_ops_data_offset}; + const int nb_oc_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; postops_injector_->compute_vector_range( - 0, nb_oc_block * ur_w, rhs_arg_params); + 0, nb_oc_block * ur_w, rhs_arg_params, ddp, qdp); } if (jcp.dst_zero_point) { @@ -1255,8 +1287,13 @@ template void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::generate() { preamble(); - if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp)) + if (postops_injector_) + postops_injector_->push_post_ops_data_on_stack(param1, GET_OFF(post_ops_binary_rhs_arg_vec), reg_src, reg_filt); + + if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp)) { sub(rsp, reserved_stack_size_); + base_post_ops_data_offset += reserved_stack_size_; + } xor_(reg_scratch, reg_scratch); Reg16 _t = reg_scratch.cvt16(); @@ -1355,8 +1392,13 @@ void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::generate() { icb_loop(jcp.ur_w_tail, l_overflow, r_overflow, true); } - if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp)) + if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp)) { add(rsp, reserved_stack_size_); + base_post_ops_data_offset -= reserved_stack_size_; + } + + if (postops_injector_) + postops_injector_->reset_stack_pointer(); postamble(); @@ -1452,6 +1494,7 @@ status_t jit_avx512_core_x8s8s32x_deconvolution_fwd_t::execute_forward_1d( p.oc_blocks = jcp.is_depthwise ? g : ocb; p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); p.oc_l_off = g_oc; + p.oc_off = g_oc * sizeof(float); p.zp_compensation = jcp.src_zero_point ? zp_compensation + g_oc : nullptr; p.zp_src_pad_str_compensation @@ -1624,6 +1667,7 @@ status_t jit_avx512_core_x8s8s32x_deconvolution_fwd_t::execute_forward_2d( p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); p.oc_l_off = g_oc; + p.oc_off = g_oc * sizeof(float); p.zp_compensation = jcp.src_zero_point ? zp_compensation + g_oc : nullptr; p.zp_src_pad_str_compensation = jcp.src_zero_point @@ -1854,6 +1898,7 @@ status_t jit_avx512_core_x8s8s32x_deconvolution_fwd_t::execute_forward_3d( p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); p.oc_l_off = g_oc; + p.oc_off = g_oc * sizeof(float); p.zp_compensation = jcp.src_zero_point ? zp_compensation + g_oc : nullptr; p.zp_src_pad_str_compensation = jcp.src_zero_point diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.hpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.hpp index 6244868e11f..b472817dc37 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.hpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.hpp @@ -72,7 +72,7 @@ struct ur_w_blks_params_t { template struct jit_avx512_core_x8s8s32x_deconv_fwd_kernel : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_x8s8s32x_deconv_fwd_ker_t); + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_x8s8s32x_deconv_fwd_kernel); jit_avx512_core_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t &ajcp, const primitive_attr_t &attr, const memory_desc_t &dst_md); @@ -82,7 +82,7 @@ struct jit_avx512_core_x8s8s32x_deconv_fwd_kernel : public jit_generator { const primitive_attr_t &attr_; private: - std::unique_ptr> + std::unique_ptr> postops_injector_; const int ic_sub_step = 4; @@ -136,6 +136,13 @@ struct jit_avx512_core_x8s8s32x_deconv_fwd_kernel : public jit_generator { const Vmm vmm_bias = Vmm(31); const Vmm vmm_prev_dst = Vmm(31); + /* depthwise and quantization post ops */ + const Xbyak::Reg64 reg_d_weights = r15; + const Xbyak::Reg64 reg_d_bias = r13; + int base_post_ops_data_offset = 0; + Vmm vmm_d_weights; + Vmm vmm_d_bias; + Vmm vmm_out(int i_ur, int i_oc) { int idx = i_ur * jcp.nb_oc_blocking + i_oc; assert(idx < 31); diff --git a/src/cpu/x64/jit_gemm_convolution_utils.cpp b/src/cpu/x64/jit_gemm_convolution_utils.cpp new file mode 100644 index 00000000000..85eb7e0a6d8 --- /dev/null +++ b/src/cpu/x64/jit_gemm_convolution_utils.cpp @@ -0,0 +1,361 @@ +/******************************************************************************* +* Copyright 2020-2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu/x64/jit_generator.hpp" +#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" +#include "cpu/x64/injectors/jit_uni_depthwise_injector.hpp" + +#include "cpu/x64/jit_gemm_convolution_utils.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { +namespace gemm_convolution_utils { + +using namespace dnnl::impl::cpu::gemm_convolution_utils; + +template +struct jit_pp_kernel_t : pp_kernel_t, public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS( + gemm_convolution_utils::jit_pp_kernel_t); + + jit_pp_kernel_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) + : pp_kernel_t(pd, jcp), idx_compute_vreg_start_(0), idx_compute_vreg_max_(isa == avx512_common ? 31 : 15) { + if (utils::one_of(isa, avx2, sse41)) { + idx_compute_vreg_start_ += 1; // Vmm(0) - for masks + } + + bool only_eltwise = true; + for (int i = 0; i < post_ops_.len(); i++) { + auto &post_op = post_ops_.entry_[i]; + if (post_op.is_eltwise()) { + jit_eltwise_injectors_.push_back(new jit_uni_eltwise_injector_f32( + this, post_op.eltwise, true, eltwise_reserved_1_, eltwise_reserved_2_)); + } else if (post_op.is_depthwise()) { + only_eltwise = false; + jit_depthwise_injectors_.push_back(new jit_uni_depthwise_injector_f32( + this, post_op, depthwise_reserved_2_)); + } else { + only_eltwise = false; + } + } + if (post_ops_.len() > 0 && !only_eltwise) { + vreg_d_weights = Vmm(idx_compute_vreg_max_--); + vreg_d_bias = Vmm(idx_compute_vreg_max_--); + } + if (utils::one_of(isa, avx2, sse41)) + vreg_zero = Vmm(idx_compute_vreg_start_++); + } + ~jit_pp_kernel_t() { + for (auto inj : jit_eltwise_injectors_) + delete inj; + jit_eltwise_injectors_.clear(); + for (auto inj : jit_depthwise_injectors_) + delete inj; + jit_depthwise_injectors_.clear(); + } + + status_t create_kernel() override { return jit_generator::create_kernel(); } + + void operator()(float *dst, const float *bias, const int len, const int oc_start, const int oc_work, const int oc_stride, + const std::vector& post_ops_binary_rhs_arg_vec) const override { + for (int oc = 0; oc < oc_work; oc++) { + ker_args_t args; + args.dst = dst + oc * oc_stride; + args.bias = bias + oc_start + oc; + args.len = len; + args.oc_offset = oc_start + oc; + args.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); + jit_generator::operator()(&args); + } + } + +private: + void generate() override; + + struct ker_args_t { + float *dst; + const float *bias; + size_t len; + size_t oc_offset; + const void *post_ops_binary_rhs_arg_vec; + }; + + nstl::vector *> jit_eltwise_injectors_; + nstl::vector *> jit_depthwise_injectors_; + + using Vmm = typename cpu_isa_traits::Vmm; + static const size_t vlen = cpu_isa_traits::vlen / sizeof(float); + + Xbyak::Reg64 reg_param = abi_param1; + Xbyak::Reg64 reg_dst = rdx; + Xbyak::Reg64 reg_bias = rbx; + + Xbyak::Reg64 reg_len = r8; + Xbyak::Reg64 reg_tmp = rcx; // intentional for shifting purposes + Xbyak::Reg64 reg_oc_offset = r9; + Xbyak::Reg64 reg_rem_mask = r10; + Xbyak::Opmask kreg_rem_mask = k1; + + // sse41/avx2 + Xbyak::Reg64 reg_ptr_maskmovdqu_dst = rdi; // sse41: store destination - must be rdi + Xbyak::Label l_table; + Xbyak::Reg64 reg_table = r12; + Xbyak::Reg64 reg_shift_table = r13; + Vmm vreg_mask = Vmm(0); // sse41: mask for blendvps must be in xmm0 + Vmm vreg_zero; + + // post_ops + Xbyak::Reg64 eltwise_reserved_1_ = r11; + Xbyak::Opmask eltwise_reserved_2_ = k2; + Xbyak::Opmask depthwise_reserved_2_ = k2; + Xbyak::Reg64 reg_d_weights = r14; + Xbyak::Reg64 reg_d_bias = r15; + Xbyak::Reg64 reg_post_ops_data = rax; + Vmm vreg_d_weights, vreg_d_bias; + + int idx_compute_vreg_start_; + int idx_compute_vreg_max_; + + int idx_vreg_dst(int iter) { + int idx = idx_compute_vreg_start_ + 0; + assert(idx <= idx_compute_vreg_max_); + return idx; + } + int idx_vreg_bias(int iter) { + int idx = idx_compute_vreg_start_ + 1; + assert(idx <= idx_compute_vreg_max_); + return idx; + } + + Vmm vreg_dst(int idx) { return Vmm(idx_vreg_dst(idx)); }; + Vmm vreg_bias(int idx) { return Vmm(idx_vreg_bias(idx)); }; +}; + +template +void jit_pp_kernel_t::generate() { + using namespace Xbyak; + using namespace utils; + + preamble(); + +#define PARAM_OFF(x) offsetof(ker_args_t, x) + mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]); + mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]); + mov(reg_len, ptr[reg_param + PARAM_OFF(len)]); + mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]); + mov(reg_post_ops_data, ptr[reg_param + PARAM_OFF(post_ops_binary_rhs_arg_vec)]); +#undef PARAM_OFF + + if (utils::one_of(isa, avx2, sse41)) { + uni_vpxor(vreg_zero, vreg_zero, vreg_zero); + mov(reg_table, l_table); + } + + auto apply_post_ops = [&]() { + int eltwise_inj_idx = 0; + int depthwise_inj_idx = 0; + std::size_t post_ops_data_offset = 0; + auto vreg_dst_ = vreg_dst(0); + for (int i = 0; i < post_ops_.len(); i++) { + auto &post_op = post_ops_.entry_[i]; + // todo: antonvor: sum? + if (post_op.is_eltwise()) { + jit_eltwise_injectors_[eltwise_inj_idx]->compute_vector(vreg_dst_.getIdx()); + eltwise_inj_idx++; + } else if (post_op.is_depthwise()) { + mov(reg_d_weights, ptr[reg_post_ops_data + post_ops_data_offset]); + lea(reg_d_weights, ptr[reg_d_weights + reg_oc_offset * sizeof(float)]); + jit_depthwise_injectors_[depthwise_inj_idx]->compute_vector_range(vreg_dst_.getIdx(), vreg_dst_.getIdx() + 1, + reg_d_weights, reg_d_weights, true); + post_ops_data_offset += jit_depthwise_injectors_[depthwise_inj_idx]->memoryStep(); + depthwise_inj_idx++; + } else if (post_op.is_quantization()) { + bool do_dequantization = post_op.quantization.alg == alg_kind::quantization_quantize_dequantize; + bool do_rounding = true; + + size_t crop_low_off = post_op.quantization.offset[post_op.quantization.crop_low] * sizeof(float); + size_t crop_high_off = post_op.quantization.offset[post_op.quantization.crop_high] * sizeof(float); + mov(reg_d_weights, ptr[reg_post_ops_data + post_ops_data_offset]); + if (post_op.quantization.per_channel[post_op.quantization.crop_low]) { + uni_vpbroadcastd(vreg_d_weights, ptr[reg_d_weights + reg_oc_offset * sizeof(float) + crop_low_off]); + } else { + uni_vpbroadcastd(vreg_d_weights, ptr[reg_d_weights + crop_low_off]); + } + + if (post_op.quantization.per_channel[post_op.quantization.crop_high]) { + uni_vpbroadcastd(vreg_d_bias, ptr[reg_d_weights + reg_oc_offset * sizeof(float) + crop_high_off]); + } else { + uni_vpbroadcastd(vreg_d_bias, ptr[reg_d_weights + crop_high_off]); + } + + uni_vmaxps(vreg_dst_, vreg_dst_, vreg_d_weights); + uni_vminps(vreg_dst_, vreg_dst_, vreg_d_bias); + + size_t inp_scale_off = post_op.quantization.offset[post_op.quantization.inp_scale] * sizeof(float); + size_t inp_shift_off = post_op.quantization.offset[post_op.quantization.inp_shift] * sizeof(float); + if (post_op.quantization.per_channel[post_op.quantization.inp_scale]) { + uni_vpbroadcastd(vreg_d_weights, ptr[reg_d_weights + reg_oc_offset * sizeof(float) + inp_scale_off]); + } else { + uni_vpbroadcastd(vreg_d_weights, ptr[reg_d_weights + inp_scale_off]); + } + + if (post_op.quantization.per_channel[post_op.quantization.inp_shift]) { + uni_vpbroadcastd(vreg_d_bias, ptr[reg_d_weights + reg_oc_offset * sizeof(float) + inp_shift_off]); + } else { + uni_vpbroadcastd(vreg_d_bias, ptr[reg_d_weights + inp_shift_off]); + } + + uni_vfmadd213ps(vreg_dst_, vreg_d_weights, vreg_d_bias); + + if (do_rounding) + uni_vroundps(vreg_dst_, vreg_dst_, 0); + + size_t output_scale_off = post_op.quantization.offset[post_op.quantization.output_scale] * sizeof(float); + size_t output_shift_off = post_op.quantization.offset[post_op.quantization.output_shift] * sizeof(float); + if (do_dequantization) { + if (post_op.quantization.per_channel[post_op.quantization.output_scale]) { + uni_vpbroadcastd(vreg_d_weights, ptr[reg_d_weights + reg_oc_offset * sizeof(float) + output_scale_off]); + } else { + uni_vpbroadcastd(vreg_d_weights, ptr[reg_d_weights + output_scale_off]); + } + + if (post_op.quantization.per_channel[post_op.quantization.output_shift]) { + uni_vpbroadcastd(vreg_d_bias, ptr[reg_d_weights + reg_oc_offset * sizeof(float) + output_shift_off]); + } else { + uni_vpbroadcastd(vreg_d_bias, ptr[reg_d_weights + output_shift_off]); + } + + uni_vfmadd213ps(vreg_dst_, vreg_d_weights, vreg_d_bias); + } + + post_ops_data_offset += sizeof(float*); + } + } + }; + + // Load accumulated value, convert to float, apply bias (if any), scaling, + // and eltwise (if any); then convert to destination type and store + auto compute = [&](bool apply_mask) { + auto dst_addr = ptr[reg_dst]; + auto vreg_dst_ = vreg_dst(0); + if (isa == avx512_common) { + if (apply_mask) + vreg_dst_ = vreg_dst_ | kreg_rem_mask; + uni_vmovups(vreg_dst_, dst_addr); + } else { + if (apply_mask) { + if (isa != sse41) { + uni_vblendvps(vreg_dst_, vreg_zero, dst_addr, vreg_mask); + } else { + uni_vmovups(vreg_dst_, dst_addr); + } + } else { + uni_vmovups(vreg_dst_, dst_addr); + } + } + + if (do_bias_) { + auto vreg_bias_ = vreg_bias(0); + if (isa == avx512_common && apply_mask) + vreg_bias_ = vreg_bias_ | kreg_rem_mask; + + uni_vpbroadcastd(vreg_bias_, ptr[reg_bias]); + uni_vaddps(vreg_dst_, vreg_dst_, vreg_bias_); + } + + apply_post_ops(); + + if (isa == avx512_common) { + uni_vmovups(dst_addr, vreg_dst_); + } else { + if (apply_mask) { + if (isa != sse41) { + vmaskmovps(dst_addr, vreg_mask, vreg_dst_); + } else { + lea(reg_ptr_maskmovdqu_dst, dst_addr); + maskmovdqu(vreg_dst_, vreg_mask); + } + } else { + uni_vmovups(dst_addr, vreg_dst_); + } + } + }; + + Label loop_end; + { + cmp(reg_len, 0); + je(loop_end, T_NEAR); + + Label loop, loop_tail; + cmp(reg_len, vlen); + jl(loop_tail, T_NEAR); + L(loop); { + compute(false); + sub(reg_len, vlen); + add(reg_dst, vlen * sizeof(float)); + cmp(reg_len, vlen); + jge(loop, T_NEAR); + } + + L(loop_tail); + mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift + if (isa == avx512_common) { + mov(reg_rem_mask, 1); + shl(reg_rem_mask, cl); // reg_tmp == rcx and reg_tail < vlen == 16 + sub(reg_rem_mask, 1); + jz(loop_end, T_NEAR); + kmovq(kreg_rem_mask, reg_rem_mask); + } else { + mov(reg_shift_table, vlen); + sub(reg_shift_table, reg_tmp); + uni_vmovups(vreg_mask, ptr[reg_table + reg_shift_table * sizeof(float)]); + } + compute(true); + } + L(loop_end); + + postamble(); + + for (auto& inj : jit_eltwise_injectors_) + inj->prepare_table(); + + if (utils::one_of(isa, avx2, sse41)) { + align(64); + L(l_table); + for (size_t i = 0; i < vlen; i++) dd(0xFFFFFFFF); + for (size_t i = 0; i < vlen; i++) dd(0x00000000); + } +} + +pp_kernel_t *jit_pp_kernel_create( + const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) { + if (mayiuse(avx512_common)) { + return new jit_pp_kernel_t(pd, jcp); + } else if (mayiuse(avx2)) { + return new jit_pp_kernel_t(pd, jcp); + } else if (mayiuse(sse41)) { + return new jit_pp_kernel_t(pd, jcp); + } + return nullptr; +} + +} // namespace gemm_convolution_utils +} // namespace x64 +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/x64/jit_gemm_convolution_utils.hpp b/src/cpu/x64/jit_gemm_convolution_utils.hpp new file mode 100644 index 00000000000..728269c0919 --- /dev/null +++ b/src/cpu/x64/jit_gemm_convolution_utils.hpp @@ -0,0 +1,36 @@ +/******************************************************************************* +* Copyright 2020-2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_X64_JIT_GEMM_CONVOLUTION_UTILS_HPP +#define CPU_X64_JIT_GEMM_CONVOLUTION_UTILS_HPP + +#include "cpu/gemm_convolution_utils.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { +namespace gemm_convolution_utils { + +cpu::gemm_convolution_utils::pp_kernel_t *jit_pp_kernel_create( + const convolution_pd_t *pd, const conv_gemm_conf_t &jcp); +} // namespace gemm_convolution_utils +} // namespace x64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/cpu/x64/jit_gemm_inner_product_utils.cpp b/src/cpu/x64/jit_gemm_inner_product_utils.cpp index 853b44db99b..807e8ee50ed 100644 --- a/src/cpu/x64/jit_gemm_inner_product_utils.cpp +++ b/src/cpu/x64/jit_gemm_inner_product_utils.cpp @@ -307,6 +307,7 @@ jit_pp_kernel_t::jit_pp_kernel_t(size_t OC, size_t MB, dim_t dst_mb_stride, static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; static const size_t helper_vmm_idx = is_avx512_ ? 31 : 15; + static const size_t prelu_helper_vmm_idx = is_avx512_ ? 30 : 0; // todo: [antonvor] check prelu_helper_vmm_idx if is_avx512_ == false static constexpr bool use_exact_tail_scalar_bcast = false; const auto dst_md_wrapper = memory_desc_wrapper(*dst_md); @@ -327,7 +328,7 @@ jit_pp_kernel_t::jit_pp_kernel_t(size_t OC, size_t MB, dim_t dst_mb_stride, helper_vmm_idx, eltwise_reserved_gpr_, r14, preserve_gpr, preserve_vmm, PARAM_OFF(post_ops_binary_rhs_arg_vec), PARAM_OFF(dst_orig), dst_md_wrapper, tail_size, opmask_binary, - reg_tmp, use_exact_tail_scalar_bcast}; + reg_tmp, use_exact_tail_scalar_bcast, prelu_helper_vmm_idx}; static const bcast_set_t enabled_bcast_strategy = {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc, @@ -1160,8 +1161,7 @@ void jit_pp_kernel_t::generate() { // at least 2 blocks of mb within vlen bool dim_restrict = !this->runtime_oc() && !this->runtime_mb() && (this->OC_ <= vlen / 2) && (this->MB_ >= vlen); - bool supported_postops = this->do_scale_ || this->do_eltwise_ - || this->do_binary_ || this->do_sum_ || this->do_dst_zero_points_; + bool supported_postops = this->do_scale_ || (this->post_ops_.len() > 0) || this->do_dst_zero_points_; if (this->do_bias() && !supported_postops && dim_restrict && this->has_trivial_mb_stride()) { this->mb_blk_kernel_ = true; diff --git a/src/cpu/x64/jit_gemm_x8s8s32x_convolution_utils.cpp b/src/cpu/x64/jit_gemm_x8s8s32x_convolution_utils.cpp index a863bafecb3..5438e121ca0 100644 --- a/src/cpu/x64/jit_gemm_x8s8s32x_convolution_utils.cpp +++ b/src/cpu/x64/jit_gemm_x8s8s32x_convolution_utils.cpp @@ -31,40 +31,122 @@ namespace x64 { namespace gemm_x8s8s32x_convolution_utils { using namespace dnnl::impl::cpu::gemm_x8s8s32x_convolution_utils; +template struct jit_pp_ker_t : pp_ker_t, public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS( gemm_x8s8s32x_convolution_utils::jit_pp_ker_t); - jit_pp_ker_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp); + jit_pp_ker_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) + : pp_ker_t(pd, jcp) + , do_eltwise_(false) + , do_sum_(false) + , sum_scale_(0) + , sum_data_type_(dnnl_f32) + , default_OC_loop_unroll_(4) + , max_OC_loop_unroll_(isa == avx512_common ? 12 : 6) + , idx_compute_vreg_start_(0) + , idx_compute_vreg_max_(isa == avx512_common ? 31 : 15) + , compute_vregs_per_iter_(1) + { + if (utils::one_of(isa, avx2, sse41)) { + idx_compute_vreg_start_ += 2; // Vmm(0), Vmm(1) - for masks + } + if (do_scale_) { + vreg_scale = Vmm(idx_compute_vreg_start_++); + } + dst_data_type_size_ = types::data_type_size(dst_data_type_); + if (dst_data_type_ == data_type::u8 || utils::one_of(isa, avx2, sse41)) { + vreg_zero = Vmm(idx_compute_vreg_start_++); + } + bool only_eltwise_or_sum = true; + for (int idx = 0; idx < post_ops_.len(); ++idx) { + const auto &e = post_ops_.entry_[idx]; + if (e.is_eltwise(true)) { + do_eltwise_ = true; + } else if (e.is_sum()) { + do_sum_ = true; + sum_scale_ = e.sum.scale; + sum_data_type_ = e.sum.dt; + } else { + only_eltwise_or_sum = false; + } + } + if (post_ops_.len() > 0 && !only_eltwise_or_sum) { + vreg_d_weights = Vmm(idx_compute_vreg_max_--); + vreg_d_bias = Vmm(idx_compute_vreg_max_--); + } + + do_signed_scaling_ = jcp_.signed_input; + if (do_signed_scaling_) + vreg_signed_scale = Vmm(idx_compute_vreg_start_++); + + if (do_bias_) { + bias_data_type_size_ = types::data_type_size(bias_data_type_); + compute_vregs_per_iter_++; + } + if (do_sum_) { + vreg_sum_scale = Vmm(idx_compute_vreg_start_++); + compute_vregs_per_iter_++; + } + + for (int i = 0; i < post_ops_.len(); i++) { + auto &post_op = post_ops_.entry_[i]; + if (post_op.is_eltwise()) { + jit_eltwise_injectors_.push_back(new jit_uni_eltwise_injector_f32( + this, post_op.eltwise, true, eltwise_reserved, mask_post_op_reserved)); + } else if (post_op.is_depthwise()) { + jit_depthwise_injectors_.push_back(new jit_uni_depthwise_injector_f32( + this, post_op, mask_post_op_reserved)); + } + } + + int max_unroll = (idx_compute_vreg_max_ - idx_compute_vreg_start_ + 1) / compute_vregs_per_iter_; + max_OC_loop_unroll_ = nstl::min(max_OC_loop_unroll_, max_unroll); + default_OC_loop_unroll_ = nstl::min(default_OC_loop_unroll_, max_unroll); + } + ~jit_pp_ker_t() { + for (auto inj : jit_eltwise_injectors_) + delete inj; + jit_eltwise_injectors_.clear(); + for (auto inj : jit_depthwise_injectors_) + delete inj; + jit_depthwise_injectors_.clear(); + } status_t create_kernel() override { return jit_generator::create_kernel(); } - void operator()(void *void_dst, const acc_data_t *acc, const char *bias, - const float *scales, float sum_scale, float signed_scale, int g, - size_t start, size_t end, const zero_point_call_params_t &zp, - const void *post_ops_binary_rhs_arg_vec, const void *dst_orig, - const exec_ctx_t & /* ctx */, const memory_desc_t & /* dst_md */, - const single_gemm_conv_chunk_desc_t &) const override; + + void operator()(void *void_dst, acc_data_t *acc, const char *bias, const float *scales, float sum_scale, float signed_scale, + int g, size_t start, size_t end, + const zero_point_call_params_t &zp, + const void * post_ops_binary_rhs_arg_vec, + const void * /* dst_orig */, const exec_ctx_t &ctx, + const memory_desc_t &dst_md, + const single_gemm_conv_chunk_desc_t &chunk_desc) const override { + + if (end <= start) return; + + char *dst = (char *)void_dst; + + ker_args_t args; + size_t oc_offset = start % OC_; + size_t os_offset = start / OC_; + args.acc = acc + start; + args.dst = dst + + (os_offset * dst_os_stride_ + oc_offset) + * dst_data_type_size_; + args.bias = bias + (g * jcp_.oc + oc_offset) * bias_data_type_size_; + args.scales = scales + scale_idx_mult_ * (g * jcp_.oc + oc_offset); + args.sum_scale = sum_scale_; + args.signed_scale = signed_scale; + args.len = end - start; + args.oc_offset = oc_offset; + args.g_offset = g * jcp_.oc; + args.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; + jit_generator::operator()(&args); + } private: - void apply_postops(const Xbyak::Reg64 ®_dst, const int idx); void generate() override; - void append_zp_src_comp(size_t offset, int idx, bool apply_mask); - void load_as_f32(const Xbyak::Zmm &dst, const Xbyak::Opmask &mask, - const Xbyak::Address &src_addr, const data_type_t &src_dt); - - int vreg_dst_idx(const int idx) const noexcept; - Xbyak::Zmm get_vreg_dst(int idx) const; - Xbyak::Zmm get_vreg_bias(int idx) const; - Xbyak::Zmm get_vreg_prev_dst(int idx) const; - Xbyak::Zmm get_vreg_zp_comp_src(int idx) const; - Xbyak::Zmm get_masked_vreg_dst(int idx, bool apply_mask) const; - Xbyak::Zmm reserve_zmm(); - - template - void advance_binary_postops_off(const T &offset); - void zero_binary_postops_off(); - void set_binary_postops_off(const Xbyak::Reg64 ®); - const Xbyak::Opmask &opmask_binary = k2; struct ker_args_t { char *dst; @@ -75,494 +157,433 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator { float signed_scale; size_t len; size_t oc_offset; - const int32_t *zp_src; - const int32_t *zp_dst; - const int32_t *zp_src_comp; - const int32_t *zp_src_pad_comp; - size_t g_oc_offset_prologue; - size_t g_oc_offset; + size_t g_offset; const void *post_ops_binary_rhs_arg_vec; - const void *dst_orig; - dim_t h; - dim_t w; - dim_t w_size; - dim_t w_off; - dim_t zp_src_pad_com_d_offset; - bool should_apply_zp_src_pad_comp_d; }; - std::unique_ptr> - postops_injector_; - - size_t number_of_reserved_zmm_regs_; - const size_t bias_data_type_size_; - const size_t dst_data_type_size_; - const bool saturation_needed_; - - const Xbyak::Reg64 ®_param_ = rdi; - const Xbyak::Reg64 ®_tmp_ = rcx; // intentional for shifting purposes - - const Xbyak::Reg64 ®_dst_ = rdx; - const Xbyak::Reg64 ®_acc_ = rax; - const Xbyak::Reg64 ®_bias_ = rbx; - const Xbyak::Reg64 ®_scales_ = rsi; - const Xbyak::Reg64 ®_len_ = r8; - const Xbyak::Reg64 ®_oc_offset_ = r9; - const Xbyak::Reg64 ®_rem_mask_short_ = r10; - const Xbyak::Reg64 ®_rem_mask_vlen_ = reg_rem_mask_short_; - const Xbyak::Reg64 ®_zp_pad_comp_temp_ = r10; - const Xbyak::Reg64 ®_zp_pad_comp_ = r11; - const Xbyak::Reg8 ®_should_apply_src_pad_comp_ = r13b; - - const Xbyak::Reg64 ®_tmp_comp_ - = r12; // used to broadcast scalar values to vreg - const Xbyak::Reg64 ®_g_oc_off_ = reg_tmp_comp_; - const Xbyak::Reg64 ®_zp_src_comp_ = r14; - - const Xbyak::Zmm vreg_zero_; - const Xbyak::Zmm vreg_scale_; - const Xbyak::Zmm vreg_sum_scale_; - const Xbyak::Zmm vreg_signed_scale_; - const Xbyak::Zmm vreg_saturation_ubound_; - const Xbyak::Zmm vreg_zp_dst_common_; - - const Xbyak::Opmask &kreg_rem_mask_short_ = k3; - const Xbyak::Opmask &kreg_rem_mask_vlen_ = k4; - - static constexpr size_t def_unroll_ = 4u; - size_t zmm_step_; - const size_t bias_step_factor_; - const size_t sum_step_factor_; - const size_t max_unroll_; - int dst_l_offset_ = 0; - - std::unique_ptr zp_pad_comp_helper_; -}; - -jit_pp_ker_t::jit_pp_ker_t( - const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) - : pp_ker_t(pd, jcp) - , number_of_reserved_zmm_regs_(0) - , bias_data_type_size_(jcp.bias_data_type != data_type::undef - ? types::data_type_size(jcp.bias_data_type) - : 0u) - , dst_data_type_size_(types::data_type_size(jcp.dst_data_type)) - , saturation_needed_(utils::one_of( - jcp_.dst_data_type, data_type::u8, data_type::s8, data_type::s32)) - , vreg_zero_((jcp_.with_eltwise || saturation_needed_) ? reserve_zmm() - : Xbyak::Zmm(0)) - , vreg_scale_(reserve_zmm()) - , vreg_sum_scale_(jcp_.with_sum ? reserve_zmm() : Xbyak::Zmm(0)) - , vreg_signed_scale_(jcp_.signed_input ? reserve_zmm() : Xbyak::Zmm(0)) - , vreg_saturation_ubound_( - saturation_needed_ ? reserve_zmm() : Xbyak::Zmm(0)) - , vreg_zp_dst_common_(jcp_.zp.dst_exists ? reserve_zmm() : Xbyak::Zmm(0)) - , zmm_step_(1u) - , bias_step_factor_(jcp_.with_bias ? zmm_step_++ : 0u) - , sum_step_factor_(jcp_.with_sum ? zmm_step_++ : 0) - , max_unroll_((cpu_isa_traits::n_vregs - - number_of_reserved_zmm_regs_) - / zmm_step_) - , zp_pad_comp_helper_(jit_gemm_convolution_utils::padding_exists(jcp) - && jcp.zp.src_exists - ? utils::make_unique< - jit_gemm_x8s8s32x_zp_pad_comp_helper>(this, jcp_, - reg_zp_pad_comp_, reg_zp_pad_comp_temp_, - reg_should_apply_src_pad_comp_, - pd->src_md()->ndims) - : nullptr) - -{ - - if (jcp.with_eltwise || jcp.with_binary) { - using namespace binary_injector; - static constexpr bool preserve_gpr = true; - static constexpr bool preserve_vmm = true; - static constexpr size_t helper_vmm_idx = 31; - // tail_size = 1 just indicates that tailing is to be performed - // actual tail value is held in opmask passed to injector - static constexpr size_t tail_size = 1; - static constexpr bool use_exact_tail_scalar_bcast = false; - -#define PARAM_OFF(x) offsetof(ker_args_t, x) - const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, - r13, r14, preserve_gpr, preserve_vmm, - PARAM_OFF(post_ops_binary_rhs_arg_vec), PARAM_OFF(dst_orig), - memory_desc_wrapper(pd->dst_md()), tail_size, opmask_binary, - use_exact_tail_scalar_bcast}; -#undef PARAM_OFF - - const static_params_t static_params {reg_param_, rhs_arg_static_params}; - - postops_injector_ = utils::make_unique< - injector::jit_uni_postops_injector_t>( - this, jcp_.post_ops, static_params); - } -} - -void jit_pp_ker_t::operator()(void *void_dst, const acc_data_t *acc, - const char *bias, const float *scales, float sum_scale, - float signed_scale, int g, size_t start, size_t end, - const zero_point_call_params_t &zp, - const void *post_ops_binary_rhs_arg_vec, const void *dst_orig, - const exec_ctx_t & /* ctx */, const memory_desc_t & /* dst_md */, - const single_gemm_conv_chunk_desc_t &chunk_desc) const { - - if (end <= start) return; - - char *dst = (char *)void_dst; - - ker_args_t args; - const auto dv = std::div(start, jcp_.oc); - const size_t oc_offset = dv.rem; - const size_t os_offset = dv.quot; - args.acc = acc + start; - args.dst = dst - + (os_offset * jcp_.dst_os_stride + oc_offset) - * dst_data_type_size_; - - const ptrdiff_t g_oc_offset = g * jcp_.oc; - const ptrdiff_t g_oc_offset_prologue = g_oc_offset + oc_offset; - args.bias = bias + g_oc_offset_prologue * bias_data_type_size_; - args.zp_src = zp.src + (jcp_.zp.src_is_common ? 0 : g_oc_offset_prologue); - args.zp_src_comp - = zp.src_comp ? zp.src_comp + g_oc_offset_prologue : nullptr; - args.zp_dst = zp.dst; - args.scales = scales + jcp_.scale_idx_mult * g_oc_offset_prologue; - args.sum_scale = sum_scale; - args.signed_scale = signed_scale; - args.len = end - start; - args.oc_offset = oc_offset; - - args.g_oc_offset = g_oc_offset; - args.g_oc_offset_prologue = g_oc_offset_prologue; - - args.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; - args.dst_orig = dst_orig; - - if (zp_pad_comp_helper_) { - const auto hw - = std::div(static_cast(os_offset), chunk_desc.w_size_); - args.h = hw.quot + chunk_desc.h_off_; - args.w = hw.rem + chunk_desc.w_off_; - args.w_size = chunk_desc.w_size_ + chunk_desc.w_off_; - args.w_off = chunk_desc.w_off_; - args.zp_src_pad_comp = zp.src_pad_comp; - const auto zp_src_pad_com_d - = zp_pad_comp_helper_->calculate_zp_src_pad_com_d( - chunk_desc.d_off_); - args.zp_src_pad_com_d_offset = zp_src_pad_com_d.offset; - args.should_apply_zp_src_pad_comp_d - = zp_src_pad_com_d.should_apply_pad_comp_d; + nstl::vector *> jit_eltwise_injectors_; + nstl::vector *> jit_depthwise_injectors_; + + using Vmm = typename cpu_isa_traits::Vmm; + static const size_t vlen = cpu_isa_traits::vlen / sizeof(float); + + Xbyak::Reg64 reg_param = abi_param1; + Xbyak::Reg64 reg_dst = rdx; + Xbyak::Reg64 reg_acc = rax; + Xbyak::Reg64 reg_bias = rbx; + Xbyak::Reg64 reg_scales = rsi; + Xbyak::Reg64 reg_g_offset = rbp; + + Xbyak::Reg64 reg_len = r8; + Xbyak::Reg64 reg_tmp = rcx; // intentional for shifting purposes + Xbyak::Reg64 reg_oc_offset = r9; + Xbyak::Reg64 reg_rem_mask_short = r10; + Xbyak::Opmask kreg_rem_mask_short = k1; + + Vmm vreg_zero, vreg_scale, vreg_sum_scale, vreg_signed_scale, vreg_comp; + + // sse41/avx2 + Xbyak::Reg64 reg_ptr_maskmovdqu_dst = rdi; // sse41: store destination - must be rdi + Xbyak::Label l_table; + Xbyak::Reg64 reg_table = r12; + Xbyak::Reg64 reg_shift_table = r13; + Vmm vreg_mask = Vmm(0); // sse41: mask for blendvps must be in xmm0 + Vmm vreg_store_mask = Vmm(1); + + // post_ops + Xbyak::Opmask mask_post_op_reserved = k2; + Xbyak::Reg64 eltwise_reserved = rax; + Xbyak::Reg64 reg_d_weights = r14; + Xbyak::Reg64 reg_d_bias = r15; + Vmm vreg_d_weights, vreg_d_bias; + + size_t dst_data_type_size_ = 0; + size_t bias_data_type_size_ = 0; + + bool do_eltwise_; + bool do_sum_; + float sum_scale_; + data_type_t sum_data_type_; + bool do_signed_scaling_; + + int default_OC_loop_unroll_; + int max_OC_loop_unroll_; + int idx_compute_vreg_start_; + int idx_compute_vreg_max_; + int compute_vregs_per_iter_; + + int idx_vreg_dst(int iter) { + int idx = idx_compute_vreg_start_ + iter * compute_vregs_per_iter_ + 0; + assert(idx <= idx_compute_vreg_max_); + return idx; } - - jit_generator::operator()(&args); -} - -template -void jit_pp_ker_t::advance_binary_postops_off(const T &offset) { - add(reg_g_oc_off_, offset); - - Xbyak::Label end; - cmp(reg_g_oc_off_, jcp_.oc); - jl(end, T_NEAR); - xor_(reg_g_oc_off_, reg_g_oc_off_); - - L(end); -} -void jit_pp_ker_t::zero_binary_postops_off() { - xor_(reg_g_oc_off_, reg_g_oc_off_); - dst_l_offset_ = 0; -} -void jit_pp_ker_t::set_binary_postops_off(const Xbyak::Reg64 ®) { - mov(reg_g_oc_off_, reg); - dst_l_offset_ = 0; -} - -Xbyak::Zmm jit_pp_ker_t::reserve_zmm() { - return Xbyak::Zmm(number_of_reserved_zmm_regs_++); -} - -int jit_pp_ker_t::vreg_dst_idx(const int idx) const noexcept { - return (number_of_reserved_zmm_regs_ + idx * zmm_step_); -} - -Xbyak::Zmm jit_pp_ker_t::get_vreg_dst(int idx) const { - return Xbyak::Zmm(vreg_dst_idx(idx)); -} - -Xbyak::Zmm jit_pp_ker_t::get_vreg_bias(int idx) const { - return Xbyak::Zmm(vreg_dst_idx(idx) + bias_step_factor_); -} - -Xbyak::Zmm jit_pp_ker_t::get_vreg_prev_dst(int idx) const { - return Xbyak::Zmm(vreg_dst_idx(idx) + sum_step_factor_); -} - -Xbyak::Zmm jit_pp_ker_t::get_masked_vreg_dst(int idx, bool apply_mask) const { - auto vreg_dst = this->get_vreg_dst(idx); - if (apply_mask) - vreg_dst = vreg_dst | kreg_rem_mask_short_; - else - vreg_dst = vreg_dst | kreg_rem_mask_vlen_; - return vreg_dst; -} - -void jit_pp_ker_t::append_zp_src_comp(size_t offset, int idx, bool apply_mask) { - const auto vreg_dst_masked = get_masked_vreg_dst(idx, apply_mask); - const auto vreg_dst = get_vreg_dst(idx); - const auto zp_src_comp_offset = offset * sizeof(int32_t); - const auto zp_src_comp_addr = ptr[reg_zp_src_comp_ + zp_src_comp_offset]; - - vpaddd(vreg_dst_masked, vreg_dst, zp_src_comp_addr); - - if (zp_pad_comp_helper_) - zp_pad_comp_helper_->zp_src_comp_pad_operation( - [&](const Xbyak::Reg64 ®_zp_pad_comp) { - vpaddd(vreg_dst_masked, vreg_dst, - ptr[reg_zp_pad_comp + zp_src_comp_offset]); - }); -} - -void jit_pp_ker_t::apply_postops(const Xbyak::Reg64 ®_dst, const int idx) { -#define PARAM_OFF(x) offsetof(ker_args_t, x) - if (jcp_.with_eltwise || jcp_.with_binary) { - if (jcp_.with_binary) { - binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; - const auto vmm_idx = vreg_dst_idx(idx); - - rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_dst); - rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(vmm_idx, - dst_l_offset_ * types::data_type_size(jcp_.dst_data_type)); - rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); - - postops_injector_->compute_vector( - vreg_dst_idx(idx), rhs_arg_params); - } else - postops_injector_->compute_vector(vreg_dst_idx(idx)); + int idx_vreg_bias(int iter) { + int idx = idx_compute_vreg_start_ + iter * compute_vregs_per_iter_ + 1; + assert(idx <= idx_compute_vreg_max_); + return idx; } -#undef PARAM_OFF -} - -void jit_pp_ker_t::load_as_f32(const Xbyak::Zmm &dst, - const Xbyak::Opmask &mask_reg, const Xbyak::Address &src_addr, - const data_type_t &src_dt) { - - const auto dst_masked = dst | mask_reg; - - switch (src_dt) { - case data_type::s8: vpmovsxbd(dst_masked, src_addr); break; - case data_type::u8: vpmovzxbd(dst_masked, src_addr); break; - case data_type::s32: vcvtdq2ps(dst_masked, src_addr); break; - case data_type::f32: vmovups(dst_masked, src_addr); break; - default: assert(!"unimplemented"); + int idx_vreg_prev_dst(int iter) { + int idx = idx_compute_vreg_start_ + iter * compute_vregs_per_iter_ + 2; + assert(idx <= idx_compute_vreg_max_); + return idx; } - if (utils::one_of(src_dt, data_type::s8, data_type::u8)) - vcvtdq2ps(dst_masked, dst); -} + Vmm vreg_dst(int idx) { return Vmm(idx_vreg_dst(idx)); }; + Xbyak::Ymm ymm_dst(int idx) { return Xbyak::Ymm(idx_vreg_dst(idx)); }; + Xbyak::Xmm xmm_dst(int idx) { return Xbyak::Xmm(idx_vreg_dst(idx)); }; + Vmm vreg_bias(int idx) { return Vmm(idx_vreg_bias(idx)); }; + Vmm vreg_prev_dst(int idx) { return Vmm(idx_vreg_prev_dst(idx)); }; +}; -void jit_pp_ker_t::generate() { +template +void jit_pp_ker_t::generate() { using namespace Xbyak; using namespace utils; - size_t vlen = cpu_isa_traits::vlen / sizeof(float); - for (; vlen >= 1 && (jcp_.oc % vlen != 0); --vlen) {} - preamble(); -#ifdef _WIN32 - mov(reg_param_, rcx); -#endif + const auto &p = post_ops_; + std::size_t post_ops_pointers_count = 0; + for (int i = 0; i < p.len(); i++) { + if (p.entry_[i].is_depthwise() || p.entry_[i].is_quantization()) { + post_ops_pointers_count++; + } + } #define PARAM_OFF(x) offsetof(ker_args_t, x) - mov(reg_dst_, ptr[reg_param_ + PARAM_OFF(dst)]); - mov(reg_acc_, ptr[reg_param_ + PARAM_OFF(acc)]); - mov(reg_bias_, ptr[reg_param_ + PARAM_OFF(bias)]); - mov(reg_scales_, ptr[reg_param_ + PARAM_OFF(scales)]); - mov(reg_len_, ptr[reg_param_ + PARAM_OFF(len)]); - mov(reg_oc_offset_, ptr[reg_param_ + PARAM_OFF(oc_offset)]); - - if (jcp_.zp.src_exists) { - mov(reg_zp_src_comp_, ptr[reg_param_ + PARAM_OFF(zp_src_comp)]); - if (zp_pad_comp_helper_) - zp_pad_comp_helper_->init(PARAM_OFF(w), PARAM_OFF(h), - PARAM_OFF(w_size), PARAM_OFF(w_off), - PARAM_OFF(zp_src_pad_comp), PARAM_OFF(g_oc_offset_prologue), - PARAM_OFF(g_oc_offset), PARAM_OFF(zp_src_pad_com_d_offset), - PARAM_OFF(should_apply_zp_src_pad_comp_d)); - } + if (post_ops_pointers_count != 0) { + sub(rsp, post_ops_pointers_count * sizeof(float *)); + + auto aux_reg0 = reg_dst; + auto aux_reg1 = reg_acc; - if (jcp_.zp.dst_exists) { - mov(reg_tmp_, ptr[reg_param_ + PARAM_OFF(zp_dst)]); - vcvtdq2ps(vreg_zp_dst_common_, ptr_b[reg_tmp_]); + mov(aux_reg0, ptr[reg_param + PARAM_OFF(post_ops_binary_rhs_arg_vec)]); + for (size_t i = 0; i < post_ops_pointers_count; i++) { + mov(aux_reg1, ptr[aux_reg0 + i * sizeof(float *)]); + mov(ptr[rsp + i * sizeof(float *)], aux_reg1); + } } - if (jcp_.with_sum) - vbroadcastss(vreg_sum_scale_, ptr[reg_param_ + PARAM_OFF(sum_scale)]); - if (jcp_.signed_input) - vbroadcastss( - vreg_signed_scale_, ptr[reg_param_ + PARAM_OFF(signed_scale)]); - if (jcp_.scale_idx_mult == 0) vbroadcastss(vreg_scale_, dword[reg_scales_]); + mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]); + mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]); + mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]); + mov(reg_scales, ptr[reg_param + PARAM_OFF(scales)]); + mov(reg_len, ptr[reg_param + PARAM_OFF(len)]); + mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]); + mov(reg_g_offset, ptr[reg_param + PARAM_OFF(g_offset)]); + if (do_sum_) + uni_vbroadcastss(vreg_sum_scale, ptr[reg_param + PARAM_OFF(sum_scale)]); + if (do_signed_scaling_) + uni_vbroadcastss(vreg_signed_scale, ptr[reg_param + PARAM_OFF(signed_scale)]); + if (do_scale_ && scale_idx_mult_ == 0) + uni_vbroadcastss(vreg_scale, dword[reg_scales]); #undef PARAM_OFF - mov(reg_rem_mask_vlen_, 1); - shl(reg_rem_mask_vlen_, vlen); - sub(reg_rem_mask_vlen_, 1); - kmovq(kreg_rem_mask_vlen_, reg_rem_mask_vlen_); + if (do_eltwise_ || dst_data_type_ == data_type::u8 || utils::one_of(isa, avx2, sse41)) + uni_vpxor(vreg_zero, vreg_zero, vreg_zero); + + if (utils::one_of(isa, avx2, sse41)) + mov(reg_table, l_table); + + auto apply_post_ops = [&](size_t offset, int idx) { + std::size_t post_ops_data_offset = 0; + int eltwise_inj_idx = 0; + int depthwise_inj_idx = 0; + for (int i = 0; i < post_ops_.len(); i++) { + auto& post_op = post_ops_.entry_[i]; + if (post_op.is_sum()) { + auto dst_addr = ptr[reg_dst + offset * dst_data_type_size_]; + auto vreg_prev_dst_ = vreg_prev_dst(idx); + switch (sum_data_type_) { + case data_type::f32: + case data_type::s32: uni_vmovups(vreg_prev_dst_, dst_addr); break; + case data_type::s8: uni_vpmovsxbd(vreg_prev_dst_, dst_addr); break; + case data_type::u8: uni_vpmovzxbd(vreg_prev_dst_, dst_addr); break; + default: assert(!"unsupported data type"); + } + if (sum_data_type_ != data_type::f32) + uni_vcvtdq2ps(vreg_prev_dst(idx), vreg_prev_dst(idx)); + + uni_vfmadd231ps(vreg_dst(idx), vreg_prev_dst(idx), vreg_sum_scale); + } else if (post_op.is_eltwise()) { + jit_eltwise_injectors_[eltwise_inj_idx]->compute_vector_range(vreg_dst(idx).getIdx(), vreg_dst(idx).getIdx() + 1); + eltwise_inj_idx++; + } else if (post_op.is_depthwise()) { + add(reg_oc_offset, reg_g_offset); + + const Xbyak::RegExp depthwise_arg_base = rsp + post_ops_data_offset; + mov(reg_d_weights, ptr[depthwise_arg_base]); + lea(reg_d_weights, ptr[reg_d_weights + reg_oc_offset * sizeof(float) + offset]); + + jit_depthwise_injectors_[depthwise_inj_idx]->compute_vector_range(vreg_dst(idx).getIdx(), vreg_dst(idx).getIdx() + 1, reg_d_weights, reg_d_weights); + + sub(reg_oc_offset, reg_g_offset); + + post_ops_data_offset += jit_depthwise_injectors_[depthwise_inj_idx]->memoryStep(); + depthwise_inj_idx++; + } else if (post_op.is_quantization()) { + add(reg_oc_offset, reg_g_offset); + bool do_dequantization = post_op.quantization.alg == alg_kind::quantization_quantize_dequantize; + bool do_rounding = do_dequantization || dst_data_type_ == dnnl_f32 || i != post_ops_.len() - 1; + + const Xbyak::RegExp quantization_arg_base = rsp + post_ops_data_offset; + size_t crop_low_off = post_op.quantization.offset[post_op.quantization.crop_low] * sizeof(float); + if (post_op.quantization.per_channel[post_op.quantization.crop_low]) { + mov(reg_d_weights, ptr[quantization_arg_base]); + uni_vmovups(vreg_d_weights, ptr[reg_d_weights + reg_oc_offset * sizeof(float) + offset * sizeof(float) + crop_low_off]); + } else { + mov(reg_d_weights, ptr[quantization_arg_base]); + uni_vbroadcastss(vreg_d_weights, ptr[reg_d_weights + crop_low_off]); + } - if (jcp_.with_eltwise) vxorps(vreg_zero_, vreg_zero_, vreg_zero_); - if (saturation_needed_) - init_saturate_f32(vreg_zero_, vreg_saturation_ubound_, reg_tmp_comp_, - data_type::f32, jcp_.dst_data_type); + size_t crop_high_off = post_op.quantization.offset[post_op.quantization.crop_high] * sizeof(float); + if (post_op.quantization.per_channel[post_op.quantization.crop_high]) { + mov(reg_d_bias, ptr[quantization_arg_base]); + uni_vmovups(vreg_d_bias, ptr[reg_d_bias + reg_oc_offset * sizeof(float) + offset * sizeof(float) + crop_high_off]); + } else { + mov(reg_d_bias, ptr[quantization_arg_base]); + uni_vbroadcastss(vreg_d_bias, ptr[reg_d_bias + crop_high_off]); + } - if (jcp_.with_binary) set_binary_postops_off(reg_oc_offset_); + uni_vmaxps(vreg_dst(idx), vreg_dst(idx), vreg_d_weights); + uni_vminps(vreg_dst(idx), vreg_dst(idx), vreg_d_bias); - // Load accumulated value, convert to float, apply sum (if any), - // bias (if any), scaling, and relu (if any); - // then convert to destination type and store - const auto compute = [&](size_t offset, int idx, bool apply_mask) { - auto acc_addr = ptr[reg_acc_ + offset * sizeof(acc_data_t)]; - - const auto &mask_reg - = apply_mask ? kreg_rem_mask_short_ : kreg_rem_mask_vlen_; - - if (jcp_.scale_idx_mult > 0) { - assert(jcp_.scale_idx_mult == 1); - const auto scale_addr = ptr[reg_scales_ + offset * sizeof(float)]; - auto vreg_scale = vreg_scale_; - vreg_scale = vreg_scale | mask_reg; - vmovups(vreg_scale, scale_addr); - } + size_t inp_scale_off = post_op.quantization.offset[post_op.quantization.inp_scale] * sizeof(float); + if (post_op.quantization.per_channel[post_op.quantization.inp_scale]) { + mov(reg_d_weights, ptr[quantization_arg_base]); + uni_vmovups(vreg_d_weights, ptr[reg_d_weights + reg_oc_offset * sizeof(float) + offset * sizeof(float) + inp_scale_off]); + } else { + mov(reg_d_weights, ptr[quantization_arg_base ]); + uni_vbroadcastss(vreg_d_weights, ptr[reg_d_weights + inp_scale_off]); + } - if (jcp_.with_binary) { - if (offset) { - advance_binary_postops_off(vlen); - dst_l_offset_ += offset; - } - kmovq(opmask_binary, mask_reg); - } - const auto vreg_dst_masked = get_masked_vreg_dst(idx, apply_mask); - const auto vreg_dst = get_vreg_dst(idx); - if (jcp_.zp.src_exists) { - vmovups(vreg_dst_masked, acc_addr); - append_zp_src_comp(offset, idx, apply_mask); - vcvtdq2ps(vreg_dst_masked, vreg_dst); - } else { - vcvtdq2ps(vreg_dst_masked, acc_addr); - } + size_t inp_shift_off = post_op.quantization.offset[post_op.quantization.inp_shift] * sizeof(float); + if (post_op.quantization.per_channel[post_op.quantization.inp_shift]) { + mov(reg_d_bias, ptr[quantization_arg_base]); + uni_vmovups(vreg_d_bias, ptr[reg_d_bias + reg_oc_offset * sizeof(float) + offset * sizeof(float) + inp_shift_off]); + } else { + mov(reg_d_bias, ptr[quantization_arg_base]); + uni_vbroadcastss(vreg_d_bias, ptr[reg_d_bias + inp_shift_off]); + } - if (jcp_.signed_input) - vmulps(vreg_dst_masked, vreg_dst, vreg_signed_scale_); + uni_vfmadd213ps(vreg_dst(idx), vreg_d_weights, vreg_d_bias); + + if (do_rounding) + uni_vroundps(vreg_dst(idx), vreg_dst(idx), 0); + + if (do_dequantization) { + size_t output_scale_off = post_op.quantization.offset[post_op.quantization.output_scale] * sizeof(float); + if (post_op.quantization.per_channel[post_op.quantization.output_scale]) { + mov(reg_d_weights, ptr[quantization_arg_base ]); + uni_vmovups(vreg_d_weights, ptr[reg_d_weights + reg_oc_offset * sizeof(float) + offset * sizeof(float) + output_scale_off]); + } else { + mov(reg_d_weights, ptr[quantization_arg_base]); + uni_vbroadcastss(vreg_d_weights, ptr[reg_d_weights + output_scale_off]); + } + + size_t output_shift_off = post_op.quantization.offset[post_op.quantization.output_shift] * sizeof(float); + if (post_op.quantization.per_channel[post_op.quantization.output_shift]) { + mov(reg_d_bias, ptr[quantization_arg_base]); + uni_vmovups(vreg_d_bias, ptr[reg_d_bias + reg_oc_offset * sizeof(float) + offset * sizeof(float) + output_shift_off]); + } else { + mov(reg_d_bias, ptr[quantization_arg_base]); + uni_vbroadcastss(vreg_d_bias, ptr[reg_d_bias + output_shift_off]); + } + + uni_vfmadd213ps(vreg_dst(idx), vreg_d_weights, vreg_d_bias); + } + sub(reg_oc_offset, reg_g_offset); - if (jcp_.with_bias) { - const auto bias_addr - = ptr[reg_bias_ + offset * bias_data_type_size_]; - const auto vreg_bias = get_vreg_bias(idx); - load_as_f32(vreg_bias, mask_reg, bias_addr, jcp_.bias_data_type); - vaddps(vreg_dst_masked, vreg_dst, vreg_bias); + post_ops_data_offset += sizeof(float*); + } } + }; - vmulps(vreg_dst_masked, vreg_dst, vreg_scale_); + // Load accumulated value, convert to float, + // bias (if any), scaling, and simple operations (if any); + // then convert to destination type and store + auto compute = [&](size_t offset, int idx, bool apply_mask) { + auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)]; + + if (do_scale_ && scale_idx_mult_ > 0) { + assert(scale_idx_mult_ == 1); + auto scale_addr = ptr[reg_scales + offset * sizeof(float)]; + auto vreg_scale_ = vreg_scale; + if (isa == avx512_common) { + if (apply_mask) + vreg_scale_ = vreg_scale_ | kreg_rem_mask_short; + uni_vmovups(vreg_scale_, scale_addr); + } else { + if (apply_mask) + if (isa != sse41) { + uni_vblendvps(vreg_scale, vreg_zero, scale_addr, vreg_mask); + } else { + uni_vmovups(vreg_scale, vreg_zero); + uni_vblendvps(vreg_scale, vreg_scale, scale_addr, vreg_mask); + } + else + uni_vmovups(vreg_scale, scale_addr); + } + } - const auto dst_addr = ptr[reg_dst_ + offset * dst_data_type_size_]; + auto vreg_dst_ = vreg_dst(idx); + if (isa == avx512_common) { + if (apply_mask) + vreg_dst_ = vreg_dst_ | kreg_rem_mask_short; + uni_vcvtdq2ps(vreg_dst_, acc_addr); + } else { + if (apply_mask) { + if (isa != sse41) { + uni_vblendvps(vreg_dst_, vreg_zero, acc_addr, vreg_mask); + } else { + uni_vmovups(vreg_dst_, acc_addr); + } + uni_vcvtdq2ps(vreg_dst_, vreg_dst_); + } else { + if (isa == sse41) { + uni_vmovups(vreg_dst_, acc_addr); + uni_vcvtdq2ps(vreg_dst_, vreg_dst_); + } else { + uni_vcvtdq2ps(vreg_dst_, acc_addr); + } + } + } - if (jcp_.with_sum) { - const auto vreg_prev_dst = get_vreg_prev_dst(idx); - load_as_f32(vreg_prev_dst, mask_reg, dst_addr, jcp_.sum_data_type); - vfmadd231ps(vreg_dst_masked, vreg_prev_dst, vreg_sum_scale_); + if (do_signed_scaling_) + uni_vmulps(vreg_dst(idx), vreg_dst(idx), vreg_signed_scale); + + if (do_bias_) { + auto bias_addr = ptr[reg_bias + offset * bias_data_type_size_]; + auto vreg_bias_ = vreg_bias(idx); + if (isa == avx512_common && apply_mask) + vreg_bias_ = vreg_bias_ | kreg_rem_mask_short; + + switch (bias_data_type_) { + case data_type::s8: uni_vpmovsxbd(vreg_bias_, bias_addr); break; + case data_type::u8: uni_vpmovzxbd(vreg_bias_, bias_addr); break; + case data_type::s32: + case data_type::f32: uni_vmovups(vreg_bias_, bias_addr); break; + default: assert(!"unimplemented"); + } + if (bias_data_type_ != data_type::f32) + uni_vcvtdq2ps(vreg_bias(idx), vreg_bias(idx)); + uni_vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias(idx)); } - apply_postops(reg_dst_, idx); + if (do_scale_) + uni_vmulps(vreg_dst(idx), vreg_dst(idx), vreg_scale); - if (jcp_.zp.dst_exists) { - vaddps(vreg_dst_masked, vreg_dst, vreg_zp_dst_common_); - } + apply_post_ops(offset, idx); - if (saturation_needed_) { - saturate_f32(get_vreg_dst(idx), vreg_zero_, vreg_saturation_ubound_, - jcp_.dst_data_type); - vcvtps2dq(vreg_dst_masked, vreg_dst); + if (dst_data_type_ != data_type::f32) { + if (isa == avx512_common) { + auto rmode_control = T_rn_sae; + vcvtps2dq(vreg_dst(idx) | rmode_control, vreg_dst(idx)); + } else { + uni_vcvtps2dq(vreg_dst(idx), vreg_dst(idx)); + } } - switch (jcp_.dst_data_type) { - case data_type::s8: vpmovsdb(dst_addr, vreg_dst_masked); break; - case data_type::u8: vpmovusdb(dst_addr, vreg_dst_masked); break; + if (dst_data_type_ == data_type::u8) + uni_vpmaxsd(vreg_dst(idx), vreg_dst(idx), vreg_zero); + + auto dst_addr = ptr[reg_dst + offset * dst_data_type_size_]; + switch (dst_data_type_) { + case data_type::s8: + if (isa == avx512_common) { + vpmovsdb(dst_addr, vreg_dst_); + } else { + uni_vpackssdw(vreg_dst_, vreg_dst_, vreg_dst_); + if (isa != sse41) + vpermq(ymm_dst(idx), ymm_dst(idx), 0x08); + uni_vpacksswb(vreg_dst_, vreg_dst_, vreg_dst_); + if (apply_mask) { + lea(reg_ptr_maskmovdqu_dst, dst_addr); + maskmovdqu(vreg_dst_, vreg_store_mask); + } else { + if (isa != sse41) { + vmovq(dst_addr, xmm_dst(idx)); + } else { + movd(dst_addr, xmm_dst(idx)); + } + } + } + break; + case data_type::u8: + if (isa == avx512_common) { + vpmovusdb(dst_addr, vreg_dst_); + } else { + uni_vpackusdw(vreg_dst_, vreg_dst_, vreg_dst_); + if (isa != sse41) + vpermq(ymm_dst(idx), ymm_dst(idx), 0x08); + uni_vpackuswb(vreg_dst_, vreg_dst_, vreg_dst_); + if (apply_mask) { + lea(reg_ptr_maskmovdqu_dst, dst_addr); + maskmovdqu(vreg_dst_, vreg_store_mask); + } else { + if (isa != sse41) { + vmovq(dst_addr, xmm_dst(idx)); + } else { + movd(dst_addr, xmm_dst(idx)); + } + } + } + break; case data_type::f32: - case data_type::s32: vmovups(dst_addr, vreg_dst_masked); break; + case data_type::s32: + if (isa == avx512_common) { + uni_vmovups(dst_addr, vreg_dst_); + } else { + if (apply_mask) { + if (isa != sse41) { + vmaskmovps(dst_addr, vreg_mask, vreg_dst_); + } else { + lea(reg_ptr_maskmovdqu_dst, dst_addr); + maskmovdqu(vreg_dst_, vreg_mask); + } + } else { + uni_vmovups(dst_addr, vreg_dst_); + } + } + break; default: assert(!"unimplemented"); } }; // Advance all pointers by an immediate - const auto advance_ptrs_imm = [&](const size_t offset, - const size_t binary_offset) { - add(reg_dst_, offset * dst_data_type_size_); - add(reg_acc_, offset * sizeof(acc_data_t)); - if (jcp_.with_binary) { advance_binary_postops_off(binary_offset); } - if (jcp_.scale_idx_mult) { - assert(jcp_.scale_idx_mult == 1); - add(reg_scales_, offset * sizeof(float)); - } - if (jcp_.with_bias) add(reg_bias_, offset * bias_data_type_size_); - if (jcp_.zp.src_exists) { - add(reg_zp_src_comp_, offset * sizeof(int32_t)); - - if (zp_pad_comp_helper_) { - zp_pad_comp_helper_->zp_src_comp_pad_operation( - [&](const Xbyak::Reg64 ®_zp_pad_comp) { - add(reg_zp_pad_comp, offset * sizeof(int32_t)); - }); - } + auto advance_ptrs_imm = [&](size_t offset) { + add(reg_dst, offset * dst_data_type_size_); + add(reg_acc, offset * sizeof(acc_data_t)); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + add(reg_scales, offset * sizeof(float)); } + if (do_bias_) + add(reg_bias, offset * bias_data_type_size_); }; // Advance all pointers by a value stored in a register - const auto advance_ptrs_reg = [&](const Reg64 offset, - const Reg64 binary_offset) { - lea(reg_dst_, ptr[reg_dst_ + offset * dst_data_type_size_]); - lea(reg_acc_, ptr[reg_acc_ + offset * sizeof(acc_data_t)]); - if (jcp_.with_binary) { advance_binary_postops_off(binary_offset); } - if (jcp_.scale_idx_mult) { - assert(jcp_.scale_idx_mult == 1); - lea(reg_scales_, ptr[reg_scales_ + offset * sizeof(float)]); - } - if (jcp_.with_bias) - lea(reg_bias_, ptr[reg_bias_ + offset * bias_data_type_size_]); - - if (jcp_.zp.src_exists) { - lea(reg_zp_src_comp_, - ptr[reg_zp_src_comp_ + offset * sizeof(int32_t)]); - - if (zp_pad_comp_helper_) - zp_pad_comp_helper_->zp_src_comp_pad_operation( - [&](const Xbyak::Reg64 ®_zp_pad_comp) { - lea(reg_zp_pad_comp, - ptr[reg_zp_pad_comp - + offset * sizeof(int32_t)]); - }); + auto advance_ptrs_reg = [&](Reg64 offset) { + lea(reg_dst, ptr[reg_dst + offset * dst_data_type_size_]); + lea(reg_acc, ptr[reg_acc + offset * sizeof(acc_data_t)]); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + lea(reg_scales, ptr[reg_scales + offset * sizeof(float)]); } + if (do_bias_) + lea(reg_bias, ptr[reg_bias + offset * bias_data_type_size_]); }; // Rewind pointers that point to data that is indexed by output channel // (bias or per-oc scaling factors) - const auto rewind_ptrs = [&]() { - if (jcp_.with_bias) sub(reg_bias_, jcp_.oc * bias_data_type_size_); - if (jcp_.with_binary) { - zero_binary_postops_off(); - dst_l_offset_ = 0; - } - if (jcp_.zp.src_exists) { - const auto offset = jcp_.oc * sizeof(int32_t); - sub(reg_zp_src_comp_, offset); - if (zp_pad_comp_helper_) - zp_pad_comp_helper_->load_next_point_zp_src_comp_pad_addr(); - } - if (jcp_.scale_idx_mult) { - assert(jcp_.scale_idx_mult == 1); - sub(reg_scales_, jcp_.oc * sizeof(float)); + auto rewind_ptrs = [&]() { + if (do_bias_) + sub(reg_bias, OC_ * bias_data_type_size_); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + sub(reg_scales, OC_ * sizeof(float)); } - add(reg_dst_, (jcp_.dst_os_stride - jcp_.oc) * dst_data_type_size_); + add(reg_dst, (dst_os_stride_ - OC_) * dst_data_type_size_); }; // <--------- OC ---------------> @@ -577,40 +598,55 @@ void jit_pp_ker_t::generate() { // | . | Epilogue loop|not accessed : . // v ................+--------------+.............+....................... + bool do_post_ops = post_ops_.len() != 0; + Label prologue_end; - cmp(reg_oc_offset_, 0); + cmp(reg_oc_offset, 0); je(prologue_end, T_NEAR); // Prologue loop { - mov(reg_tmp_, jcp_.oc); - sub(reg_tmp_, reg_oc_offset_); - cmp(reg_tmp_, reg_len_); - cmovg(reg_tmp_, reg_len_); - sub(reg_len_, reg_tmp_); + mov(reg_tmp, OC_); + sub(reg_tmp, reg_oc_offset); + cmp(reg_tmp, reg_len); + cmovg(reg_tmp, reg_len); + sub(reg_len, reg_tmp); Label prologue_loop, prologue_loop_tail, prologue_loop_end; - cmp(reg_tmp_, vlen); - jle(prologue_loop_tail, T_NEAR); + cmp(reg_tmp, vlen); + jl(prologue_loop_tail, T_NEAR); L(prologue_loop); { - compute(0, max_unroll_ - 1, false); - advance_ptrs_imm(vlen, vlen); - sub(reg_tmp_, vlen); - cmp(reg_tmp_, vlen); + compute(0, 0, false); + advance_ptrs_imm(vlen); + if (do_post_ops) + add(reg_oc_offset, vlen); + sub(reg_tmp, vlen); + cmp(reg_tmp, vlen); jge(prologue_loop, T_NEAR); } L(prologue_loop_tail); - mov(reg_rem_mask_short_, 1); - // cl == reg_tmp_ because reg_tmp_ <= vlen here - shl(reg_rem_mask_short_, cl); - sub(reg_rem_mask_short_, 1); - jz(prologue_loop_end, T_NEAR); - - kmovq(kreg_rem_mask_short_, reg_rem_mask_short_); - compute(0, max_unroll_ - 1, true); - advance_ptrs_reg(reg_tmp_, reg_tmp_); + if (isa == avx512_common) { + mov(reg_rem_mask_short, 1); + // cl == reg_tmp because reg_tmp <= vlen here + shl(reg_rem_mask_short, cl); + sub(reg_rem_mask_short, 1); + jz(prologue_loop_end, T_NEAR); + + kmovq(kreg_rem_mask_short, reg_rem_mask_short); + } else { + mov(reg_shift_table, vlen); + sub(reg_shift_table, reg_tmp); + uni_vmovups(vreg_mask, ptr[reg_table + reg_shift_table * sizeof(float)]); + if (dst_data_type_ == data_type::s8 || dst_data_type_ == data_type::u8) { + mov(reg_shift_table, vlen * sizeof(float)); + sub(reg_shift_table, reg_tmp); + uni_vmovups(vreg_store_mask, ptr[reg_table + reg_shift_table]); + } + } + compute(0, 0, true); + advance_ptrs_reg(reg_tmp); L(prologue_loop_end); rewind_ptrs(); @@ -620,40 +656,55 @@ void jit_pp_ker_t::generate() { // Main loop Label main_loop_end; { - cmp(reg_len_, jcp_.oc); - jle(main_loop_end, T_NEAR); - - Label main_loop; - L(main_loop); - { - size_t OC_loop, OC_tail; - if (static_cast(jcp_.oc) < max_unroll_ * vlen) { - // Fully unroll small loops - OC_loop = 0; - OC_tail = jcp_.oc; - } else { - OC_loop = vlen * def_unroll_; - OC_tail = jcp_.oc % OC_loop; - } + cmp(reg_len, OC_); + jl(main_loop_end, T_NEAR); + + size_t OC_loop, OC_tail; + if (OC_ < max_OC_loop_unroll_ * vlen) { + // Fully unroll small loops + OC_loop = 0; + OC_tail = OC_; + } else { + OC_loop = vlen * default_OC_loop_unroll_; + OC_tail = OC_ % OC_loop; + } - assert(!!OC_loop || !!OC_tail); + assert(!!OC_loop || !!OC_tail); - const int vlen_tail = OC_tail % vlen; - if (vlen_tail) { + if (OC_tail % vlen) { + int vlen_tail = OC_tail % vlen; + if (isa == avx512_common) { unsigned tail_mask = (1 << vlen_tail) - 1; - mov(reg_tmp_, tail_mask); - kmovq(kreg_rem_mask_short_, reg_tmp_); + mov(reg_tmp, tail_mask); + kmovq(kreg_rem_mask_short, reg_tmp); + } else { + mov(reg_shift_table, vlen - vlen_tail); + uni_vmovups(vreg_mask, ptr[reg_table + reg_shift_table * sizeof(float)]); + if (dst_data_type_ == data_type::s8 || dst_data_type_ == data_type::u8) { + mov(reg_shift_table, vlen * sizeof(float)); + sub(reg_shift_table, vlen_tail); + uni_vmovups(vreg_store_mask, ptr[reg_table + reg_shift_table]); + } } + } + + Label main_loop; + L(main_loop); + { + if (do_post_ops) + mov(reg_oc_offset, 0); if (OC_loop) { - mov(reg_tmp_, rnd_dn(jcp_.oc, OC_loop)); + mov(reg_tmp, rnd_dn(OC_, OC_loop)); Label oc_loop; L(oc_loop); { for (size_t offset = 0; offset < OC_loop; offset += vlen) compute(offset, offset / vlen, false); - advance_ptrs_imm(OC_loop, vlen); - sub(reg_tmp_, OC_loop); + advance_ptrs_imm(OC_loop); + if (do_post_ops) + add(reg_oc_offset, OC_loop); + sub(reg_tmp, OC_loop); jnz(oc_loop); } } @@ -663,14 +714,12 @@ void jit_pp_ker_t::generate() { bool use_mask = (offset + vlen) > OC_tail; compute(offset, offset / vlen, use_mask); } - const size_t oc_tail_rem = OC_tail % vlen; - const size_t binary_offset = oc_tail_rem ? oc_tail_rem : vlen; - advance_ptrs_imm(OC_tail, binary_offset); + advance_ptrs_imm(OC_tail); } rewind_ptrs(); - sub(reg_len_, jcp_.oc); - cmp(reg_len_, jcp_.oc); + sub(reg_len, OC_); + cmp(reg_len, OC_); jge(main_loop, T_NEAR); } } @@ -679,61 +728,75 @@ void jit_pp_ker_t::generate() { // Epilogue loop Label epilogue_end; { - cmp(reg_len_, 0); + cmp(reg_len, 0); je(epilogue_end, T_NEAR); Label epilogue_loop, epilogue_loop_tail; - cmp(reg_len_, vlen); - jle(epilogue_loop_tail, T_NEAR); + if (do_post_ops) + mov(reg_oc_offset, 0); + cmp(reg_len, vlen); + jl(epilogue_loop_tail, T_NEAR); L(epilogue_loop); { compute(0, 0, false); - sub(reg_len_, vlen); - advance_ptrs_imm(vlen, vlen); - cmp(reg_len_, vlen); + sub(reg_len, vlen); + advance_ptrs_imm(vlen); + if (do_post_ops) + add(reg_oc_offset, vlen); + cmp(reg_len, vlen); jge(epilogue_loop, T_NEAR); } L(epilogue_loop_tail); - mov(reg_tmp_, - reg_len_); // reg_tmp_ is rcx, and we need cl for the shift - mov(reg_rem_mask_short_, 1); - shl(reg_rem_mask_short_, cl); // reg_tmp_ == rcx and reg_tail < vlen - sub(reg_rem_mask_short_, 1); - jz(epilogue_end, T_NEAR); - kmovq(kreg_rem_mask_short_, reg_rem_mask_short_); + mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift + if (isa == avx512_common) { + mov(reg_rem_mask_short, 1); + shl(reg_rem_mask_short, cl); // reg_tmp == rcx and reg_tail < vlen + sub(reg_rem_mask_short, 1); + jz(epilogue_end, T_NEAR); + kmovq(kreg_rem_mask_short, reg_rem_mask_short); + } else { + mov(reg_shift_table, vlen); + sub(reg_shift_table, reg_tmp); + uni_vmovups(vreg_mask, ptr[reg_table + reg_shift_table * sizeof(float)]); + if (dst_data_type_ == data_type::s8 || dst_data_type_ == data_type::u8) { + mov(reg_shift_table, vlen * sizeof(float)); + sub(reg_shift_table, reg_tmp); + uni_vmovups(vreg_store_mask, ptr[reg_table + reg_shift_table]); + } + } compute(0, 0, true); } L(epilogue_end); - if (zp_pad_comp_helper_) zp_pad_comp_helper_->fin(); + if (post_ops_pointers_count != 0) { + add(rsp, post_ops_pointers_count * sizeof(float *)); + } postamble(); - if (jcp_.with_eltwise) postops_injector_->prepare_table(); -} + for (auto& inj : jit_eltwise_injectors_) + inj->prepare_table(); -bool mayiuse_jit_pp_kernel() noexcept { - return mayiuse(avx512_core); + if (utils::one_of(isa, avx2, sse41)) { + align(64); + L(l_table); + for (size_t i = 0; i < vlen; i++) dd(0xFFFFFFFF); + for (size_t i = 0; i < vlen; i++) dd(0x00000000); + } } pp_ker_t *jit_pp_ker_create( const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) { - const auto is_bf16_dst_dt = pd->dst_md()->data_type == data_type::bf16; - return mayiuse_jit_pp_kernel() && !is_bf16_dst_dt - ? new jit_pp_ker_t(pd, jcp) - : nullptr; -} - -bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d) { - using namespace x64::injector; - static constexpr bool sum_at_pos_0_only = true; - static constexpr bool sum_requires_scale_one = false; - return mayiuse_jit_pp_kernel() - && dnnl::impl::cpu::x64::injector::post_ops_ok( - {avx512_core, {binary, eltwise, sum}, post_ops, dst_d, - sum_at_pos_0_only, sum_requires_scale_one}); + if (mayiuse(avx512_common)) { + return new jit_pp_ker_t(pd, jcp); + } else if (mayiuse(avx2)) { + return new jit_pp_ker_t(pd, jcp); + } else if (mayiuse(sse41)) { + return new jit_pp_ker_t(pd, jcp); + } + return nullptr; } } // namespace gemm_x8s8s32x_convolution_utils diff --git a/src/cpu/x64/jit_gemm_x8s8s32x_convolution_utils.hpp b/src/cpu/x64/jit_gemm_x8s8s32x_convolution_utils.hpp index da107135abc..03a4030b933 100644 --- a/src/cpu/x64/jit_gemm_x8s8s32x_convolution_utils.hpp +++ b/src/cpu/x64/jit_gemm_x8s8s32x_convolution_utils.hpp @@ -28,10 +28,8 @@ namespace gemm_x8s8s32x_convolution_utils { cpu::gemm_x8s8s32x_convolution_utils::pp_ker_t *jit_pp_ker_create( const convolution_pd_t *pd, const conv_gemm_conf_t &jcp); -bool mayiuse_jit_pp_kernel() noexcept; -bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d); -} // namespace gemm_x8s8s32x_convolution_utils +} // namespace gemm_x8s8s32x_convolutilon_utils } // namespace x64 } // namespace cpu } // namespace impl diff --git a/src/cpu/x64/jit_generator.hpp b/src/cpu/x64/jit_generator.hpp index c31c22b13a2..18988af1844 100644 --- a/src/cpu/x64/jit_generator.hpp +++ b/src/cpu/x64/jit_generator.hpp @@ -156,6 +156,7 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { _cmp_nlt_us = 5u, _cmp_nle_us = 6u, + _op_near = 0u, _op_floor = 1u, _op_mxcsr = 4u, }; @@ -420,6 +421,14 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { else movdqu(x, addr); } + + void uni_vmovdqu(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2) { + if (is_valid_isa(avx)) + vmovdqu(x1, x2); + else + movdqu(x1, x2); + } + void uni_vmovdqu(const Xbyak::Ymm &x, const Xbyak::Address &addr) { vmovdqu(x, addr); } @@ -508,6 +517,7 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { pshufd(x, x, 0x0); } } + void uni_vpbroadcastd(const Xbyak::Ymm &x, const Xbyak::Operand &op) { if (is_valid_isa(avx2)) { vpbroadcastd(x, op); @@ -544,7 +554,10 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { } void uni_vrcpss(const Xbyak::Xmm &x, const Xbyak::Operand &op) { - rcpss(x, op); + if (is_valid_isa(avx)) + vrcpss(x, x, op); + else + rcpss(x, op); } void uni_vrcpss(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2) { Xbyak::Xmm x1_(x1.getIdx()); @@ -557,7 +570,10 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { } void uni_vrcpps(const Xbyak::Xmm &x, const Xbyak::Operand &op) { - rcpps(x, op); + if (is_valid_isa(avx)) + vrcpps(x, op); + else + rcpps(x, op); } void uni_vrcpps(const Xbyak::Ymm &x, const Xbyak::Operand &op) { vrcpps(x, op); @@ -592,9 +608,14 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, const Xbyak::Operand &op2, const Xbyak::Xmm &buf) { - movups(buf, op1); - divps(buf, op2); - if (x.getIdx() != buf.getIdx()) { movups(x, buf); } + if (is_valid_isa(avx)) { + vdivps(buf, op1, op2); + if (x.getIdx() != buf.getIdx()) { vmovups(x, buf); } + } else { + movups(buf, op1); + divps(buf, op2); + if (x.getIdx() != buf.getIdx()) { movups(x, buf); } + } } void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, @@ -607,7 +628,7 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { if (is_valid_isa(avx)) vaddps(x, op1, op2); else { - assert(x.getIdx() == op1.getIdx()); + if (!x.isEqualIfNotInherited(op1)) movups(x, op1); addps(x, op2); } } @@ -651,8 +672,12 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { void uni_vpsignd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) { - assert(x1.getIdx() == x2.getIdx()); - psignd(x1, op); + if (is_valid_isa(avx)) + vpsignd(x1, x2, op); + else { + assert(x1.getIdx() == x2.getIdx()); + psignd(x1, op); + } } void uni_vpsignd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) { @@ -661,8 +686,12 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { void uni_vpsubd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) { - assert(x1.getIdx() == x2.getIdx()); - psubd(x1, op); + if (is_valid_isa(avx)) + vpsubd(x1, x2, op); + else { + assert(x1.getIdx() == x2.getIdx()); + psubd(x1, op); + } } void uni_vpsubd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) { @@ -671,8 +700,12 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { void uni_vpsubb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) { - assert(x1.getIdx() == x2.getIdx()); - psubb(x1, op); + if (is_valid_isa(avx)) + vpsubb(x1, x2, op); + else { + assert(x1.getIdx() == x2.getIdx()); + psubb(x1, op); + } } void uni_vpsubb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) { @@ -681,8 +714,13 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { void uni_vsubss(const Xbyak::Xmm &x, const Xbyak::Operand &op1, const Xbyak::Operand &op2) { - assert(x.isEqualIfNotInherited(op1)); - subps(x, op2); + if (is_valid_isa(avx)) + // previously there was "subps(x, op2)" for some reason + vsubss(x, op1, op2); + else { + assert(x.isEqualIfNotInherited(op1)); + subss(x, op2); + } } void uni_vsubss(const Xbyak::Ymm &x, const Xbyak::Operand &op1, const Xbyak::Operand &op2) { @@ -694,7 +732,7 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { if (is_valid_isa(avx)) vsubps(x, op1, op2); else { - assert(x.isEqualIfNotInherited(op1)); + if (!x.isEqualIfNotInherited(op1)) movups(x, op1); subps(x, op2); } } @@ -705,9 +743,14 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, const Xbyak::Operand &op2, const Xbyak::Xmm &buf) { - movups(buf, op1); - subps(buf, op2); - if (x.getIdx() != buf.getIdx()) { movups(x, buf); } + if (is_valid_isa(avx)) { + vsubps(buf, op1, op2); + if (x.getIdx() != buf.getIdx()) { vmovups(x, buf); } + } else { + movups(buf, op1); + subps(buf, op2); + if (x.getIdx() != buf.getIdx()) { movups(x, buf); } + } } void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, @@ -763,11 +806,19 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { void uni_vfmadd132ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) { - // Note: x1 gets overriden by x1*op - // This is incorrect if x1 == x2 - assert(x1.getIdx() != x2.getIdx()); - mulps(x1, op); - addps(x1, x2); + if (is_valid_isa(avx2)) + vfmadd132ps(x1, x2 , op); + else if (is_valid_isa(avx)) { + assert(x1.getIdx() != x2.getIdx()); + vmulps(x1, x1, op); + vaddps(x1, x1, x2); + } else { + // Note: x1 gets overriden by x1*op + // This is incorrect if x1 == x2 + assert(x1.getIdx() != x2.getIdx()); + mulps(x1, op); + addps(x1, x2); + } } void uni_vfmadd132ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) { @@ -784,12 +835,21 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { void uni_vfmadd213ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) { - // Note: x1 gets overriden by x1*x2 - // This is incorrect if x1 == op - assert(!x1.isEqualIfNotInherited(op)); - mulps(x1, x2); - addps(x1, op); + if (is_valid_isa(avx2)) + vfmadd213ps(x1, x2, op); + else if (is_valid_isa(avx)) { + assert(!x1.isEqualIfNotInherited(op)); + vmulps(x1, x1, x2); + vaddps(x1, x1, op); + } else { + // Note: x1 gets overriden by x1*x2 + // This is incorrect if x1 == op + assert(!x1.isEqualIfNotInherited(op)); + mulps(x1, x2); + addps(x1, op); + } } + void uni_vfmadd213ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) { if (is_valid_isa(avx2)) @@ -805,11 +865,19 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { void uni_vfmadd213ss(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) { - // Note: x1 gets overriden by x1*x2 - // This is incorrect if x1 == op - assert(!x1.isEqualIfNotInherited(op)); - mulss(x1, x2); - addss(x1, op); + if (is_valid_isa(avx2)) + vfmadd213ss(x1, x2 ,op); + if (is_valid_isa(avx)) { + assert(!x1.isEqualIfNotInherited(op)); + vmulss(x1, x1, x2); + vaddss(x1, x1, op); + } else { + // Note: x1 gets overriden by x1*x2 + // This is incorrect if x1 == op + assert(!x1.isEqualIfNotInherited(op)); + mulss(x1, x2); + addss(x1, op); + } } void uni_vfmadd213ss(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) { @@ -826,11 +894,19 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { void uni_vfmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) { - // Note: x2 gets overriden by x2*op - // This is incorrect if x1 == x2 - assert(x1.getIdx() != x2.getIdx()); - mulps(x2, op); - addps(x1, x2); + if (is_valid_isa(avx2)) { + vfmadd231ps(x1, x2, op); + } else if (is_valid_isa(avx)) { + assert(x1.getIdx() != x2.getIdx()); + vmulps(x2, x2, op); + vaddps(x1, x1, x2); + } else { + // Note: x2 gets overriden by x2*op + // This is incorrect if x1 == x2 + assert(x1.getIdx() != x2.getIdx()); + mulps(x2, op); + addps(x1, x2); + } } void uni_vfmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) { @@ -932,12 +1008,21 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { void uni_vfmsub213ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) { - // Note: x1 gets overriden by x1*x2 - // This is incorrect if x1 == op - assert(!x1.isEqualIfNotInherited(op)); - mulps(x1, x2); - subps(x1, op); + if (is_valid_isa(avx2)) { + vfmsub213ps(x1, x2, op); + } else if (is_valid_isa(avx)) { + assert(!x1.isEqualIfNotInherited(op)); + vmulps(x1, x1, x2); + vsubps(x1, x1, op); + } else { + // Note: x1 gets overriden by x1*x2 + // This is incorrect if x1 == op + assert(!x1.isEqualIfNotInherited(op)); + mulps(x1, x2); + subps(x1, op); + } } + void uni_vfmsub213ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) { if (is_valid_isa(avx2)) @@ -952,7 +1037,11 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { } void uni_vsqrtps(const Xbyak::Xmm &x, const Xbyak::Operand &op) { - sqrtps(x, op); + if (is_valid_isa(avx)) { + vsqrtps(x, op); + } else { + sqrtps(x, op); + } } void uni_vsqrtps(const Xbyak::Ymm &x, const Xbyak::Operand &op) { vsqrtps(x, op); @@ -1016,8 +1105,12 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { void uni_vandps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) { - assert(x1.getIdx() == x2.getIdx()); - andps(x1, op); + if (is_valid_isa(avx)) + vandps(x1, x2, op); + else { + assert(x1.getIdx() == x2.getIdx()); + andps(x1, op); + } } void uni_vandps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) { @@ -1029,8 +1122,12 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { void uni_vorps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) { - assert(x1.getIdx() == x2.getIdx()); - orps(x1, op); + if (is_valid_isa(avx)) + vorps(x1, x2, op); + else { + assert(x1.getIdx() == x2.getIdx()); + orps(x1, op); + } } void uni_vorps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) { @@ -1059,8 +1156,12 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { void uni_vpslld( const Xbyak::Xmm &x, const Xbyak::Operand &op, const int imm) { - assert(x.isEqualIfNotInherited(op)); - pslld(x, imm); + if (is_valid_isa(avx)) + vpslld(x, op, imm); + else { + assert(x.isEqualIfNotInherited(op)); + pslld(x, imm); + } } void uni_vpslld( const Xbyak::Ymm &x, const Xbyak::Operand &op, const int imm) { @@ -1069,8 +1170,12 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { void uni_vpsrld( const Xbyak::Xmm &x, const Xbyak::Operand &op, const int imm) { - if (!x.isEqualIfNotInherited(op)) uni_vmovups(x, op); - psrld(x, imm); + if (is_valid_isa(avx)) + vpsrld(x, op, imm); + else { + if (!x.isEqualIfNotInherited(op)) uni_vmovups(x, op); + psrld(x, imm); + } } void uni_vpsrld( const Xbyak::Ymm &x, const Xbyak::Operand &op, const int imm) { @@ -1152,16 +1257,33 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { void uni_vcmpps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op, int cmp_predicate) { - if (x1.getIdx() != x2.getIdx()) uni_vmovups(x1, x2); - cmpps(x1, op, cmp_predicate); + if (is_valid_isa(avx)) + vcmpps(x1, x2 ,op, cmp_predicate); + else { + if (x1.getIdx() != x2.getIdx()) uni_vmovups(x1, x2); + cmpps(x1, op, cmp_predicate); + } } void uni_vcmpps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op, int cmp_predicate) { vcmpps(x1, x2, op, cmp_predicate); } + void uni_cmpneqps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + if (is_valid_isa(avx)) + vcmpneqps(x1, x2, op); + else { + if (x1.getIdx() != x2.getIdx()) uni_vmovups(x1, x2); + cmpneqps(x1, op); + } + } + void uni_vtestps(const Xbyak::Xmm &x1, const Xbyak::Operand &op) { - ptest(x1, op); + if (is_valid_isa(avx)) + vptest(x1, op); + else + ptest(x1, op); } void uni_vtestps(const Xbyak::Ymm &x1, const Xbyak::Operand &op) { @@ -1171,9 +1293,13 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { void uni_vblendvps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op, const Xbyak::Xmm &msk) { - assert(x1.getIdx() == x2.getIdx()); - assert(msk.getIdx() == 0); - blendvps(x1, op); + if (is_valid_isa(avx)) + vblendvps(x1, x2, op, msk); + else { + assert(x1.getIdx() == x2.getIdx()); + assert(msk.getIdx() == 0); + blendvps(x1, op); + } } void uni_vblendvps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op, const Xbyak::Ymm &msk) { @@ -1194,7 +1320,10 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { void uni_vroundps( const Xbyak::Xmm &x, const Xbyak::Operand &op, const int imm) { - roundps(x, op, imm); + if (is_valid_isa(avx)) + vroundps(x, op, imm); + else + roundps(x, op, imm); } void uni_vroundps( const Xbyak::Ymm &x, const Xbyak::Operand &op, const int imm) { @@ -1278,6 +1407,18 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { else movq(addr, x); } + void uni_vmovq(const Xbyak::Reg64 &r, const Xbyak::Xmm &x) { + if (is_valid_isa(avx)) + vmovq(r, x); + else + movq(r, x); + } + void uni_vmovq(const Xbyak::Xmm &x, const Xbyak::Address &addr) { + if (is_valid_isa(avx)) + vmovq(x, addr); + else + movq(x, addr); + } void uni_vpackssdw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) { @@ -1463,6 +1604,147 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { } } + void uni_vcmpgtps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + if (is_valid_isa(avx)) + vcmpps(x1, x2, op, _cmp_nle_us); + else { + assert(x1.getIdx() == x2.getIdx()); + cmpps(x1, op, _cmp_nle_us); + } + } + + void uni_vcmpgtps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vcmpgtps(x1, x2, op); + } + + void uni_vpmovzxwd(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + if (is_valid_isa(avx)) + vpmovzxwd(x, op); + else + pmovzxwd(x, op); + } + void uni_vpmovzxwd(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + vpmovzxwd(x, op); + } + + void uni_vpmovsxwd(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + if (is_valid_isa(avx)) + vpmovsxwd(x, op); + else + pmovsxwd(x, op); + } + void uni_vpmovsxwd(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + vpmovsxwd(x, op); + } + + void uni_vpmovsxdq(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + if (is_valid_isa(avx)) + vpmovsxdq(x, op); + else + pmovsxdq(x, op); + } + + void uni_vpackusdw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) { + if (is_valid_isa(avx)) + vpackusdw(x1, x2, op); + else { + assert(x1.getIdx() == x2.getIdx()); + packusdw(x1, op); + } + } + void uni_vpackusdw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) { + vpackusdw(x1, x2, op); + } + + void uni_vpcmpeqd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + if (is_valid_isa(avx)) + vpcmpeqd(x1, x2, op); + else { + assert(x1.getIdx() == x2.getIdx()); + pcmpeqd(x1, op); + } + } + + void uni_vpcmpeqd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vpcmpeqd(x1, x2, op); + } + + void uni_vpminsd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) { + if (is_valid_isa(avx)) + vpminsd(x1, x2, op); + else { + assert(x1.getIdx() == x2.getIdx()); + pminsd(x1, op); + } + } + void uni_vpminsd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) { + vpminsd(x1, x2, op); + } + + void uni_vpand(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + if (is_valid_isa(avx)) + vpand(x1, x2, op); + else { + assert(x1.getIdx() == x2.getIdx()); + pand(x1, op); + } + } + void uni_vpand(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vpand(x1, x2, op); + } + + void uni_vpshufb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + if (is_valid_isa(avx)) + vpshufb(x1, x2, op); + else { + assert(x1.getIdx() == x2.getIdx()); + pshufb(x1, op); + } + } + + void uni_vpshufb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vpshufb(x1, x2, op); + } + + void uni_vpslldq(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::uint8 &op) { + if (is_valid_isa(avx)) + vpslldq(x1, x2, op); + else { + assert(x1.getIdx() == x2.getIdx()); + pslldq(x1, op); + } + } + + void uni_vpslldq(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::uint8 &op) { + vpslldq(x1, x2, op); + } + + + void uni_vmovshdup(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + if (is_valid_isa(avx)) + vmovshdup(x, op); + else + movshdup(x, op); + } + + void uni_vmovhlps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Xmm &x3) { + if (is_valid_isa(avx)) + vmovhlps(x1, x2, x3); + else { + assert(x1.getIdx() == x2.getIdx()); + movhlps(x1, x3); + } + } + void mul_by_const( const Xbyak::Reg &out, const Xbyak::Reg64 &tmp, int value) { // Generates a shift + add sequence for multiplicating contents of the diff --git a/src/cpu/x64/jit_primitive_conf.hpp b/src/cpu/x64/jit_primitive_conf.hpp index 2f0d5282682..7a79cfa5c90 100644 --- a/src/cpu/x64/jit_primitive_conf.hpp +++ b/src/cpu/x64/jit_primitive_conf.hpp @@ -109,6 +109,8 @@ struct jit_conv_conf_t { bool with_sum; bool with_eltwise; bool with_binary; + bool with_depthwise; + bool with_quantization; data_type_t sum_dt; @@ -262,6 +264,16 @@ struct jit_conv_conf_t { int max_width; bool transform_to_vnni; + + bool with_input_zp; + bool with_weights_zp; + + int oh_block; + int oh_block_step; + int nb_ow_blocking; + + int dw_conv_oh, dw_conv_ow; + data_type_t dw_conv_dst_dt; }; // calculates filter size taking into account dilation @@ -275,6 +287,12 @@ inline int calculate_end_padding(int start_padding, int dst_size, int src_size, - (src_size + start_padding); } +inline int calculate_end_padding_log(int start_padding, int dst_size, int src_size, + int spatial_stride, int dilated_filter_size, int end_pad) { + return (dst_size - 1) * spatial_stride + dilated_filter_size + - (src_size + start_padding + end_pad); +} + inline status_t init_tag(format_tag_t &tag, const memory_desc_wrapper &mdw, const format_tag_t &tag_value) { if (mdw.format_kind() == format_kind::any) return status::unimplemented; @@ -488,6 +506,18 @@ struct jit_conv_call_s { int oc_flag; size_t last_ic_block; size_t last_oc_block; + + size_t oc_off; + size_t ic_off; + size_t oc_off_prf; + size_t oh_blocks; + + const void *input_zp; + + size_t oc_work; + const void *src_row0; /* hack, non-const for backward_data */ + const void *src_row1; /* hack, non-const for backward_data */ + const void *src_row2; /* hack, non-const for backward_data */ }; struct jit_deconv_call_s { @@ -520,6 +550,7 @@ struct jit_deconv_call_s { size_t kh_padding; size_t kd_padding; size_t oc_blocks; + size_t oc_off; }; struct jit_dw_conv_call_s { @@ -567,6 +598,8 @@ struct jit_1x1_conv_conf_t { bool with_sum; bool with_eltwise; bool with_binary; + bool with_depthwise; + bool with_quantization; bool with_dw_conv; post_ops_t post_ops; @@ -603,6 +636,7 @@ struct jit_1x1_conv_conf_t { int tr_is; int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b; int is_oc_scale; + data_type_t src_dt; data_type_t bia_dt; data_type_t dst_dt; data_type_t sum_dt; @@ -615,6 +649,12 @@ struct jit_1x1_conv_conf_t { cpu_isa_t isa; bool uses_permw_transposition; + + bool with_input_zp; + bool with_weights_zp; + + int dw_conv_oh, dw_conv_ow; + data_type_t dw_conv_dst_dt; }; struct jit_1x1_conv_call_s { @@ -648,6 +688,8 @@ struct jit_1x1_conv_call_s { size_t output_stride; // used in backward_weights only size_t first_last_flag; + + size_t oc_off; }; struct jit_pool_conf_t { @@ -656,7 +698,7 @@ struct jit_pool_conf_t { int id, ih, iw, od, oh, ow; int stride_d, stride_h, stride_w; int kd, kh, kw; - int f_pad, t_pad, l_pad; + int f_pad, t_pad, l_pad, b_pad, r_pad, back_pad; alg_kind_t alg; bool is_training; bool pad_w_is_null; @@ -687,6 +729,8 @@ struct jit_pool_conf_t { bool with_postops; bool with_eltwise; bool with_binary; + bool with_depthwise; + bool with_quantization; int nthr; }; @@ -1014,6 +1058,26 @@ struct jit_reduction_call_s { void *dst = nullptr; }; +/* softmax */ +struct jit_softmax_conf_t { + size_t outer_size; + size_t channels; + size_t inner_size; + size_t ur_channel; + size_t ur_inner; + size_t outer_block; + size_t dt_size; + data_type_t dt; +}; + +struct jit_softmax_call_s { + const uint8_t* src; + uint8_t* dst; + + size_t channels; + size_t work; +}; + } // namespace x64 } // namespace cpu } // namespace impl diff --git a/src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.cpp b/src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.cpp index c10b214838c..a6d45a36ab6 100644 --- a/src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.cpp +++ b/src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.cpp @@ -44,7 +44,7 @@ jit_sse41_1x1_conv_kernel_f32::jit_sse41_1x1_conv_kernel_f32( : jit_generator(nullptr, MAX_CODE_SIZE, true, sse41) , jcp(ajcp) , attr_(attr) { - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; static constexpr size_t helper_vmm_idx = 15; @@ -58,9 +58,12 @@ jit_sse41_1x1_conv_kernel_f32::jit_sse41_1x1_conv_kernel_f32( use_exact_tail_scalar_bcast}; const binary_injector::static_params_t static_params { this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params + {xmm_d_weights.getIdx(), xmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; + postops_injector_ = utils::make_unique< injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } } @@ -130,6 +133,15 @@ static void iterate(const int load_loop_blk, const int ur, const F &f) { } void jit_sse41_1x1_conv_kernel_f32::apply_postops( const int load_loop_blk, const int ur) { + std::map vmm_idx_off; + iterate(load_loop_blk, ur, + [&](const int i, const int j, const int n) { + vmm_idx_off.insert({reg_accum_idx(load_loop_blk, i, j, n), (2 * i + n) * jcp.load_block / 2 * sizeof(float)}); + }); + depthwise_injector::dynamic_params_t ddp {xmm_d_weights.getIdx(), xmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + reg_oc_off, vmm_idx_off, this->rsp, base_post_ops_data_offset}; + quantization_injector::dynamic_params_t qdp {reg_oc_off, vmm_idx_off, jcp.dst_dt, this->rsp, base_post_ops_data_offset}; + injector_utils::vmm_index_set_t vmm_idxs; if (jcp.with_binary) { binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; @@ -153,12 +165,12 @@ void jit_sse41_1x1_conv_kernel_f32::apply_postops( mov(abi_param1, ptr[rsp + reg_abi_param1_backup + reg_guard_stack_occupied]); - postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); } else { iterate(load_loop_blk, ur, [&](const int i, const int j, const int n) { vmm_idxs.emplace(reg_accum_idx(load_loop_blk, i, j, n)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t(), ddp, qdp); } } @@ -281,7 +293,7 @@ void jit_sse41_1x1_conv_kernel_f32::generate_reduce_loop( L(store_noadd); - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { assert(ur * load_loop_blk < 14); Label store_nopostops; @@ -425,7 +437,12 @@ void jit_sse41_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk) { void jit_sse41_1x1_conv_kernel_f32::generate() { preamble(); + if (postops_injector_) + postops_injector_->push_post_ops_data_on_stack(this->param1, GET_OFF(post_ops_binary_rhs_arg_vec), reg_bcast_data, reg_load_data); + sub(rsp, stack_space_needed); + base_post_ops_data_offset += stack_space_needed; + if (jcp.with_binary) { // backup abi_param1 for usage in post_ops processing mov(ptr[rsp + reg_abi_param1_backup], abi_param1); @@ -453,6 +470,7 @@ void jit_sse41_1x1_conv_kernel_f32::generate() { mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); if (jcp.prop_kind == backward_weights) mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); + mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]); auto generate_load_loop_body = [=](int load_loop_blk) { generate_bcast_loop(load_loop_blk); @@ -484,6 +502,7 @@ void jit_sse41_1x1_conv_kernel_f32::generate() { default: assert(!"invalid prop_kind"); } sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + add(reg_oc_off, load_loop_blk * jcp.oc_block * sizeof(float)); }; Label load_loop_blk_8; @@ -532,6 +551,10 @@ void jit_sse41_1x1_conv_kernel_f32::generate() { L(load_loop_blk_end); add(rsp, stack_space_needed); + base_post_ops_data_offset -= stack_space_needed; + + if (postops_injector_) + postops_injector_->reset_stack_pointer(); postamble(); @@ -557,6 +580,7 @@ status_t jit_sse41_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp, jcp.mb = src_d.dims()[0]; jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; jcp.ic = src_d.dims()[1] / jcp.ngroups; jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2]; @@ -594,6 +618,8 @@ status_t jit_sse41_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp, const int binary_ind = post_ops.find(primitive_kind::binary, 0, dw_conv_ind); jcp.with_binary = binary_ind != -1; + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise, 0, dw_conv_ind) != -1; + jcp.with_quantization = post_ops.find(primitive_kind::quantization, 0, dw_conv_ind) != -1; if (dw_conv_ind >= 0) { // dw_conv and post_ops after it are handled externally, so skip them @@ -607,15 +633,15 @@ status_t jit_sse41_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp, static constexpr bool sum_at_pos_0_only = true; static constexpr bool sum_requires_scale_one = true; static constexpr bool sum_requires_zp_zero = true; - const bool post_ops_ok_ = post_ops_ok({sse41, {eltwise, binary, sum}, + const bool post_ops_ok_ = post_ops_ok({sse41, {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero}); if (!post_ops_ok_) return status::unimplemented; const auto dat_tag_nxc = utils::pick(ndims - 3, nwc, nhwc); const auto dat_tag_blocked = utils::pick(ndims - 3, nCw8c, nChw8c); - jcp.src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_blocked); - jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_blocked); + jcp.src_tag = src_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_blocked); + jcp.dst_tag = dst_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_blocked); const bool is_data_layout_nxc = utils::everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag); const auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_blocked; @@ -634,6 +660,12 @@ status_t jit_sse41_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp, const int simd_w = 4; + bool ok_to_pad_channels = true && !is_data_layout_nxc && jcp.ngroups == 1; + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w*2); + jcp.ic = rnd_up(jcp.ic, simd_w*2); + } + jcp.ic_block = jcp.oc_block = simd_w * 2; args_ok = true && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0 @@ -795,6 +827,15 @@ status_t jit_sse41_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp, return status::success; } +void jit_sse41_1x1_conv_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp) { + using namespace dnnl::impl::memory_tracking::names; + + if (jcp.prop_kind != backward_data && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc); +} + } // namespace x64 } // namespace cpu } // namespace impl diff --git a/src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.hpp b/src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.hpp index 4f2632866a3..2ce5179e54e 100644 --- a/src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.hpp +++ b/src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.hpp @@ -17,6 +17,7 @@ #ifndef CPU_X64_JIT_SSE41_1X1_CONV_KERNEL_F32_HPP #define CPU_X64_JIT_SSE41_1X1_CONV_KERNEL_F32_HPP +#include "common/memory_tracking.hpp" #include "common/c_types_map.hpp" #include "common/memory.hpp" @@ -39,6 +40,9 @@ struct jit_sse41_1x1_conv_kernel_f32 : public jit_generator { const memory_desc_wrapper &dst_d, const primitive_attr_t &attr, int nthreads); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp); + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_1x1_conv_kernel_f32) jit_1x1_conv_conf_t jcp; @@ -54,12 +58,12 @@ struct jit_sse41_1x1_conv_kernel_f32 : public jit_generator { reg64_t reg_output_data = rbx; reg64_t aux_reg_bcast_data = rdx; reg64_t aux1_reg_bcast_data = abi_not_param1; - reg64_t aux_reg_load_data = abi_param1; reg64_t aux_reg_output_data = rbp; reg64_t reg_load_loop_work = r9; reg64_t reg_bcast_loop_work = r10; reg64_t reg_reduce_loop_work = r11; reg64_t load_loop_iter = r13; + reg64_t aux_reg_load_data = load_loop_iter; reg64_t imm_addr64 = load_loop_iter; reg64_t bcast_loop_iter = r14; reg64_t reduce_loop_iter = r15; @@ -79,6 +83,14 @@ struct jit_sse41_1x1_conv_kernel_f32 : public jit_generator { std::unique_ptr> postops_injector_; + reg64_t reg_oc_off = abi_param1; + reg64_t reg_d_weights = aux_reg_bcast_data; + reg64_t reg_d_bias = reduce_loop_iter; + int base_post_ops_data_offset = 0; + + Xbyak::Xmm xmm_d_weights = Xbyak::Xmm(14); + Xbyak::Xmm xmm_d_bias = Xbyak::Xmm(15); + void generate_bcast_loop(int load_loop_blk); void generate_reduce_loop(int load_loop_blk, int ur); void generate_diff_bias_loop(int load_loop_blk); diff --git a/src/cpu/x64/jit_sse41_1x1_convolution.cpp b/src/cpu/x64/jit_sse41_1x1_convolution.cpp index cfa93ed7a91..244199a68d8 100644 --- a/src/cpu/x64/jit_sse41_1x1_convolution.cpp +++ b/src/cpu/x64/jit_sse41_1x1_convolution.cpp @@ -32,6 +32,7 @@ namespace x64 { using namespace dnnl::impl::status; using namespace dnnl::impl::utils; +using namespace dnnl::impl::memory_tracking::names; void jit_sse41_1x1_convolution_fwd_t::execute_forward( const exec_ctx_t &ctx) const { @@ -51,11 +52,22 @@ void jit_sse41_1x1_convolution_fwd_t::execute_forward( pd()->jcp_.post_ops.entry_.size() + 1) : std::vector {}; + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + auto scratchpad = ctx.get_scratchpad_grantor(); + + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad.get(key_conv_padded_bias); + utils::array_copy(padded_bias, bias, kernel_->jcp.oc_without_padding); + utils::array_set(padded_bias + kernel_->jcp.oc_without_padding, 0.f, + kernel_->jcp.oc - kernel_->jcp.oc_without_padding); + bias = padded_bias; + } + parallel(kernel_->jcp.nthr, [&](const int ithr, const int nthr) { execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw, dst, scratchpad, post_ops_binary_rhs_arg_vec.data(), - post_ops_binary_rhs_arg_vec_dw.data()); + post_ops_binary_rhs_arg_vec_dw.data(), MB); }); if (pd()->wants_zero_pad_dst()) ctx.zero_pad_output(DNNL_ARG_DST); @@ -66,7 +78,7 @@ void jit_sse41_1x1_convolution_fwd_t::execute_forward_thr(const int ithr, const data_t *bias, const data_t *weights_dw, const data_t *bias_dw, data_t *dst, const memory_tracking::grantor_t &scratchpad, const void *post_ops_binary_rhs_arg_vec, - const void *post_ops_binary_rhs_arg_vec_dw) const { + const void *post_ops_binary_rhs_arg_vec_dw, int MB) const { const memory_desc_wrapper src_d(pd()->src_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); @@ -115,7 +127,7 @@ void jit_sse41_1x1_convolution_fwd_t::execute_forward_thr(const int ithr, int bcast_end, int &oh, int &ow, int &ih, int &iw) { int osb {0}; - nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, nb_bcast); + nd_iterator_init(iwork, n, MB, g, jcp.ngroups, osb, nb_bcast); bcast_step = step( nb_bcast_blocking, nb_bcast - osb, nb_bcast_blocking_max); @@ -172,6 +184,7 @@ void jit_sse41_1x1_convolution_fwd_t::execute_forward_thr(const int ithr, par_conv.oc_l_off = _ocb * jcp.oc_block; par_conv.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; par_conv.dst_orig = jcp.with_dw_conv ? pbuf : dst; + par_conv.oc_off = _ocb * jcp.oc_block * sizeof(float); (*kernel_)(&par_conv); }; @@ -248,6 +261,8 @@ void jit_sse41_1x1_convolution_fwd_t::execute_forward_thr(const int ithr, = post_ops_binary_rhs_arg_vec_dw; par_conv_dw.dst_orig = dst; + par_conv_dw.oc_off = ch * jcp_dw.ch_block * sizeof(float); + (*kernel_dw_)(&par_conv_dw); for (int i = 0; i < jcp_dw.kh; ++i) @@ -270,7 +285,7 @@ void jit_sse41_1x1_convolution_fwd_t::execute_forward_thr(const int ithr, addrs.resize(jcp_dw.kh); int bcast_start {0}, bcast_end {0}, ocb_start, ocb_end; - balance2D(nthr, ithr, jcp.mb * jcp.ngroups * jcp_dw.oh, bcast_start, + balance2D(nthr, ithr, MB * jcp.ngroups * jcp_dw.oh, bcast_start, bcast_end, nb_oc, ocb_start, ocb_end, 1); while (ocb_start < ocb_end) { @@ -281,7 +296,7 @@ void jit_sse41_1x1_convolution_fwd_t::execute_forward_thr(const int ithr, auto bcast_iter = bcast_start; while (bcast_iter < bcast_end) { int n, g, oh_dw; - nd_iterator_init(bcast_iter, n, jcp.mb, g, jcp.ngroups, oh_dw, + nd_iterator_init(bcast_iter, n, MB, g, jcp.ngroups, oh_dw, jcp_dw.oh); if (oh_dw == 0) oh_1x1 = 0; // Reset over mb boundary const int oh_1x1_range = oh_dw * jcp_dw.stride_h - jcp_dw.t_pad; @@ -310,7 +325,7 @@ void jit_sse41_1x1_convolution_fwd_t::execute_forward_thr(const int ithr, if (jcp.with_dw_conv) { conv_dw(); } else { - const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + const int work_amount = MB * jcp.ngroups * jcp.nb_bcast; int start {0}, end {0}; balance211(work_amount, nthr, ithr, start, end); conv_1x1(start, end, 0, jcp.nb_load); diff --git a/src/cpu/x64/jit_sse41_1x1_convolution.hpp b/src/cpu/x64/jit_sse41_1x1_convolution.hpp index 6c7e1ce8bd5..09024be03f7 100644 --- a/src/cpu/x64/jit_sse41_1x1_convolution.hpp +++ b/src/cpu/x64/jit_sse41_1x1_convolution.hpp @@ -65,7 +65,14 @@ struct jit_sse41_1x1_convolution_fwd_t : public primitive_t { CHECK(jit_sse41_1x1_conv_kernel_f32::init_conf(jcp_, *desc(), *src_md(), *weights_md(), *dst_md(), *attr(), dnnl_get_max_threads())); - if (jcp_.with_dw_conv) CHECK(depthwise_po_init(engine)); + if (jcp_.with_dw_conv) { + // todo: [antonvor] enable when new behavior of dw convolution fusing from oneDNN 1.6 will be supported + return status::unimplemented; + CHECK(depthwise_po_init(engine)); + } + + auto scratchpad = scratchpad_registry().registrar(); + jit_sse41_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_); return status::success; } @@ -225,7 +232,7 @@ struct jit_sse41_1x1_convolution_fwd_t : public primitive_t { if (pd()->jcp_.with_dw_conv) { CHECK(safe_ptr_assign(kernel_dw_, new dw_conv_kernel_t( - pd()->dw_conv_pd_->jcp_, *pd()->dst_md(0)))); + pd()->dw_conv_pd_->jcp_, *pd()->dst_md(0), *pd()->dw_conv_pd_->attr()))); return kernel_dw_->create_kernel(); } @@ -244,7 +251,7 @@ struct jit_sse41_1x1_convolution_fwd_t : public primitive_t { const data_t *bias_dw, data_t *dst, const memory_tracking::grantor_t &scratchpad, const void *post_ops_binary_rhs_arg_vec, - const void *post_ops_binary_rhs_arg_vec_dw) const; + const void *post_ops_binary_rhs_arg_vec_dw, int MB) const; const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } std::unique_ptr kernel_; using dw_conv_kernel_t = jit_uni_dw_conv_fwd_kernel_f32; diff --git a/src/cpu/x64/jit_sse41_conv_kernel_f32.cpp b/src/cpu/x64/jit_sse41_conv_kernel_f32.cpp index 4b374c1113a..2d22f80be6a 100644 --- a/src/cpu/x64/jit_sse41_conv_kernel_f32.cpp +++ b/src/cpu/x64/jit_sse41_conv_kernel_f32.cpp @@ -34,6 +34,7 @@ namespace x64 { using namespace dnnl::impl::format_tag; using namespace dnnl::impl::prop_kind; using namespace dnnl::impl::utils; +using namespace dnnl::impl::memory_tracking::names; using namespace Xbyak; @@ -41,7 +42,7 @@ jit_sse41_conv_fwd_kernel_f32::jit_sse41_conv_fwd_kernel_f32( const jit_conv_conf_t &ajcp, const primitive_attr_t &attr, const memory_desc_t &dst_md) : jit_generator(nullptr, MAX_CODE_SIZE, sse41), jcp(ajcp), attr_(attr) { - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; static constexpr size_t helper_vmm_idx = 15; @@ -55,10 +56,12 @@ jit_sse41_conv_fwd_kernel_f32::jit_sse41_conv_fwd_kernel_f32( use_exact_tail_scalar_bcast}; const binary_injector::static_params_t static_params { this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params + {xmm_d_weights.getIdx(), xmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique< injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } } @@ -156,6 +159,14 @@ static void iterate(const int oc_blocks, const int ur_w, const F &f) { } void jit_sse41_conv_fwd_kernel_f32::apply_postops( const int oc_blocks, const int ur_w) { + std::map vmm_idx_off; + iterate(oc_blocks, ur_w, [&](const bool, const int i, const int j) { + vmm_idx_off.insert({get_xmm_idx(ur_w, i, j), i * jcp.oc_block * sizeof(float)}); + }); + depthwise_injector::dynamic_params_t ddp {xmm_d_weights.getIdx(), xmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + reg_oc_off, vmm_idx_off, this->rsp}; + quantization_injector::dynamic_params_t qdp {reg_oc_off, vmm_idx_off, jcp.dst_dt, this->rsp}; + injector_utils::vmm_index_set_t vmm_idxs; if (jcp.with_binary) { binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; @@ -173,12 +184,12 @@ void jit_sse41_conv_fwd_kernel_f32::apply_postops( rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); }); - postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); } else { iterate(oc_blocks, ur_w, [&](const bool, const int i, const int j) { vmm_idxs.emplace(get_xmm_idx(ur_w, i, j)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t(), ddp, qdp); } } @@ -264,7 +275,7 @@ void jit_sse41_conv_fwd_kernel_f32::width_blk_step( L(skip_kh_loop); - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { Label regular_store; test(reg_ci_flag, FLAG_IC_LAST); je(regular_store, T_NEAR); @@ -286,12 +297,15 @@ void jit_sse41_conv_fwd_kernel_f32::width_blk_step( add(aux_reg_kernel, sizeof(float) * 4); add(reg_output, sizeof(float) * 4); add(reg_bias, sizeof(float) * 4); + add(reg_oc_off, sizeof(float) * 4); + inc(simd_iter); cmp(simd_iter, 2); jl(init_simd_iter_loop, T_NEAR); sub(reg_output, sizeof(float) * 8); sub(reg_bias, sizeof(float) * 8); + sub(reg_oc_off, sizeof(float) * 8); } inline void jit_sse41_conv_fwd_kernel_f32::solve_common(int oc_blocks) { @@ -346,6 +360,9 @@ inline void jit_sse41_conv_fwd_kernel_f32::solve_common(int oc_blocks) { void jit_sse41_conv_fwd_kernel_f32::generate() { this->preamble(); + if (postops_injector_) + postops_injector_->push_post_ops_data_on_stack(this->param1, GET_OFF(post_ops_binary_rhs_arg_vec), reg_input, reg_output); + mov(reg_input, ptr[this->param1 + GET_OFF(src)]); mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); @@ -353,6 +370,7 @@ void jit_sse41_conv_fwd_kernel_f32::generate() { mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]); mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]); + mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]); int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking; Label tail, exit; @@ -372,6 +390,9 @@ void jit_sse41_conv_fwd_kernel_f32::generate() { L(exit); + if (postops_injector_) + postops_injector_->reset_stack_pointer(); + this->postamble(); if (jcp.with_eltwise) postops_injector_->prepare_table(); @@ -395,6 +416,7 @@ status_t jit_sse41_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, jcp.mb = src_d.dims()[0]; jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; jcp.ic = src_d.dims()[1] / jcp.ngroups; jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2]; @@ -420,10 +442,10 @@ status_t jit_sse41_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw); jcp.b_pad = calculate_end_padding( jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh); - bool kernel_outside_src = false || ext_kw <= jcp.l_pad - || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad - || ext_kh <= jcp.b_pad; - if (kernel_outside_src) return status::unimplemented; +// bool kernel_outside_src = false || ext_kw <= jcp.l_pad +// || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad +// || ext_kh <= jcp.b_pad; +// if (kernel_outside_src) return status::unimplemented; const auto dat_tag_nxc = (ndims == 3 ? nwc : nhwc); const auto dat_tag_ncx = (ndims == 3 ? ncw : nchw); @@ -435,9 +457,9 @@ status_t jit_sse41_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, : pick(ndims - 3, Owi8o, Ohwi8o); jcp.src_tag - = src_d.matches_one_of_tag(dat_tag_ncx, dat_tag_nxc, dat_tag_nCx8c); + = src_d.mb_stride_relaxed_match(dat_tag_ncx, dat_tag_nxc, dat_tag_nCx8c); jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag_OIxio, wei_tag_Oxio); - jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c); + jcp.dst_tag = dst_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_nCx8c); const bool is_data_layout_nxc = utils::everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag); @@ -451,6 +473,8 @@ status_t jit_sse41_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, const int binary_ind = post_ops.find(primitive_kind::binary); jcp.with_binary = binary_ind != -1; + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise) != -1; + jcp.with_quantization = post_ops.find(primitive_kind::quantization) != -1; jcp.post_ops = post_ops; @@ -458,12 +482,12 @@ status_t jit_sse41_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, static constexpr bool sum_at_pos_0_only = true; static constexpr bool sum_requires_scale_one = true; static constexpr bool sum_requires_zp_zero = true; - const bool post_ops_ok_ = post_ops_ok({sse41, {eltwise, binary, sum}, + const bool post_ops_ok_ = post_ops_ok({sse41, {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero}); if (!post_ops_ok_) return status::unimplemented; - const bool flat = jcp.ic == 3; + const bool flat = one_of(jcp.ic, 1, 2, 3); const bool mimo = !flat; bool args_ok = true @@ -483,7 +507,15 @@ status_t jit_sse41_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, && jcp.oc <= dst_d.padded_dims()[1]; if (!args_ok) return status::unimplemented; + bool ok_to_pad_channels = true && !is_data_layout_nxc && jcp.ngroups == 1; + const int simd_w = 8; // 2 SSE vectors processing at once + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + if (mimo) { + jcp.ic = rnd_up(jcp.ic, simd_w); + } + } jcp.ur_h = 1; /* no code-unrolling by h so far */ jcp.ur_w = 3; @@ -540,6 +572,15 @@ status_t jit_sse41_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, return status::success; } +void jit_sse41_conv_fwd_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp) { + using namespace dnnl::impl::memory_tracking::names; + + if (jcp.prop_kind != backward_data && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc); +} + } // namespace x64 } // namespace cpu } // namespace impl diff --git a/src/cpu/x64/jit_sse41_conv_kernel_f32.hpp b/src/cpu/x64/jit_sse41_conv_kernel_f32.hpp index d5db484b65b..f8d3d94842b 100644 --- a/src/cpu/x64/jit_sse41_conv_kernel_f32.hpp +++ b/src/cpu/x64/jit_sse41_conv_kernel_f32.hpp @@ -17,6 +17,7 @@ #ifndef CPU_X64_JIT_SSE41_CONV_KERNEL_F32_HPP #define CPU_X64_JIT_SSE41_CONV_KERNEL_F32_HPP +#include "common/memory_tracking.hpp" #include "common/c_types_map.hpp" #include "common/memory.hpp" @@ -39,6 +40,9 @@ struct jit_sse41_conv_fwd_kernel_f32 : public jit_generator { const memory_desc_wrapper &dst_d, const primitive_attr_t &attr, int nthreads); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_conv_fwd_kernel_f32) jit_conv_conf_t jcp; const primitive_attr_t &attr_; @@ -66,6 +70,13 @@ struct jit_sse41_conv_fwd_kernel_f32 : public jit_generator { std::unique_ptr> postops_injector_; + reg64_t reg_d_weights = imm_addr64; + reg64_t reg_d_bias = ki_iter; + reg64_t reg_oc_off = abi_param1; + + Xbyak::Xmm xmm_d_weights = Xbyak::Xmm(14); + Xbyak::Xmm xmm_d_bias = Xbyak::Xmm(15); + inline void oh_step_unroll_kw( int ur_w, int pad_l, int pad_r, int oc_blocks); inline void oh_step_nopad(int ur_w, int pad_l, int pad_r, int oc_blocks); diff --git a/src/cpu/x64/jit_sse41_convolution.cpp b/src/cpu/x64/jit_sse41_convolution.cpp index 20c259c9ff9..eeca42f8da4 100644 --- a/src/cpu/x64/jit_sse41_convolution.cpp +++ b/src/cpu/x64/jit_sse41_convolution.cpp @@ -27,6 +27,7 @@ namespace x64 { using namespace dnnl::impl::status; using namespace dnnl::impl::utils; +using namespace dnnl::impl::memory_tracking::names; #define src_blk_off(f, n, c, h, w) \ (pd()->ndims() == 3) ? (f).blk_off(n, c, w) : (f).blk_off(n, c, h, w) @@ -47,19 +48,30 @@ void jit_sse41_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { const auto post_ops_binary_rhs_arg_vec = binary_injector::prepare_binary_args(jcp.post_ops, ctx); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const memory_desc_wrapper src_d(pd()->src_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); const memory_desc_wrapper bias_d(pd()->weights_md(1)); int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking); - const size_t work_amount = jcp.mb * jcp.ngroups * ocb_work * jcp.oh; + const size_t work_amount = MB * jcp.ngroups * ocb_work * jcp.oh; const bool is_src_layout_nxc = one_of(jcp.src_tag, format_tag::nwc, format_tag::nhwc); const bool is_dst_layout_nxc = one_of(jcp.dst_tag, format_tag::nwc, format_tag::nhwc); + auto scratchpad = ctx.get_scratchpad_grantor(); + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad.get(key_conv_padded_bias); + utils::array_copy(padded_bias, bias, kernel_->jcp.oc_without_padding); + utils::array_set(padded_bias + kernel_->jcp.oc_without_padding, 0.f, + kernel_->jcp.oc - kernel_->jcp.oc_without_padding); + bias = padded_bias; + } + parallel(jcp.nthr, [&](const int ithr, const int nthr) { assert(nthr == jcp.nthr); @@ -73,7 +85,7 @@ void jit_sse41_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { if (icb_step_rem < jcp.nb_ic_blocking_max) icb_step = icb_step_rem; size_t n {0}, g {0}, ocbb {0}, oh {0}; - nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, + nd_iterator_init(start, n, MB, g, jcp.ngroups, ocbb, ocb_work, oh, jcp.oh); for (size_t iwork = start; iwork < end; ++iwork) { int ocb = ocbb * jcp.nb_oc_blocking; @@ -116,7 +128,7 @@ void jit_sse41_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { par_conv.flags |= FLAG_IC_FIRST; } - if ((jcp.with_eltwise || jcp.with_binary) + if ((jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) && icb + 1 == jcp.nb_ic) { par_conv.flags |= FLAG_IC_LAST; } @@ -134,11 +146,12 @@ void jit_sse41_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { par_conv.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); par_conv.dst_orig = dst; + par_conv.oc_off = _oc * (is_dst_layout_nxc ? 1 : jcp.oc_block) * sizeof(float); (*kernel_)(&par_conv); } nd_iterator_step( - n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, oh, jcp.oh); + n, MB, g, jcp.ngroups, ocbb, ocb_work, oh, jcp.oh); } icbb += icb_step; } diff --git a/src/cpu/x64/jit_sse41_convolution.hpp b/src/cpu/x64/jit_sse41_convolution.hpp index 6f078ca2bd7..8e53aae1e78 100644 --- a/src/cpu/x64/jit_sse41_convolution.hpp +++ b/src/cpu/x64/jit_sse41_convolution.hpp @@ -56,6 +56,9 @@ struct jit_sse41_convolution_fwd_t : public primitive_t { *src_md(), *weights_md(), *dst_md(), *attr(), dnnl_get_max_threads())); + auto scratchpad = scratchpad_registry().registrar(); + jit_sse41_conv_fwd_kernel_f32::init_scratchpad(scratchpad, jcp_); + return status::success; } @@ -65,7 +68,7 @@ struct jit_sse41_convolution_fwd_t : public primitive_t { bool set_default_formats() { using namespace format_tag; - const bool flat = IC() == 3; + const bool flat = utils::one_of(IC(), 1, 2, 3); auto src_tag = flat ? utils::pick(ndims() - 3, ncw, nchw, ncdhw) : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); diff --git a/src/cpu/x64/jit_uni_batch_normalization_s8.cpp b/src/cpu/x64/jit_uni_batch_normalization_s8.cpp index 5a6d7e31303..d3acbfb7a9c 100644 --- a/src/cpu/x64/jit_uni_batch_normalization_s8.cpp +++ b/src/cpu/x64/jit_uni_batch_normalization_s8.cpp @@ -36,7 +36,7 @@ using namespace Xbyak; using data_t = int8_t; -struct call_params_t { +struct call_params_bnorm_t { // keep int sizes at 8 bytes -- jit code expects this size_t channel_offt_count, spat_offt_count; float eps; @@ -97,7 +97,7 @@ struct jit_bnorm_base_t : public jit_generator { uni_vmovq(xone, reg_tmp); uni_vbroadcastss(vone, xone); -#define PARAM_OFF(x) offsetof(call_params_t, x) +#define PARAM_OFF(x) offsetof(call_params_bnorm_t, x) uni_vbroadcastss(veps, vmmword[reg_param + PARAM_OFF(eps)]); uni_vpxor(vzero, vzero, vzero); @@ -586,7 +586,7 @@ struct driver_t : public c_compatible { dim_t W = pd_->W(); dim_t SP = D * H * W; - call_params_t p; + call_params_bnorm_t p; p.eps = pd_->desc()->batch_norm_epsilon; diff --git a/src/cpu/x64/jit_uni_binary.cpp b/src/cpu/x64/jit_uni_binary.cpp index ae93b260097..ffaa15ea631 100644 --- a/src/cpu/x64/jit_uni_binary.cpp +++ b/src/cpu/x64/jit_uni_binary.cpp @@ -236,7 +236,7 @@ bool jit_uni_binary_t::pd_t::alg_preserves_zero() const { using namespace alg_kind; return utils::one_of(desc()->alg_kind, binary_add, binary_max, binary_min, binary_mul, binary_sub, binary_ge, binary_gt, binary_le, binary_lt, - binary_eq, binary_ne); + binary_eq, binary_ne, binary_prelu); } bool jit_uni_binary_t::pd_t::check_scales_mask() const { @@ -357,7 +357,7 @@ bool jit_uni_binary_t::pd_t::is_applicable() { if (utils::one_of(desc()->alg_kind, alg_kind::binary_ge, alg_kind::binary_gt, alg_kind::binary_le, alg_kind::binary_lt, alg_kind::binary_eq, - alg_kind::binary_ne) + alg_kind::binary_ne, alg_kind::binary_prelu) && (has_oc_tail || has_outer_dims_tail)) return false; @@ -987,7 +987,7 @@ status_t jit_uni_binary_t::execute(const exec_ctx_t &ctx) const { // blocked format due to overwriting the vector tail by vcmpps. const bool vector_overwrite = utils::one_of(alg, alg_kind::binary_ge, alg_kind::binary_gt, alg_kind::binary_le, alg_kind::binary_lt, - alg_kind::binary_eq, alg_kind::binary_ne); + alg_kind::binary_eq, alg_kind::binary_ne, alg_kind::binary_prelu); const bool blocked_oc_tail = op_type == op_t::c_blocked && has_oc_tail && (with_postops || point_broadcast || bcast_type == bcast_t::per_w || vector_overwrite); diff --git a/src/cpu/x64/jit_uni_dw_conv_kernel_f32.cpp b/src/cpu/x64/jit_uni_dw_conv_kernel_f32.cpp index b162f5e76fb..ebdee87991c 100644 --- a/src/cpu/x64/jit_uni_dw_conv_kernel_f32.cpp +++ b/src/cpu/x64/jit_uni_dw_conv_kernel_f32.cpp @@ -37,9 +37,9 @@ using namespace Xbyak; template jit_uni_dw_conv_fwd_kernel_f32::jit_uni_dw_conv_fwd_kernel_f32( - const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md) - : jit_generator(nullptr, MAX_CODE_SIZE, true, isa), jcp(ajcp) { - if (jcp.with_eltwise || jcp.with_binary) { + const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md, const primitive_attr_t &attr) + : jit_generator(nullptr, MAX_CODE_SIZE, true, isa), jcp(ajcp), attr_(attr) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; @@ -53,10 +53,12 @@ jit_uni_dw_conv_fwd_kernel_f32::jit_uni_dw_conv_fwd_kernel_f32( memory_desc_wrapper(dst_md), tail_size, k_oc_tail_mask, use_exact_tail_scalar_bcast}; static_params_t static_params {this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params + {vmm_d_weights.getIdx(), vmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } } @@ -272,8 +274,25 @@ void iterate( template void jit_uni_dw_conv_fwd_kernel_f32::apply_postops( const int ur_ch_blocks, const int ur_w, const bool is_ch_tail) { - if (this->jcp.with_eltwise || this->jcp.with_binary) { + if (this->jcp.with_eltwise || this->jcp.with_binary || this->jcp.with_depthwise || this->jcp.with_quantization) { + push(aux_reg_blocks_offset); + base_post_ops_data_offset += reg64_size; + add(aux_reg_blocks_offset, ptr[this->param1 + GET_OFF(oc_off)]); //add offset of processed blocks + const int repeats = max_repeats(); + + std::map vmm_idx_off; + iterate(repeats, ur_ch_blocks, ur_w, + [&](const int r, const int ch, const int ow, const bool) { + vmm_idx_off.insert({get_acc_reg_idx(r * ur_ch_blocks * ur_w + ch * ur_w + ow), (ch * repeats + r) * jcp.ch_block / repeats * sizeof(float)}); + }); + + depthwise_injector::dynamic_params_t ddp {vmm_d_weights.getIdx(), vmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + aux_reg_blocks_offset, vmm_idx_off, + this->rsp, base_post_ops_data_offset}; + quantization_injector::dynamic_params_t qdp {aux_reg_blocks_offset, vmm_idx_off, jcp.dst_dt, + this->rsp, base_post_ops_data_offset}; + injector_utils::vmm_index_set_t vmm_idxs; if (jcp.with_binary) { binary_injector::rhs_arg_dynamic_params_t rhs_arg_params, @@ -321,16 +340,16 @@ void jit_uni_dw_conv_fwd_kernel_f32::apply_postops( cmp(reg_tmp, jcp.nb_ch_blocking * jcp.ch_block); jge(postops_no_tail, T_NEAR); postops_injector_->compute_vector_range( - vmm_idxs, rhs_arg_params_tail); + vmm_idxs, rhs_arg_params_tail, ddp, qdp); jmp(postops_done, T_NEAR); L(postops_no_tail); } else if (is_ch_tail) { postops_injector_->compute_vector_range( - vmm_idxs, rhs_arg_params_tail); + vmm_idxs, rhs_arg_params_tail, ddp, qdp); } if (!is_ch_tail) { postops_injector_->compute_vector_range( - vmm_idxs, rhs_arg_params); + vmm_idxs, rhs_arg_params, ddp, qdp); L(postops_done); } } else { @@ -339,8 +358,10 @@ void jit_uni_dw_conv_fwd_kernel_f32::apply_postops( vmm_idxs.emplace(get_acc_reg_idx( r * ur_ch_blocks * ur_w + ch * ur_w + ow)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t(), ddp, qdp); } + pop(aux_reg_blocks_offset); + base_post_ops_data_offset -= reg64_size; } } @@ -464,6 +485,8 @@ void jit_uni_dw_conv_fwd_kernel_f32::compute_loop( }; mov(aux_reg_ch_blocks, reg_ch_blocks); + xor_(aux_reg_blocks_offset, aux_reg_blocks_offset); + if (ch_loop) { Label ch_loop_label, ch_tail_label, skip_ch_tail_label; const int ch_block_tail = jcp.nb_ch @@ -473,7 +496,11 @@ void jit_uni_dw_conv_fwd_kernel_f32::compute_loop( push(reg_kernel); push(reg_input); push(reg_output); - if (jcp.with_bias) push(reg_bias); + base_post_ops_data_offset += 3 * reg64_size; + if (jcp.with_bias) { + push(reg_bias); + base_post_ops_data_offset += reg64_size; + } if ((jcp.oc / jcp.ch_block) >= jcp.nb_ch_blocking) { if (ch_block_tail) { @@ -489,6 +516,7 @@ void jit_uni_dw_conv_fwd_kernel_f32::compute_loop( add(reg_output, out_ch_stride); if (jcp.with_bias) add(reg_bias, bias_stride); sub(aux_reg_ch_blocks, ch_step); + add(aux_reg_blocks_offset, ch_step * sizeof(float)); //add initial offset of processed blocks cmp(aux_reg_ch_blocks, ch_step); jge(ch_loop_label, T_NEAR); } @@ -503,10 +531,14 @@ void jit_uni_dw_conv_fwd_kernel_f32::compute_loop( L(skip_ch_tail_label); } - if (jcp.with_bias) pop(reg_bias); + if (jcp.with_bias) { + pop(reg_bias); + base_post_ops_data_offset -= reg64_size; + } pop(reg_output); pop(reg_input); pop(reg_kernel); + base_post_ops_data_offset -= 3 * reg64_size; } else { compute(ur_ch_blocks, jcp.oc % jcp.ch_block); @@ -587,6 +619,9 @@ template void jit_uni_dw_conv_fwd_kernel_f32::generate() { this->preamble(); + if (postops_injector_) + postops_injector_->push_post_ops_data_on_stack(this->param1, GET_OFF(post_ops_binary_rhs_arg_vec), reg_input, reg_output); + if (jcp.is_fused_conv) { mov(reg_input_buffer_ptr, ptr[this->param1 + GET_OFF(src)]); /* In case of fused depthwise convolution, `param.src` is not a pointer @@ -647,6 +682,9 @@ void jit_uni_dw_conv_fwd_kernel_f32::generate() { L(exit_label); } + if (postops_injector_) + postops_injector_->reset_stack_pointer(); + this->postamble(); if (jcp.with_eltwise) postops_injector_->prepare_table(); @@ -799,6 +837,35 @@ inline void jit_uni_dw_conv_bwd_data_kernel_f32::apply_filter( L(iter_exit_label); } +template +void jit_uni_dw_conv_bwd_data_kernel_f32::apply_postprocess(int ur_ch_blocks, int ur_str_w) { + int repeats = isa == sse41 ? 2 : 1; + + const auto &p = attr_.post_ops_; + std::size_t post_ops_data_offset = 0; + int depthwise_inj_idx = 0; + for (int i = 0; i < p.len(); i++) { + auto& post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + mov(reg_d_weights, ptr[this->rsp + base_post_ops_data_offset + post_ops_data_offset]); + add(reg_d_weights, ptr[this->param1 + GET_OFF(ic_off)]); + + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int k = 0; k < repeats; k++) { + int start_idx = get_acc_reg(k*ur_ch_blocks*ur_str_w + ur_str_w * ch).getIdx(); + int end_idx = get_acc_reg(k*ur_ch_blocks*ur_str_w + ur_str_w * ch + ur_str_w).getIdx(); + + depthwise_injectors[depthwise_inj_idx]->compute_vector_range(start_idx, end_idx, reg_d_weights, reg_d_weights); + + add(reg_d_weights, jcp.ch_block / repeats * sizeof(float)); + } + } + post_ops_data_offset += depthwise_injectors[depthwise_inj_idx]->memoryStep(); + depthwise_inj_idx++; + } + } +} + template inline void jit_uni_dw_conv_bwd_data_kernel_f32::store_dsrc( int ur_ch_blocks, int ur_str_w, bool is_last_ch) { @@ -846,6 +913,7 @@ inline void jit_uni_dw_conv_bwd_data_kernel_f32::ch_loop_body( load_ddst(ur_ch_blocks, unroll_w); apply_filter(ur_ch_blocks, unroll_w, is_last_ch); + apply_postprocess(ur_ch_blocks, unroll_w); store_dsrc(ur_ch_blocks, unroll_w, is_last_ch); }; @@ -865,6 +933,7 @@ inline void jit_uni_dw_conv_bwd_data_kernel_f32::ch_loop_body( = (size_t)jcp.nb_ch_blocking * jcp.ch_block * sizeof(float); mov(aux_reg_ch_blocks, reg_ch_blocks); + base_post_ops_data_offset += 3 * reg64_size; push(reg_dsrc); push(reg_ddst); push(reg_kernel); @@ -901,6 +970,7 @@ inline void jit_uni_dw_conv_bwd_data_kernel_f32::ch_loop_body( pop(reg_kernel); pop(reg_ddst); pop(reg_dsrc); + base_post_ops_data_offset -= 3 * reg64_size; } else { call_compute_body(ur_ch_blocks, unroll_w, jcp.ch_tail > 0); @@ -939,8 +1009,39 @@ inline void jit_uni_dw_conv_bwd_data_kernel_f32::unroll_width_body( template void jit_uni_dw_conv_bwd_data_kernel_f32::generate() { + const auto &p = attr_.post_ops_; + for (int i = 0; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32( + this, + post_op + )); + } + } + preamble(); + std::size_t post_ops_pointers_count = 0; + for (int i = 0; i < p.len(); i++) { + if (p.entry_[i].is_depthwise() || p.entry_[i].is_quantization()) { + post_ops_pointers_count++; + } + } + + if (post_ops_pointers_count != 0) { + sub(rsp, post_ops_pointers_count * sizeof(float *)); + + auto aux_reg0 = reg_dsrc; + auto aux_reg1 = reg_ddst; + + mov(aux_reg0, ptr[this->param1 + GET_OFF(post_ops_binary_rhs_arg_vec)]); + for (size_t i = 0; i < post_ops_pointers_count; i++) { + mov(aux_reg1, ptr[aux_reg0 + i * sizeof(float *)]); + mov(ptr[rsp + i * sizeof(float *)], aux_reg1); + } + } + mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]); mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]); mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); @@ -981,6 +1082,10 @@ void jit_uni_dw_conv_bwd_data_kernel_f32::generate() { if (ch_blocks_tail) { ch_blocks_loop(ch_blocks_tail); } } + if (post_ops_pointers_count != 0) { + add(rsp, post_ops_pointers_count * sizeof(float *)); + } + this->postamble(); } #undef GET_OFF diff --git a/src/cpu/x64/jit_uni_dw_conv_kernel_f32.hpp b/src/cpu/x64/jit_uni_dw_conv_kernel_f32.hpp index 85284d25ecc..331f7ede361 100644 --- a/src/cpu/x64/jit_uni_dw_conv_kernel_f32.hpp +++ b/src/cpu/x64/jit_uni_dw_conv_kernel_f32.hpp @@ -36,9 +36,10 @@ struct jit_uni_dw_conv_fwd_kernel_f32 : public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_fwd_kernel_f32) jit_uni_dw_conv_fwd_kernel_f32( - const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md); + const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md, const primitive_attr_t &attr); jit_conv_conf_t jcp; + const primitive_attr_t &attr_; private: using Vmm = typename utils::conditional3 struct jit_uni_dw_conv_bwd_data_kernel_f32 : public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_data_kernel_f32) - jit_uni_dw_conv_bwd_data_kernel_f32(const jit_conv_conf_t &ajcp) - : jcp(ajcp) {} + jit_uni_dw_conv_bwd_data_kernel_f32(const jit_conv_conf_t &ajcp, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr) {} + + ~jit_uni_dw_conv_bwd_data_kernel_f32() { + for (auto inj : depthwise_injectors) + delete inj; + depthwise_injectors.clear(); + } + jit_conv_conf_t jcp; + const primitive_attr_t &attr_; private: using Vmm = typename utils::conditional3*> depthwise_injectors; + void load_vmm(Vmm &vmm, const Xbyak::Address &addr, bool tail); void store_vmm(Vmm &vmm, const Xbyak::Address &addr, bool tail); @@ -171,6 +197,7 @@ struct jit_uni_dw_conv_bwd_data_kernel_f32 : public jit_generator { inline void unroll_width_body(int ur_ch_blocks); inline void load_ddst(int ur_ch_blocks, int ur_str_w); inline void apply_filter(int ur_ch_blocks, int ur_str_w, bool is_last_ch); + inline void apply_postprocess(int ur_ch_blocks, int ur_str_w); inline void store_dsrc(int ur_ch_blocks, int ur_str_w, bool is_last_ch); void generate() override; diff --git a/src/cpu/x64/jit_uni_dw_conv_kernel_utils.cpp b/src/cpu/x64/jit_uni_dw_conv_kernel_utils.cpp index c4fcfd2b518..74bcba8e222 100644 --- a/src/cpu/x64/jit_uni_dw_conv_kernel_utils.cpp +++ b/src/cpu/x64/jit_uni_dw_conv_kernel_utils.cpp @@ -62,7 +62,7 @@ status_t jit_uni_dw_conv_fwd_kernel::init_conf( CHECK(memory_desc_init_by_tag(src_md, def_tag)); jcp.src_tag = def_tag; } else { - jcp.src_tag = src_d.matches_one_of_tag(blocked_tag, nxc_tag); + jcp.src_tag = src_d.mb_stride_relaxed_match(blocked_tag, nxc_tag); } if (weights_d.format_kind() == format_kind::any) { @@ -76,7 +76,7 @@ status_t jit_uni_dw_conv_fwd_kernel::init_conf( CHECK(memory_desc_init_by_tag(dst_md, def_tag)); jcp.dst_tag = def_tag; } else { - jcp.dst_tag = dst_d.matches_one_of_tag(blocked_tag, nxc_tag); + jcp.dst_tag = dst_d.mb_stride_relaxed_match(blocked_tag, nxc_tag); } if (jcp.with_bias) { @@ -218,19 +218,21 @@ status_t jit_uni_dw_conv_fwd_kernel::init_conf( broadcasting_strategy_t::per_oc, broadcasting_strategy_t::no_broadcast); } + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise) != -1; + jcp.with_quantization = post_ops.find(primitive_kind::quantization) != -1; jcp.post_ops = post_ops; using namespace injector; static constexpr bool sum_at_pos_0_only = true; static constexpr bool sum_requires_scale_one = true; - const bool post_ops_ok_ = post_ops_ok({isa, {eltwise, binary, sum}, + const bool post_ops_ok_ = post_ops_ok({isa, {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one}); if (!post_ops_ok_) return status::unimplemented; const bool ok_to_pad_channels = true && !is_data_layout_nxc && jcp.oc == jcp.ngroups && jcp.ic == jcp.ngroups - && one_of(isa, avx512_common, avx512_core, avx2); + && one_of(isa, avx512_common, avx512_core, avx2, sse41); if (ok_to_pad_channels) { jcp.oc = rnd_up(jcp.oc, simd_w); jcp.ic = rnd_up(jcp.oc, simd_w); @@ -260,11 +262,29 @@ void jit_uni_dw_conv_fwd_kernel::init_scratchpad( scratchpad.book(key_conv_padded_bias, jcp.oc); } +template +bool jit_uni_dw_conv_bwd_data_kernel::post_ops_ok(const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + if (p.len() > 1) + return false; + + auto all_post_ops_supported = [&]() { + bool ok = true; + + for (int i = 0; i < p.len(); i++) { + ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::depthwise); + } + return ok; + }; + + return all_post_ops_supported(); +} + template status_t jit_uni_dw_conv_bwd_data_kernel::init_conf( jit_conv_conf_t &jcp, const convolution_desc_t &cd, memory_desc_t &diff_src_md, memory_desc_t &weights_md, - memory_desc_t &diff_dst_md) { + memory_desc_t &diff_dst_md, const primitive_attr_t &attr) { using namespace dnnl::impl::format_tag; using namespace dnnl::impl::utils; @@ -324,9 +344,9 @@ status_t jit_uni_dw_conv_bwd_data_kernel::init_conf( = one_of(isa, avx512_common, avx512_core) ? Goihw16g : Goihw8g; auto curr_src_tag - = diff_src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_blocked); + = diff_src_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_blocked); auto curr_dst_tag - = diff_dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_blocked); + = diff_dst_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_blocked); bool is_data_layout_nxc = utils::everyone_is(dat_tag_nxc, curr_src_tag, curr_dst_tag); auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_blocked; @@ -363,6 +383,11 @@ status_t jit_uni_dw_conv_bwd_data_kernel::init_conf( // from: 'simd_w_ * reg_repeats_ = 4 * 2' jcp.ch_block = one_of(isa, avx512_common, avx512_core) ? 16 : 8; + if (!post_ops_ok(attr)) + return status::unimplemented; + + jcp.post_ops = attr.post_ops_; + bool ok_to_pad_channels = !is_data_layout_nxc && jcp.oc == jcp.ngroups && jcp.ic == jcp.ngroups && one_of(isa, avx512_common, avx512_core, avx2); diff --git a/src/cpu/x64/jit_uni_dw_conv_kernel_utils.hpp b/src/cpu/x64/jit_uni_dw_conv_kernel_utils.hpp index 7f47cd6fcae..9066f546f5a 100644 --- a/src/cpu/x64/jit_uni_dw_conv_kernel_utils.hpp +++ b/src/cpu/x64/jit_uni_dw_conv_kernel_utils.hpp @@ -38,8 +38,8 @@ template struct jit_uni_dw_conv_fwd_kernel { jit_uni_dw_conv_fwd_kernel( - const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md) { - ker_ = new jit_kernel_t(ajcp, dst_md); + const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md, const primitive_attr_t &attr) { + ker_ = new jit_kernel_t(ajcp, dst_md, attr); } status_t create_kernel() { return ker_->create_kernel(); } @@ -68,17 +68,20 @@ struct jit_uni_dw_conv_fwd_kernel { template struct jit_uni_dw_conv_bwd_data_kernel { - jit_uni_dw_conv_bwd_data_kernel(const jit_conv_conf_t &ajcp) + jit_uni_dw_conv_bwd_data_kernel(const jit_conv_conf_t &ajcp, const primitive_attr_t &attr) : ker_(nullptr) { - ker_ = new jit_kernel_t(ajcp); + ker_ = new jit_kernel_t(ajcp, attr); } status_t create_kernel() { return ker_->create_kernel(); } ~jit_uni_dw_conv_bwd_data_kernel() { delete ker_; } + static bool post_ops_ok(const primitive_attr_t &attr); + static status_t init_conf(jit_conv_conf_t &jcp, const convolution_desc_t &cd, memory_desc_t &diff_src_md, - memory_desc_t &weights_md, memory_desc_t &diff_dst_md); + memory_desc_t &weights_md, memory_desc_t &diff_dst_md, + const primitive_attr_t &attr); static void init_scratchpad(memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp); diff --git a/src/cpu/x64/jit_uni_dw_conv_row_f32.cpp b/src/cpu/x64/jit_uni_dw_conv_row_f32.cpp new file mode 100644 index 00000000000..63250828728 --- /dev/null +++ b/src/cpu/x64/jit_uni_dw_conv_row_f32.cpp @@ -0,0 +1,709 @@ +/******************************************************************************* +* Copyright 2019-2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* [todo] antonvor: + * This file contains the old plugin behavior in order to fix performance + * problems after upgrading to OneDNN v1.6. This kernel is executed only on + * machines with avx2 instruction set support and in the case of a fused + * convolution. Remove after problems are fixed. +*/ + +#include "common/c_types_map.hpp" +#include "common/dnnl_thread.hpp" +#include "common/nstl.hpp" +#include "common/utils.hpp" + +#include "cpu/x64/jit_uni_dw_conv_row_f32.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +using namespace Xbyak; +using namespace dnnl::impl::utils; + +#define GET_OFF_DW(field) offsetof(jit_conv_call_s, field) + +template +void jit_uni_dw_conv_row_f32::clear_vmm_regs(int ur_w) { + int repeats = isa == sse41 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ow = 0; ow < ur_w; ow++) { + Vmm vmm_acc = get_acc_reg(i*ur_w + ow); + + uni_vpxor(vmm_acc, vmm_acc, vmm_acc); + } + } +} + +template +void jit_uni_dw_conv_row_f32::apply_filter(int ur_w, int kw_size) { + auto load_src = [=](Vmm vmm_src, const Xbyak::Address &op) { + if (jcp.src_dt == data_type::u8) { + uni_vpmovzxbd(vmm_src, op); + } else { + uni_vmovups(vmm_src, op); + } + }; + + auto load_ker = [=](Vmm vmm_ker, const Xbyak::Address &op) { + if (jcp.src_dt == data_type::u8) { + uni_vpmovsxbd(vmm_ker, op); + } else { + uni_vmovups(vmm_ker, op); + } + }; + + auto compute = [=](Vmm vmm_acc, Vmm vmm_src, Vmm vmm_ker) { + if (jcp.src_dt == data_type::u8) { + uni_vpmulld(vmm_src, vmm_src, vmm_ker); + uni_vpaddd(vmm_acc, vmm_acc, vmm_src); + } else { + uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker); + } + }; + + int ch_blk = jcp.ch_block; + int stride_w = jcp.stride_w; + + Label exit_label; + + int repeats = isa == sse41 ? 2 : 1; + + cmp(reg_kh, 1); + jl(exit_label, T_NEAR); + for (int i = 0; i < repeats; i++) { + for (int kw = 0; kw < kw_size; kw++) { + int ker_off = kw * ch_blk + i*(jcp.ch_block / 2); + + Vmm vmm_ker = get_ker_reg(0); + load_ker(vmm_ker, ptr[aux_reg_kernel + ker_off * jcp.typesize_in]); + + for (int ow = 0; ow < ur_w; ow++) { + int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*(jcp.ch_block / 2); + + Vmm vmm_src = get_src_reg(0); + load_src(vmm_src, ptr[aux_reg_input0 + inp_off * jcp.typesize_in]); + + Vmm vmm_acc = get_acc_reg(i*ur_w + ow); + compute(vmm_acc, vmm_src, vmm_ker); + } + } + } + add(aux_reg_kernel, jcp.kw*ch_blk*jcp.typesize_in); + + cmp(reg_kh, 2); + jl(exit_label, T_NEAR); + for (int i = 0; i < repeats; i++) { + for (int kw = 0; kw < kw_size; kw++) { + int ker_off = kw * ch_blk + i*(jcp.ch_block / 2); + + Vmm vmm_ker = get_ker_reg(0); + load_ker(vmm_ker, ptr[aux_reg_kernel + ker_off * jcp.typesize_in]); + + for (int ow = 0; ow < ur_w; ow++) { + int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*(jcp.ch_block / 2); + + Vmm vmm_src = get_src_reg(0); + load_src(vmm_src, ptr[aux_reg_input1 + inp_off * jcp.typesize_in]); + + Vmm vmm_acc = get_acc_reg(i*ur_w + ow); + compute(vmm_acc, vmm_src, vmm_ker); + } + } + } + add(aux_reg_kernel, jcp.kw*ch_blk*jcp.typesize_in); + + cmp(reg_kh, 3); + jl(exit_label, T_NEAR); + for (int i = 0; i < repeats; i++) { + for (int kw = 0; kw < kw_size; kw++) { + int ker_off = kw * ch_blk + i*(jcp.ch_block / 2); + + Vmm vmm_ker = get_ker_reg(0); + load_ker(vmm_ker, ptr[aux_reg_kernel + ker_off * jcp.typesize_in]); + + for (int ow = 0; ow < ur_w; ow++) { + int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*(jcp.ch_block / 2); + + Vmm vmm_src = get_src_reg(0); + load_src(vmm_src, ptr[aux_reg_input2 + inp_off * jcp.typesize_in]); + + Vmm vmm_acc = get_acc_reg(i*ur_w + ow); + compute(vmm_acc, vmm_src, vmm_ker); + } + } + } + + L(exit_label); +} + +template +void jit_uni_dw_conv_row_f32::cvt2ps(data_type_t type_in, Vmm vmm_in, const Operand &op, bool scalar_load) { + Xmm xmm_in = Xmm(vmm_in.getIdx()); + + switch (type_in) { + case data_type::f32: + case data_type::s32: + if (scalar_load) { + mov(reg_tmp_32, op); + movq(xmm_in, reg_tmp_64); + } else { + uni_vmovups(vmm_in, op); + } + break; + case data_type::s8: + if (scalar_load) { + movsx(reg_tmp_32, op); + movq(xmm_in, reg_tmp_64); + } else { + uni_vpmovsxbd(vmm_in, op); + } + break; + case data_type::u8: + if (scalar_load) { + movzx(reg_tmp_32, op); + movq(xmm_in, reg_tmp_64); + } else { + uni_vpmovzxbd(vmm_in, op); + } + break; + default: assert(!"unsupported data type"); + } + + if (type_in != data_type::f32) + uni_vcvtdq2ps(vmm_in, vmm_in); +} + +template +void jit_uni_dw_conv_row_f32::apply_postprocessing(int ur_w, int oc_step) { + int repeats = isa == sse41 ? 2 : 1; + + for (int r = 0; r < repeats; r++) { + for (int ow = 0; ow < ur_w; ow++) { + if (jcp.src_dt == data_type::u8) { + uni_vcvtdq2ps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow)); + } + + if (jcp.with_bias) { + int b_off = r * (jcp.ch_block / 2); + cvt2ps(jcp.bia_dt, vmm_bias, ptr[reg_bias + b_off * jcp.typesize_bia], false); + uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_bias); + } + } + } + + const auto &p = attr_.post_ops_; + if (jcp.with_sum) { + dnnl::impl::data_type_t sum_dt = jcp.dst_dt; + int start_idx = p.find(primitive_kind::convolution) + 1; + for (int i = start_idx; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_sum()) { + sum_dt = post_op.sum.dt; + } + } + + for (int r = 0; r < repeats; r++) { + int tail_size = isa == sse41 ? nstl::min(jcp.ch_block / 2, oc_step - r * jcp.ch_block / 2) : oc_step; + bool is_scalar_store = isa == sse41 ? tail_size < jcp.ch_block / 2 : tail_size < jcp.ch_block; + + for (int ow = 0; ow < ur_w; ow++) { + if (is_scalar_store) { + if (isa == avx512_common) { + int o_off = ow * ow_stride_; + + Vmm vmm_in = vmm_sum | ktail_mask | T_z; + + cvt2ps(sum_dt, vmm_in, ptr[reg_output + o_off * jcp.typesize_out], false); + uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_sum); + } else { + for (int oc = 0; oc < tail_size; oc++) { + int o_off = ow * ow_stride_ + r * (jcp.ch_block / 2) + oc; + + uni_vpxor(vmm_sum, vmm_sum, vmm_sum); + cvt2ps(sum_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], true); + + if (oc >= jcp.ch_block / 2) { + vperm2i128(Ymm(vmm_sum.getIdx()), Ymm(vmm_sum.getIdx()), Ymm(vmm_sum.getIdx()), 0x01); + } + uni_vpslldq(vmm_sum, vmm_sum, jcp.typesize_out * (oc % (jcp.ch_block / 2))); + + uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_sum); + } + } + } else { + int o_off = ow * ow_stride_ + r * (jcp.ch_block / 2); + + uni_vpxor(vmm_sum, vmm_sum, vmm_sum); + cvt2ps(sum_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], false); + + uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_sum); + } + } + } + } + + int eltwise_inj_idx = 0; + int depthwise_inj_idx = 0; + int quantization_inj_idx = 0; + int start_idx = p.find(primitive_kind::convolution) + 1; + std::size_t post_ops_data_offset = 0; + for (int i = start_idx; i < p.len(); i++) { + auto& post_op = p.entry_[i]; + if (post_op.is_eltwise()) { + eltwise_injectors[eltwise_inj_idx]->compute_vector_range(4, 4 + repeats * ur_w); + eltwise_inj_idx++; + } else if (post_op.is_depthwise()) { + mov(reg_d_weights, ptr[this->rsp + post_ops_data_offset]); + add(reg_d_weights, reg_oc_off); + + depthwise_injectors[depthwise_inj_idx]->compute_vector_range(4, 4 + ur_w, reg_d_weights, reg_d_weights); + + if (repeats == 2) { + add(reg_d_weights, (jcp.ch_block / 2) * sizeof(float)); + + depthwise_injectors[depthwise_inj_idx]->compute_vector_range(4 + ur_w, 4 + 2 * ur_w, reg_d_weights, reg_d_weights); + } + + post_ops_data_offset += depthwise_injectors[depthwise_inj_idx]->memoryStep(); + depthwise_inj_idx++; + } else if (post_op.is_quantization()) { + bool do_dequantization = post_op.quantization.alg == alg_kind::quantization_quantize_dequantize; + bool do_rounding = do_dequantization || jcp.dst_dt == dnnl_f32 || i != p.len() - 1; + + const Xbyak::RegExp quant_arg_base = this->rsp + post_ops_data_offset; + quantization_injectors[quantization_inj_idx]->init_crop_ptrs(quant_arg_base, reg_oc_off); + for (int r = 0; r < repeats; r++) { + int s_idx = get_acc_reg(r * ur_w).getIdx(); + quantization_injectors[quantization_inj_idx]->compute_crop(s_idx, s_idx + ur_w, r * (jcp.ch_block / 2) * sizeof(float)); + } + + quantization_injectors[quantization_inj_idx]->init_input_scale_shift_ptrs(quant_arg_base, reg_oc_off); + for (int r = 0; r < repeats; r++) { + int s_idx = get_acc_reg(r * ur_w).getIdx(); + quantization_injectors[quantization_inj_idx]->compute_input_scale_shift(s_idx, s_idx + ur_w, r * (jcp.ch_block / 2) * sizeof(float), do_rounding); + } + + quantization_injectors[quantization_inj_idx]->init_output_scale_shift_ptrs(quant_arg_base, reg_oc_off); + for (int r = 0; r < repeats; r++) { + int s_idx = get_acc_reg(r * ur_w).getIdx(); + quantization_injectors[quantization_inj_idx]->compute_output_scale_shift(s_idx, s_idx + ur_w, r * (jcp.ch_block / 2) * sizeof(float)); + } + + post_ops_data_offset += quantization_injectors[quantization_inj_idx]->memoryStep(); + quantization_inj_idx++; + } + } +} + +template +void jit_uni_dw_conv_row_f32::store_dst_typed(const Xbyak::Address &op, Vmm vmm_dst, bool scalar_store) { + Ymm ymm_dst = Ymm(vmm_dst.getIdx()); + Xmm xmm_dst = Xmm(vmm_dst.getIdx()); + + switch (jcp.dst_dt) { + case data_type::f32: + case data_type::s32: + if (scalar_store) { + movq(reg_tmp_64, xmm_dst); + mov(op, reg_tmp_32); + } else { + uni_vmovups(op, vmm_dst); + } + break; + case data_type::s8: + uni_vpackssdw(vmm_dst, vmm_dst, vmm_dst); + + if (isa != sse41 && !scalar_store) + vpermq(ymm_dst, ymm_dst, 0x08); + + uni_vpacksswb(vmm_dst, vmm_dst, vmm_dst); + + if (scalar_store) { + movq(reg_tmp_64, xmm_dst); + mov(op, reg_tmp_8); + } else { + if (isa != sse41) + vmovq(op, xmm_dst); + else + movd(op, xmm_dst); + } + break; + case data_type::u8: + case data_type::bin: + uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst); + + if (isa != sse41 && !scalar_store) + vpermq(ymm_dst, ymm_dst, 0x08); + + uni_vpackuswb(vmm_dst, vmm_dst, vmm_dst); + + if (scalar_store) { + movq(reg_tmp_64, xmm_dst); + mov(op, reg_tmp_8); + } else { + if (isa != sse41) + vmovq(op, xmm_dst); + else + movd(op, xmm_dst); + } + break; + default: + assert(!"unknown dst_dt"); + } +} + +template +void jit_uni_dw_conv_row_f32::store_dst(int ur_w, int oc_step) { + int repeats = isa == sse41 && oc_step > (jcp.ch_block / 2) ? 2 : 1; + + if (isa == avx512_common && oc_step != jcp.ch_block) { + int mask = (1 << oc_step) - 1; + mov(reg_tmp_32, mask); + kmovw(ktail_mask, reg_tmp_32); + } + + for (int i = 0; i < repeats; i++) { + for (int ow = 0; ow < ur_w; ow++) { + Vmm vmm_dst = get_acc_reg(i * ur_w + ow); + if (jcp.dst_dt != data_type::f32 && jcp.dst_dt != data_type::bin) { + uni_vcvtps2dq(vmm_dst, vmm_dst); + } + } + } + for (int i = 0; i < repeats; i++) { + int tail_size = isa == sse41 ? nstl::min(jcp.ch_block / 2, oc_step - i * jcp.ch_block / 2) : oc_step; + bool is_scalar_store = isa == sse41 ? tail_size < jcp.ch_block / 2 : tail_size < jcp.ch_block; + if (is_scalar_store) { + for (int ow = 0; ow < ur_w; ow++) { + Vmm vmm_dst = get_acc_reg(i * ur_w + ow); + + if (isa == avx512_common) { + int o_off = ow * ow_stride_; + + store_dst_typed(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst | ktail_mask, false); + } else { + for (int oc = 0; oc < tail_size; oc++) { + int o_off = ow * ow_stride_ + i * (jcp.ch_block / 2) + oc; + store_dst_typed(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, true); + + if (isa == sse41) { + psrldq(vmm_dst, jcp.typesize_out); + } else { + Ymm ymm_dst = Ymm(vmm_dst.getIdx()); + + vperm2i128(ymm_tmp, ymm_dst, ymm_dst, 0x01); + vpalignr(ymm_dst, vmm_tmp, ymm_dst, jcp.typesize_out); + } + } + } + } + } else { + for (int ow = 0; ow < ur_w; ow++) { + int o_off = ow * ow_stride_ + i * (jcp.ch_block / 2); + Vmm vmm_dst = get_acc_reg(i * ur_w + ow); + + store_dst_typed(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, false); + } + } + } +} + +template +void jit_uni_dw_conv_row_f32::loop_body(int oc_step) { + Label left_pad_label; + Label right_pad_label; + Label unrolled_w_label; + Label tail_w_label; + Label exit_label; + + int output_step = ow_stride_; + + L(left_pad_label); { + int ur_w = 1; + int kw = jcp.iw == 1 ? jcp.kw - 2 : jcp.kw - 1; + + mov(aux_reg_input0, reg_input0); + mov(aux_reg_input1, reg_input1); + mov(aux_reg_input2, reg_input2); + mov(aux_reg_kernel, reg_kernel); + add(aux_reg_kernel, jcp.ch_block*jcp.typesize_in); + + clear_vmm_regs(ur_w); + apply_filter(ur_w, kw); + apply_postprocessing(ur_w, oc_step); + store_dst(ur_w, oc_step); + + add(reg_input0, jcp.typesize_in * ur_w * jcp.ch_block * (jcp.stride_w-1)); + add(reg_input1, jcp.typesize_in * ur_w * jcp.ch_block * (jcp.stride_w-1)); + add(reg_input2, jcp.typesize_in * ur_w * jcp.ch_block * (jcp.stride_w-1)); + add(reg_output, jcp.typesize_out * ur_w * output_step); + + sub(reg_ur_w, ur_w); + } + + L(unrolled_w_label); { + int ur_w = jcp.ur_w; + int kw = jcp.kw; + + cmp(reg_ur_w, ur_w); + jle(tail_w_label, T_NEAR); + + mov(aux_reg_input0, reg_input0); + mov(aux_reg_input1, reg_input1); + mov(aux_reg_input2, reg_input2); + mov(aux_reg_kernel, reg_kernel); + + clear_vmm_regs(ur_w); + apply_filter(ur_w, kw); + apply_postprocessing(ur_w, oc_step); + store_dst(ur_w, oc_step); + + add(reg_input0, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_input1, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_input2, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_output, jcp.typesize_out * ur_w * output_step); + + sub(reg_ur_w, ur_w); + jmp(unrolled_w_label, T_NEAR); + } + + L(tail_w_label); { + int ur_w = 1; + int kw = jcp.kw; + + cmp(reg_ur_w, ur_w); + if (jcp.ow > 1) + jle(right_pad_label, T_NEAR); + else + jle(exit_label, T_NEAR); + + mov(aux_reg_input0, reg_input0); + mov(aux_reg_input1, reg_input1); + mov(aux_reg_input2, reg_input2); + mov(aux_reg_kernel, reg_kernel); + + clear_vmm_regs(ur_w); + apply_filter(ur_w, kw); + apply_postprocessing(ur_w, oc_step); + store_dst(ur_w, oc_step); + + add(reg_input0, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_input1, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_input2, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_output, jcp.typesize_out * ur_w * output_step); + + sub(reg_ur_w, ur_w); + jmp(tail_w_label, T_NEAR); + } + + if (jcp.ow > 1) { + L(right_pad_label); { + int ur_w = 1; + int kw = jcp.kw - ((jcp.stride_w == 1) ? 1 : jcp.iw % jcp.stride_w); + + mov(aux_reg_input0, reg_input0); + mov(aux_reg_input1, reg_input1); + mov(aux_reg_input2, reg_input2); + mov(aux_reg_kernel, reg_kernel); + + clear_vmm_regs(ur_w); + apply_filter(ur_w, kw); + apply_postprocessing(ur_w, oc_step); + store_dst(ur_w, oc_step); + + sub(reg_ur_w, ur_w); + } + } + + L(exit_label); +} + +template +void jit_uni_dw_conv_row_f32::generate() { + const auto &p = attr_.post_ops_; + int start_idx = p.find(primitive_kind::convolution) + 1; + for (int i = start_idx; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_eltwise()) { + eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32( + this, + post_op.eltwise + )); + } else if (post_op.is_depthwise()) { + depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32( + this, + post_op + )); + } else if (post_op.is_quantization()) { + quantization_injectors.push_back(new jit_uni_quantization_injector_f32( + this, + post_op, + vmm_d_weights, vmm_d_bias, reg_d_weights, reg_d_bias + )); + } + } + + this->preamble(); + + std::size_t post_ops_pointers_count = 0; + for (int i = 0; i < p.len(); i++) { + if (p.entry_[i].is_depthwise() || p.entry_[i].is_quantization()) { + post_ops_pointers_count++; + } + } + + if (post_ops_pointers_count != 0) { + sub(rsp, post_ops_pointers_count * sizeof(float *)); + + auto aux_reg0 = reg_input0; + auto aux_reg1 = reg_input1; + + mov(aux_reg0, ptr[this->param1 + GET_OFF_DW(post_ops_binary_rhs_arg_vec)]); + for (size_t i = 0; i < post_ops_pointers_count; i++) { + mov(aux_reg1, ptr[aux_reg0 + i * sizeof(float *)]); + mov(ptr[rsp + i * sizeof(float *)], aux_reg1); + } + } + + mov(reg_input0, ptr[this->param1 + GET_OFF_DW(src_row0)]); + mov(reg_input1, ptr[this->param1 + GET_OFF_DW(src_row1)]); + mov(reg_input2, ptr[this->param1 + GET_OFF_DW(src_row2)]); + mov(reg_output, ptr[this->param1 + GET_OFF_DW(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF_DW(filt)]); + if (jcp.with_bias) + mov(reg_bias, ptr[this->param1 + GET_OFF_DW(bias)]); + mov(reg_kh, ptr[this->param1 + GET_OFF_DW(kh_padding)]); + mov(reg_ur_w, ptr[this->param1 + GET_OFF_DW(ur_w)]); + mov(reg_oc_work, ptr[this->param1 + GET_OFF_DW(oc_work)]); + mov(reg_oc_off, ptr[this->param1 + GET_OFF_DW(oc_off)]); + + Label tail_label; + Label exit_label; + + cmp(reg_oc_work, jcp.ch_block); + jl(tail_label, T_NEAR); + + loop_body(jcp.ch_block); + jmp(exit_label, T_NEAR); + + L(tail_label); + + if (jcp.oc % jcp.ch_block != 0) + loop_body(jcp.oc % jcp.ch_block); + + L(exit_label); + + if (post_ops_pointers_count != 0) { + add(rsp, post_ops_pointers_count * sizeof(float *)); + } + + this->postamble(); + + for (auto& inj : eltwise_injectors) + inj->prepare_table(); +} + +template +bool jit_uni_dw_conv_row_f32::post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + int start_idx = p.find(primitive_kind::convolution) + 1; + + auto all_post_ops_supported = [&]() { + bool ok = true; + + for (int i = start_idx; i < p.len(); i++) { + ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::sum, primitive_kind::eltwise, primitive_kind::depthwise, + primitive_kind::binarization, primitive_kind::quantization); + } + return ok; + }; + auto contain = [&](dnnl::impl::primitive_kind_t kind) { return p.find(kind, start_idx, -1) != -1; }; + auto position = [&](dnnl::impl::primitive_kind_t kind) { return p.find(kind, start_idx, -1); }; + auto count = [&](dnnl::impl::primitive_kind_t kind) { return p.count(kind, start_idx, -1); }; + + return all_post_ops_supported() && + count(primitive_kind::sum) <= 1 && + count(primitive_kind::binarization) <= 1 && + IMPLICATION(contain(primitive_kind::sum), position(primitive_kind::sum) == start_idx) && + IMPLICATION(contain(primitive_kind::binarization), position(primitive_kind::binarization) == p.len()-1) && + IMPLICATION(contain(primitive_kind::binarization), !contain(primitive_kind::sum)); +} + +template +status_t jit_uni_dw_conv_row_f32::init_conf(jit_1x1_conv_conf_t &jcp, jit_conv_conf_t &jcp_dw, + const primitive_attr_t &attr) { + if (!mayiuse(isa)) return status::unimplemented; + const int simd_w = isa == avx512_common ? 16 : 8; + + const auto &p = attr.post_ops_; + + int dw_conv_ind = p.find(primitive_kind::convolution); + jcp_dw.with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1; + + auto dw_po_len = p.len() - (dw_conv_ind + 1); + jcp_dw.post_ops.entry_.resize(dw_po_len); + for (int i = 0; i < dw_po_len; ++i) { + CHECK(jcp_dw.post_ops.entry_[i].copy_from( + p.entry_[i + dw_conv_ind + 1])); + } + + jcp_dw.ch_block = simd_w; + jcp_dw.with_bias = true; + + jcp_dw.kh = p.entry_[dw_conv_ind].depthwise_conv_old.ker_h; + jcp_dw.kw = p.entry_[dw_conv_ind].depthwise_conv_old.ker_w; + jcp_dw.ic = jcp.oc; + jcp_dw.oc = jcp.oc; + jcp_dw.ih = p.entry_[dw_conv_ind].depthwise_conv_old.in_h; + jcp_dw.iw = p.entry_[dw_conv_ind].depthwise_conv_old.in_w; + jcp_dw.oh = jcp.dw_conv_oh; + jcp_dw.ow = jcp.dw_conv_ow; + jcp_dw.stride_h = p.entry_[dw_conv_ind].depthwise_conv_old.str_h; + jcp_dw.stride_w = p.entry_[dw_conv_ind].depthwise_conv_old.str_w; + + if (jcp_dw.kh != 3 || jcp_dw.kw != 3) + return status::unimplemented; + + if (!post_ops_ok(jcp_dw, attr)) + return status::unimplemented; + + jcp_dw.ur_w = 4; + + jcp_dw.src_dt = jcp.dst_dt; + jcp_dw.dst_dt = jcp.dw_conv_dst_dt; + jcp_dw.bia_dt = jcp.bia_dt == dnnl_data_type_undef ? dnnl_f32 : jcp.bia_dt; + jcp_dw.typesize_in = (int)types::data_type_size(jcp_dw.src_dt); + jcp_dw.typesize_bia = (int)types::data_type_size(jcp_dw.bia_dt); + jcp_dw.typesize_out = (int)types::data_type_size(jcp_dw.dst_dt); + + if (jcp_dw.src_dt != dnnl_f32 && jcp_dw.src_dt != dnnl_u8) + return status::unimplemented; + + return status::success; +} + +template struct jit_uni_dw_conv_row_f32; +template struct jit_uni_dw_conv_row_f32; +template struct jit_uni_dw_conv_row_f32; + +} // namespace x64 +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/x64/jit_uni_dw_conv_row_f32.hpp b/src/cpu/x64/jit_uni_dw_conv_row_f32.hpp new file mode 100644 index 00000000000..fe4a161133e --- /dev/null +++ b/src/cpu/x64/jit_uni_dw_conv_row_f32.hpp @@ -0,0 +1,157 @@ +/******************************************************************************* +* Copyright 2019-2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* [todo] antonvor: + * This file contains the old plugin behavior in order to fix performance + * problems after upgrading to OneDNN v1.6. This kernel is executed only on + * machines with avx2 instruction set support and in the case of a fused + * convolution. Remove after problems are fixed. +*/ + +#ifndef CPU_X64_JIT_UNI_DW_CONV_ROW_F32_HPP +#define CPU_X64_JIT_UNI_DW_CONV_ROW_F32_HPP + +#include + +#include "common/c_types_map.hpp" +#include "common/primitive_attr.hpp" +#include "common/type_helpers.hpp" +#include "common/utils.hpp" + +#include "cpu/x64/jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" +#include "cpu/x64/injectors/jit_uni_depthwise_injector.hpp" +#include "cpu/x64/injectors/jit_uni_quantization_injector.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +template +struct jit_uni_dw_conv_row_f32: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_row_f32) + + jit_uni_dw_conv_row_f32(jit_conv_conf_t ajcp, const primitive_attr_t &attr, int ow_stride) + : jcp(ajcp), attr_(attr), ow_stride_(ow_stride) {} + + ~jit_uni_dw_conv_row_f32() { + for (auto inj : eltwise_injectors) + delete inj; + eltwise_injectors.clear(); + + for (auto inj : depthwise_injectors) + delete inj; + depthwise_injectors.clear(); + + for (auto inj : quantization_injectors) + delete inj; + quantization_injectors.clear(); + } + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + static status_t init_conf(jit_1x1_conv_conf_t &jcp, jit_conv_conf_t &jcp_dw, const primitive_attr_t &attr); + + jit_conv_conf_t jcp; + const primitive_attr_t &attr_; + int ow_stride_; + +private: + using Vmm = typename utils::conditional3::type; + using reg64_t = const Xbyak::Reg64; + using reg32_t = const Xbyak::Reg32; + using reg16_t = const Xbyak::Reg16; + using reg8_t = const Xbyak::Reg8; + const Xbyak::AddressFrame &vmmword = (isa == sse41) + ? xword : (isa == avx2) ? yword : zword; + const int vlen = cpu_isa_traits::vlen; + + // dw convolution + reg64_t reg_input0 = r8; + reg64_t reg_input1 = r9; + reg64_t reg_input2 = r10; + reg64_t aux_reg_input0 = r11; + reg64_t aux_reg_input1 = r12; + reg64_t aux_reg_input2 = r13; + + reg64_t reg_kernel = r14; + reg64_t aux_reg_kernel = r15; + reg64_t reg_output = rdx; + reg64_t reg_bias = rbx; + reg64_t reg_kh = rax; + reg64_t reg_ur_w = rbp; + reg64_t reg_oc_work = abi_not_param1; + + reg64_t reg_oc_off = rsi; + reg64_t reg_d_weights = aux_reg_input0; + reg64_t reg_d_bias = aux_reg_input1; + + reg64_t reg_b_weights = r15; + reg64_t reg_b_mask = reg_d_bias; + reg64_t reg_b_out_mask = rbx; + + reg32_t reg_tmp_32 = r11d; + reg64_t reg_tmp_64 = r11; + reg8_t reg_tmp_8 = r11b; + reg16_t reg_tmp_16 = r11w; + + reg32_t reg_tmp2_32 = r13d; + reg64_t reg_tmp2_64 = r13; + + inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); } + inline Vmm get_src_reg(int idx) { return Vmm(idx + 1); } + inline Vmm get_acc_reg(int idx) { return Vmm(idx + 4); } + + Xbyak::Ymm ymm_tmp = Xbyak::Ymm(0); + Vmm vmm_tmp = Vmm(0); + Vmm vmm_sum = Vmm(0); + Vmm vmm_bias = Vmm(0); + Vmm vmm_thr = Vmm(0); + Vmm vmm_out_mask = Vmm(1); + + Vmm vmm_d_weights = Vmm(0); + Vmm vmm_d_bias = Vmm(1); + + const unsigned char _cmp_gt_os = 6; + + Xbyak::Opmask ktail_mask = Xbyak::Opmask(2); + Xbyak::Opmask bin_mask0 = Xbyak::Opmask(5); + Xbyak::Opmask bin_mask1 = Xbyak::Opmask(6); + + inline void clear_vmm_regs(int ur_w); + inline void apply_filter(int ur_w, int kw_size); + inline void cvt2ps(data_type_t type_in, Vmm vmm_in, const Xbyak::Operand &op, bool scalar_load); + inline void apply_postprocessing(int ur_w, int oc_step); + inline void store_dst_typed(const Xbyak::Address &op, Vmm vmm_dst, bool scalar_store); + inline void store_dst(int ur_w, int oc_step); + inline void loop_body(int oc_step); + + void generate() override; + + nstl::vector*> eltwise_injectors; + nstl::vector*> depthwise_injectors; + nstl::vector*> quantization_injectors; +}; + +} // namespace x64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/cpu/x64/jit_uni_dw_convolution.cpp b/src/cpu/x64/jit_uni_dw_convolution.cpp index 94268022c5e..c1989b3f1dc 100644 --- a/src/cpu/x64/jit_uni_dw_convolution.cpp +++ b/src/cpu/x64/jit_uni_dw_convolution.cpp @@ -42,6 +42,8 @@ void jit_uni_dw_convolution_fwd_t::execute_forward( const auto post_ops_binary_rhs_arg_vec = binary_injector::prepare_binary_args(jcp.post_ops, ctx); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const memory_desc_wrapper src_d(pd()->src_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); @@ -79,7 +81,7 @@ void jit_uni_dw_convolution_fwd_t::execute_forward( const auto is_src_layout_nxc = jcp.src_tag == format_tag::nhwc; const auto is_dst_layout_nxc = jcp.dst_tag == format_tag::nhwc; - const int work_amount = jcp.mb * chb_work * jcp.oh; + const int work_amount = MB * chb_work * jcp.oh; const auto nthr = jcp.nthr; parallel(nthr, [&](const int ithr, const int nthr) { @@ -89,10 +91,10 @@ void jit_uni_dw_convolution_fwd_t::execute_forward( int n {0}, chb {0}, oh {0}; if (jcp.loop_order == loop_ngcw) utils::nd_iterator_init( - start, n, jcp.mb, chb, chb_work, oh, jcp.oh); + start, n, MB, chb, chb_work, oh, jcp.oh); else if (jcp.loop_order == loop_nhwcg) utils::nd_iterator_init( - start, n, jcp.mb, oh, jcp.oh, chb, chb_work); + start, n, MB, oh, jcp.oh, chb, chb_work); else assert(!"unsupported loop order"); @@ -144,14 +146,16 @@ void jit_uni_dw_convolution_fwd_t::execute_forward( par_conv.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); par_conv.dst_orig = dst; + par_conv.oc_off = ch * jcp.ch_block * sizeof(float); + (*kernel_)(&par_conv); if (jcp.loop_order == loop_ngcw) { ++iwork; - utils::nd_iterator_step(n, jcp.mb, chb, chb_work, oh, jcp.oh); + utils::nd_iterator_step(n, MB, chb, chb_work, oh, jcp.oh); } else if (jcp.loop_order == loop_nhwcg) { utils::nd_iterator_jump( - iwork, end, n, jcp.mb, oh, jcp.oh, chb, chb_work); + iwork, end, n, MB, oh, jcp.oh, chb, chb_work); } else assert(!"unsupported loop order"); } @@ -175,12 +179,16 @@ void jit_uni_dw_convolution_bwd_data_tjcp_; + const auto post_ops_binary_rhs_arg_vec + = binary_injector::prepare_binary_args(jcp.post_ops, ctx); + + auto MB = CTX_IN_BATCH(DNNL_ARG_DIFF_DST); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); - const auto &jcp = pd()->jcp_; - auto kernel_params = [&](int ur_str_w, int iw, int oh, int ih, int i_t_overflow, int i_b_overflow, int stride_off_h, int ch, int n, @@ -221,13 +229,17 @@ void jit_uni_dw_convolution_bwd_data_t(jcp.oc), ch_work); par_conv.ch_blocks = load_work; + par_conv.ic_off = ch * jcp.ch_block * sizeof(float); + par_conv.post_ops_binary_rhs_arg_vec + = post_ops_binary_rhs_arg_vec.data(); + return par_conv; }; const int aux_w = nstl::min(jcp.iw, jcp.iw - jcp.kw + jcp.r_pad + jcp.stride_w); const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking); - const dim_t work_amount = jcp.mb * chb_work * jcp.ih; + const dim_t work_amount = MB * chb_work * jcp.ih; const auto nthr = jcp.nthr; parallel(nthr, [&](const int ithr, const int nthr) { @@ -236,10 +248,10 @@ void jit_uni_dw_convolution_bwd_data_t( - pd()->jcp_, *pd()->dst_md(0)))); + pd()->jcp_, *pd()->dst_md(0), *pd()->attr()))); return kernel_->create_kernel(); } @@ -119,13 +119,13 @@ struct jit_uni_dw_convolution_bwd_data_t : public primitive_t { && set_default_alg_kind(alg_kind::convolution_direct) && expect_data_types(diff_src_type, diff_dst_type, data_type::undef, diff_dst_type, data_type::f32) - && attr()->has_default_values() && !has_zero_dim_memory(); + && !has_zero_dim_memory(); if (!ok) return status::unimplemented; status_t status = jit_uni_dw_conv_bwd_data_kernel::init_conf(jcp_, *desc(), diff_src_md_, - weights_md_, diff_dst_md_); + weights_md_, diff_dst_md_, attr_); if (status != status::success) return status; auto scratchpad = scratchpad_registry().registrar(); @@ -147,7 +147,7 @@ struct jit_uni_dw_convolution_bwd_data_t : public primitive_t { status_t init(engine_t *engine) override { CHECK(safe_ptr_assign(kernel_, new jit_uni_dw_conv_bwd_data_kernel( - pd()->jcp_))); + pd()->jcp_, *pd()->attr()))); return kernel_->create_kernel(); } diff --git a/src/cpu/x64/jit_uni_eltwise_int.cpp b/src/cpu/x64/jit_uni_eltwise_int.cpp index ffb63fdd352..ab6a4902caf 100644 --- a/src/cpu/x64/jit_uni_eltwise_int.cpp +++ b/src/cpu/x64/jit_uni_eltwise_int.cpp @@ -30,7 +30,7 @@ namespace x64 { using namespace Xbyak; -struct jit_args_t { +struct jit_args_int_t { const void *from; const void *for_comparison; const void *to; @@ -40,7 +40,7 @@ struct jit_args_t { struct jit_uni_eltwise_int_kernel : public jit_generator { jit_uni_eltwise_int_kernel(const eltwise_desc_t &desc) : desc_(desc) {} - void operator()(jit_args_t *p) { jit_generator::operator()(p); } + void operator()(jit_args_int_t *p) { jit_generator::operator()(p); } protected: data_type_t data_type() const { return desc_.data_desc.data_type; } @@ -83,7 +83,7 @@ struct jit_uni_subkernel_int_t : public jit_uni_eltwise_int_kernel { preamble(); -#define GET_OFF(field) offsetof(jit_args_t, field) +#define GET_OFF(field) offsetof(jit_args_int_t, field) mov(reg_from, ptr[param + GET_OFF(from)]); mov(reg_to, ptr[param + GET_OFF(to)]); mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]); @@ -449,7 +449,7 @@ status_t jit_uni_eltwise_int_fwd_t::execute_forward( start = nstl::min(nelems, start * cache_line); end = nstl::min(nelems, end * cache_line); - auto arg = jit_args_t(); + auto arg = jit_args_int_t(); arg.from = (const void *)&src[start]; arg.for_comparison = (const void *)&src[start]; arg.to = (const void *)&dst[start]; diff --git a/src/cpu/x64/jit_uni_fork_dw_conv_kernel_f32.cpp b/src/cpu/x64/jit_uni_fork_dw_conv_kernel_f32.cpp new file mode 100644 index 00000000000..ff9579142e0 --- /dev/null +++ b/src/cpu/x64/jit_uni_fork_dw_conv_kernel_f32.cpp @@ -0,0 +1,1024 @@ +/******************************************************************************* +* Copyright 2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "common/c_types_map.hpp" +#include "common/memory.hpp" +#include "common/nstl.hpp" +#include "common/type_helpers.hpp" +#include "common/utils.hpp" + +#include "cpu/x64/jit_uni_fork_dw_conv_kernel_f32.hpp" + +#define GET_OFF(field) offsetof(jit_conv_call_s, field) + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +using namespace dnnl::impl::prop_kind; +using namespace dnnl::impl::memory_tracking::names; +using namespace dnnl::impl::utils; + +using namespace Xbyak; + +static bool check_if_tail_load(const bool is_ch_tail, const int c_tail, const int ch, + const int ur_ch_blocks, const int vlen, const int i) { + return is_ch_tail && (ch + 1 == ur_ch_blocks) && ((i + 1) * vlen > c_tail); +} + + +template +void jit_uni_fork_dw_conv_fwd_kernel_f32::load_src(int ur_ch_blocks, int ur_w, bool is_ch_tail) { + const auto dst_layout_nxc = is_dst_layout_nxc(); + const auto ch_blk = jcp.ch_block; + const auto ocb_stride = dst_layout_nxc ? ch_blk : jcp.od * jcp.oh * jcp.ow * ch_blk; + const auto ow_stride = dst_layout_nxc ? jcp.ngroups : ch_blk; + const int vlen_numbers = cpu_isa_traits::vlen / sizeof(float); + const int c_tail = jcp.oc % jcp.ch_block; + + int repeats = jcp.ch_block / vlen_numbers; + assert((repeats == 1) || (repeats == 2 && isa == sse41)); + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + const bool is_tail_load = check_if_tail_load( + is_ch_tail, c_tail, ch, ur_ch_blocks, vlen_numbers, i); + if ((ch + 1 == ur_ch_blocks) && is_ch_tail && c_tail <= i * vlen_numbers) + continue; + for (int ow = 0; ow < ur_w; ow++) { + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow); + + int b_off = ch*ch_blk + i*vlen_numbers; + if (this->jcp.with_bias) { + if (is_tail_load) { + load_tail(vmm_acc, reg_bias, b_off * sizeof(float), + (c_tail - i*vlen_numbers) * sizeof(float)); + } else { + uni_vmovups(vmm_acc, + vmmword[reg_bias + b_off * sizeof(float)]); + } + } else { + uni_vpxor(vmm_acc, vmm_acc, vmm_acc); + } + + int o_off = ch*ocb_stride + + ow*ow_stride + i*vlen_numbers; + if (this->jcp.with_sum) { + if (is_tail_load) { + if (this->jcp.with_bias) { + // using ker_vmm as vmm_tmp as it is safe to do so. + auto vmm_tmp = get_ker_reg(0); + add_tail_from_mem(vmm_acc, vmm_tmp, reg_output, + o_off * sizeof(float), + (c_tail - i*vlen_numbers) * sizeof(float)); + } else { + // nothing to add, just load dst. + load_tail(vmm_acc, reg_output, + o_off * sizeof(float), + c_tail * sizeof(float)); + } + } else { + // blocked layout has dst padded, so no tail handling. + uni_vaddps(vmm_acc, vmm_acc, + vmmword[reg_output + o_off*sizeof(float)]); + } + } + } + } + } +} + +template +void jit_uni_fork_dw_conv_fwd_kernel_f32::apply_filter( + int ur_ch_blocks, int ur_w, bool is_ch_tail) { + int ch_blk = jcp.ch_block; + int dilate_d = jcp.dilate_d + 1; + int dilate_h = jcp.dilate_h + 1; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + + const auto src_layout_nxc = is_src_layout_nxc(); + const auto iw_stride = src_layout_nxc ? jcp.ngroups : ch_blk; + const auto ih_stride = jcp.iw * iw_stride; + const auto icb_stride = src_layout_nxc + ? ch_blk + : jcp.id * jcp.ih * jcp.iw * ch_blk; + + Label iter_exit_label; + Label kd_label, iter_d_exit_label; + + if (jcp.ndims == 5) { + push(reg_kd); + mov(reg_kd, ptr[this->param1 + GET_OFF(kd_padding)]); + cmp(reg_kd, 0); + je(iter_d_exit_label, T_NEAR); + + push(reg_input); + push(reg_kernel); + base_post_ops_data_offset += 3 * reg64_size; + + mov(aux_reg_inp_d, aux_reg_input); + mov(aux_reg_ker_d, aux_reg_kernel); + + L(kd_label); + } + + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + cmp(reg_kh, 0); + je(iter_exit_label, T_NEAR); + cmp(reg_kw, 0); + je(iter_exit_label, T_NEAR); + + mov(iter_kh, reg_kh); + Label kh_label; + push(aux1_reg_kernel); + base_post_ops_data_offset += reg64_size; + L(kh_label); { + mov(iter_kw, reg_kw); + mov(aux1_reg_input, aux_reg_input); + mov(aux1_reg_kernel, aux_reg_kernel); + + Label kw_label; + L(kw_label); { + const int vlen_numbers = cpu_isa_traits::vlen / sizeof(float); + const int c_tail = jcp.oc % jcp.ch_block; + int repeats = jcp.ch_block / vlen_numbers; + assert((repeats == 1) || (repeats == 2 && isa == sse41)); + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + const bool is_tail_load = check_if_tail_load( + is_ch_tail, c_tail, ch, ur_ch_blocks, vlen_numbers, i); + if ((ch + 1 == ur_ch_blocks) && is_ch_tail + && c_tail <= i*vlen_numbers) + continue; + int ker_off = ch*jcp.kd*jcp.kh*jcp.kw*ch_blk + i*vlen_numbers; + Vmm vmm_ker = get_ker_reg(0); + uni_vmovups(vmm_ker, ptr[aux1_reg_kernel + + ker_off*sizeof(float)]); + + for (int ow = 0; ow < ur_w; ow++) { + int inp_off = ch*icb_stride + + ow*stride_w*iw_stride + i*vlen_numbers; + Vmm vmm_src = get_src_reg(0); + if (is_tail_load) { + load_tail(vmm_src, aux1_reg_input, + inp_off * sizeof(float), + (c_tail - i*vlen_numbers) * sizeof(float)); + } else { + uni_vmovups(vmm_src, + ptr[aux1_reg_input + inp_off*sizeof(float)]); + } + + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + + ch*ur_w + ow); + uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker); + } + } + } + add(aux1_reg_kernel, ch_blk*sizeof(float)); + add(aux1_reg_input, iw_stride*dilate_w*sizeof(float)); + + dec(iter_kw); + cmp(iter_kw, 0); + jg(kw_label, T_NEAR); + } + add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float)); + add(aux_reg_input, ih_stride*dilate_h*sizeof(float)); + + dec(iter_kh); + cmp(iter_kh, 0); + jg(kh_label, T_NEAR); + pop(aux1_reg_kernel); + base_post_ops_data_offset -= reg64_size; + } + + L(iter_exit_label); + + if (jcp.ndims == 5) { + add(aux_reg_ker_d, jcp.kh*jcp.kw*ch_blk*sizeof(float)); + add(aux_reg_inp_d, jcp.ih*dilate_d*ih_stride*sizeof(float)); + + mov(aux_reg_input, aux_reg_inp_d); + mov(aux_reg_kernel, aux_reg_ker_d); + + dec(reg_kd); + cmp(reg_kd, 0); + jg(kd_label, T_NEAR); + + pop(reg_kernel); + pop(reg_input); + + L(iter_d_exit_label); + pop(reg_kd); + base_post_ops_data_offset -= 3 * reg64_size; + } +} + +template +void jit_uni_fork_dw_conv_fwd_kernel_f32::apply_filter_unrolled( + int ur_ch_blocks, int ur_w, bool is_ch_tail) { + int ch_blk = jcp.ch_block; + int dilate_d = jcp.dilate_d + 1; + int dilate_h = jcp.dilate_h + 1; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + + const auto src_layout_nxc = is_src_layout_nxc(); + const auto iw_stride = src_layout_nxc ? jcp.ngroups : ch_blk; + const auto ih_stride = jcp.iw * iw_stride; + const auto icb_stride = src_layout_nxc + ? ch_blk + : jcp.id * jcp.ih * jcp.iw * ch_blk; + + Label iter_exit_label; + Label kd_label, iter_d_exit_label; + + if (jcp.ndims == 5) { + push(reg_kd); + mov(reg_kd, ptr[this->param1 + GET_OFF(kd_padding)]); + cmp(reg_kd, 0); + je(iter_d_exit_label, T_NEAR); + + push(reg_input); + push(reg_kernel); + + base_post_ops_data_offset += 3 * reg64_size; + + mov(aux_reg_inp_d, aux_reg_input); + mov(aux_reg_ker_d, aux_reg_kernel); + + L(kd_label); + } + + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + cmp(reg_kh, 0); + je(iter_exit_label, T_NEAR); + + mov(iter_kh, reg_kh); + Label kh_label; + L(kh_label); { + const int vlen_numbers = cpu_isa_traits::vlen / sizeof(float); + const int c_tail = jcp.oc % jcp.ch_block; + int repeats = jcp.ch_block / vlen_numbers; + assert((repeats == 1) || (repeats == 2 && isa == sse41)); + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + const bool is_tail_load = check_if_tail_load( + is_ch_tail, c_tail, ch, ur_ch_blocks, vlen_numbers, i); + if ((ch + 1 == ur_ch_blocks) && is_ch_tail + && c_tail <= i * vlen_numbers) + continue; + for (int kw = 0; kw < jcp.kw; kw++) { + int ker_off = ch*jcp.kd*jcp.kh*jcp.kw*ch_blk + kw*ch_blk + i*vlen_numbers; + + Vmm vmm_ker = get_ker_reg(0); + uni_vmovups(vmm_ker, ptr[aux_reg_kernel + + ker_off*sizeof(float)]); + + for (int ow = 0; ow < ur_w; ow++) { + int inp_off = ch*icb_stride + + ow*stride_w*iw_stride + kw*dilate_w*iw_stride + i*vlen_numbers; + + Vmm vmm_src = get_src_reg(0); + if (is_tail_load) { + load_tail(vmm_src, aux_reg_input, + inp_off * sizeof(float), + (c_tail - i*vlen_numbers) * sizeof(float)); + } else { + uni_vmovups(vmm_src, + ptr[aux_reg_input + inp_off*sizeof(float)]); + } + + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + + ch*ur_w + ow); + uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker); + } + } + } + } + + add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float)); + add(aux_reg_input, ih_stride*dilate_h*sizeof(float)); + + dec(iter_kh); + cmp(iter_kh, 0); + jg(kh_label, T_NEAR); + } + + L(iter_exit_label); + + if (jcp.ndims == 5) { + add(aux_reg_ker_d, jcp.kh*jcp.kw*ch_blk*sizeof(float)); + add(aux_reg_inp_d, jcp.ih*dilate_d*ih_stride*sizeof(float)); + + mov(aux_reg_input, aux_reg_inp_d); + mov(aux_reg_kernel, aux_reg_ker_d); + + dec(reg_kd); + cmp(reg_kd, 0); + jg(kd_label, T_NEAR); + + pop(reg_kernel); + pop(reg_input); + + L(iter_d_exit_label); + pop(reg_kd); + base_post_ops_data_offset -= 3 * reg64_size; + } +} + +template +void jit_uni_fork_dw_conv_fwd_kernel_f32::apply_postprocess(int ur_ch_blocks, int ur_w) { + int repeats = isa == sse41 ? 2 : 1; + + int eltwise_inj_idx = 0; + int depthwise_inj_idx = 0; + int quantization_inj_idx = 0; + std::size_t post_ops_data_offset = 0; + const auto &p = attr_.post_ops_; + + for (int i = 0; i < p.len(); i++) { + auto& post_op = p.entry_[i]; + if (post_op.is_eltwise()) { + int start_idx = get_acc_reg(0).getIdx(); + int end_idx = get_acc_reg(repeats * ur_w * ur_ch_blocks).getIdx(); + + eltwise_injectors[eltwise_inj_idx]->compute_vector_range(start_idx, end_idx); + eltwise_inj_idx++; + } else if (post_op.is_depthwise()) { + push(aux_reg_blocks_offset); + base_post_ops_data_offset += reg64_size; + add(aux_reg_blocks_offset, ptr[this->param1 + GET_OFF(oc_off)]); //add offset of processed blocks + + mov(reg_d_weights, ptr[this->rsp + base_post_ops_data_offset + post_ops_data_offset]); + add(reg_d_weights, aux_reg_blocks_offset); + + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int k = 0; k < repeats; k++) { + int start_idx = get_acc_reg(k*ur_ch_blocks*ur_w + ur_w * ch).getIdx(); + int end_idx = get_acc_reg(k*ur_ch_blocks*ur_w + ur_w * ch + ur_w).getIdx(); + + depthwise_injectors[depthwise_inj_idx]->compute_vector_range( + start_idx, end_idx, reg_d_weights, reg_d_weights); + + add(reg_d_weights, jcp.ch_block / repeats * sizeof(float)); + } + } + pop(aux_reg_blocks_offset); + base_post_ops_data_offset -= reg64_size; + + post_ops_data_offset += depthwise_injectors[depthwise_inj_idx]->memoryStep(); + depthwise_inj_idx++; + } else if (post_op.is_quantization()) { + push(aux_reg_blocks_offset); + base_post_ops_data_offset += reg64_size; + add(aux_reg_blocks_offset, ptr[this->param1 + GET_OFF(oc_off)]); //add offset of processed blocks + + const Xbyak::RegExp quant_arg_base = this->rsp + base_post_ops_data_offset + post_ops_data_offset; + quantization_injectors[quantization_inj_idx]->init_crop_ptrs(quant_arg_base, aux_reg_blocks_offset); + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int k = 0; k < repeats; k++) { + int s_idx = get_acc_reg(k*ur_ch_blocks*ur_w + ch*ur_w).getIdx(); + quantization_injectors[quantization_inj_idx]->compute_crop(s_idx, s_idx + ur_w, + (k * (jcp.ch_block / 2) + ch * jcp.ch_block) * sizeof(float)); + } + } + + quantization_injectors[quantization_inj_idx]->init_input_scale_shift_ptrs(quant_arg_base, aux_reg_blocks_offset); + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int k = 0; k < repeats; k++) { + int s_idx = get_acc_reg(k*ur_ch_blocks*ur_w + ch*ur_w).getIdx(); + quantization_injectors[quantization_inj_idx]->compute_input_scale_shift(s_idx, s_idx + ur_w, + (k * (jcp.ch_block / 2) + ch * jcp.ch_block) * sizeof(float), true); + } + } + + quantization_injectors[quantization_inj_idx]->init_output_scale_shift_ptrs(quant_arg_base, aux_reg_blocks_offset); + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int k = 0; k < repeats; k++) { + int s_idx = get_acc_reg(k*ur_ch_blocks*ur_w + ch*ur_w).getIdx(); + quantization_injectors[quantization_inj_idx]->compute_output_scale_shift(s_idx, s_idx + ur_w, + (k * (jcp.ch_block / 2) + ch * jcp.ch_block) * sizeof(float)); + } + } + pop(aux_reg_blocks_offset); + base_post_ops_data_offset -= reg64_size; + + post_ops_data_offset += quantization_injectors[quantization_inj_idx]->memoryStep(); + quantization_inj_idx++; + } + } +} + +template +void jit_uni_fork_dw_conv_fwd_kernel_f32::load_tail( + Vmm &vmm, const Xbyak::Reg64 ®, int64_t offset, int load_size) { + uni_vmovups(vmm | k_oc_tail_mask | T_z, ptr[reg + offset]); +} + +template <> +void jit_uni_fork_dw_conv_fwd_kernel_f32::load_tail( + Vmm &vmm, const Xbyak::Reg64 ®, int64_t offset, int load_size) { + load_bytes(vmm, reg, offset, load_size); +} + +template <> +void jit_uni_fork_dw_conv_fwd_kernel_f32::load_tail( + Vmm &vmm, const Xbyak::Reg64 ®, int64_t offset, int load_size) { + load_bytes(vmm, reg, offset, load_size); +} + +template +void jit_uni_fork_dw_conv_fwd_kernel_f32::add_tail_from_mem(Vmm &vmm_acc, + Vmm &vmm_tmp, const Xbyak::Reg64 ®, int64_t offset, int load_size) { + uni_vaddps(vmm_acc | k_oc_tail_mask | T_z, vmm_acc, ptr[reg + offset]); +} + +template <> +void jit_uni_fork_dw_conv_fwd_kernel_f32::add_tail_from_mem(Vmm &vmm_acc, + Vmm &vmm_tmp, const Xbyak::Reg64 ®, int64_t offset, int load_size) { + load_bytes(vmm_tmp, reg, offset, load_size); + uni_vaddps(vmm_acc, vmm_acc, vmm_tmp); +} + +template <> +void jit_uni_fork_dw_conv_fwd_kernel_f32::add_tail_from_mem(Vmm &vmm_acc, + Vmm &vmm_tmp, const Xbyak::Reg64 ®, int64_t offset, int load_size) { + load_bytes(vmm_tmp, reg, offset, load_size); + uni_vaddps(vmm_acc, vmm_acc, vmm_tmp); +} + +template +void jit_uni_fork_dw_conv_fwd_kernel_f32::store_tail( + Vmm &vmm, const Xbyak::Reg64 ®, int64_t offset, int store_size) { + uni_vmovups(vmmword[reg + offset], vmm | k_oc_tail_mask); +} + +template <> +void jit_uni_fork_dw_conv_fwd_kernel_f32::store_tail( + Vmm &vmm, const Xbyak::Reg64 ®, int64_t offset, int store_size) { + store_bytes(vmm, reg, offset, store_size); +} + +template <> +void jit_uni_fork_dw_conv_fwd_kernel_f32::store_tail( + Vmm &vmm, const Xbyak::Reg64 ®, int64_t offset, int store_size) { + store_bytes(vmm, reg, offset, store_size); +} + + +template +void jit_uni_fork_dw_conv_fwd_kernel_f32::store_dst( + int ur_ch_blocks, int ur_w, bool is_ch_tail) { + const auto dst_layout_nxc = is_dst_layout_nxc(); + const auto ch_blk = jcp.ch_block; + const auto ocb_stride = dst_layout_nxc ? ch_blk : jcp.od * jcp.oh * jcp.ow * ch_blk; + const auto ow_stride = dst_layout_nxc ? jcp.ngroups : ch_blk; + const int vlen_numbers = cpu_isa_traits::vlen / sizeof(float); + const int c_tail = jcp.oc_without_padding % jcp.ch_block; + + int repeats = jcp.ch_block / vlen_numbers; + assert((repeats == 1) || (repeats == 2 && isa == sse41)); + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + const bool is_tail_load = check_if_tail_load( + is_ch_tail, c_tail, ch, ur_ch_blocks, vlen_numbers, i); + if ((ch + 1 == ur_ch_blocks) && is_ch_tail && c_tail <= i * vlen_numbers) + continue; + for (int ow = 0; ow < ur_w; ow++) { + int o_off = ch*ocb_stride + ow*ow_stride + i*vlen_numbers; + Vmm vmm_dst = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow); + + if (is_tail_load) { + store_tail(vmm_dst, reg_output, o_off * sizeof(float), + (c_tail - i*vlen_numbers) * sizeof(float)); + } else + uni_vmovups(vmmword[reg_output + o_off*sizeof(float)], vmm_dst); + } + } + } +} + +template +void jit_uni_fork_dw_conv_fwd_kernel_f32::compute_loop(int ur_w, int ur_ch_blocks) { + const bool ch_loop = ur_ch_blocks > jcp.nb_ch_blocking; + // ch_loop currently happen only when data layout is nxc. The strides are + // calculated for this layout only. + const size_t wei_ch_stride = (size_t)jcp.nb_ch_blocking * jcp.kd * jcp.kh * jcp.kw + * jcp.ch_block * sizeof(float); + const size_t inp_ch_stride + = (size_t)jcp.nb_ch_blocking * jcp.ch_block * sizeof(float); + const size_t out_ch_stride + = (size_t)jcp.nb_ch_blocking * jcp.ch_block * sizeof(float); + const size_t bias_stride + = (size_t)jcp.nb_ch_blocking * jcp.ch_block * sizeof(float); + + auto compute = [&](int ur_ch_blocks, bool is_ch_tail) { + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + + load_src(ur_ch_blocks, ur_w, is_ch_tail); + if (ur_w == 1) { + apply_filter(ur_ch_blocks, ur_w, is_ch_tail); + } else { + apply_filter_unrolled(ur_ch_blocks, ur_w, is_ch_tail); + } + apply_postprocess(ur_ch_blocks, ur_w); + store_dst(ur_ch_blocks, ur_w, is_ch_tail); + }; + + xor_(aux_reg_blocks_offset, aux_reg_blocks_offset); + + if (ch_loop) { + Label ch_loop_label, ch_tail_label, skip_ch_tail_label; + const int ch_block_tail = jcp.nb_ch + - (utils::rnd_dn(jcp.oc / jcp.ch_block, jcp.nb_ch_blocking)); + const int ch_step = jcp.nb_ch_blocking * jcp.ch_block; + + push(aux_reg_ch_blocks); + mov(aux_reg_ch_blocks, reg_ch_blocks); + push(reg_kernel); + push(reg_input); + push(reg_output); + base_post_ops_data_offset += 4 * reg64_size; + if (jcp.with_bias) { + push(reg_bias); + base_post_ops_data_offset += reg64_size; + } + + if ((jcp.oc / jcp.ch_block) >= jcp.nb_ch_blocking) { + if (ch_block_tail) { + cmp(aux_reg_ch_blocks, ch_step); + jl(ch_tail_label, T_NEAR); + } + + L(ch_loop_label); + { + compute(jcp.nb_ch_blocking, false); + add(reg_kernel, wei_ch_stride); + add(reg_input, inp_ch_stride); + add(reg_output, out_ch_stride); + if (jcp.with_bias) add(reg_bias, bias_stride); + sub(aux_reg_ch_blocks, ch_step); + add(aux_reg_blocks_offset, ch_step * sizeof(float)); //add initial offset of processed blocks + cmp(aux_reg_ch_blocks, ch_step); + jge(ch_loop_label, T_NEAR); + } + } + + if (ch_block_tail) { + // ch work range [1, jcp.nb_ch_blocking * ch_block) + L(ch_tail_label); + cmp(aux_reg_ch_blocks, 0); + jle(skip_ch_tail_label, T_NEAR); + compute(ch_block_tail, jcp.oc % jcp.ch_block); + L(skip_ch_tail_label); + } + + if (jcp.with_bias) { + pop(reg_bias); + base_post_ops_data_offset -= reg64_size; + } + pop(reg_output); + pop(reg_input); + pop(reg_kernel); + pop(aux_reg_ch_blocks); + base_post_ops_data_offset -= 4 * reg64_size; + + } else { + compute(ur_ch_blocks, jcp.oc % jcp.ch_block); + } +} + +template +void jit_uni_fork_dw_conv_fwd_kernel_f32::loop_body(int ur_ch_blocks) { + Label unrolled_w_label; + Label tail_w_label; + Label exit_label; + + const auto src_layout_nxc = is_src_layout_nxc(); + const auto dat_c_stride = src_layout_nxc ? jcp.ngroups : jcp.ch_block; + + L(unrolled_w_label); { + int ur_w = jcp.ur_w; + + size_t inp_shift = sizeof(float) * ur_w * jcp.stride_w * dat_c_stride; + size_t out_shift = sizeof(float) * ur_w * dat_c_stride; + + cmp(reg_ur_w, ur_w); + jl(tail_w_label, T_NEAR); + + compute_loop(ur_w, ur_ch_blocks); + + add(reg_input, inp_shift); + add(reg_output, out_shift); + + sub(reg_ur_w, ur_w); + jmp(unrolled_w_label); + } + + L(tail_w_label); { + int ur_w = 1; + + size_t inp_shift = sizeof(float) * ur_w * jcp.stride_w * dat_c_stride; + size_t out_shift = sizeof(float) * ur_w * dat_c_stride; + + cmp(reg_ur_w, ur_w); + jl(exit_label, T_NEAR); + + compute_loop(ur_w, ur_ch_blocks); + + add(reg_input, inp_shift); + add(reg_output, out_shift); + + sub(reg_ur_w, ur_w); + jmp(tail_w_label); + } + + L(exit_label); +} + +template +void jit_uni_fork_dw_conv_fwd_kernel_f32::generate() { + const auto &p = attr_.post_ops_; + for (int i = 0; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_eltwise()) { + eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32( + this, + post_op.eltwise + )); + } else if (post_op.is_depthwise()) { + depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32( + this, + post_op + )); + } else if (post_op.is_quantization()) { + quantization_injectors.push_back(new jit_uni_quantization_injector_f32( + this, + post_op, + vmm_d_weights, vmm_d_bias, reg_d_weights, reg_d_bias + )); + } + } + + this->preamble(); + + std::size_t post_ops_pointers_count = 0; + for (int i = 0; i < p.len(); i++) { + if (p.entry_[i].is_depthwise() || p.entry_[i].is_quantization()) { + post_ops_pointers_count++; + } + } + + if (post_ops_pointers_count != 0) { + sub(rsp, post_ops_pointers_count * sizeof(float *)); + + auto aux_reg0 = reg_input; + auto aux_reg1 = reg_output; + + mov(aux_reg0, ptr[this->param1 + GET_OFF(post_ops_binary_rhs_arg_vec)]); + for (size_t i = 0; i < post_ops_pointers_count; i++) { + mov(aux_reg1, ptr[aux_reg0 + i * sizeof(float *)]); + mov(ptr[rsp + i * sizeof(float *)], aux_reg1); + } + } + + mov(reg_input, ptr[this->param1 + GET_OFF(src)]); + mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + if (jcp.with_bias) + mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]); + mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]); + mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(load_work)]); + mov(reg_ur_w, ptr[this->param1 + GET_OFF(ur_w)]); + + Label ch_blocks_tail_label; + Label exit_label; + + int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking; + if (isa & avx512_common_bit) { + const auto oc_tail = jcp.oc_without_padding % jcp.ch_block; + if (oc_tail != 0) { + // Prepare masks for tailing + const int oc_tail_shift + = jcp.ch_block - jcp.oc_without_padding % jcp.ch_block; + static constexpr auto zmm_full_mask = ((1 << 16) - 1); + Reg32 reg_tail_32 = reg_tail.cvt32(); + mov(reg_tail_32, (zmm_full_mask >> oc_tail_shift)); + kmovw(k_oc_tail_mask, reg_tail_32); + } + } + + if (is_src_layout_nxc()) { + loop_body(jcp.nb_ch); + } else { + cmp(reg_ch_blocks, (jcp.nb_ch_blocking - 1) * jcp.ch_block); + jle(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR); + + loop_body(jcp.nb_ch_blocking); // channel main loop + + if (ch_blocks_tail) { + jmp(exit_label, T_NEAR); + L(ch_blocks_tail_label); + loop_body(ch_blocks_tail); // channel tail loop + } + + L(exit_label); + } + + if (post_ops_pointers_count != 0) { + add(rsp, post_ops_pointers_count * sizeof(float *)); + } + + this->postamble(); + + for (auto& inj : eltwise_injectors) + inj->prepare_table(); +} + +template struct jit_uni_fork_dw_conv_fwd_kernel_f32; +template struct jit_uni_fork_dw_conv_fwd_kernel_f32; +template struct jit_uni_fork_dw_conv_fwd_kernel_f32; + +template +inline void jit_uni_fork_dw_conv_bwd_data_kernel_f32::load_ddst( + int ur_ch_blocks, int ur_str_w) { + int repeats = isa == sse41 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int w = 0; w < ur_str_w; w++) { + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w + + ch*ur_str_w + w); + uni_vpxor(vmm_acc, vmm_acc, vmm_acc); + } + } + } +} + +template +inline void jit_uni_fork_dw_conv_bwd_data_kernel_f32::apply_filter( + int ur_ch_blocks, int ur_str_w) { + int kw = jcp.kw; + int kh = jcp.kh; + int ow = jcp.ow; + int oh = jcp.oh; + + int ch_blk = jcp.ch_block; + int stride_h = jcp.stride_h; + int stride_w = jcp.stride_w; + + Label iter_exit_label; + + cmp(reg_kh, 0); + je(iter_exit_label, T_NEAR); + + cmp(reg_kw, 0); + je(iter_exit_label, T_NEAR); + + mov(iter_kh, reg_kh); + Label kh_label; + L(kh_label); { + mov(aux1_reg_ddst, aux_reg_ddst); + mov(aux1_reg_kernel, aux_reg_kernel); + + mov(iter_kw, reg_kw); + Label kw_label; + L(kw_label); { + int repeats = isa == sse41 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + int ker_off = ch*kh*kw*ch_blk + i*4; + Vmm vmm_ker = get_ker_reg(0); + uni_vmovups(vmm_ker, ptr[aux1_reg_kernel + + ker_off*sizeof(float)]); + + for (int w = 0; w < ur_str_w; w++) { + int ddst_off = (ch*oh*ow + w)*ch_blk + i*4; + + Vmm vmm_src = get_src_reg(0); + uni_vmovups(vmm_src, ptr[aux1_reg_ddst + + ddst_off*sizeof(float)]); + + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w + + ch*ur_str_w + w); + uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker); + } + } + } + + add(aux1_reg_kernel, ch_blk*stride_w*sizeof(float)); + sub(aux1_reg_ddst, ch_blk*sizeof(float)); + + sub(iter_kw, stride_w); + cmp(iter_kw, 0); + jg(kw_label, T_NEAR); + } + + add(aux_reg_kernel, kw*ch_blk*stride_h*sizeof(float)); + sub(aux_reg_ddst, ow*ch_blk*sizeof(float)); + + sub(iter_kh, stride_h); + cmp(iter_kh, 0); + jg(kh_label, T_NEAR); + } + + L(iter_exit_label); +} + +template +void jit_uni_fork_dw_conv_bwd_data_kernel_f32::apply_postprocess(int ur_ch_blocks, int ur_str_w) { + int repeats = isa == sse41 ? 2 : 1; + + const auto &p = attr_.post_ops_; + std::size_t post_ops_data_offset = 0; + int depthwise_inj_idx = 0; + for (int i = 0; i < p.len(); i++) { + auto& post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + mov(reg_d_weights, ptr[this->rsp + post_ops_data_offset]); + add(reg_d_weights, ptr[this->param1 + GET_OFF(ic_off)]); + + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int k = 0; k < repeats; k++) { + int start_idx = get_acc_reg(k*ur_ch_blocks*ur_str_w + ur_str_w * ch).getIdx(); + int end_idx = get_acc_reg(k*ur_ch_blocks*ur_str_w + ur_str_w * ch + ur_str_w).getIdx(); + + depthwise_injectors[depthwise_inj_idx]->compute_vector_range(start_idx, end_idx, reg_d_weights, reg_d_weights); + + add(reg_d_weights, jcp.ch_block / repeats * sizeof(float)); + add(reg_d_bias, jcp.ch_block / repeats * sizeof(float)); + } + } + post_ops_data_offset += depthwise_injectors[depthwise_inj_idx]->memoryStep(); + depthwise_inj_idx++; + } + } +} + +template +inline void jit_uni_fork_dw_conv_bwd_data_kernel_f32::store_dsrc( + int ur_ch_blocks, int ur_str_w) { + int ch_blk = jcp.ch_block; + int iw = jcp.iw; + int ih = jcp.ih; + int stride_w = jcp.stride_w; + + int repeats = isa == sse41 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int w = 0; w < ur_str_w; w++) { + int dsrc_off = (ch*ih*iw + w*stride_w)*ch_blk + i*4; + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w + + ch*ur_str_w + w); + + uni_vmovups(ptr[reg_dsrc + dsrc_off*sizeof(float)], vmm_acc); + } + } + } +} + +template +inline void jit_uni_fork_dw_conv_bwd_data_kernel_f32::loop_body( + int ur_ch_blocks) { + Label unrolled_w_label; + Label tail_w_label; + Label exit_label; + + L(unrolled_w_label); { + int ur_w = jcp.ur_w; + + cmp(reg_ur_str_w, ur_w); + jl(tail_w_label, T_NEAR); + + mov(aux_reg_ddst, reg_ddst); + mov(aux_reg_kernel, reg_kernel); + + load_ddst(ur_ch_blocks, ur_w); + apply_filter(ur_ch_blocks, ur_w); + apply_postprocess(ur_ch_blocks, ur_w); + store_dsrc(ur_ch_blocks, ur_w); + + add(reg_dsrc, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_ddst, sizeof(float) * ur_w * jcp.ch_block); + + sub(reg_ur_str_w, ur_w); + jmp(unrolled_w_label); + } + + L(tail_w_label); { + int ur_w = 1; + + cmp(reg_ur_str_w, ur_w); + jl(exit_label, T_NEAR); + + mov(aux_reg_ddst, reg_ddst); + mov(aux_reg_kernel, reg_kernel); + + load_ddst(ur_ch_blocks, ur_w); + apply_filter(ur_ch_blocks, ur_w); + apply_postprocess(ur_ch_blocks, ur_w); + store_dsrc(ur_ch_blocks, ur_w); + + add(reg_dsrc, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_ddst, sizeof(float) * ur_w * jcp.ch_block); + + sub(reg_ur_str_w, ur_w); + jmp(tail_w_label); + } + + L(exit_label); +} + +template +void jit_uni_fork_dw_conv_bwd_data_kernel_f32::generate() { + const auto &p = attr_.post_ops_; + for (int i = 0; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32( + this, + post_op + )); + } + } + + preamble(); + + std::size_t post_ops_pointers_count = 0; + for (int i = 0; i < p.len(); i++) { + if (p.entry_[i].is_depthwise() || p.entry_[i].is_quantization()) { + post_ops_pointers_count++; + } + } + + if (post_ops_pointers_count != 0) { + sub(rsp, post_ops_pointers_count * sizeof(float *)); + + auto aux_reg0 = reg_dsrc; + auto aux_reg1 = reg_ddst; + + mov(aux_reg0, ptr[this->param1 + GET_OFF(post_ops_binary_rhs_arg_vec)]); + for (size_t i = 0; i < post_ops_pointers_count; i++) { + mov(aux_reg1, ptr[aux_reg0 + i * sizeof(float *)]); + mov(ptr[rsp + i * sizeof(float *)], aux_reg1); + } + } + + mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]); + mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]); + mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]); + mov(reg_ur_str_w, ptr[this->param1 + GET_OFF(ur_str_w)]); + + Label ch_blocks_tail_label; + Label exit_label; + + int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking; + + cmp(reg_ch_blocks, jcp.nb_ch_blocking); + jne(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR); + + loop_body(jcp.nb_ch_blocking); // channel main loop + + if (ch_blocks_tail) { + L(ch_blocks_tail_label); + + cmp(reg_ch_blocks, ch_blocks_tail); + jne(exit_label, T_NEAR); + + loop_body(ch_blocks_tail); // channel tail loop + } + + L(exit_label); + + if (post_ops_pointers_count != 0) { + add(rsp, post_ops_pointers_count * sizeof(float *)); + } + + this->postamble(); +} + +template struct jit_uni_fork_dw_conv_bwd_data_kernel_f32; +template struct jit_uni_fork_dw_conv_bwd_data_kernel_f32; +template struct jit_uni_fork_dw_conv_bwd_data_kernel_f32; + +} +} +} +} diff --git a/src/cpu/x64/jit_uni_fork_dw_conv_kernel_f32.hpp b/src/cpu/x64/jit_uni_fork_dw_conv_kernel_f32.hpp new file mode 100644 index 00000000000..20946163739 --- /dev/null +++ b/src/cpu/x64/jit_uni_fork_dw_conv_kernel_f32.hpp @@ -0,0 +1,197 @@ +/******************************************************************************* +* Copyright 2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_X64_JIT_UNI_FORK_DW_CONV_KERNEL_HPP +#define CPU_X64_JIT_UNI_FORK_DW_CONV_KERNEL_HPP + +#include "common/c_types_map.hpp" +#include "common/memory_tracking.hpp" + +#include "cpu/x64/jit_generator.hpp" +#include "cpu/x64/jit_primitive_conf.hpp" +#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" +#include "cpu/x64/injectors/jit_uni_depthwise_injector.hpp" +#include "cpu/x64/injectors/jit_uni_quantization_injector.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +template +struct jit_uni_fork_dw_conv_fwd_kernel_f32 : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_fork_dw_conv_fwd_kernel_f32) + + jit_uni_fork_dw_conv_fwd_kernel_f32(const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr) { + } + + ~jit_uni_fork_dw_conv_fwd_kernel_f32() { + for (auto inj : eltwise_injectors) + delete inj; + eltwise_injectors.clear(); + + for (auto inj : depthwise_injectors) + delete inj; + depthwise_injectors.clear(); + + for (auto inj : quantization_injectors) + delete inj; + quantization_injectors.clear(); + } + + jit_conv_conf_t jcp; + const primitive_attr_t &attr_; + +private: + using Vmm = typename utils::conditional3::type; + using mask_t = const Xbyak::Opmask; + using reg64_t = const Xbyak::Reg64; + const Xbyak::AddressFrame &vmmword = (isa == sse41) + ? xword : (isa == avx2) ? yword : zword; + const int vlen = cpu_isa_traits::vlen; + + // dw convolution + reg64_t reg_input = r8; + reg64_t aux_reg_input = r9; + reg64_t aux1_reg_input = r10; + reg64_t reg_kernel = r11; + reg64_t aux_reg_kernel = r12; + reg64_t reg_ch_blocks = r13; + reg64_t reg_output = r14; + reg64_t reg_bias = r15; + reg64_t reg_tail = rax; + reg64_t reg_kw = rbx; + reg64_t iter_kh = rdx; + reg64_t iter_kw = rsi; + reg64_t reg_ur_w = rbp; + reg64_t reg_kh = reg_tail; + reg64_t aux1_reg_kernel = reg_ch_blocks; + reg64_t imm_addr64 = aux1_reg_input; + reg64_t aux_reg_ch_blocks = reg_ur_w; + reg64_t aux_reg_blocks_offset = abi_not_param1; + + reg64_t reg_d_weights = imm_addr64; + reg64_t reg_d_bias = iter_kh; + int base_post_ops_data_offset = 0; + constexpr static int reg64_size = 8; + + reg64_t reg_kd = aux_reg_blocks_offset; + reg64_t aux_reg_inp_d = reg_input; + reg64_t aux_reg_ker_d = reg_kernel; + + mask_t k_oc_tail_mask = Xbyak::Opmask(2); + + Vmm vmm_d_weights = Vmm(0); + Vmm vmm_d_bias = Vmm(1); + + inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); } + inline Vmm get_src_reg(int idx) { return Vmm(idx + 1); } + inline Vmm get_acc_reg(int idx) { return Vmm(idx + 4); } + + inline bool is_src_layout_nxc() { + return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc, + format_tag::nwc); + } + inline bool is_dst_layout_nxc() { + return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc, + format_tag::nwc); + } + + inline void load_src(int ur_ch_blocks, int ur_w, bool is_ch_tail); + inline void compute_loop(int ur_w, int ur_ch_blocks); + inline void apply_filter(int ur_ch_blocks, int ur_w, bool is_ch_tail); + inline void apply_filter_unrolled(int ur_ch_blocks, int ur_w, bool is_ch_tail); + inline void apply_postprocess(int ur_ch_blocks, int ur_w); + inline void store_dst(int ur_ch_blocks, int ur_w, bool is_ch_tail); + inline void loop_body(int ur_ch_blocks); + + void load_tail( + Vmm &vmm, const Xbyak::Reg64 ®, int64_t offset, int load_size); + void add_tail_from_mem(Vmm &vmm_acc, Vmm &vmm_tmp, const Xbyak::Reg64 ®, + int64_t offset, int load_size); + void store_tail( + Vmm &vmm, const Xbyak::Reg64 ®, int64_t offset, int store_size); + + void generate() override; + + nstl::vector*> eltwise_injectors; + nstl::vector*> depthwise_injectors; + nstl::vector*> quantization_injectors; +}; + +template +struct jit_uni_fork_dw_conv_bwd_data_kernel_f32: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_fork_dw_conv_bwd_data_kernel_f32) + + jit_uni_fork_dw_conv_bwd_data_kernel_f32(const jit_conv_conf_t &ajcp, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr) {} + + ~jit_uni_fork_dw_conv_bwd_data_kernel_f32() { + for (auto inj : depthwise_injectors) + delete inj; + depthwise_injectors.clear(); + } + + jit_conv_conf_t jcp; + const primitive_attr_t &attr_; + +private: + using Vmm = typename utils::conditional3::type; + using reg64_t = const Xbyak::Reg64; + + inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); } + inline Vmm get_src_reg(int idx) { return Vmm(idx + 1); } + inline Vmm get_acc_reg(int idx) { return Vmm(idx + 4); } + + reg64_t reg_ddst = rax; + reg64_t aux_reg_ddst = r8; + reg64_t aux1_reg_ddst = abi_not_param1; + reg64_t reg_kernel = rdx; + reg64_t aux_reg_kernel = r10; + reg64_t aux1_reg_kernel = rbp; + reg64_t reg_dsrc = rsi; + + reg64_t reg_ur_str_w = r9; + reg64_t reg_ch_blocks = rbx; + + reg64_t iter_kh = r11; + reg64_t iter_kw = r12; + reg64_t reg_kh = r13; + reg64_t reg_kw = r14; + + reg64_t reg_d_weights = r15; + reg64_t reg_d_bias = iter_kh; + + inline void loop_body(int ur_ch_blocks); + inline void load_ddst(int ur_ch_blocks, int ur_str_w); + inline void apply_filter(int ur_ch_blocks, int ur_str_w); + inline void apply_postprocess(int ur_ch_blocks, int ur_str_w); + inline void store_dsrc(int ur_ch_blocks, int ur_str_w); + + void generate() override; + + nstl::vector*> depthwise_injectors; +}; + +} +} +} +} + +#endif diff --git a/src/cpu/x64/jit_uni_fork_dw_conv_kernel_utils.hpp b/src/cpu/x64/jit_uni_fork_dw_conv_kernel_utils.hpp new file mode 100644 index 00000000000..36595216acc --- /dev/null +++ b/src/cpu/x64/jit_uni_fork_dw_conv_kernel_utils.hpp @@ -0,0 +1,436 @@ +/******************************************************************************* +* Copyright 2019-2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_X64_JIT_UNI_FORK_DW_CONV_KERNEL_UTILS_HPP +#define CPU_X64_JIT_UNI_FORK_DW_CONV_KERNEL_UTILS_HPP + +#include "common/nstl.hpp" +#include "common/type_helpers.hpp" +#include "common/utils.hpp" + +#include "common/c_types_map.hpp" +#include "common/memory_tracking.hpp" + +#include "cpu/x64/jit_generator.hpp" +#include "cpu/x64/jit_primitive_conf.hpp" +#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" + +#include "cpu/x64/jit_avx512_core_fork_bf16_dw_conv_kernel.hpp" +#include "cpu/x64/jit_uni_fork_dw_conv_kernel_f32.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +template +struct jit_uni_fork_dw_conv_fwd_kernel { + + jit_uni_fork_dw_conv_fwd_kernel(const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md, const primitive_attr_t &attr) : ker_(nullptr) { + ker_ = new jit_kernel_t(ajcp, dst_md, attr); + } + + status_t create_kernel() { return ker_->create_kernel(); } + ~jit_uni_fork_dw_conv_fwd_kernel() { delete ker_; } + + static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr); + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, memory_desc_t &src_md, + memory_desc_t &weights_md, memory_desc_t &bias_md, + memory_desc_t &dst_md, const primitive_attr_t &attr); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_generator *ker() const { return ker_; } + void operator()(const jit_conv_call_s *p) const { (*ker_)(p); } + +private: + using jit_kernel_t = typename utils::conditional>::type; + jit_kernel_t *ker_; +}; + +template +bool jit_uni_fork_dw_conv_fwd_kernel::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + auto all_post_ops_supported = [&]() { + bool ok = true; + + for (int i = 0; i < p.len(); i++) { + ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::sum, primitive_kind::eltwise, primitive_kind::depthwise, primitive_kind::quantization); + } + return ok; + }; + auto contain = [&](dnnl::impl::primitive_kind_t kind) { return p.find(kind) != -1; }; + auto position = [&](dnnl::impl::primitive_kind_t kind) { return p.find(kind); }; + auto count = [&](dnnl::impl::primitive_kind_t kind) { return p.count(kind); }; + + return all_post_ops_supported() && + count(primitive_kind::sum) <= 1 && + IMPLICATION(contain(primitive_kind::sum), position(primitive_kind::sum) == 0); +} + +template +status_t jit_uni_fork_dw_conv_fwd_kernel::init_conf( + jit_conv_conf_t &jcp, const convolution_desc_t &cd, + memory_desc_t &src_md, memory_desc_t &weights_md, + memory_desc_t &bias_md, memory_desc_t &dst_md, + const primitive_attr_t &attr) { + + using namespace dnnl::impl::format_tag; + using namespace dnnl::impl::utils; + + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper weights_d(&weights_md); + const memory_desc_wrapper dst_d(&dst_md); + const memory_desc_wrapper bias_d(&bias_md); + + const int ndims = src_d.ndims(); + const auto blocked_tag = one_of(isa, avx512_common, avx512_core) ? + pick(ndims - 3, nCw16c, nChw16c, nCdhw16c) : + pick(ndims - 3, nCw8c, nChw8c, nCdhw8c); + const auto wei_tag = one_of(isa, avx512_common, avx512_core) ? + pick(ndims - 3, Goiw16g, Goihw16g, Goidhw16g) : + pick(ndims - 3, Goiw8g, Goihw8g, Goidhw8g); + const auto nxc_tag = pick(ndims - 3, nwc, nhwc, ndhwc); + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (src_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, blocked_tag)); + jcp.src_tag = blocked_tag; + } else { + jcp.src_tag = src_d.mb_stride_relaxed_match(blocked_tag, nxc_tag); + } + + if (weights_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(weights_md, wei_tag)); + jcp.wei_tag = wei_tag; + } else { + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + } + + if (dst_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(dst_md, blocked_tag)); + jcp.dst_tag = blocked_tag; + } else { + jcp.dst_tag = dst_d.mb_stride_relaxed_match(blocked_tag, nxc_tag); + } + + if (jcp.with_bias) { + if (bias_d.format_kind() == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md, format_tag::x)); + } + + if (jcp.dst_tag != jcp.src_tag) return status::unimplemented; + const auto data_tag = jcp.src_tag; + const bool is_data_layout_nxc = data_tag == nxc_tag; + + const bool is_bf16 = src_d.data_type() == data_type::bf16; + // 3D bf16 fork DW kernel does not support 3D convolution + if (is_bf16 && ndims == 5) return status::unimplemented; + + jcp.dst_dt = cd.dst_desc.data_type; + jcp.isa = (is_bf16 && mayiuse(avx512_core_bf16)) ? avx512_core_bf16 : isa; + + if (!mayiuse(isa) || (is_bf16 && !mayiuse(avx512_core))) + return status::unimplemented; + + const int simd_w = one_of(isa, avx512_common, avx512_core) ? 16 : 8; + + jcp.prop_kind = cd.prop_kind; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + if (!with_groups) return status::unimplemented; + + jcp.ndims = ndims; + + jcp.ngroups = weights_d.dims()[0]; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1]; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1]; + + jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2]; + jcp.ow = dst_d.dims()[ndims - 1]; + + jcp.kd = (ndims == 5) ? weights_d.dims()[3] : 1; + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[ndims - 1]; + jcp.kw = weights_d.dims()[ndims]; + + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4]; + jcp.l_pad = cd.padding[0][ndims - 3]; + jcp.back_pad = (ndims == 5) ? cd.padding[1][0] : 0; + jcp.b_pad = (ndims == 3) ? 0 : cd.padding[1][ndims - 4]; + jcp.r_pad = cd.padding[1][ndims - 3]; + + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4]; + jcp.stride_w = cd.strides[ndims - 3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4]; + jcp.dilate_w = cd.dilates[ndims - 3]; + + jcp.loop_order = loop_ngcw; + + if (is_data_layout_nxc) { + jcp.loop_order = loop_nhwcg; + } + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + jcp.post_ops = p; + + bool ok_to_pad_channels = true + && !is_data_layout_nxc + && jcp.oc == jcp.ngroups + && jcp.ic == jcp.ngroups + && one_of(isa, avx512_common, avx512_core, avx2); + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.oc, simd_w); + jcp.ngroups = rnd_up(jcp.ngroups, simd_w); + } + + bool args_ok = true && jcp.oc == jcp.ngroups && jcp.ic == jcp.ngroups + && IMPLICATION(!is_data_layout_nxc, jcp.ngroups % simd_w == 0) + && jcp.wei_tag == wei_tag + && data_tag != format_tag::undef && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= dst_d.padded_dims()[1] + && jcp.ngroups <= weights_d.padded_dims()[0]; + if (!args_ok) return status::unimplemented; + + jcp.typesize_out = jcp.dst_dt == data_type::bf16 ? sizeof(bfloat16_t) + : sizeof(float); + jcp.typesize_in = src_d.data_type() == data_type::bf16 + ? sizeof(bfloat16_t) + : sizeof(float); + + jcp.ur_w = is_bf16 ? (isa_has_bf16(jcp.isa) ? 6 : 4) + : isa == avx512_common ? 6 : isa == avx2 ? 4 : 3; + + jcp.ch_block = simd_w; + jcp.nb_ch = div_up(jcp.oc, jcp.ch_block); + jcp.nb_ch_blocking + = one_of(isa, avx512_common, avx512_core) ? 4 : isa == avx2 ? 3 : 2; + if (jcp.nb_ch < jcp.nb_ch_blocking) + jcp.nb_ch_blocking = jcp.nb_ch; + + jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; + + return status::success; +} + +template +void jit_uni_fork_dw_conv_fwd_kernel::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + using namespace dnnl::impl::memory_tracking::names; + if (jcp.bia_dt == data_type::bf16) + scratchpad.book(key_conv_bias_bf16_convert_wsp, jcp.oc); + else if (jcp.with_bias && jcp.oc_without_padding != jcp.oc) + scratchpad.book(key_conv_padded_bias, jcp.oc); +} + +template struct jit_uni_fork_dw_conv_fwd_kernel; +template struct jit_uni_fork_dw_conv_fwd_kernel; +template struct jit_uni_fork_dw_conv_fwd_kernel; +template struct jit_uni_fork_dw_conv_fwd_kernel; + +template +struct jit_uni_fork_dw_conv_bwd_data_kernel { + + jit_uni_fork_dw_conv_bwd_data_kernel(const jit_conv_conf_t &ajcp, const primitive_attr_t &attr) + : ker_(nullptr) { + ker_ = new jit_kernel_t(ajcp, attr); + } + + status_t create_kernel() { return ker_->create_kernel(); } + ~jit_uni_fork_dw_conv_bwd_data_kernel() { delete ker_; } + + static bool post_ops_ok(const primitive_attr_t &attr); + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d, const primitive_attr_t &attr); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + void operator()(const jit_conv_call_s *p) const { (*ker_)(p); } + +private: + using jit_kernel_t = typename utils::conditional>::type; + jit_kernel_t *ker_; + + DNNL_DISALLOW_COPY_AND_ASSIGN(jit_uni_fork_dw_conv_bwd_data_kernel); +}; + +template +bool jit_uni_fork_dw_conv_bwd_data_kernel::post_ops_ok(const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + if (p.len() > 1) + return false; + + auto all_post_ops_supported = [&]() { + bool ok = true; + + for (int i = 0; i < p.len(); i++) { + ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::depthwise); + } + return ok; + }; + + return all_post_ops_supported(); +} + +template +status_t jit_uni_fork_dw_conv_bwd_data_kernel::init_conf( + jit_conv_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d, const primitive_attr_t &attr) { + using namespace dnnl::impl::format_tag; + using namespace dnnl::impl::utils; + + jcp.dsrc_dt = cd.diff_src_desc.data_type; + const bool is_bf16 = diff_dst_d.data_type() == data_type::bf16; + jcp.isa = (is_bf16 && mayiuse(avx512_core_bf16)) ? avx512_core_bf16 : isa; + + if (!mayiuse(isa) || (is_bf16 && !mayiuse(avx512_core))) + return status::unimplemented; + + const int simd_w = one_of(isa, avx512_common, avx512_core) ? 16 : 8; + + const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; + if (!with_groups) return status::unimplemented; + + jcp.ngroups = weights_d.dims()[0]; + jcp.mb = diff_src_d.dims()[0]; + + jcp.oc = diff_dst_d.dims()[1]; + jcp.oc_without_padding = jcp.oc; + jcp.ic = diff_src_d.dims()[1]; + + jcp.ih = diff_src_d.dims()[2]; + jcp.iw = diff_src_d.dims()[3]; + jcp.oh = diff_dst_d.dims()[2]; + jcp.ow = diff_dst_d.dims()[3]; + + jcp.kh = weights_d.dims()[3]; + jcp.kw = weights_d.dims()[4]; + + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.b_pad = cd.padding[1][0]; + jcp.r_pad = cd.padding[1][1]; + + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + + if (!post_ops_ok(attr)) + return status::unimplemented; + + jcp.post_ops = attr.post_ops_; + + bool ok_to_pad_channels = true && jcp.oc == jcp.ngroups + && jcp.ic == jcp.ngroups + && one_of(isa, avx512_common, avx512_core, avx2); + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.oc, simd_w); + jcp.ngroups = rnd_up(jcp.ngroups, simd_w); + } + + auto dat_tag = one_of(isa, avx512_common, avx512_core) ? nChw16c : nChw8c; + auto wei_tag = one_of(isa, avx512_common, avx512_core) ? Goihw16g : Goihw8g; + + jcp.src_tag = diff_src_d.mb_stride_relaxed_match(dat_tag); + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = diff_dst_d.mb_stride_relaxed_match(dat_tag); + + bool args_ok = true && jcp.oc == jcp.ngroups && jcp.ic == jcp.ngroups + && jcp.ngroups % simd_w == 0 && jcp.dilate_h == 0 + && jcp.dilate_w == 0 && jcp.src_tag == dat_tag + && jcp.wei_tag == wei_tag && jcp.dst_tag == dat_tag + && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1 + && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1 + && jcp.ic <= diff_src_d.padded_dims()[1] + && jcp.oc <= diff_dst_d.padded_dims()[1] + && jcp.ngroups <= weights_d.padded_dims()[0]; + if (!args_ok) return status::unimplemented; + + jcp.typesize_out = types::data_type_size(diff_src_d.data_type()); + jcp.typesize_in = types::data_type_size(diff_dst_d.data_type()); + + jcp.ur_w = is_bf16 ? (isa_has_bf16(jcp.isa) ? 6 : 4) + : isa == avx512_common ? 6 : isa == avx2 ? 4 : 3; + + jcp.ch_block = simd_w; + jcp.nb_ch = jcp.ic / jcp.ch_block; + jcp.nb_ch_blocking + = one_of(isa, avx512_common, avx512_core) ? 4 : isa == avx2 ? 3 : 2; + if (jcp.nb_ch < jcp.nb_ch_blocking) jcp.nb_ch_blocking = jcp.nb_ch; + + return status::success; +} + +template +void jit_uni_fork_dw_conv_bwd_data_kernel::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + UNUSED(scratchpad); + UNUSED(jcp); +} + +template struct jit_uni_fork_dw_conv_bwd_data_kernel; +template struct jit_uni_fork_dw_conv_bwd_data_kernel; +template struct jit_uni_fork_dw_conv_bwd_data_kernel; +template struct jit_uni_fork_dw_conv_bwd_data_kernel; + +} // namespace x64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif /* CPU_X64_JIT_uni_fork_dw_CONV_KERNEL_UTILS_HPP */ diff --git a/src/cpu/x64/jit_uni_fork_dw_convolution.cpp b/src/cpu/x64/jit_uni_fork_dw_convolution.cpp new file mode 100644 index 00000000000..6daafaccb93 --- /dev/null +++ b/src/cpu/x64/jit_uni_fork_dw_convolution.cpp @@ -0,0 +1,347 @@ +/******************************************************************************* +* Copyright 2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "common/c_types_map.hpp" +#include "common/dnnl_thread.hpp" +#include "common/memory_tracking.hpp" + +#include "common/bfloat16.hpp" + +#include "jit_uni_fork_dw_convolution.hpp" +#include "cpu/x64/injectors/jit_uni_binary_injector.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +using namespace dnnl::impl::status; +using namespace dnnl::impl::memory_tracking::names; +using namespace dnnl::impl::utils; + +template +void jit_uni_fork_dw_convolution_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); + auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST); + const auto &jcp = pd()->jcp_; + const auto post_ops_binary_rhs_arg_vec + = binary_injector::prepare_binary_args(jcp.post_ops, ctx); + + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + f32_data_t *bias = nullptr; + if (pd()->desc()->bias_desc.data_type == data_type::bf16) { + auto bias_in = CTX_IN_MEM(const bf16_data_t *, DNNL_ARG_BIAS); + bias = ctx.get_scratchpad_grantor().template get( + key_conv_bias_bf16_convert_wsp); + cvt_bfloat16_to_float(bias, bias_in, jcp.oc_without_padding); + utils::array_set(bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + } else { + auto bias_in = CTX_IN_MEM(const f32_data_t *, DNNL_ARG_BIAS); + if (pd()->wants_padded_bias()) { + auto padded_bias + = ctx.get_scratchpad_grantor().template get( + key_conv_padded_bias); + utils::array_copy(padded_bias, bias_in, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bias = padded_bias; + } else + bias = const_cast (bias_in); + } + + int dil_d = jcp.dilate_d + 1; + int dil_h = jcp.dilate_h + 1; + int dil_w = jcp.dilate_w + 1; + int str_d = jcp.stride_d; + int str_h = jcp.stride_h; + int str_w = jcp.stride_w; + + const auto is_src_layout_nxc = one_of(jcp.src_tag, format_tag::nhwc, format_tag::ndhwc); + const auto is_dst_layout_nxc = one_of(jcp.dst_tag, format_tag::nhwc, format_tag::ndhwc); + + auto kernel_params = [&](int ur_w_step, int ow, int oh, int od, int ih, int id, int kh, int kd, + int kh_padding, int kd_padding, int ch, int ch_step, int n, int work_rem) { + auto par_conv = jit_conv_call_s(); + + const int i_l_overflow = nstl::max(0, (jcp.l_pad - ow * str_w)); + const int i_r_overflow = nstl::max(jcp.iw, (ow * str_w + + (jcp.kw - 1)*dil_w - jcp.l_pad + 1)) - jcp.iw; + + const int iw = nstl::max((ow*str_w - jcp.l_pad + + div_up(i_l_overflow, dil_w)*dil_w), 0); + const int kw = div_up(i_l_overflow, dil_w); + + const int kw_padding = jcp.kw - div_up(i_l_overflow, dil_w) + - div_up(i_r_overflow, dil_w); + + const auto ic_off_idx = is_src_layout_nxc ? ch * jcp.ch_block : ch; + const auto oc_off_idx = is_dst_layout_nxc ? ch * jcp.ch_block : ch; + + size_t src_off = (jcp.ndims == 3) ? src_d.blk_off(n, ic_off_idx, iw) : + (jcp.ndims == 4) ? src_d.blk_off(n, ic_off_idx, ih, iw) : src_d.blk_off(n, ic_off_idx, id, ih, iw); + size_t dst_off = (jcp.ndims == 3) ? dst_d.blk_off(n, oc_off_idx, ow) : + (jcp.ndims == 4) ? dst_d.blk_off(n, oc_off_idx, oh, ow) : dst_d.blk_off(n, oc_off_idx, od, oh, ow); + size_t wei_off = (jcp.ndims == 3) ? weights_d.blk_off(ch, 0, 0, kw) : + (jcp.ndims == 4) ? weights_d.blk_off(ch, 0, 0, kh, kw) : weights_d.blk_off(ch, 0, 0, kd, kh, kw); + + par_conv.src = &src[src_off]; + par_conv.dst = &dst[dst_off]; + par_conv.filt = &weights[wei_off]; + + if (bias) par_conv.bias = &bias[bias_d.blk_off(ch*jcp.ch_block)]; + + par_conv.kd_padding = (size_t)nstl::max(0, kd_padding); + par_conv.kh_padding = (size_t)nstl::max(0, kh_padding); + par_conv.kw_padding = (size_t)nstl::max(0, kw_padding); + + par_conv.ur_w = (size_t)ur_w_step; + + assert(IMPLICATION( + jcp.loop_order == loop_nhwcg, is_src_layout_nxc)); + // For is_src_layout_nxc maximize jit work along contiguous dim. + par_conv.load_work = utils::this_block_size(ch * jcp.ch_block, + jcp.oc_without_padding, + (is_src_layout_nxc ? work_rem * ch_step : ch_step) + * jcp.ch_block); + par_conv.oc_off = ch * jcp.ch_block * sizeof(float); + par_conv.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); + + return par_conv; + }; + + const int ch_step = jcp.nb_ch_blocking; + const int chb_work = utils::div_up(jcp.nb_ch, ch_step); + + const int work_amount = MB * chb_work * jcp.od * jcp.oh; + const auto nthr = jcp.nthr; + + parallel(nthr, [&](const int ithr, const int nthr) { + int start {0}, end {0}; + balance211(work_amount, nthr, ithr, start, end); + + int n {0}, chb {0}, od {0}, oh {0}; + if (jcp.loop_order == loop_ngcw) + utils::nd_iterator_init( + start, n, MB, chb, chb_work, od, jcp.od, oh, jcp.oh); + else if (jcp.loop_order == loop_nhwcg) + utils::nd_iterator_init( + start, n, MB, od, jcp.od, oh, jcp.oh, chb, chb_work); + else + assert(!"unsupported loop order"); + + auto iwork = start; + while (iwork < end) { + int ch = chb * ch_step; + + const int i_front_overflow = nstl::max(0, (int) (jcp.f_pad - od * str_d)); + const int i_back_overflow = nstl::max(jcp.id, + (int) (od * str_d + (jcp.kd - 1) * dil_d - jcp.f_pad + 1)) - jcp.id; + + const int i_t_overflow = nstl::max(0, (int) (jcp.t_pad - oh * str_h)); + const int i_b_overflow = nstl::max(jcp.ih, + (int) (oh * str_h + (jcp.kh - 1) * dil_h - jcp.t_pad + 1)) - jcp.ih; + + const int id = nstl::max((int) (od * str_d - jcp.f_pad + + div_up(i_front_overflow, dil_d) * dil_d), 0); + const int kd = div_up(i_front_overflow, dil_d); + const int kd_padding = jcp.kd - div_up(i_front_overflow, dil_d) + - div_up(i_back_overflow, dil_d); + + const int ih = nstl::max((int) (oh * str_h - jcp.t_pad + + div_up(i_t_overflow, dil_h) * dil_h), 0); + const int kh = div_up(i_t_overflow, dil_h); + const int kh_padding = jcp.kh - div_up(i_t_overflow, dil_h) + - div_up(i_b_overflow, dil_h); + + // left border + int ow = 0; + int l_border = nstl::min(div_up(jcp.l_pad, str_w), jcp.ow); + int ur_w_step = 1; + for (; ow < l_border; ow++) { + jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, od, ih, id, + kh, kd, kh_padding, kd_padding, ch, ch_step, n, end - iwork); + + (*kernel_)(&par_conv); + } + + // main loop + ur_w_step = (jcp.iw - (jcp.kw - 1) * dil_w + jcp.l_pad - 1) + / jcp.stride_w - ow + 1; + if (ur_w_step > 0) { + jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, od, ih, id, + kh, kd, kh_padding, kd_padding, ch, ch_step, n, end - iwork); + + (*kernel_)(&par_conv); + + ow += ur_w_step; + } + + // right border + ur_w_step = 1; + for (; ow < jcp.ow; ow++) { + jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, od, ih, id, + kh, kd, kh_padding, kd_padding, ch, ch_step, n, end - iwork); + + (*kernel_)(&par_conv); + } + + if (jcp.loop_order == loop_ngcw) { + ++iwork; + utils::nd_iterator_step(n, MB, chb, chb_work, od, jcp.od, oh, jcp.oh); + } else if (jcp.loop_order == loop_nhwcg) { + utils::nd_iterator_jump( + iwork, end, n, MB, od, jcp.od, oh, jcp.oh, chb, chb_work); + } else + assert(!"unsupported loop order"); + } + }); + + if (pd()->wants_zero_pad_dst()) + ctx.zero_pad_output(DNNL_ARG_DST); +} + +template struct jit_uni_fork_dw_convolution_fwd_t; +template struct jit_uni_fork_dw_convolution_fwd_t; +template struct jit_uni_fork_dw_convolution_fwd_t; +template struct jit_uni_fork_dw_convolution_fwd_t; +template struct jit_uni_fork_dw_convolution_fwd_t; + +template +void jit_uni_fork_dw_convolution_bwd_data_t + ::execute_backward_data(const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC); + + const auto &jcp = pd()->jcp_; + const auto post_ops_binary_rhs_arg_vec + = binary_injector::prepare_binary_args(jcp.post_ops, ctx); + + auto MB = CTX_IN_BATCH(DNNL_ARG_DIFF_DST); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + auto kernel_params = [&](int ur_str_w, int iw, int oh, int ih, + int i_t_overflow, int i_b_overflow, int stride_off_h, + int ch, int ch_num, int n) { + auto par_conv = jit_conv_call_s(); + + const int i_l_overflow = nstl::max(0, (jcp.kw - 1 - iw - jcp.l_pad)); + const int i_r_overflow = nstl::max(0, (jcp.kw - 1 - (jcp.iw - 1 - iw) + - jcp.r_pad)); + + int ow = iw + jcp.l_pad - i_r_overflow; + int stride_off_w = ow % jcp.stride_w; + ow /= jcp.stride_w; + + par_conv.src = &diff_src[diff_src_d.blk_off(n, ch, ih, iw)]; + par_conv.dst = &diff_dst[diff_dst_d.blk_off(n, ch, oh, ow)]; + par_conv.filt = &weights[weights_d.blk_off(ch, 0, 0, i_b_overflow + + stride_off_h, i_r_overflow + stride_off_w)]; + + par_conv.kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow + - stride_off_h); + par_conv.kw_padding = nstl::max(0, jcp.kw - i_l_overflow - i_r_overflow + - stride_off_w); + + par_conv.ur_str_w = ur_str_w; + + par_conv.ch_blocks = nstl::min(ch + ch_num, jcp.nb_ch) - ch; + par_conv.ic_off = ch * jcp.ch_block * sizeof(float); + par_conv.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); + + return par_conv; + }; + + const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking); + parallel_nd(MB, chb_work, jcp.ih, + [&](int n, int chb, int ih) { + int ch = chb * jcp.nb_ch_blocking; + int ch_num = jcp.nb_ch_blocking; + + const int i_t_overflow = nstl::max(0, (int)(jcp.kh - 1 - ih + - jcp.t_pad)); + const int i_b_overflow = nstl::max(0, (int)(jcp.kh - 1 + - (jcp.ih - 1 - ih) - jcp.b_pad)); + + int oh = ih + jcp.t_pad - i_b_overflow; + int stride_off_h = oh % jcp.stride_h; + oh /= jcp.stride_h; + + for (int i_str_w = 0; i_str_w < jcp.stride_w; i_str_w++) { + // left border + int iw = i_str_w; + int l_border = nstl::min(jcp.kw - 1 - jcp.l_pad, jcp.iw); + int ur_str_w = 1; + for (; iw < l_border; iw += jcp.stride_w) { + jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh, + ih, i_t_overflow, i_b_overflow, + stride_off_h, ch, ch_num, n); + + (*kernel_)(&par_conv); + } + + // main loop + ur_str_w = nstl::min((jcp.iw - jcp.kw + jcp.r_pad - iw) + / jcp.stride_w, jcp.iw); + if (ur_str_w > 0) { + jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh, + ih, i_t_overflow, i_b_overflow, + stride_off_h, ch, ch_num, n); + + (*kernel_)(&par_conv); + + iw += ur_str_w * jcp.stride_w; + } + + // right border + ur_str_w = 1; + for (; iw < jcp.iw; iw += jcp.stride_w) { + jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh, + ih, i_t_overflow, i_b_overflow, + stride_off_h, ch, ch_num, n); + + (*kernel_)(&par_conv); + } + } + }); +} + +template struct jit_uni_fork_dw_convolution_bwd_data_t; +template struct jit_uni_fork_dw_convolution_bwd_data_t; +template struct jit_uni_fork_dw_convolution_bwd_data_t; +template struct jit_uni_fork_dw_convolution_bwd_data_t; +template struct jit_uni_fork_dw_convolution_bwd_data_t; + +} +} +} +} diff --git a/src/cpu/x64/jit_uni_fork_dw_convolution.hpp b/src/cpu/x64/jit_uni_fork_dw_convolution.hpp new file mode 100644 index 00000000000..e8aa69384f3 --- /dev/null +++ b/src/cpu/x64/jit_uni_fork_dw_convolution.hpp @@ -0,0 +1,190 @@ +/******************************************************************************* +* Copyright 2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_X64_JIT_UNI_FORK_DW_CONVOLUTION_HPP +#define CPU_X64_JIT_UNI_FORK_DW_CONVOLUTION_HPP + + +#include "common/c_types_map.hpp" +#include "common/memory_tracking.hpp" +#include "common/primitive.hpp" +#include "cpu/cpu_convolution_pd.hpp" + +#include "jit_uni_fork_dw_conv_kernel_utils.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +template +struct jit_uni_fork_dw_convolution_fwd_t : public primitive_t { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), jcp_() {} + + DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_dw:", jcp_.isa, ""), + jit_uni_fork_dw_convolution_fwd_t); + + status_t init(engine_t *engine) { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, src_type, + data_type::undef, dst_type, data_type::f32) + && IMPLICATION(this->with_bias(), utils::one_of( + this->desc()->bias_desc.data_type, data_type::f32, + data_type::bf16)) + && attr()->has_default_values(primitive_attr_t::skip_mask_t::post_ops, dst_type) + && !has_zero_dim_memory(); + if (!ok) return status::unimplemented; + + status_t status = jit_uni_fork_dw_conv_fwd_kernel::init_conf(jcp_, + *desc(), src_md_, weights_md_, bias_md_, dst_md_, *attr()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_uni_fork_dw_conv_fwd_kernel::init_scratchpad( + scratchpad, jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + }; + + jit_uni_fork_dw_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} + + typedef typename prec_traits::type f32_data_t; + typedef typename prec_traits::type bf16_data_t; + typedef typename prec_traits::type data_t; + typedef typename prec_traits::type dst_data_t; + + status_t init(engine_t *engine) override { + CHECK(safe_ptr_assign(kernel_, + new jit_uni_fork_dw_conv_fwd_kernel(pd()->jcp_, *pd()->dst_md(0), *pd()->attr()))); + return kernel_->create_kernel(); + } + + status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + + std::unique_ptr> kernel_; +}; + +using jit_avx512_common_fork_dw_convolution_fwd_t = + jit_uni_fork_dw_convolution_fwd_t; +using jit_avx2_fork_dw_convolution_fwd_t = + jit_uni_fork_dw_convolution_fwd_t; +using jit_sse41_fork_dw_convolution_fwd_t = + jit_uni_fork_dw_convolution_fwd_t; + +template +struct jit_uni_fork_dw_convolution_bwd_data_t : public primitive_t { + struct pd_t : public cpu_convolution_bwd_data_pd_t { + pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd), jcp_() {} + + DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_dw:", jcp_.isa, ""), + jit_uni_fork_dw_convolution_bwd_data_t); + + status_t init(engine_t *engine) { + bool ok = true && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(diff_src_type, diff_dst_type, + data_type::undef, diff_dst_type, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + + if (!ok) return status::unimplemented; + + status_t status = jit_uni_fork_dw_conv_bwd_data_kernel::init_conf(jcp_, *desc(), *diff_src_md(), + *weights_md(), *diff_dst_md(), *attr()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_uni_fork_dw_conv_bwd_data_kernel::init_scratchpad(scratchpad, jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::one_of(isa, avx512_common, avx512_core) + ? nChw16c + : nChw8c; + auto wei_tag = utils::one_of(isa, avx512_common, avx512_core) + ? Goihw16g + : Goihw8g; + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + jit_uni_fork_dw_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {} + + typedef typename prec_traits::type diff_src_data_t; + typedef typename prec_traits::type diff_dst_data_t; + typedef typename prec_traits::type wei_data_t; + + status_t init(engine_t *engine) override { + CHECK(safe_ptr_assign(kernel_, + new jit_uni_fork_dw_conv_bwd_data_kernel( + pd()->jcp_, *pd()->attr()))); + return kernel_->create_kernel(); + } + + status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + + std::unique_ptr> + kernel_; +}; + +using jit_avx512_common_fork_dw_convolution_bwd_data_t = + jit_uni_fork_dw_convolution_bwd_data_t; +using jit_avx2_fork_dw_convolution_bwd_data_t = + jit_uni_fork_dw_convolution_bwd_data_t; +using jit_sse41_fork_dw_convolution_bwd_data_t = + jit_uni_fork_dw_convolution_bwd_data_t; + +} +} +} +} + +#endif diff --git a/src/cpu/x64/jit_uni_fork_softmax.cpp b/src/cpu/x64/jit_uni_fork_softmax.cpp new file mode 100644 index 00000000000..006da590c56 --- /dev/null +++ b/src/cpu/x64/jit_uni_fork_softmax.cpp @@ -0,0 +1,112 @@ +/******************************************************************************* +* Copyright 2019-2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "common/dnnl_thread.hpp" +#include "cpu/x64/jit_uni_fork_softmax.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +using namespace utils; + +template +jit_uni_fork_softmax_fwd_t::jit_uni_fork_softmax_fwd_t(const pd_t *apd) + : primitive_t(apd) {} + +template +status_t jit_uni_fork_softmax_fwd_t::execute(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const uint8_t*, DNNL_ARG_SRC); + auto dst = CTX_OUT_MEM(uint8_t*, DNNL_ARG_DST); + + const memory_desc_wrapper data_d(pd()->src_md()); + + const auto &jpp = pd()->jpp_; + + auto real_src_md = ctx.input(DNNL_ARG_SRC)->md(); + size_t outer_size = utils::array_product(real_src_md->dims, pd()->desc()->softmax_axis); + + size_t dim = jpp.channels * jpp.inner_size; + + if (jpp.inner_size > 1) { + const size_t work_amount = outer_size; + + auto ker = [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + + balance211(work_amount, nthr, ithr, start, end); + + size_t ou{0}; + nd_iterator_init(start, ou, outer_size); + + for (size_t iwork = start; iwork < end; ++iwork) { + auto args = jit_softmax_call_s(); + args.channels = jpp.channels; + args.work = jpp.inner_size; + size_t off = data_d.off_l(ou * dim); + args.src = src + off * jpp.dt_size; + args.dst = dst + off * jpp.dt_size; + + (*kernel_)(&args); + + nd_iterator_step(ou, outer_size); + } + }; + + parallel(0, ker); + } else { + int ou_blocks = div_up(outer_size, jpp.outer_block); + const size_t work_amount = ou_blocks; + + auto ker = [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + + balance211(work_amount, nthr, ithr, start, end); + + size_t oub{0}; + nd_iterator_init(start, oub, ou_blocks); + + for (size_t iwork = start; iwork < end; ++iwork) { + size_t work = nstl::min(jpp.outer_block, outer_size - oub * jpp.outer_block); + + auto args = jit_softmax_call_s(); + args.channels = jpp.channels; + args.work = work; + size_t off = data_d.off_l(oub * jpp.outer_block * dim); + args.src = src + off * jpp.dt_size; + args.dst = dst + off * jpp.dt_size; + + (*kernel_)(&args); + + nd_iterator_step(oub, ou_blocks); + } + }; + + parallel(0, ker); + } + + return status::success; +} + +template struct jit_uni_fork_softmax_fwd_t; +template struct jit_uni_fork_softmax_fwd_t; +template struct jit_uni_fork_softmax_fwd_t; + +} +} +} +} diff --git a/src/cpu/x64/jit_uni_fork_softmax.hpp b/src/cpu/x64/jit_uni_fork_softmax.hpp new file mode 100644 index 00000000000..f5a3419ca0c --- /dev/null +++ b/src/cpu/x64/jit_uni_fork_softmax.hpp @@ -0,0 +1,97 @@ +/******************************************************************************* +* Copyright 2019-2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_X64_JIT_UNI_FORK_SOFTMAX_HPP +#define CPU_X64_JIT_UNI_FORK_SOFTMAX_HPP + +#include +#include + +#include "common/c_types_map.hpp" +#include "common/primitive.hpp" +#include "common/type_helpers.hpp" +#include "common/utils.hpp" + +#include "cpu/cpu_softmax_pd.hpp" +#include "cpu/x64/cpu_isa_traits.hpp" +#include "cpu/x64/jit_uni_fork_softmax_kernel_f32.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +template +struct jit_uni_fork_softmax_fwd_t : public primitive_t { + struct pd_t : public cpu_softmax_fwd_pd_t { + using cpu_softmax_fwd_pd_t::cpu_softmax_fwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_fork_softmax_fwd_t); + + status_t init(engine_t *engine) { + const memory_desc_wrapper src_d(src_md()); + const memory_desc_wrapper dst_d(dst_md()); + auto data_type = src_d.data_type(); + + auto ndims = desc_.data_desc.ndims; + auto dims = desc_.data_desc.dims; + auto axis = desc_.softmax_axis; + + size_t inner_size = utils::array_product(dims + axis + 1, ndims - axis - 1); + + format_tag_t dat_tag = utils::pick(ndims - 3, format_tag::ncw, format_tag::nchw, format_tag::ncdhw); + + // TODO: disabled because of failed test (case: for axis == 0, batch == 2). Needs to be debugged. + if (ndims == 3) + return status::unimplemented; + + using namespace data_type; + bool ok = src_d == dst_d && mayiuse(isa) && is_fwd() + && !has_zero_dim_memory() + && utils::one_of(data_type, f32, bf16) + && attr()->has_default_values() + && src_d.is_dense(true) + && src_d.matches_one_of_tag(dat_tag) == dat_tag + && inner_size > 1; + if (!ok) return status::unimplemented; + + return jit_uni_fork_softmax_kernel_f32::init_conf(jpp_, desc_, src_md(), dst_md()); + } + jit_softmax_conf_t jpp_; + }; + + jit_uni_fork_softmax_fwd_t(const pd_t *apd); + + status_t init(engine_t *engine) override { + CHECK(safe_ptr_assign(kernel_, new jit_uni_fork_softmax_kernel_f32(pd()->jpp_))); + return kernel_->create_kernel(); + } + + status_t execute(const exec_ctx_t &ctx) const override; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + std::unique_ptr> kernel_; +}; + +} +} +} +} + +#endif diff --git a/src/cpu/x64/jit_uni_fork_softmax_kernel_f32.cpp b/src/cpu/x64/jit_uni_fork_softmax_kernel_f32.cpp new file mode 100644 index 00000000000..6ad89e7c673 --- /dev/null +++ b/src/cpu/x64/jit_uni_fork_softmax_kernel_f32.cpp @@ -0,0 +1,744 @@ +/******************************************************************************* +* Copyright 2019-2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "jit_uni_fork_softmax_kernel_f32.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +using namespace Xbyak; + +#define GET_OFF(field) offsetof(jit_softmax_call_s, field) + +template +jit_uni_fork_softmax_kernel_f32::jit_uni_fork_softmax_kernel_f32(jit_softmax_conf_t ajpp) : jpp(ajpp) { + if (jpp.dt == data_type::bf16 && !mayiuse(avx512_core_bf16)) { + bf16_emu_.reset(new bf16_emulation_t(this, bf16_emu_zmm_1, bf16_emu_zmm_2, bf16_emu_zmm_3, + bf16_emu_gpr, bf16_emu_zmm_4, bf16_emu_zmm_5)); + } +} + + +template +status_t jit_uni_fork_softmax_kernel_f32::init_conf(jit_softmax_conf_t &jpp, + const softmax_desc_t &pd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &dst_d) { + auto ndims = pd.data_desc.ndims; + auto dims = pd.data_desc.dims; + auto axis = pd.softmax_axis; + + jpp.dt = src_d.data_type(); + jpp.dt_size = src_d.data_type_size(); + + size_t nregs = cpu_isa_traits::n_vregs; + + if (jpp.dt == data_type::bf16) { + if (isa != avx512_common) { + return status::unimplemented; + } + else if (!mayiuse(avx512_core_bf16)) { + nregs -= 5; // reserved for the bf16 emulator + } + } + + size_t aux_simd_registers = 5; // 3 aux for exp + one + (-FTL_MAX) + size_t regs_for_one_unroll = 2; + size_t max_inner_unroll = (nregs - aux_simd_registers) / regs_for_one_unroll; + size_t max_channels_unroll = 4; + + jpp.outer_size = utils::array_product(dims, axis); + jpp.channels = dims[axis]; + jpp.inner_size = utils::array_product(dims + axis + 1, ndims - axis - 1); + + if (jpp.outer_size < 1 || jpp.channels < 1 || jpp.inner_size < 1) { + return status::unimplemented; + } + + jpp.ur_inner = max_inner_unroll; + jpp.ur_channel = nstl::min(max_channels_unroll, jpp.channels); + jpp.outer_block = 2 * cpu_isa_traits::vlen / sizeof(float); + + if (jpp.inner_size == 1) { + // limit max jit code size for dense case + if (jpp.channels > 128) { + return status::unimplemented; + } + + // ref implementation is faster for small work amount + if (jpp.channels * jpp.outer_size < 16) { + return status::unimplemented; + } + } + + return status::success; +} + +template +int jit_uni_fork_softmax_kernel_f32::id_vreg_max(int ur_inner) { + return 5+ur_inner; +} + +template +int jit_uni_fork_softmax_kernel_f32::id_vreg_denom(int ur_inner) { + return 5+jpp.ur_inner + ur_inner; +} + +template +int jit_uni_fork_softmax_kernel_f32::id_vreg_src(int ur_inner) { + return 5+2*jpp.ur_inner; +} + +template +auto jit_uni_fork_softmax_kernel_f32::vreg_max(int ur_inner) -> Vmm { + return Vmm(id_vreg_max(ur_inner)); +} + +template +auto jit_uni_fork_softmax_kernel_f32::vreg_denom(int ur_inner) -> Vmm { + return Vmm(id_vreg_denom(ur_inner)); +} + +template +auto jit_uni_fork_softmax_kernel_f32::vreg_src(int ur_inner) -> Vmm { + return Vmm(id_vreg_src(ur_inner)); +} + +template +void jit_uni_fork_softmax_kernel_f32::load_vector(Vmm vmm_src_, const Xbyak::Address &op) { + switch (jpp.dt) { + case data_type::f32: + uni_vmovups(vmm_src_, op); + break; + case data_type::bf16: + uni_vpmovzxwd(vmm_src_, op); + uni_vpslld(vmm_src_, vmm_src_, 16); + break; + default: + assert(!"unknown data type"); + } +} + +template +void jit_uni_fork_softmax_kernel_f32::load_scalar(Xmm xmm_src_, const Xbyak::Address &op) { + switch (jpp.dt) { + case data_type::f32: + movss(xmm_src_, op); + break; + case data_type::bf16: + pinsrw(xmm_src_, op, 0x0); + uni_vpslld(xmm_src_, xmm_src_, 16); + break; + default: + assert(!"unknown data type"); + } +} + +template +void jit_uni_fork_softmax_kernel_f32::store_vector(const Xbyak::Address &op, Vmm vmm_dst_) { + Ymm ymm_dst_ = Ymm(vmm_dst_.getIdx()); + switch (jpp.dt) { + case data_type::f32: + uni_vmovups(op, vmm_dst_); + break; + case data_type::bf16: + if (mayiuse(avx512_core_bf16)) + vcvtneps2bf16(ymm_dst_, vmm_dst_); + else + bf16_emu_->vcvtneps2bf16(ymm_dst_, Zmm(vmm_dst_.getIdx())); + vmovdqu16(op, ymm_dst_); + break; + default: + assert(!"unknown data type"); + } +} + +template +void jit_uni_fork_softmax_kernel_f32::store_scalar(const Xbyak::Address &op, Xmm xmm_dst_) { + switch (jpp.dt) { + case data_type::f32: + movss(op, xmm_dst_); + break; + case data_type::bf16: + if (mayiuse(avx512_core_bf16)) + vcvtneps2bf16(xmm_dst_, xmm_dst_); + else + bf16_emu_->vcvtneps2bf16(xmm_dst_, Ymm(xmm_dst_.getIdx())); + pextrw(op, xmm_dst_, 0x0); + break; + default: + assert(!"unknown dst_dt"); + } +} + +template +void jit_uni_fork_softmax_kernel_f32::prepare_table() { + const unsigned int cvals[] = { + 0x3f800000, // [0] 1.0f + 0x3f000000, // [1] 0.5f + 0x3fb8aa3b, // [2] log2ef = 1.44269502f + 0x3f317218, // [3] ln2f = 0.69314718f + 0x0000007f, // [4] 0x7f + // exp(x) polynom + 0x3f800001, // [5] p0 = 1.0000001f + 0x3efffe85, // [6] p2 = 0.4999887f + 0x3e2aaa3e, // [7] p3 = 0.16666505f + 0x3d2bb1b1, // [8] p4 = 0.041917507f + 0x3c091ec1, // [9] p5 = 0.008369149f + 0x42b0c0a5, //[10] max logf = 88.3762589f + 0xc1766666 //[11] min logf = -14.5f + }; + + align(64); + L(l_table); + for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) { + for (size_t d = 0; d < vlen / sizeof(float); ++d) { + dd(cvals[i]); + } + } +} + +template +void jit_uni_fork_softmax_kernel_f32::simd_expf(const Vmm &vmm_src) { + uni_vminps(vmm_src, vmm_src, ptr[imm_addr64 + 10 * vlen]); + uni_vmaxps(vmm_src, vmm_src, ptr[imm_addr64 + 11 * vlen]); + uni_vmovups(vmm_aux0, vmm_src); + //calculate exp(x) + // fx = x * log2ef + 0.5 + uni_vmulps(vmm_src, vmm_src, ptr[imm_addr64 + 2 * vlen]); + uni_vaddps(vmm_src, vmm_src, ptr[imm_addr64 + 1 * vlen]); + + // tmp = floorf(fx) + if (isa < avx512_common) { + uni_vroundps(vmm_aux1, vmm_src, _op_floor); + } else { + vcvtps2dq(vmm_aux1 | T_rd_sae, vmm_src); + vcvtdq2ps(vmm_aux1, vmm_aux1); + + vcmpps(k_mask_tmp, vmm_aux1, vmm_src, _cmp_gt_os); + vmovups(vmm_aux2 | k_mask_tmp | T_z, zword[imm_addr64 + 0 * vlen]); + + uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux2); + } + //keep fx for further computations + uni_vmovups(vmm_src, vmm_aux1); //vmm_src = fx + // compute 2^n + uni_vcvtps2dq(vmm_aux2, vmm_src); + uni_vpaddd(vmm_aux2, vmm_aux2, ptr[imm_addr64 + 4 * vlen]); + uni_vpslld(vmm_aux2, vmm_aux2, 23); //Vmm(6) = 2^-fx + + //x = x - fx * ln2 + uni_vfnmadd231ps(vmm_aux0, vmm_aux1, ptr[imm_addr64 + 3 * vlen]); + // y = p5 + uni_vmovups(vmm_src, ptr[imm_addr64 + 9 * vlen]); + // y = y * x + p4 + uni_vfmadd213ps(vmm_src, vmm_aux0, ptr[imm_addr64 + 8 * vlen]); + // y = y * x + p3 + uni_vfmadd213ps(vmm_src, vmm_aux0, ptr[imm_addr64 + 7 * vlen]); + // y = y * x + p2 + uni_vfmadd213ps(vmm_src, vmm_aux0, ptr[imm_addr64 + 6 * vlen]); + // y = y * x + p1 + uni_vfmadd213ps(vmm_src, vmm_aux0, vmm_one); + // y = y * x + p0 + uni_vfmadd213ps(vmm_src, vmm_aux0, ptr[imm_addr64 + 5 * vlen]); //exp(q) + // y = y * 2^n + uni_vmulps(vmm_src, vmm_src, vmm_aux2); +} + +template +void jit_uni_fork_softmax_kernel_f32::scalar_expf(const Xmm &xmm_src) { + minss(xmm_src, ptr[imm_addr64 + 10 * vlen]); + maxss(xmm_src, ptr[imm_addr64 + 11 * vlen]); + movups(xmm_aux0, xmm_src); + //calculate exp(x) + // fx = x * log2ef + 0.5 + mulss(xmm_src, ptr[imm_addr64 + 2 * vlen]); + addss(xmm_src, ptr[imm_addr64 + 1 * vlen]); + // tmp = floorf(fx) + roundss(xmm_aux1, xmm_src, _op_floor); + //keep fx for further computations + movups(xmm_src, xmm_aux1); //xmm_src = fx + // compute 2^n + cvtps2dq(xmm_aux2, xmm_src); + paddd(xmm_aux2, ptr[imm_addr64 + 4 * vlen]); + pslld(xmm_aux2, 23); //Xmm(6) = 2^-fx + + //calculation fx * ln2 + mulss(xmm_aux1, ptr[imm_addr64 + 3 * vlen]); + //x = x - fx * ln2 + subss(xmm_aux0, xmm_aux1); + // y = p5 + movups(xmm_src, ptr[imm_addr64 + 9 * vlen]); + // y = y * x + p4 + mulss(xmm_src, xmm_aux0); + addss(xmm_src, ptr[imm_addr64 + 8 * vlen]); + + // y = y * x + p3 + mulss(xmm_src, xmm_aux0); + addss(xmm_src, ptr[imm_addr64 + 7 * vlen]); + // y = y * x + p2 + mulss(xmm_src, xmm_aux0); + addss(xmm_src, ptr[imm_addr64 + 6 * vlen]); + + // y = y * x + p1 + mulss(xmm_src, xmm_aux0); + addss(xmm_src, xmm_one); + + // y = y * x + p0 + mulss(xmm_src, xmm_aux0); + addss(xmm_src, ptr[imm_addr64 + 5 * vlen]); //exp(q) + + // y = y * 2^n + mulps(xmm_src, xmm_aux2); +} + +template +void jit_uni_fork_softmax_kernel_f32::simd_loop_max(int ur_inner) { + Label loop_channel_blocks; + Label loop_channel_tail; + Label loop_channel_end; + + for (int i = 0; i < ur_inner; ++i) { + uni_vbroadcastss(vreg_max(i), xmm_float_min); + } + + mov(reg_ch_work, reg_channels); + mov(reg_src_ptr, reg_src_base_ptr); + + L(loop_channel_blocks); { + cmp(reg_ch_work, jpp.ur_channel); + jl(loop_channel_tail, T_NEAR); + + for (int i = 0; i < ur_inner; ++i) { + for (int c = 0; c < (int)jpp.ur_channel; ++c) { + load_vector(vreg_src(i), ptr[reg_src_ptr + (i*simd_w + c*jpp.inner_size) * jpp.dt_size]); + uni_vmaxps(vreg_max(i), vreg_max(i), vreg_src(i)); + } + } + + sub(reg_ch_work, jpp.ur_channel); + add(reg_src_ptr, jpp.ur_channel * jpp.inner_size * jpp.dt_size); + + jmp(loop_channel_blocks, T_NEAR); + } + + L(loop_channel_tail); { + cmp(reg_ch_work, 0); + jle(loop_channel_end, T_NEAR); + + for (int i = 0; i < ur_inner; ++i) { + load_vector(vreg_src(i), ptr[reg_src_ptr + i * simd_w * jpp.dt_size]); + uni_vmaxps(vreg_max(i), vreg_max(i), vreg_src(i)); + } + + add(reg_src_ptr, jpp.inner_size * jpp.dt_size); + + dec(reg_ch_work); + jmp(loop_channel_tail, T_NEAR); + } + + L(loop_channel_end); +} + +template +void jit_uni_fork_softmax_kernel_f32::simd_loop_exp(int ur_inner) { + Label loop_channel_blocks; + Label loop_channel_tail; + Label loop_channel_end; + + for (int i = 0; i < ur_inner; ++i) { + uni_vpxor(vreg_denom(i), vreg_denom(i), vreg_denom(i)); + } + + mov(reg_ch_work, reg_channels); + + mov(reg_src_ptr, reg_src_base_ptr); + mov(reg_dst_ptr, reg_dst_base_ptr); + + L(loop_channel_blocks); { + cmp(reg_ch_work, jpp.ur_channel); + jl(loop_channel_tail, T_NEAR); + + for (int i = 0; i < ur_inner; ++i) { + for (int c = 0; c < (int)jpp.ur_channel; ++c) { + load_vector(vreg_src(i), ptr[reg_src_ptr + (i*simd_w + c*jpp.inner_size) * jpp.dt_size]); + uni_vsubps(vreg_src(i),vreg_src(i), vreg_max(i)); + simd_expf(vreg_src(i)); + uni_vaddps(vreg_denom(i), vreg_denom(i), vreg_src(i)); + store_vector(ptr[reg_dst_ptr + (i*simd_w + c*jpp.inner_size) * jpp.dt_size], vreg_src(i)); + } + } + + sub(reg_ch_work, jpp.ur_channel); + add(reg_src_ptr, jpp.ur_channel * jpp.inner_size * jpp.dt_size); + add(reg_dst_ptr, jpp.ur_channel * jpp.inner_size * jpp.dt_size); + + jmp(loop_channel_blocks, T_NEAR); + } + + L(loop_channel_tail); { + cmp(reg_ch_work, 0); + jle(loop_channel_end, T_NEAR); + + for (int i = 0; i < ur_inner; ++i) { + load_vector(vreg_src(i), ptr[reg_src_ptr + i * simd_w * jpp.dt_size]); + uni_vsubps(vreg_src(i), vreg_src(i), vreg_max(i)); + simd_expf(vreg_src(i)); + uni_vaddps(vreg_denom(i), vreg_denom(i), vreg_src(i)); + store_vector(ptr[reg_dst_ptr + i * simd_w * jpp.dt_size], vreg_src(i)); + } + + add(reg_src_ptr, jpp.inner_size * jpp.dt_size); + add(reg_dst_ptr, jpp.inner_size * jpp.dt_size); + + dec(reg_ch_work); + jmp(loop_channel_tail, T_NEAR); + } + + L(loop_channel_end); +} + + +template +void jit_uni_fork_softmax_kernel_f32::simd_loop_div(int ur_inner) { + Label loop_channel_blocks; + Label loop_channel_tail; + Label loop_channel_end; + + for (int i = 0; i < ur_inner; ++i) { + if (isa == sse41) { + uni_vmovups(vmm_aux0, vmm_one); + uni_vdivps(vmm_aux0, vmm_aux0, vreg_denom(i)); + uni_vmovups(vreg_denom(i), vmm_aux0); + } else { + uni_vdivps(vreg_denom(i), vmm_one, vreg_denom(i)); + } + } + + mov(reg_ch_work, reg_channels); + + mov(reg_src_ptr, reg_src_base_ptr); + mov(reg_dst_ptr, reg_dst_base_ptr); + + L(loop_channel_blocks); { + cmp(reg_ch_work, jpp.ur_channel); + jl(loop_channel_tail, T_NEAR); + + for (int i = 0; i < ur_inner; ++i) { + for (int c = 0; c < (int)jpp.ur_channel; ++c) { + load_vector(vreg_src(i), ptr[reg_dst_ptr + (i*simd_w + c*jpp.inner_size) * jpp.dt_size]); + uni_vmulps(vreg_src(i), vreg_src(i), vreg_denom(i)); + store_vector(ptr[reg_dst_ptr + (i*simd_w + c*jpp.inner_size) * jpp.dt_size], vreg_src(i)); + } + } + + sub(reg_ch_work, jpp.ur_channel); + add(reg_src_ptr, jpp.ur_channel * jpp.inner_size * jpp.dt_size); + add(reg_dst_ptr, jpp.ur_channel * jpp.inner_size * jpp.dt_size); + + jmp(loop_channel_blocks, T_NEAR); + } + + L(loop_channel_tail); { + cmp(reg_ch_work, 0); + jle(loop_channel_end, T_NEAR); + + for (int i = 0; i < ur_inner; ++i) { + load_vector(vreg_src(i), ptr[reg_dst_ptr + i * simd_w * jpp.dt_size]); + uni_vmulps(vreg_src(i), vreg_src(i), vreg_denom(i)); + store_vector(ptr[reg_dst_ptr + i * simd_w * jpp.dt_size], vreg_src(i)); + } + + add(reg_src_ptr, jpp.inner_size * jpp.dt_size); + add(reg_dst_ptr, jpp.inner_size * jpp.dt_size); + + dec(reg_ch_work); + jmp(loop_channel_tail, T_NEAR); + } + + L(loop_channel_end); +} + +template +void jit_uni_fork_softmax_kernel_f32::scalar_loop_max() { + Label loop_channel_tail; + Label loop_channel_end; + + movups(xmm_max, xmm_float_min); + mov(reg_src_ptr, reg_src_base_ptr); + mov(reg_ch_work, reg_channels); + + L(loop_channel_tail); { + cmp(reg_ch_work, 0); + jle(loop_channel_end, T_NEAR); + + load_scalar(xmm_src, ptr[reg_src_ptr]); + maxss(xmm_max, xmm_src); + + add(reg_src_ptr, jpp.inner_size * jpp.dt_size); + + dec(reg_ch_work); + jmp(loop_channel_tail); + } + + L(loop_channel_end); +} + +template +void jit_uni_fork_softmax_kernel_f32::scalar_loop_exp() { + Label loop_channel_tail; + Label loop_channel_end; + + mov(reg_src_ptr, reg_src_base_ptr); + mov(reg_dst_ptr, reg_dst_base_ptr); + + mov(reg_ch_work, reg_channels); + + pxor(xmm_denom, xmm_denom); + + L(loop_channel_tail); { + cmp(reg_ch_work, 0); + jle(loop_channel_end, T_NEAR); + + load_scalar(xmm_src, ptr[reg_src_ptr]); + subss(xmm_src, xmm_max); + scalar_expf(xmm_src); + addss(xmm_denom, xmm_src); + store_scalar(ptr[reg_dst_ptr], xmm_src); + + add(reg_src_ptr, jpp.inner_size * jpp.dt_size); + add(reg_dst_ptr, jpp.inner_size * jpp.dt_size); + + dec(reg_ch_work); + jmp(loop_channel_tail); + } + + L(loop_channel_end); +} + +template +void jit_uni_fork_softmax_kernel_f32::scalar_loop_div() { + Label loop_channel_tail; + Label loop_channel_end; + + mov(reg_src_ptr, reg_src_base_ptr); + mov(reg_dst_ptr, reg_dst_base_ptr); + mov(reg_ch_work, reg_channels); + + L(loop_channel_tail); { + cmp(reg_ch_work, 0); + jle(loop_channel_end, T_NEAR); + + load_scalar(xmm_src, ptr[reg_dst_ptr]); + divss(xmm_src, xmm_denom); + store_scalar(ptr[reg_dst_ptr], xmm_src); + + add(reg_src_ptr, jpp.inner_size * jpp.dt_size); + add(reg_dst_ptr, jpp.inner_size * jpp.dt_size); + + dec(reg_ch_work); + jmp(loop_channel_tail); + } + + L(loop_channel_end); +} + +template +void jit_uni_fork_softmax_kernel_f32::dense_loop(int ou_block) { + for (int ou = 0; ou < ou_block; ou++) { + movups(xmm_max, xmm_float_min); + for (int ch = 0; ch < (int)jpp.channels; ch++) { + load_scalar(xmm_src, ptr[reg_src_base_ptr + (ou * jpp.channels + ch) * jpp.dt_size]); + maxss(xmm_max, xmm_src); + } + + for (int ch = 0; ch < (int)jpp.channels; ch++) { + load_scalar(xmm_src, ptr[reg_src_base_ptr + (ou * jpp.channels + ch) * jpp.dt_size]); + subss(xmm_src, xmm_max); + store_scalar(ptr[reg_dst_base_ptr + (ou * jpp.channels + ch) * jpp.dt_size], xmm_src); + } + } + + int full_work = ou_block * (int)jpp.channels; + int i = 0; + for (; i <= full_work - simd_w; i += simd_w) { + load_vector(vreg_src(0), ptr[reg_dst_base_ptr + i * jpp.dt_size]); + simd_expf(vreg_src(0)); + store_vector(ptr[reg_dst_base_ptr + i * jpp.dt_size], vreg_src(0)); + } + + for (; i < full_work; i++) { + load_scalar(xmm_src, ptr[reg_dst_base_ptr + i * jpp.dt_size]); + scalar_expf(xmm_src); + store_scalar(ptr[reg_dst_base_ptr + i * jpp.dt_size], xmm_src); + } + + for (int ou = 0; ou < ou_block; ou++) { + pxor(xmm_denom, xmm_denom); + for (int ch = 0; ch < (int)jpp.channels; ch++) { + load_scalar(xmm_src, ptr[reg_dst_base_ptr + (ou * jpp.channels + ch) * jpp.dt_size]); + addss(xmm_denom, xmm_src); + } + + movss(xmm_one, ptr[imm_addr64 + 0 * vlen]); + divss(xmm_one, xmm_denom); + movss(xmm_denom, xmm_one); + for (int ch = 0; ch < (int)jpp.channels; ch++) { + load_scalar(xmm_src, ptr[reg_dst_base_ptr + (ou * jpp.channels + ch) * jpp.dt_size]); + mulss(xmm_src, xmm_denom); + store_scalar(ptr[reg_dst_base_ptr + (ou * jpp.channels + ch) * jpp.dt_size], xmm_src); + } + } +} + +template +void jit_uni_fork_softmax_kernel_f32::generate() { + this->preamble(); + if (bf16_emu_) bf16_emu_->init_vcvtneps2bf16(); + + if (jpp.inner_size == 1) { + this->generate_dense(); + } else { + mov(reg_src_base_ptr, ptr[abi_param1 + GET_OFF(src)]); + mov(reg_dst_base_ptr, ptr[abi_param1 + GET_OFF(dst)]); + mov(reg_work_amount, ptr[abi_param1 + GET_OFF(work)]); + mov(reg_channels, ptr[abi_param1 + GET_OFF(channels)]); + + mov(reg_min, float2int(-FLT_MAX)); + movq(xmm_float_min, reg_min); + + mov(imm_addr64, jit_uni_fork_softmax_kernel_f32::l_table); + uni_vmovups(vmm_one, ptr[imm_addr64 + 0 * vlen]); + + cmp(reg_work_amount, jpp.ur_inner * simd_w); + jl(loop_simd, T_NEAR); + + L(loop_simd_unroll); + { + simd_loop_max(jpp.ur_inner); + simd_loop_exp(jpp.ur_inner); + simd_loop_div(jpp.ur_inner); + + add(reg_src_base_ptr, jpp.ur_inner * simd_w * jpp.dt_size); + add(reg_dst_base_ptr, jpp.ur_inner * simd_w * jpp.dt_size); + + sub(reg_work_amount, jpp.ur_inner * simd_w); + cmp(reg_work_amount, jpp.ur_inner * simd_w); + jge(loop_simd_unroll, T_NEAR); + } + + L(loop_simd); + { + cmp(reg_work_amount, simd_w); + jl(loop_scalar, T_NEAR); + + simd_loop_max(1); + simd_loop_exp(1); + simd_loop_div(1); + + add(reg_src_base_ptr, simd_w * jpp.dt_size); + add(reg_dst_base_ptr, simd_w * jpp.dt_size); + + sub(reg_work_amount, simd_w); + jmp(loop_simd, T_NEAR); + } + + L(loop_scalar); + { + cmp(reg_work_amount, 0); + jle(loop_end, T_NEAR); + + scalar_loop_max(); + scalar_loop_exp(); + scalar_loop_div(); + + add(reg_src_base_ptr, jpp.dt_size); + add(reg_dst_base_ptr, jpp.dt_size); + + dec(reg_work_amount); + jmp(loop_scalar, T_NEAR); + } + + L(loop_end); + + this->postamble(); + + prepare_table(); + } +} + +template +void jit_uni_fork_softmax_kernel_f32::generate_dense() { + mov(reg_src_base_ptr, ptr[abi_param1 + GET_OFF(src)]); + mov(reg_dst_base_ptr, ptr[abi_param1 + GET_OFF(dst)]); + mov(reg_work_amount, ptr[abi_param1 + GET_OFF(work)]); + + mov(reg_min, float2int(-FLT_MAX)); + movq(xmm_float_min, reg_min); + + mov(imm_addr64, jit_uni_fork_softmax_kernel_f32::l_table); + uni_vmovups(vmm_one, ptr[imm_addr64 + 0 * vlen]); + + int outer_tail = jpp.outer_size % jpp.outer_block; + Label ou_loop_tail_label; + Label ou_loop_tail_1_label; + Label ou_loop_exit_label; + + cmp(reg_work_amount, jpp.outer_block); + jne(ou_loop_tail_label, T_NEAR); + + dense_loop(jpp.outer_block); + + jmp(ou_loop_exit_label, T_NEAR); + + L(ou_loop_tail_label); + cmp(reg_work_amount, outer_tail); + jne(ou_loop_tail_1_label, T_NEAR); + + dense_loop(outer_tail); + + jmp(ou_loop_exit_label, T_NEAR); + + L(ou_loop_tail_1_label); { + cmp(reg_work_amount, 1); + jl(ou_loop_exit_label, T_NEAR); + + dense_loop(1); + + add(reg_src_base_ptr, jpp.dt_size * jpp.channels); + add(reg_dst_base_ptr, jpp.dt_size * jpp.channels); + dec(reg_work_amount); + + jmp(ou_loop_tail_1_label, T_NEAR); + } + + L(ou_loop_exit_label); + + this->postamble(); + + prepare_table(); +} + +template struct jit_uni_fork_softmax_kernel_f32; +template struct jit_uni_fork_softmax_kernel_f32; +template struct jit_uni_fork_softmax_kernel_f32; + +} +} +} +} diff --git a/src/cpu/x64/jit_uni_fork_softmax_kernel_f32.hpp b/src/cpu/x64/jit_uni_fork_softmax_kernel_f32.hpp new file mode 100644 index 00000000000..d6aaff5d8c4 --- /dev/null +++ b/src/cpu/x64/jit_uni_fork_softmax_kernel_f32.hpp @@ -0,0 +1,132 @@ +/******************************************************************************* +* Copyright 2019-2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_X64_JIT_UNI_FORK_SOFTMAX_KERNEL_F32_HPP +#define CPU_X64_JIT_UNI_FORK_SOFTMAX_KERNEL_F32_HPP + +#include +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "cpu/x64/jit_avx512_core_bf16cvt.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +using namespace Xbyak; + +template +struct jit_uni_fork_softmax_kernel_f32 : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_softmax_kernel_f32) + using Vmm = typename utils::conditional3::type; + + jit_uni_fork_softmax_kernel_f32(jit_softmax_conf_t ajpp); + + jit_softmax_conf_t jpp; + + static status_t init_conf(jit_softmax_conf_t &jpp, + const softmax_desc_t &pd, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &dst_d); + + void prepare_table(); + void simd_expf(const Vmm &vmm_src); + void scalar_expf(const Xmm &xmm_src); + + void simd_loop_max(int ur_inner); + void simd_loop_exp(int ur_inner); + void simd_loop_div(int ur_inner); + + void scalar_loop_max(); + void scalar_loop_exp(); + void scalar_loop_div(); + + void dense_loop(int ou_block); + void generate_dense(); +private: + const int simd_w = cpu_isa_traits::vlen / sizeof(float); + const int vlen = cpu_isa_traits::vlen; + + Reg64 reg_work_amount = rax; + Reg64 reg_src_base_ptr = rbx; + Reg64 reg_dst_base_ptr = rsi; + Reg64 reg_src_ptr = r8; + Reg64 reg_dst_ptr = r9; + Reg64 reg_channels = r12; + Reg64 reg_ch_work = r13; + Reg64 reg_min = rdx; + Reg64 imm_addr64 = r14; + Reg64 bf16_emu_gpr = r15; + + Vmm vmm_aux0 = Vmm(0); + Vmm vmm_aux1 = Vmm(1); + Vmm vmm_aux2 = Vmm(2); + Xmm xmm_aux0 = Xmm(0); + Xmm xmm_aux1 = Xmm(1); + Xmm xmm_aux2 = Xmm(2); + + Xmm xmm_float_min = Xmm(3); + Xmm xmm_one = Xmm(4); + Vmm vmm_one = Vmm(4); + + Xmm xmm_max = Xmm(5); + Xmm xmm_denom = Xmm(6); + Xmm xmm_src = Xmm(7); + + Zmm bf16_emu_zmm_1 = Zmm(27); + Zmm bf16_emu_zmm_2 = Zmm(28); + Zmm bf16_emu_zmm_3 = Zmm(29); + Zmm bf16_emu_zmm_4 = Zmm(30); + Zmm bf16_emu_zmm_5 = Zmm(31); + + Opmask k_mask_tmp = Opmask(2); + + unsigned char _cmp_gt_os = isa == avx512_common ? 14 : 6; + + int id_vreg_max(int ur_inner); + int id_vreg_denom(int ur_inner); + int id_vreg_src(int ur_inner); + + auto vreg_max(int ur_inner) -> Vmm; + auto vreg_denom(int ur_inner) -> Vmm; + auto vreg_src(int ur_inner) -> Vmm; + + void load_vector(Vmm vmm_src, const Xbyak::Address &op); + void load_scalar(Xmm xmm_src, const Xbyak::Address &op); + void store_vector(const Xbyak::Address &op, Vmm vmm_dst); + void store_scalar(const Xbyak::Address &op, Xmm xmm_dst); + + Label loop_simd_unroll; + Label loop_simd; + Label loop_scalar; + Label loop_end; + Label l_table; + + std::unique_ptr bf16_emu_; + + unsigned char _op_floor = 1; + + void generate() override; +}; + +} +} +} +} + +#endif diff --git a/src/cpu/x64/jit_uni_i8i8_pooling.cpp b/src/cpu/x64/jit_uni_i8i8_pooling.cpp index 3b9acaad0ba..bac70930b33 100644 --- a/src/cpu/x64/jit_uni_i8i8_pooling.cpp +++ b/src/cpu/x64/jit_uni_i8i8_pooling.cpp @@ -96,12 +96,17 @@ struct jit_uni_i8i8_pooling_fwd_ker_t : public jit_generator { Reg64 aux_reg_src_h = rax; Reg64 aux_reg_src_w = rbx; + Reg64 reg_store_tmp = r11; // shared with reg_kh_index and used only as tmp register for store on avx2 Reg64 reg_tmp = rdx; // only used during mask init and store Reg64 reg_src_safe_access = rbp; Reg64 reg_dst_safe_access = rsi; Reg64 reg_mask = r15; // only used during mask init + Reg64 reg_oc_off = reg_tmp; + Reg64 reg_d_weights = aux_reg_src_h; + Reg64 reg_d_bias = aux_reg_src_w; + Opmask k_cmp_mask = Opmask(7); Opmask mask(int idx) { return Opmask(6 - idx); } @@ -138,6 +143,9 @@ struct jit_uni_i8i8_pooling_fwd_ker_t : public jit_generator { std::unique_ptr> postops_injector_; + Vmm vmm_d_weights = vreg(3); + Vmm vmm_d_bias = vreg(4); + enum : int { max_vidx_base = utils::one_of(isa, sse41, avx2) ? 7 : 2 }; //"avg" pool uses more registers for unrolling. enum : int { avg_vidx_base = utils::one_of(isa, sse41, avx2) ? 4 : 2 }; @@ -262,10 +270,12 @@ struct jit_uni_i8i8_pooling_fwd_ker_t : public jit_generator { use_exact_tail_scalar_bcast}; const binary_injector::static_params_t bsp { reg_param, get_supported_bcast_strategies(), rhs_sp}; + quantization_injector::static_params_t qsp = + {vmm_d_weights.getIdx(), vmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique< injector::jit_uni_postops_injector_t>( - this, jpp.post_ops, bsp); + this, jpp.post_ops, bsp, qsp); } } }; @@ -658,18 +668,18 @@ void jit_uni_i8i8_pooling_fwd_ker_t::store_dst_avg_op( // Don't generate useless code if (masked && !msk) return; - const Vmm &vr_dst = vreg_dst_s32(jj, ll); + const Vmm &vr_dst = jpp.dst_dt == f32 ? vreg_dst_f32(jj, ll) : vreg_dst_s32(jj, ll); - if (jpp.src_dt == s32) { + if (jpp.dst_dt == s32 || jpp.dst_dt == f32) { if (masked) for (int i = 0; i < jpp.c_tail; i++) pextrd(ptr[reg_ptr_dst_i8 + offset + i * data_type_size(s32)], vr_dst, i); else movups(ptr[reg_ptr_dst_i8 + offset], vr_dst); - } else if (utils::one_of(jpp.src_dt, s8, u8)) { + } else if (utils::one_of(jpp.dst_dt, s8, u8)) { packssdw(vr_dst, vr_dst); - if (jpp.src_dt == s8) + if (jpp.dst_dt == s8) packsswb(vr_dst, vr_dst); else packuswb(vr_dst, vr_dst); @@ -728,8 +738,8 @@ void jit_uni_i8i8_pooling_fwd_ker_t::store_dst_avg_op( // maskmovdqu/vmaskmovdqu // with low 8-bytes mask throws exception if high 8-bytes belongs write-protected page. // NOTE: use indirect move via gpr to avoid transition penalty - vmovq(reg_tmp, Xmm(vr_dst.getIdx())); - movq(mmx_dst_i8, reg_tmp); + vmovq(reg_store_tmp, Xmm(vr_dst.getIdx())); + movq(mmx_dst_i8, reg_store_tmp); // mmx_full_msk - mask for all 8 bytes in zero-tail case // mmx_mask(ll) - ll-th mask of tail in non-zero-tail case @@ -770,6 +780,17 @@ void jit_uni_i8i8_pooling_fwd_ker_t::store_dst_avg_op( }; switch (jpp.dst_dt) { + case f32: + if (masked) { + if (sizeof_src_dt() != sizeof_dst_dt()) { + vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask_2, vreg_dst_f32(jj, ll)); + } else { + vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst_f32(jj, ll)); + } + } else { + vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst_f32(jj, ll)); + } + break; case s32: if (masked) { vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, @@ -791,11 +812,11 @@ void jit_uni_i8i8_pooling_fwd_ker_t::store_dst_avg_op( // Don't generate useless code if (masked && !msk) return; - const Vmm &vr_dst - = masked ? vreg_dst_s32(jj, ll) | mask(ll) : vreg_dst_s32(jj, ll); + const Vmm &vr_dst = jpp.dst_dt == f32 ? masked ? vreg_dst_f32(jj, ll) | mask(ll) : vreg_dst_f32(jj, ll) + : masked ? vreg_dst_s32(jj, ll) | mask(ll) : vreg_dst_s32(jj, ll); switch (jpp.dst_dt) { - case s32: vmovups(ptr[reg_ptr_dst_i8 + offset], vr_dst); break; + case f32: case s32: vmovups(ptr[reg_ptr_dst_i8 + offset], vr_dst); break; case s8: vpmovsdb(ptr[reg_ptr_dst_i8 + offset], vr_dst); break; case u8: vpmovusdb(ptr[reg_ptr_dst_i8 + offset], vr_dst); break; default: assert(!"unsupported dst data_type"); @@ -935,7 +956,7 @@ void jit_uni_i8i8_pooling_fwd_ker_t::compute_avg_step( int iw = jpp.iw; int c = jpp.c; - const int num_ll = data_type_size(avg_proc_dt) / data_type_size(jpp.src_dt); + const int num_ll = data_type_size(avg_proc_dt) / data_type_size(jpp.dst_dt); for (int jj = 0; jj < ur_c; jj++) { for (int ll = 0; ll < num_ll; ll++) { @@ -949,6 +970,9 @@ void jit_uni_i8i8_pooling_fwd_ker_t::compute_avg_step( } } + if (jpp.with_depthwise || jpp.with_quantization) + push(reg_oc_off); + mov(aux_reg_src_d, reg_ptr_src_i8); xor_(reg_kd_index, reg_kd_index); L(l_kd); @@ -988,6 +1012,11 @@ void jit_uni_i8i8_pooling_fwd_ker_t::compute_avg_step( jl(l_kd, T_NEAR); } + static constexpr int vlen_size_elem = cpu_isa_traits::vlen / sizeof(float); + + if (jpp.with_depthwise || jpp.with_quantization) + pop(reg_oc_off); + for (int jj = 0; jj < ur_c; jj++) { for (int ll = 0; ll < num_ll; ll++) { const bool masked = jj == ur_c - 1 && c_tail; @@ -999,6 +1028,15 @@ void jit_uni_i8i8_pooling_fwd_ker_t::compute_avg_step( uni_vfmadd132ps(reg_dst_f32, vreg_zeros, vreg_tmp); if (jpp.with_postops) { + std::map vmm_idx_off; + vmm_idx_off.insert({reg_dst_f32.getIdx(), (ll * vlen_size_elem + jj * vlen_size_elem) * sizeof(float)}); + depthwise_injector::dynamic_params_t ddp {vmm_d_weights.getIdx(), vmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + reg_oc_off, vmm_idx_off, this->rsp}; + quantization_injector::dynamic_params_t qdp {reg_oc_off, vmm_idx_off, jpp.dst_dt, this->rsp}; + + injector_utils::vmm_index_set_t vmm_idxs; + vmm_idxs.emplace(reg_dst_f32.getIdx()); + binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; if (jpp.with_binary) { rhs_arg_params.vmm_idx_to_out_reg.emplace( @@ -1010,16 +1048,19 @@ void jit_uni_i8i8_pooling_fwd_ker_t::compute_avg_step( rhs_arg_params.vmm_tail_idx_.emplace( reg_dst_f32.getIdx()); } - postops_injector_->compute_vector( - reg_dst_f32.getIdx(), rhs_arg_params); + postops_injector_->compute_vector_range( + vmm_idxs, rhs_arg_params, ddp, qdp); } - uni_vcvtps2dq(reg_dst_s32, reg_dst_f32); + if (jpp.dst_dt != f32) { + uni_vcvtps2dq(reg_dst_s32, reg_dst_f32); + } if (jpp.with_postops) if (jpp.dst_dt == u8) { uni_vpmaxsd(reg_dst_s32, reg_dst_s32, vreg_zeros); } + store_dst(jj, ll, c_tail); } } @@ -1048,12 +1089,17 @@ void jit_uni_i8i8_pooling_fwd_ker_t::compute_c_block() { int c_tail = jpp.c_tail; xor_(c_iter, c_iter); + if (jpp.with_quantization) + xor_(reg_oc_off, reg_oc_off); + if (c_steps > 0) { L(l_main_loop); { compute_step(ur_c, 0); add(reg_ptr_src_i8, ur_c * c_block * sizeof_src_dt()); add(reg_ptr_dst_i8, ur_c * c_block * sizeof_dst_dt()); + if (jpp.with_quantization) + add(reg_oc_off, ur_c*c_block*sizeof(float)); inc(c_iter); cmp(c_iter, c_steps); jl(l_main_loop, T_NEAR); @@ -1129,6 +1175,10 @@ void jit_uni_i8i8_pooling_fwd_ker_t::init_mask() { vpalignr(vreg_mask_2, vreg_mask_2, vreg_zeros, 32 - shift); } vextracti128(xreg_mask_2_hi, vreg_mask_2, 0x1); + + if (sizeof_src_dt() != sizeof_dst_dt()) { + vpmovsxbd(vreg_mask_2, vreg_mask); + } } // Need mask in MMX regs ? @@ -1234,6 +1284,9 @@ void jit_uni_i8i8_pooling_fwd_ker_t::generate() { mov(rcx, rdi); #endif + if (postops_injector_) + postops_injector_->push_post_ops_data_on_stack(reg_param, GET_OFF(post_ops_binary_rhs_arg_vec), reg_ptr_src_i8, reg_ptr_dst_i8); + #define READ_PARAM(reg, field) \ mov(reg, ptr[reg_param + offsetof(call_params_t, field)]) READ_PARAM(reg_ptr_src_i8, src_i8); @@ -1255,6 +1308,10 @@ void jit_uni_i8i8_pooling_fwd_ker_t::generate() { compute_c_block(); emms(); + + if (postops_injector_) + postops_injector_->reset_stack_pointer(); + postamble(); if (jpp.with_eltwise && postops_injector_) @@ -1317,7 +1374,7 @@ status_t jit_uni_i8i8_pooling_fwd_ker_t::init_conf( // isa == sse41 : 16 bytes -> 16 for s8/u8, 4 for s32 // isa == avx2 : 32 bytes -> 32 for s8/u8, 8 for s32 // isa == avx512* : 64 bytes -> 64 for s8/u8, 16 for s32 - int simd_w = cpu_isa_traits::vlen / data_type_size(jpp.src_dt); + int simd_w = cpu_isa_traits::vlen / data_type_size(jpp.dst_dt); /* Verify that vlen-sized memory access happens within the tensor's * size, otherwise load/store will always spill outside the memory @@ -1379,6 +1436,8 @@ bool jit_uni_i8i8_pooling_fwd_ker_t::post_ops_ok(jit_pool_conf_t &jpp, jpp.with_postops = false; jpp.with_eltwise = false; jpp.with_binary = false; + jpp.with_depthwise = false; + jpp.with_quantization = false; if (entries.empty()) return true; @@ -1391,11 +1450,16 @@ bool jit_uni_i8i8_pooling_fwd_ker_t::post_ops_ok(jit_pool_conf_t &jpp, && entry.binary.src1_desc.data_type == data_type::bf16) return false; jpp.with_binary = true; - } else + } else if (entry.is_depthwise()) { + jpp.with_depthwise = true; + } else if (entry.is_quantization()) { + jpp.with_quantization = true; + } else { return false; + } } - jpp.with_postops = jpp.with_eltwise || jpp.with_binary; + jpp.with_postops = jpp.with_eltwise || jpp.with_binary || jpp.with_depthwise || jpp.with_quantization; jpp.post_ops = post_ops; /* @@ -1433,6 +1497,8 @@ status_t jit_uni_i8i8_pooling_fwd_t::execute_forward( auto src_i8 = CTX_IN_MEM(const char *, DNNL_ARG_SRC); auto dst_i8 = CTX_OUT_MEM(char *, DNNL_ARG_DST); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const memory_desc_wrapper src_d(pd()->src_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); @@ -1449,7 +1515,7 @@ status_t jit_uni_i8i8_pooling_fwd_t::execute_forward( reinterpret_cast(dst_i8 + dst_d.size() - 1) - (cpu_isa_traits::vlen - 1)); - parallel_nd(jpp.mb, jpp.od, jpp.oh, jpp.ow, + parallel_nd(MB, jpp.od, jpp.oh, jpp.ow, [&](dim_t n, dim_t od, dim_t oh, dim_t ow) { dim_t id = nstl::max(od * jpp.stride_d - jpp.f_pad, dim_t(0)); dim_t ih = nstl::max(oh * jpp.stride_h - jpp.t_pad, dim_t(0)); diff --git a/src/cpu/x64/jit_uni_i8i8_pooling.hpp b/src/cpu/x64/jit_uni_i8i8_pooling.hpp index 8c11ca4974d..410fe9ff8de 100644 --- a/src/cpu/x64/jit_uni_i8i8_pooling.hpp +++ b/src/cpu/x64/jit_uni_i8i8_pooling.hpp @@ -51,10 +51,11 @@ struct jit_uni_i8i8_pooling_fwd_t : public primitive_t { alg_kind::pooling_avg_exclude_padding) && utils::one_of(src_md()->data_type, data_type::s32, data_type::s8, data_type::u8) - && src_md()->data_type == dst_md()->data_type && !is_dilated() && attr()->has_default_values( primitive_attr_t::skip_mask_t::post_ops) + && IMPLICATION(utils::one_of(desc()->alg_kind, alg_kind::pooling_avg_include_padding, alg_kind::pooling_avg_exclude_padding), + utils::one_of(dst_md()->data_type, data_type::u8, data_type::s8, data_type::f32)) && set_default_params() == status::success && memory_desc_matches_one_of_tag( *src_md(), nwc, nhwc, ndhwc) diff --git a/src/cpu/x64/jit_uni_planar_conv_kernel_f32.cpp b/src/cpu/x64/jit_uni_planar_conv_kernel_f32.cpp new file mode 100644 index 00000000000..84b77bd3925 --- /dev/null +++ b/src/cpu/x64/jit_uni_planar_conv_kernel_f32.cpp @@ -0,0 +1,804 @@ +/******************************************************************************* +* Copyright 2019-2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "common/c_types_map.hpp" +#include "common/memory.hpp" +#include "common/nstl.hpp" +#include "common/type_helpers.hpp" +#include "common/utils.hpp" + +#include "cpu/x64/jit_uni_planar_conv_kernel_f32.hpp" + +#define GET_OFF(field) offsetof(jit_conv_call_s, field) + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +using namespace dnnl::impl::prop_kind; +using namespace dnnl::impl::utils; + +using namespace Xbyak; + +template +void jit_uni_planar_conv_fwd_kernel_f32::load_src_scalar(int ur_h) { + Label init_done_label; + Label init_first_label; + + mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]); + if (jcp.with_bias) + mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]); + + if (!jcp.with_sum) { + test(reg_ci_flag, FLAG_IC_FIRST); + jne(init_first_label, T_NEAR); + } + + for (int kk = 0; kk < ur_h; kk++) { + size_t offt = sizeof(float) * (kk * jcp.ow * jcp.oh_block_step); + movss(Xmm(kk), make_safe_addr(reg_output, offt, reg_long_offt)); + } + + if (jcp.with_sum && jcp.with_bias) { + test(reg_ci_flag, FLAG_IC_FIRST); + je(init_done_label, T_NEAR); + + movss(xmm_tmp, make_safe_addr(reg_bias, 0, reg_long_offt)); + for (int kk = 0; kk < ur_h; kk++) { + uni_vaddps(Vmm(kk), Vmm(kk), vmm_tmp); + } + } + + jmp(init_done_label, T_NEAR); + + L(init_first_label); + if (this->jcp.with_bias) { + movss(xmm_tmp, make_safe_addr(reg_bias, 0, reg_long_offt)); + for (int kk = 0; kk < ur_h; kk++) { + uni_vmovups(Vmm(kk), vmm_tmp); + } + } else { + for (int kk = 0; kk < ur_h; kk++) { + uni_vpxor(Vmm(kk), Vmm(kk), Vmm(kk)); + } + } + + L(init_done_label); +} + +template +void jit_uni_planar_conv_fwd_kernel_f32::filter_scalar(int ur_h) { + Label iter_exit_label; + + int iw = jcp.iw; + int ih = jcp.ih; + int id = jcp.id; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int kw = jcp.kw; + int kh = jcp.kh; + int kd = jcp.kd; + + cmp(reg_kw, 0); + je(iter_exit_label, T_NEAR); + + mov(aux_reg_input_w, aux_reg_input_h); + mov(aux_reg_kernel_w, aux_reg_kernel_h); + mov(kw_iter, reg_kw); + + Label kw_label; + L(kw_label); + { + for (size_t ifm2 = 0; ifm2 < (size_t)ic_blk; ifm2++) { + for (int kk = 0; kk < ur_h; kk++) { + size_t inp_off = sizeof(float) * (ifm2 * id * ih * iw + kk * jcp.iw * jcp.oh_block_step); + movss(xmm_src, make_safe_addr(aux_reg_input_w, inp_off, reg_long_offt)); + + size_t ker_off = sizeof(float) * (ifm2 * kd * kh * kw); + movss(xmm_ker, ptr[aux_reg_kernel_w + ker_off]); + + uni_vfmadd231ps(Vmm(kk), vmm_src, vmm_ker); + } + } + + add(aux_reg_kernel_w, sizeof(float)); + add(aux_reg_input_w, dilate_w * sizeof(float)); + + dec(kw_iter); + cmp(kw_iter, 0); + jg(kw_label, T_NEAR); + } + + L(iter_exit_label); +} + +template +void jit_uni_planar_conv_fwd_kernel_f32::apply_filter_scalar(int ur_h) { + int iw = jcp.iw; + int kw = jcp.kw; + int dilate_h = jcp.dilate_h + 1; + int dilate_d = jcp.dilate_h + 1; + const int inp_mult_h = dilate_h; + const int inp_mult_d = dilate_d; + + Label skip_kh_loop, skip_kd_loop, kd_label; + if (jcp.ndims == 5) { + push(reg_kernel); + push(reg_output); + + mov(reg_kd, ptr[param1 + GET_OFF(kd_padding)]); + mov(aux_reg_ker_d, aux_reg_kernel_h); + mov(aux_reg_inp_d, aux_reg_input_h); + + cmp(reg_kd, 0); + je(skip_kd_loop, T_NEAR); + + L(kd_label); + mov(kh_iter, ptr[param1 + GET_OFF(kh_padding)]); + } else { + mov(kh_iter, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_input_h, aux_reg_inp_d); + mov(aux_reg_kernel_h, aux_reg_ker_d); + } + + cmp(kh_iter, 0); + je(skip_kh_loop, T_NEAR); + + Label kh_label; + L(kh_label); + { + filter_scalar(ur_h); + + add(aux_reg_kernel_h, sizeof(float) * kw); + add(aux_reg_input_h, sizeof(float) * iw * inp_mult_h); + + dec(kh_iter); + cmp(kh_iter, 0); + jg(kh_label, T_NEAR); + } + + L(skip_kh_loop); + + if (jcp.ndims == 5) { + add(aux_reg_ker_d, sizeof(float) * jcp.kw * jcp.kh); + add(aux_reg_inp_d, sizeof(float) * jcp.ih * jcp.iw * inp_mult_d); + + dec(reg_kd); + cmp(reg_kd, 0); + jg(kd_label, T_NEAR); + L(skip_kd_loop); + + pop(reg_output); + pop(reg_kernel); + } +} + +template +void jit_uni_planar_conv_fwd_kernel_f32::apply_postprocess_scalar(int ur_h) { + Label regular_store_label; + + mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]); + test(reg_ci_flag, FLAG_IC_LAST); + je(regular_store_label, T_NEAR); + + int eltwise_inj_idx = 0; + const auto &p = attr_.post_ops_; + + if (p.len() == 0 && eltwise_injectors.size() == 1) { + eltwise_injectors[0]->compute_vector_range(0, ur_h); + } + + for (int i = 0; i < p.len(); i++) { + auto& post_op = p.entry_[i]; + if (post_op.is_eltwise()) { + eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0, ur_h); + eltwise_inj_idx++; + } + } + + L(regular_store_label); +} + +template +void jit_uni_planar_conv_fwd_kernel_f32::store_dst_scalar(int ur_h) { + for (int kk = 0; kk < ur_h; kk++) { + size_t o_off = sizeof(float) * (kk * jcp.ow * jcp.oh_block_step); + movss(make_safe_addr(reg_output, o_off, reg_long_offt), Xmm(kk)); + } +} + +template +void jit_uni_planar_conv_fwd_kernel_f32::load_src(int ur_h, int ur_w) { + Label init_done_label; + Label init_first_label; + + mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]); + if (jcp.with_bias) + mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]); + + if (!jcp.with_sum) { + test(reg_ci_flag, FLAG_IC_FIRST); + jne(init_first_label, T_NEAR); + } + + for (int kk = 0; kk < ur_h; kk++) { + for (int jj = 0; jj < ur_w; jj++) { + size_t offt = sizeof(float) * (jj * jcp.ow_block + kk * jcp.ow * jcp.oh_block_step); + uni_vmovups(Vmm(kk * ur_w + jj), make_safe_addr(reg_output, offt, reg_long_offt)); + } + } + + if (jcp.with_sum && jcp.with_bias) { + test(reg_ci_flag, FLAG_IC_FIRST); + je(init_done_label, T_NEAR); + + uni_vbroadcastss(vmm_tmp, make_safe_addr(reg_bias, 0, reg_long_offt)); + for (int kk = 0; kk < ur_h; kk++) { + for (int jj = 0; jj < ur_w; jj++) { + uni_vaddps(Vmm(kk * ur_w + jj), Vmm(kk * ur_w + jj), vmm_tmp); + } + } + } + + jmp(init_done_label, T_NEAR); + + L(init_first_label); + if (this->jcp.with_bias) { + uni_vbroadcastss(vmm_tmp, make_safe_addr(reg_bias, 0, reg_long_offt)); + for (int kk = 0; kk < ur_h; kk++) { + for (int jj = 0; jj < ur_w; jj++) { + uni_vmovups(Vmm(kk * ur_w + jj), vmm_tmp); + } + } + } else { + for (int kk = 0; kk < ur_h; kk++) { + for (int jj = 0; jj < ur_w; jj++) { + uni_vpxor(Vmm(kk * ur_w + jj), Vmm(kk * ur_w + jj), Vmm(kk * ur_w + jj)); + } + } + } + + L(init_done_label); +} + +template +void jit_uni_planar_conv_fwd_kernel_f32::filter_unrolled(int ur_h, int ur_w) { + int iw = jcp.iw; + int ih = jcp.ih; + int id = jcp.id; + int stride_w = jcp.stride_w; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int kw = jcp.kw; + int kh = jcp.kh; + int kd = jcp.kd; + int ow_blk = jcp.ow_block; + + for (int ki = 0; ki < kw; ki++) { + for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) { + for (int kk = 0; kk < ur_h; kk++) { + for (int jj = 0; jj < ur_w; jj++) { + size_t inp_off = sizeof(float) * ((size_t) ifm2 * id * ih * iw + ki * dilate_w + + jj * stride_w * ow_blk + kk * jcp.ow * jcp.oh_block_step); + uni_vmovups(vmm_src, make_safe_addr(aux_reg_input_h, inp_off, reg_long_offt)); + + int ker_off = sizeof(float) * ((size_t) ifm2 * kd * kh * kw + ki); + uni_vbroadcastss(vmm_ker, ptr[aux_reg_kernel_h + ker_off]); + + uni_vfmadd231ps(Vmm(kk * ur_w + jj), vmm_src, vmm_ker); + } + } + } + } +} + +template +void jit_uni_planar_conv_fwd_kernel_f32::filter(int ur_h) { + Label iter_exit_label; + + int iw = jcp.iw; + int ih = jcp.ih; + int id = jcp.id; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int kw = jcp.kw; + int kh = jcp.kh; + int kd = jcp.kd; + + cmp(reg_kw, 0); + je(iter_exit_label, T_NEAR); + + mov(aux_reg_input_w, aux_reg_input_h); + mov(aux_reg_kernel_w, aux_reg_kernel_h); + mov(kw_iter, reg_kw); + + Label kw_label; + L(kw_label); + { + for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) { + for (int kk = 0; kk < ur_h; kk++) { + size_t inp_off = sizeof(float) * ((size_t) ifm2 * id * ih * iw + kk * jcp.ow * jcp.oh_block_step); + uni_vmovups(vmm_src, make_safe_addr(aux_reg_input_w, inp_off, reg_long_offt)); + + size_t ker_off = sizeof(float) * ((size_t) ifm2 * kd * kh * kw); + uni_vbroadcastss(vmm_ker, ptr[aux_reg_kernel_w + ker_off]); + + uni_vfmadd231ps(Vmm(kk), vmm_src, vmm_ker); + } + } + + add(aux_reg_kernel_w, sizeof(float)); + add(aux_reg_input_w, dilate_w * sizeof(float)); + + dec(kw_iter); + cmp(kw_iter, 0); + jg(kw_label, T_NEAR); + } + + L(iter_exit_label); +} + +template +void jit_uni_planar_conv_fwd_kernel_f32::apply_filter(int ur_h, int ur_w) { + int iw = jcp.iw; + int kw = jcp.kw; + int dilate_h = jcp.dilate_h + 1; + int dilate_d = jcp.dilate_h + 1; + const int inp_mult_h = dilate_h; + const int inp_mult_d = dilate_d; + + Label skip_kh_loop, skip_kd_loop, kd_label; + if (jcp.ndims == 5) { + push(reg_kernel); + push(reg_output); + + mov(reg_kd, ptr[param1 + GET_OFF(kd_padding)]); + mov(aux_reg_ker_d, aux_reg_kernel_h); + mov(aux_reg_inp_d, aux_reg_input_h); + + cmp(reg_kd, 0); + je(skip_kd_loop, T_NEAR); + + L(kd_label); + mov(kh_iter, ptr[param1 + GET_OFF(kh_padding)]); + } else { + mov(kh_iter, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_input_h, aux_reg_inp_d); + mov(aux_reg_kernel_h, aux_reg_ker_d); + } + + cmp(kh_iter, 0); + je(skip_kh_loop, T_NEAR); + + Label kh_label; + L(kh_label); + { + if (ur_w == jcp.nb_ow_blocking) + filter_unrolled(ur_h, ur_w); + else + filter(ur_h); + + add(aux_reg_kernel_h, sizeof(float) * kw); + add(aux_reg_input_h, sizeof(float) * iw * inp_mult_h); + + dec(kh_iter); + cmp(kh_iter, 0); + jg(kh_label, T_NEAR); + } + + L(skip_kh_loop); + + if (jcp.ndims == 5) { + add(aux_reg_ker_d, sizeof(float) * jcp.kw * jcp.kh); + add(aux_reg_inp_d, sizeof(float) * jcp.ih * jcp.iw * inp_mult_d); + + dec(reg_kd); + cmp(reg_kd, 0); + jg(kd_label, T_NEAR); + L(skip_kd_loop); + + pop(reg_output); + pop(reg_kernel); + } +} + +template +void jit_uni_planar_conv_fwd_kernel_f32::apply_postprocess(int ur_h, int ur_w) { + Label regular_store_label; + + mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]); + test(reg_ci_flag, FLAG_IC_LAST); + je(regular_store_label, T_NEAR); + + int eltwise_inj_idx = 0; + const auto &p = attr_.post_ops_; + + if (p.len() == 0 && eltwise_injectors.size() == 1) { + eltwise_injectors[0]->compute_vector_range(0, ur_w * ur_h); + } + + for (int i = 0; i < p.len(); i++) { + auto& post_op = p.entry_[i]; + if (post_op.is_eltwise()) { + eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0, ur_w * ur_h); + eltwise_inj_idx++; + } + } + + L(regular_store_label); +} + +template +void jit_uni_planar_conv_fwd_kernel_f32::store_dst(int ur_h, int ur_w) { + for (int kk = 0; kk < ur_h; kk++) { + for (int jj = 0; jj < ur_w; jj++) { + size_t o_off = sizeof(float) * (jj * jcp.ow_block + kk * jcp.ow * jcp.oh_block_step); + uni_vmovups(make_safe_addr(reg_output, o_off, reg_long_offt), Vmm(kk * ur_w + jj)); + } + } +} + +template +void jit_uni_planar_conv_fwd_kernel_f32::solve_common(int ur_h) { + auto solve_loop = [&](int ur_w, int step_w) { + Label loop_label; + Label exit_label; + + L(loop_label); + { + if (step_w == 1) { + load_src_scalar(ur_h); + apply_filter_scalar(ur_h); + apply_postprocess_scalar(ur_h); + store_dst_scalar(ur_h); + } else { + load_src(ur_h, ur_w); + apply_filter(ur_h, ur_w); + apply_postprocess(ur_h, ur_w); + store_dst(ur_h, ur_w); + } + + add(reg_input, sizeof(float) * step_w * jcp.stride_w); + add(reg_output, sizeof(float) * step_w); + } + + L(exit_label); + }; + + Label left_border_label; + Label main_loop_unrolled_label; + Label main_loop_label; + Label right_border_label; + Label exit_label; + + xor_(reg_ow, reg_ow); + sub(reg_input, sizeof(float) * jcp.l_pad); + + auto adjust_indexes_left = [&]() { + Label border_indexes_label; + Label border_indexes_exit_label; + + mov(reg_wj, jcp.l_pad); + sub(reg_wj, reg_ow); + L(border_indexes_label); + { + cmp(reg_wj, 0); + jle(border_indexes_exit_label, T_NEAR); + + add(aux_reg_kernel_h, sizeof(float)); + add(aux_reg_input_h, sizeof(float) * (jcp.dilate_w + 1)); + dec(reg_kw); + sub(reg_wj, jcp.dilate_w + 1); + + jmp(border_indexes_label); + + L(border_indexes_exit_label); + } + }; + + auto adjust_indexes_right = [&]() { + Label border_indexes_right_label; + Label border_indexes_right_exit_label; + + imul(reg_wj, reg_ow, jcp.stride_w); + add(reg_wj, (jcp.kw-1) * (jcp.dilate_w+1) - jcp.l_pad+1 - jcp.iw); + + L(border_indexes_right_label); + { + cmp(reg_wj, 0); + jle(border_indexes_right_exit_label, T_NEAR); + + dec(reg_kw); + sub(reg_wj, jcp.dilate_w + 1); + + jmp(border_indexes_right_label); + + L(border_indexes_right_exit_label); + } + }; + + int left_border_end = nstl::min(div_up(jcp.l_pad, jcp.stride_w), jcp.ow); + L(left_border_label); { + cmp(reg_ow, left_border_end); + jge(main_loop_unrolled_label, T_NEAR); + + mov(aux_reg_input_h, reg_input); + mov(aux_reg_kernel_h, reg_kernel); + mov(reg_kw, jcp.kw); + + adjust_indexes_left(); + adjust_indexes_right(); + + solve_loop(1, 1); // scalar + + inc(reg_ow); + jmp(left_border_label, T_NEAR); + } + + int main_loop_end = (jcp.iw - (jcp.kw - 1)*(jcp.dilate_w + 1) + jcp.l_pad - 1) / jcp.stride_w + 1; + L(main_loop_unrolled_label); { + cmp(reg_ow, main_loop_end - jcp.nb_ow_blocking * jcp.ow_block); + jg(main_loop_label, T_NEAR); + + mov(aux_reg_input_h, reg_input); + mov(aux_reg_kernel_h, reg_kernel); + mov(reg_kw, jcp.kw); + + solve_loop(jcp.nb_ow_blocking, jcp.nb_ow_blocking * jcp.ow_block); + + add(reg_ow, jcp.nb_ow_blocking * jcp.ow_block); + jmp(main_loop_unrolled_label, T_NEAR); + } + + L(main_loop_label); { + cmp(reg_ow, main_loop_end - jcp.ow_block); + jg(right_border_label, T_NEAR); + + mov(aux_reg_input_h, reg_input); + mov(aux_reg_kernel_h, reg_kernel); + mov(reg_kw, jcp.kw); + + solve_loop(1, jcp.ow_block); // vectorized + + add(reg_ow, jcp.ow_block); + jmp(main_loop_label, T_NEAR); + } + + int right_border_end = jcp.ow; + L(right_border_label); { + cmp(reg_ow, right_border_end); + jge(exit_label, T_NEAR); + + mov(aux_reg_input_h, reg_input); + mov(aux_reg_kernel_h, reg_kernel); + mov(reg_kw, jcp.kw); + + adjust_indexes_left(); + adjust_indexes_right(); + + solve_loop(1, 1); // scalar + + inc(reg_ow); + jmp(right_border_label, T_NEAR); + } + + L(exit_label); +} + +template +void jit_uni_planar_conv_fwd_kernel_f32::generate() { + const auto &p = attr_.post_ops_; + for (int i = 0; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_eltwise()) { + eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32( + this, + post_op.eltwise + )); + } + } + + this->preamble(); + + mov(reg_input, ptr[this->param1 + GET_OFF(src)]); + mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + mov(reg_oh_blocks, ptr[this->param1 + GET_OFF(oh_blocks)]); + + Label tail_label; + Label exit_label; + + solve_common(1); + + this->postamble(); + + for (auto& inj : eltwise_injectors) + inj->prepare_table(); +} + +template +bool jit_uni_planar_conv_fwd_kernel_f32::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + auto is_simple = [&](int idx) { return is_eltwise(idx); }; + + switch (p.len()) { + case 0: return true; // no post_ops + case 1: + return true // sum OR eltwise OR depthwise + && !jcp.with_eltwise && (is_simple(0) || is_sum(0)); + case 2: + return true // sum->relu + && !jcp.with_eltwise && ((is_sum(0) && is_simple(1)) || + (is_simple(0) && is_simple(1))); + case 3: + return true // sum->relu + && !jcp.with_eltwise && (is_sum(0) && is_simple(1) && is_simple(2)); + default: return false; + } + + return false; +} + +template +status_t jit_uni_planar_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, memory_desc_t &src_md, + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const primitive_attr_t &attr) +{ + if (!mayiuse(isa)) return status::unimplemented; + + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper weights_d(&weights_md); + const memory_desc_wrapper dst_d(&dst_md); + const memory_desc_wrapper bias_d(&bias_md); + + jcp.prop_kind = cd.prop_kind; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + int ndims = src_d.ndims(); + jcp.ndims = ndims; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2]; + jcp.iw = src_d.dims()[ndims-1]; + jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims-2]; + jcp.ow = dst_d.dims()[ndims-1]; + jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims-2]; + jcp.kw = weights_d.dims()[with_groups + ndims-1]; + + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; + jcp.l_pad = cd.padding[0][ndims-3]; + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; + jcp.stride_w = cd.strides[ndims-3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; + jcp.dilate_w = cd.dilates[ndims-3]; + + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.ih + jcp.t_pad - 1); + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + jcp.with_eltwise = false; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + + const int simd_w = isa == avx512_common ? 16 : 8; + + auto set_or_check_wei_format = [&]() { + using namespace format_tag; + format_tag_t wei_tag = with_groups ? ndims == 5 ? goidhw : goihw + : ndims == 5 ? oidhw : oihw; + + memory_desc_t want_wei_md = weights_md; + memory_desc_init_by_tag(want_wei_md, wei_tag); + + if (weights_md.format_kind == format_kind::any) { + weights_md = want_wei_md; + return true; + } + + return weights_md == want_wei_md; + }; + + if (!set_or_check_wei_format()) + return status::unimplemented; + + auto dat_tag = ndims == 5 ? format_tag::ncdhw : format_tag::nchw; + if (src_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, dat_tag)); + jcp.src_tag = dat_tag; + } else { + jcp.src_tag = src_d.mb_stride_relaxed_match(dat_tag); + } + if (jcp.src_tag != dat_tag) + return status::unimplemented; + + if (dst_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(dst_md, dat_tag)); + jcp.dst_tag = dat_tag; + } else { + jcp.dst_tag = dst_d.mb_stride_relaxed_match(dat_tag); + } + if (jcp.dst_tag != dat_tag) + return status::unimplemented; + + if (jcp.with_bias) { + if (bias_d.format_kind() == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md, format_tag::x)); + } + + // This convolution implementation was introduced as workaround to provide competitive performance on MSD topology. + // The conditions below are needed to bound applicability scope. + bool args_ok = jcp.ngroups == 1 && + jcp.oc == 1 && + jcp.stride_d == 1 && jcp.stride_h == 1 && jcp.stride_w == 1; + if (!args_ok) return status::unimplemented; + + jcp.ur_w = 1; + + jcp.ow_block = simd_w; + jcp.nb_ow_blocking = isa == avx512_common ? 3 : 3; + + jcp.oh_block = 1; + jcp.nb_oh_blocking = 1; + jcp.oh_block_step = 1; // (jcp.dilate_h + 1); + + jcp.oc_block = 1; + jcp.nb_oc = jcp.oc / jcp.oc_block; + jcp.nb_oc_blocking = 1; + + jcp.ic_block = 1; + jcp.nb_ic = jcp.ic / jcp.ic_block; + jcp.nb_ic_blocking = 1; + + return status::success; +} + +template struct jit_uni_planar_conv_fwd_kernel_f32; +template struct jit_uni_planar_conv_fwd_kernel_f32; + +} +} +} +} diff --git a/src/cpu/x64/jit_uni_planar_conv_kernel_f32.hpp b/src/cpu/x64/jit_uni_planar_conv_kernel_f32.hpp new file mode 100644 index 00000000000..350c46d5687 --- /dev/null +++ b/src/cpu/x64/jit_uni_planar_conv_kernel_f32.hpp @@ -0,0 +1,135 @@ +/******************************************************************************* +* Copyright 2019-2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_X64_JIT_UNI_PLANAR_CONV_KERNEL_F32_HPP +#define CPU_X64_JIT_UNI_PLANAR_CONV_KERNEL_F32_HPP + +#include "common/c_types_map.hpp" +#include "common/memory_tracking.hpp" + +#include "cpu/x64/jit_generator.hpp" +#include "cpu/x64/jit_primitive_conf.hpp" +#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" +#include "cpu/x64/injectors/jit_uni_depthwise_injector.hpp" +#include "cpu/x64/injectors/jit_uni_quantization_injector.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +template +struct jit_uni_planar_conv_fwd_kernel_f32: public jit_generator { + jit_uni_planar_conv_fwd_kernel_f32(jit_conv_conf_t ajcp, + const primitive_attr_t &attr): jcp(ajcp), attr_(attr) {} + + ~jit_uni_planar_conv_fwd_kernel_f32() { + for (auto inj : eltwise_injectors) + delete inj; + eltwise_injectors.clear(); + + for (auto inj : depthwise_injectors) + delete inj; + depthwise_injectors.clear(); + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_planar_conv_fwd_kernel_f32) + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, memory_desc_t &src_md, + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const primitive_attr_t &attr); + + jit_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_conv_call_s *); + +private: + using Vmm = typename utils::conditional3::type; + using reg64_t = const Xbyak::Reg64; + using reg32_t = const Xbyak::Reg32; + const Xbyak::AddressFrame &vmmword = (isa == sse41) + ? xword : (isa == avx2) ? yword : zword; + + reg64_t reg_input = r8; + reg64_t reg_kernel = r9; + reg64_t reg_output = r10; + + reg64_t aux_reg_input_h = r11; + reg64_t aux_reg_kernel_h = r12; + + reg64_t aux_reg_input_w = r13; + reg64_t aux_reg_kernel_w = r14; + + reg64_t aux_reg_inp_d = r9; + reg64_t aux_reg_ker_d = r10; + + reg64_t reg_kd = rbx; + reg64_t reg_kh = rdx; + reg64_t reg_kw = rsi; + + reg64_t kh_iter = rax; + reg64_t kw_iter = abi_not_param1; + + reg64_t reg_bias = r13; + reg64_t reg_long_offt = r15; + reg32_t reg_ci_flag = r15d; + + reg64_t reg_d_weights = r15; + reg64_t reg_d_bias = kh_iter; + + reg64_t reg_ow = rbp; + + reg64_t reg_oh_blocks = aux_reg_kernel_w; + + reg64_t reg_wj = aux_reg_input_w; + + Vmm vmm_ker = Vmm(15); + Vmm vmm_tmp = Vmm(15); + Vmm vmm_src = Vmm(14); + Xbyak::Xmm xmm_ker = Xbyak::Xmm(15); + Xbyak::Xmm xmm_tmp = Xbyak::Xmm(15); + Xbyak::Xmm xmm_src = Xbyak::Xmm(14); + + nstl::vector*> eltwise_injectors; + nstl::vector*> depthwise_injectors; + + inline void load_src(int ur_h, int ur_w); + inline void filter(int ur_h); + inline void filter_unrolled(int ur_h, int ur_w); + inline void apply_filter(int ur_h, int ur_w); + inline void apply_postprocess(int ur_h, int ur_w); + inline void store_dst(int ur_h, int ur_w); + inline void solve_common(int ur_h); + + inline void filter_scalar(int ur_h); + inline void load_src_scalar(int ur_h); + inline void apply_filter_scalar(int ur_h); + inline void apply_postprocess_scalar(int ur_h); + inline void store_dst_scalar(int ur_h); + + void generate() override; +}; + +} +} +} +} + +#endif diff --git a/src/cpu/x64/jit_uni_planar_convolution.cpp b/src/cpu/x64/jit_uni_planar_convolution.cpp new file mode 100644 index 00000000000..6146a6881b2 --- /dev/null +++ b/src/cpu/x64/jit_uni_planar_convolution.cpp @@ -0,0 +1,171 @@ +/******************************************************************************* +* Copyright 2019-2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_uni_planar_convolution.hpp" + +#include "common/c_types_map.hpp" +#include "common/dnnl_thread.hpp" +#include "common/memory_tracking.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +using namespace dnnl::impl::status; +using namespace dnnl::impl::utils; + +#define src_blk_off(f, n, c, d, h, w) \ + pd()->ndims() == 5 \ + ? (f).blk_off(n, c, d, h, w) \ + : (f).blk_off(n, c, h, w) + +#define wht_blk_off(f, g, oc, ic, kd, kh, kw) \ + pd()->ndims() == 5 \ + ? pd()->with_groups() \ + ? (f).blk_off(g, oc, ic, kd, kh, kw) \ + : (f).blk_off(oc, ic, kd, kh, kw) \ + : pd()->with_groups() \ + ? (f).blk_off(g, oc, ic, kh, kw) \ + : (f).blk_off(oc, ic, kh, kw) + +template +void _jit_uni_planar_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, DNNL_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); + + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const auto &jcp = pd()->jcp_; + + std::vector od_indexes(jcp.od); + + int idx = 0; + for (int i = 0; i < (jcp.dilate_d + 1); i++) { + for (int ib = 0; ib < jcp.od; ib += (jcp.dilate_d + 1)) { + if (ib + i >= jcp.od) + continue; + + od_indexes[idx++] = ib + i; + if (idx >= jcp.od) + break; + } + if (idx >= jcp.od) + break; + } + + int threads_count = dnnl_get_max_threads(); + int odb_size = div_up(jcp.od, threads_count); + + auto kernel_params = [&](int n, int g, int icb, int oc, int od, int oh, int oh_blocks, int id, int wd, int kd_padding) { + auto par_conv = jit_conv_call_s(); + + const int hj = oh * jcp.stride_h; + const int i_t_overflow = nstl::max(0, jcp.t_pad - hj); + const int i_b_overflow = nstl::max(jcp.ih, hj + (jcp.kh - 1) * (jcp.dilate_h + 1) - jcp.t_pad + 1) - jcp.ih; + const int ih = nstl::max(hj - jcp.t_pad + div_up(i_t_overflow, (jcp.dilate_h + 1)) * (jcp.dilate_h + 1), 0); + const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1)); + const int kh_padding = jcp.kh - div_up(i_t_overflow, (jcp.dilate_h + 1)) - div_up(i_b_overflow, (jcp.dilate_h + 1)); + + const size_t _oc = oc; + const size_t _ic = g * jcp.nb_ic + icb; + + par_conv.src = &src[src_blk_off(src_d, n, _ic, id, ih, 0)]; + par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, od, oh, 0)]; + par_conv.filt = &weights[wht_blk_off(weights_d, g, _oc, _ic, wd, wh, 0)]; + + if (icb == 0) { + if (bias) + par_conv.bias = &bias[bias_d.blk_off(_oc)]; + par_conv.flags |= FLAG_IC_FIRST; + } + + if (icb + 1 == jcp.nb_ic) { + par_conv.flags |= FLAG_IC_LAST; + } + + par_conv.oc_off = _oc * sizeof(float); + par_conv.oh_blocks = (size_t)oh_blocks; + + par_conv.kh_padding = (size_t)nstl::max(0, kh_padding); + par_conv.kd_padding = (size_t)nstl::max(0, kd_padding); + + return par_conv; + }; + + auto ker = [&](const int ithr, const int nthr) { + int g = 0; + int oc = 0; + + for (int n = 0; n < MB; n++) { + int icbb = 0; + while (icbb < jcp.nb_ic) { + int icb_step = jcp.nb_ic_blocking; + int icb_step_rem = jcp.nb_ic - icbb; + if (icb_step_rem < jcp.nb_ic_blocking_max) + icb_step = icb_step_rem; + + for (int icb = icbb; icb < icbb + icb_step; ++icb) { + for (int ohb = 0; ohb < (jcp.dilate_h + 1); ohb++) { + for (int oh = ohb; oh < jcp.oh; oh += (jcp.dilate_h + 1)) { + int od_idx_off = ithr * odb_size; + for (int od_idx = 0; od_idx < odb_size; od_idx++) { + if ((od_idx_off + od_idx) >= jcp.od || od_indexes[od_idx_off + od_idx] >= jcp.od) + continue; + int od = od_indexes[od_idx_off + od_idx]; + + const int dj = od * jcp.stride_d; + const int d_t_overflow = nstl::max(0, jcp.f_pad - dj); + const int d_b_overflow = + nstl::max(jcp.id, dj + (jcp.kd - 1) * (jcp.dilate_d + 1) - jcp.f_pad + 1) - + jcp.id; + const int id = nstl::max(dj - jcp.f_pad + + div_up(d_t_overflow, (jcp.dilate_d + 1)) * (jcp.dilate_d + 1), + 0); + const int wd = div_up(d_t_overflow, (jcp.dilate_d + 1)); + const int kd_padding = jcp.kd - div_up(d_t_overflow, (jcp.dilate_d + 1)) - + div_up(d_b_overflow, (jcp.dilate_d + 1)); + + jit_conv_call_s par_conv = kernel_params(n, g, icb, oc, od, oh, 1, id, wd, kd_padding); + + (*kernel_)(&par_conv); + } + } + } + } + icbb += icb_step; + } + } + }; + + parallel(0, ker); +} + + +template struct _jit_uni_planar_convolution_fwd_t; +template struct _jit_uni_planar_convolution_fwd_t; + +} +} +} +} diff --git a/src/cpu/x64/jit_uni_planar_convolution.hpp b/src/cpu/x64/jit_uni_planar_convolution.hpp new file mode 100644 index 00000000000..3d1e60836e3 --- /dev/null +++ b/src/cpu/x64/jit_uni_planar_convolution.hpp @@ -0,0 +1,96 @@ +/******************************************************************************* +* Copyright 2019-2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_X64_JIT_UNI_PLANAR_CONVOLUTION_HPP +#define CPU_X64_JIT_UNI_PLANAR_CONVOLUTION_HPP + +#include "jit_primitive_conf.hpp" +#include "jit_uni_planar_conv_kernel_f32.hpp" + +#include "common/c_types_map.hpp" +#include "common/memory_tracking.hpp" +#include "common/primitive.hpp" + +#include "cpu/cpu_convolution_pd.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +template +struct _jit_uni_planar_convolution_fwd_t: public primitive_t { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_planar:", isa, ""), + _jit_uni_planar_convolution_fwd_t); + + status_t init(engine_t *engine) { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && !this->has_zero_dim_memory() + && utils::everyone_is(data_type::f32, + this->desc()->src_desc.data_type, + this->desc()->weights_desc.data_type, + this->desc()->dst_desc.data_type) + && IMPLICATION(this->with_bias(), data_type::f32 == this->desc()->bias_desc.data_type) + && attr()->has_default_values(primitive_attr_t::skip_mask_t::post_ops); + if (!ok) return status::unimplemented; + + status_t sts = jit_uni_planar_conv_fwd_kernel_f32::init_conf(jcp_, *desc(), src_md_, weights_md_, dst_md_, bias_md_, *attr()); + + return sts; + } + + jit_conv_conf_t jcp_; + }; + + _jit_uni_planar_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} + + typedef typename prec_traits::type data_t; + + status_t init(engine_t *engine) override { + CHECK(safe_ptr_assign(kernel_, new jit_uni_planar_conv_fwd_kernel_f32(pd()->jcp_, *pd()->attr()))); + return kernel_->create_kernel(); + } + + status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + + std::unique_ptr> kernel_; +}; + +using jit_avx512_common_planar_convolution_fwd_t = _jit_uni_planar_convolution_fwd_t; +using jit_avx2_planar_convolution_fwd_t = _jit_uni_planar_convolution_fwd_t; + +} +} +} +} + +#endif diff --git a/src/cpu/x64/jit_uni_pooling.cpp b/src/cpu/x64/jit_uni_pooling.cpp index 0913ac36e8d..f55946693b3 100644 --- a/src/cpu/x64/jit_uni_pooling.cpp +++ b/src/cpu/x64/jit_uni_pooling.cpp @@ -546,6 +546,7 @@ jit_uni_pooling_fwd_t::~jit_uni_pooling_fwd_t() = default; template void jit_uni_pooling_fwd_t::execute_forward(const data_t *src, data_t *dst, char *indices, const exec_ctx_t &ctx) const { + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); const memory_desc_wrapper src_d = pd()->src_md(); const memory_desc_wrapper dst_d = pd()->dst_md(); @@ -620,7 +621,7 @@ void jit_uni_pooling_fwd_t::execute_forward(const data_t *src, if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) { const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); - parallel_nd(jpp.mb, jpp.oh, nb2_c, [&](dim_t n, dim_t oh, dim_t b2_c) { + parallel_nd(MB, jpp.oh, nb2_c, [&](dim_t n, dim_t oh, dim_t b2_c) { const auto b_c = b2_c * jpp.ur_bc; const auto ur_bc = nstl::min(dim_t(jpp.ur_bc), jpp.nb_c - b_c); ker(0, n, b_c, oh, ur_bc); @@ -628,7 +629,7 @@ void jit_uni_pooling_fwd_t::execute_forward(const data_t *src, } else { if (trans_src || trans_dst) { // ncsp format - parallel_nd_ext(nthr, jpp.mb, jpp.nb_c, + parallel_nd_ext(nthr, MB, jpp.nb_c, [&](dim_t ithr, dim_t nthr, dim_t n, dim_t b_c) { if (trans_src) transpose_facade.execute_transpose_input( @@ -642,7 +643,7 @@ void jit_uni_pooling_fwd_t::execute_forward(const data_t *src, } else { // nChw16c, nChw8c format parallel(nthr, [&](dim_t ithr, dim_t nthr) { - dim_t work_amount = jpp.mb * jpp.nb_c * jpp.oh; + dim_t work_amount = MB * jpp.nb_c * jpp.oh; if (ithr >= work_amount) return; dim_t start {0}, end {0}; @@ -650,12 +651,12 @@ void jit_uni_pooling_fwd_t::execute_forward(const data_t *src, balance211(work_amount, nthr, ithr, start, end); utils::nd_iterator_init( - start, n, jpp.mb, b_c, jpp.nb_c, oh, jpp.oh); + start, n, MB, b_c, jpp.nb_c, oh, jpp.oh); for (dim_t iwork = start; iwork < end; ++iwork) { ker(ithr, n, b_c, oh, 1); utils::nd_iterator_step( - n, jpp.mb, b_c, jpp.nb_c, oh, jpp.oh); + n, MB, b_c, jpp.nb_c, oh, jpp.oh); } }); } @@ -667,6 +668,9 @@ void jit_uni_pooling_fwd_t::execute_forward_3d(const data_t *src, data_t *dst, char *indices, const exec_ctx_t &ctx) const { const auto &jpp = pd()->jpp_; + + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const memory_desc_wrapper src_d(pd()->src_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); const memory_desc_wrapper indices_d(pd()->workspace_md()); @@ -748,7 +752,7 @@ void jit_uni_pooling_fwd_t::execute_forward_3d(const data_t *src, if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) { const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); - parallel_nd(jpp.mb, jpp.od, nb2_c, [&](dim_t n, dim_t od, dim_t b2_c) { + parallel_nd(MB, jpp.od, nb2_c, [&](dim_t n, dim_t od, dim_t b2_c) { const dim_t b_c = b2_c * jpp.ur_bc; const dim_t ur_bc = nstl::min(dim_t(jpp.ur_bc), jpp.nb_c - b_c); @@ -765,7 +769,7 @@ void jit_uni_pooling_fwd_t::execute_forward_3d(const data_t *src, }); } else { if (trans_src || trans_dst) { - parallel_nd_ext(nthr, jpp.mb, jpp.nb_c, + parallel_nd_ext(nthr, MB, jpp.nb_c, [&](dim_t ithr, dim_t nthr, dim_t n, dim_t b_c) { if (trans_src) transpose_facade.execute_transpose_input( @@ -790,7 +794,7 @@ void jit_uni_pooling_fwd_t::execute_forward_3d(const data_t *src, ithr, n, b_c); }); } else { - parallel_nd(jpp.mb, jpp.nb_c, jpp.od, + parallel_nd(MB, jpp.nb_c, jpp.od, [&](dim_t n, dim_t b_c, dim_t od) { const int ik = od * jpp.stride_d; const int d_t_overflow = nstl::max(0, jpp.f_pad - ik); diff --git a/src/cpu/x64/jit_uni_reorder.cpp b/src/cpu/x64/jit_uni_reorder.cpp index 5a0102d2caf..ae34645a9ba 100644 --- a/src/cpu/x64/jit_uni_reorder.cpp +++ b/src/cpu/x64/jit_uni_reorder.cpp @@ -2450,11 +2450,28 @@ status_t jit_blk_reorder_t::pd_t::create(reorder_pd_t **reorder_pd, // TODO: Add tail processing support in blk_reorder if (prb.is_tail_present) return status::unimplemented; + // NB! Fall back to ref, if input and output both batch-strided + bool batch_strided_input = false; + bool batch_strided_output = false; + if (prb.ndims > 1) { + int batch_idx = prb.nodes[0].is > prb.nodes[1].is ? 0 : 1; + int channel_idx = batch_idx == 0 ? 1 : 0; + batch_strided_input = + (ptrdiff_t) prb.nodes[channel_idx].n * prb.nodes[channel_idx].is < prb.nodes[batch_idx].is; + batch_idx = prb.nodes[0].os > prb.nodes[1].os ? 0 : 1; + channel_idx = batch_idx == 0 ? 1 : 0; + batch_strided_output = + (ptrdiff_t) prb.nodes[channel_idx].n * prb.nodes[channel_idx].is < prb.nodes[batch_idx].is; + } + prb_tile_normalize(prb); DEBUG({ printf("tile : "); prb_dump(prb); }); + // NB! Fall back to ref, if input and output both batch-strided + if (batch_strided_input && batch_strided_output) + return status::unimplemented; if (!tr::jit_single_blk_kernel_t::applicable(prb)) { return status::unimplemented; diff --git a/src/cpu/x64/jit_uni_softmax.cpp b/src/cpu/x64/jit_uni_softmax.cpp index 45429a30078..ebdd5d00328 100644 --- a/src/cpu/x64/jit_uni_softmax.cpp +++ b/src/cpu/x64/jit_uni_softmax.cpp @@ -28,7 +28,7 @@ #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" #include "cpu/x64/jit_uni_softmax.hpp" -#if __INTEL_COMPILER && __INTEL_COMPILER < 1900 +#if defined(__INTEL_COMPILER) && __INTEL_COMPILER < 1900 // Intel Compilers 17.x and 18.x do not like that diff_src_ptr() is only used // in a single descendant class and marks it as unused. This breaks builds // with DNNL_WERROR=on. Disabling the warning for this file seems to be less @@ -723,7 +723,8 @@ status_t jit_uni_softmax_fwd_t::execute(const exec_ctx_t &ctx) const { auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); - const memory_desc_wrapper data_d(pd()->src_md()); + auto real_src_md = ctx.input(DNNL_ARG_SRC)->md(); + const memory_desc_wrapper data_d(real_src_md); const auto data_type_size = data_d.data_type() == data_type::bf16 ? sizeof(bfloat16_t) : sizeof(float); diff --git a/src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.cpp b/src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.cpp index 5fa3043c353..437f9955414 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.cpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.cpp @@ -48,7 +48,7 @@ _jit_uni_x8s8s32x_1x1_conv_kernel::_jit_uni_x8s8s32x_1x1_conv_kernel( const jit_1x1_conv_conf_t &ajcp, const primitive_attr_t &attr, const memory_desc_t &dst_md) : jit_generator(nullptr, MAX_CODE_SIZE, true, isa), jcp(ajcp), attr_(attr) { - if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = true; @@ -57,10 +57,12 @@ _jit_uni_x8s8s32x_1x1_conv_kernel::_jit_uni_x8s8s32x_1x1_conv_kernel( GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), memory_desc_wrapper(dst_md)}; static_params_t static_params {this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params + {vmm_d_weights.getIdx(), vmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } } @@ -186,7 +188,17 @@ void _jit_uni_x8s8s32x_1x1_conv_kernel::apply_postops(const int ur, const int load_loop_blk, const bool mask_flag_in, const float *p_sum_scale, const int32_t *p_sum_zp) { - if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum || jcp.with_depthwise || jcp.with_quantization) { + std::map vmm_idx_off; + iterate(ur, load_loop_blk, [&](const int i_ur, const int i_load) { + vmm_idx_off.insert({vreg_accum_idx(load_loop_blk, i_load, i_ur), i_load * jcp.load_block * sizeof(float)}); + }); + depthwise_injector::dynamic_params_t ddp {vmm_d_weights.getIdx(), vmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + reg_oc_off, vmm_idx_off, + this->rsp, base_post_ops_data_offset}; + quantization_injector::dynamic_params_t qdp {reg_oc_off, vmm_idx_off, jcp.dst_dt, + this->rsp, base_post_ops_data_offset}; + if (jcp.with_sum && *p_sum_zp != 0) mov(ptr[rsp + reg_bcast_loop_iter_off], reg_ptr_sum_zp); apply_sum(ur, load_loop_blk, mask_flag_in, p_sum_scale, p_sum_zp); @@ -210,12 +222,13 @@ void _jit_uni_x8s8s32x_1x1_conv_kernel::apply_postops(const int ur, vmm_idx, aux_output_offset); }); - postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); + + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); } else { iterate(ur, load_loop_blk, [&](const int i_ur, const int i_load) { vmm_idxs.emplace(vreg_accum_idx(load_loop_blk, i_load, i_ur)); }); - postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); } if (jcp.with_sum && *p_sum_zp != 0) mov(reg_ptr_sum_zp, ptr[rsp + reg_bcast_loop_iter_off]); @@ -317,14 +330,14 @@ void _jit_uni_x8s8s32x_1x1_conv_kernel::reduce_loop( const auto ptr_scales_offset = jcp.is_oc_scale * (sizeof(float) * jcp.oc_block * i_load); if (jcp.with_bias) { - if (jcp.signed_input) + if (jcp.signed_input || jcp.with_input_zp) mov(reg_bias_data, ptr[rsp + reg_bias_data_off]); cvt2ps(jcp.bia_dt, vmm_bias, reg_bias_data, jcp.typesize_bia * jcp.oc_block * i_load, load_size); if (jcp.signed_input && jcp.ver != ver_vnni) uni_vmulps(vmm_bias, vmm_bias, vmm_bias_alpha()); } - if (jcp.signed_input) { + if (jcp.signed_input || jcp.with_input_zp) { mov(reg_comp_data, ptr[rsp + reg_comp_data_off]); cvt2ps(data_type::s32, vmm_comp, reg_comp_data, sizeof(int32_t) * jcp.oc_block * i_load, load_size); @@ -350,7 +363,7 @@ void _jit_uni_x8s8s32x_1x1_conv_kernel::reduce_loop( for (int i_ur = 0; i_ur < ur; ++i_ur) { const auto r = vreg_accum(load_loop_blk, i_load, i_ur); uni_vcvtdq2ps(r, r); - if (jcp.signed_input) uni_vaddps(r, r, vmm_comp); + if (jcp.signed_input || jcp.with_input_zp) uni_vaddps(r, r, vmm_comp); if (jcp.src_zero_point) uni_vaddps(r, r, vmm_zp_comp); if (jcp.with_bias) uni_vaddps(r, r, vmm_bias); @@ -446,6 +459,8 @@ void _jit_uni_x8s8s32x_1x1_conv_kernel::reduce_loop( Label reduce_loop; Label reduce_loop_tail; + push(reg_oc_off); + mov(aux_reg_load_data, reg_load_data); mov(aux_reg_bcast_data, aux1_reg_bcast_data); @@ -467,6 +482,8 @@ void _jit_uni_x8s8s32x_1x1_conv_kernel::reduce_loop( L(reduce_loop_tail); fma_block(jcp.ic != jcp.ic_without_padding); + pop(reg_oc_off); + if (jcp.oc_without_padding != jcp.oc) { Label end_store, common_store; mov(ptr[rsp + reg_bcast_data_off], reg_bcast_data); @@ -498,7 +515,12 @@ template void _jit_uni_x8s8s32x_1x1_conv_kernel::generate() { preamble(); + if (postops_injector_) + postops_injector_->push_post_ops_data_on_stack(this->param1, GET_OFF(post_ops_binary_rhs_arg_vec), reg_load_data, reg_output_data); + sub(rsp, stack_space_needed); + base_post_ops_data_offset += stack_space_needed; + if (jcp.with_binary) { // zero initialize binary post_ops offset accumulator (store on stack) const auto binary_post_op_acc_off_reg = r15; @@ -507,7 +529,7 @@ void _jit_uni_x8s8s32x_1x1_conv_kernel::generate() { } if (jcp.with_bias) mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); - if (jcp.signed_input) { + if (jcp.signed_input || jcp.with_input_zp) { mov(ptr[rsp + reg_bias_data_off], reg_bias_data); mov(reg_comp_data, ptr[param1 + GET_OFF(compensation)]); mov(ptr[rsp + reg_comp_data_off], reg_comp_data); @@ -533,16 +555,17 @@ void _jit_uni_x8s8s32x_1x1_conv_kernel::generate() { mov(ptr[rsp + bcast_loop_work_off], reg_bcast_loop_work); mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); + mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]); auto load_loop_body = [&](int load_loop_blk) { bcast_loop(load_loop_blk); add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); if (jcp.with_bias) { - if (jcp.signed_input) + if (jcp.signed_input || jcp.with_input_zp) mov(reg_bias_data, ptr[rsp + reg_bias_data_off]); add(reg_bias_data, load_loop_blk * jcp.load_block * jcp.typesize_bia); - if (jcp.signed_input) + if (jcp.signed_input || jcp.with_input_zp) mov(ptr[rsp + reg_bias_data_off], reg_bias_data); } if (jcp.with_binary) { @@ -552,7 +575,7 @@ void _jit_uni_x8s8s32x_1x1_conv_kernel::generate() { mov(EVEX_compress_addr(rsp, reg_binary_post_op_acc_off), aux_reg_load_data); } - if (jcp.signed_input) { + if (jcp.signed_input || jcp.with_input_zp) { mov(reg_comp_data, ptr[rsp + reg_comp_data_off]); add(reg_comp_data, load_loop_blk * jcp.load_block * sizeof(int32_t)); @@ -573,6 +596,7 @@ void _jit_uni_x8s8s32x_1x1_conv_kernel::generate() { mov(reg_bcast_data, ptr[rsp + reg_bcast_data_off]); add(reg_output_data, load_loop_blk * jcp.load_block * jcp.typesize_out); sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + add(reg_oc_off, load_loop_blk * jcp.oc_block * sizeof(float)); }; static const int ur_cases[] = {2, 3, 5, 12}; @@ -616,7 +640,13 @@ void _jit_uni_x8s8s32x_1x1_conv_kernel::generate() { } } L(load_loop_blk[num_ur_cases]); + + base_post_ops_data_offset -= stack_space_needed; add(rsp, stack_space_needed); + + if (postops_injector_) + postops_injector_->reset_stack_pointer(); + postamble(); if (jcp.with_eltwise) postops_injector_->prepare_table(); @@ -664,6 +694,20 @@ status_t jit_uni_x8s8s32x_1x1_conv_kernel::init_conf( jcp.signed_input = (src_d.data_type() == data_type::s8); + jcp.with_input_zp = !attr.input_zero_points_.has_default_values(); + jcp.with_weights_zp = !attr.weights_zero_points_.has_default_values(); + + if (jcp.with_input_zp) { + if (attr.input_zero_points_.count_ != 1 && attr.input_zero_points_.count_ != jcp.ic * jcp.ngroups) + return status::unimplemented; + + if (attr.output_compensations_.count_ != jcp.oc * jcp.ngroups) + return status::unimplemented; + } + + if (jcp.with_weights_zp) + return status::unimplemented; + jcp.os = jcp.od * jcp.oh * jcp.ow; jcp.is = jcp.id * jcp.ih * jcp.iw; @@ -683,6 +727,9 @@ status_t jit_uni_x8s8s32x_1x1_conv_kernel::init_conf( const int sum_ind = post_ops.find(primitive_kind::sum, 0, dw_conv_ind); jcp.with_sum = sum_ind != -1; + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise, 0, dw_conv_ind) != -1; + jcp.with_quantization = post_ops.find(primitive_kind::quantization, 0, dw_conv_ind) != -1; + const auto zp = attr.zero_points_; jcp.dst_zero_point = !zp.has_default_values(DNNL_ARG_DST); jcp.src_zero_point = !zp.has_default_values(DNNL_ARG_SRC); @@ -721,7 +768,7 @@ status_t jit_uni_x8s8s32x_1x1_conv_kernel::init_conf( } using namespace injector; - const bool post_ops_ok_ = post_ops_ok({isa, {eltwise, binary, sum}, + const bool post_ops_ok_ = post_ops_ok({isa, {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, false, false, false}); if (!post_ops_ok_) return status::unimplemented; @@ -734,7 +781,8 @@ status_t jit_uni_x8s8s32x_1x1_conv_kernel::init_conf( jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; jcp.dst_dt = cd.dst_desc.data_type; - jcp.sum_dt = post_ops.get_sum_dt(jcp.dst_dt); + if (jcp.with_sum) + jcp.sum_dt = post_ops.get_sum_dt(jcp.dst_dt); jcp.ic_block = jcp.oc_block = simd_w; diff --git a/src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.hpp b/src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.hpp index 67a72c8fded..74c3705cda6 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.hpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.hpp @@ -34,7 +34,6 @@ struct _jit_uni_x8s8s32x_1x1_conv_kernel : public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_uni_x8s8s32x_1x1_conv_kernel) _jit_uni_x8s8s32x_1x1_conv_kernel(const jit_1x1_conv_conf_t &ajcp, const primitive_attr_t &attr, const memory_desc_t &dst_md); - int get_tail_size() { return jcp.oc_without_padding % jcp.oc_block; } jit_1x1_conv_conf_t jcp; @@ -61,7 +60,7 @@ struct _jit_uni_x8s8s32x_1x1_conv_kernel : public jit_generator { const Xbyak::Reg64 reg_reduce_loop_iter = r13; const Xbyak::Reg64 aux_reg_bcast_data = r14; const Xbyak::Reg64 aux_reg_load_data = r15; - const Xbyak::Reg64 aux_reg_saturation = r15; + const Xbyak::Reg64 aux_reg_saturation = r14; const Xbyak::Reg64 reg_reduce_pos_flag = rax; const Xbyak::Reg64 aux1_reg_bcast_data = rbx; const Xbyak::Reg64 reg_bcast_loop_work = rbx; @@ -73,6 +72,14 @@ struct _jit_uni_x8s8s32x_1x1_conv_kernel : public jit_generator { const Xbyak::Reg64 reg_src_zero_point = aux_reg_bcast_data; // r14 const Xbyak::Reg64 reg_dst_zero_point = reg_src_zero_point; + const Xbyak::Reg64 reg_d_weights = aux_reg_bcast_data; + const Xbyak::Reg64 reg_d_bias = abi_param1; + const Xbyak::Reg64 reg_oc_off = aux_reg_load_data; + int base_post_ops_data_offset = 0; + + Vmm vmm_d_weights = Vmm(0); + Vmm vmm_d_bias = Vmm(1); + const Vmm vmm_tmp = Vmm(3); const Vmm vmm_one = Vmm(2); const Vmm vmm_zero = Vmm(1); diff --git a/src/cpu/x64/jit_uni_x8s8s32x_1x1_convolution.cpp b/src/cpu/x64/jit_uni_x8s8s32x_1x1_convolution.cpp index 2bfedc09921..472086b1d9e 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_1x1_convolution.cpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_1x1_convolution.cpp @@ -60,6 +60,10 @@ status_t jit_uni_x8s8s32x_1x1_convolution_fwd_t::execute_forward( DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); + DEFINE_OUTPUT_COMPENSATION_BUFFER(output_compensation, pd()->jcp_); + + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + auto scratchpad = ctx.get_scratchpad_grantor(); if (pd()->jcp_.signed_input && pd()->jcp_.ver != ver_vnni) { @@ -101,7 +105,7 @@ status_t jit_uni_x8s8s32x_1x1_convolution_fwd_t::execute_forward( execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw, dst, src_zero_point, dst_zero_point, scratchpad, post_ops_binary_rhs_arg_vec.data(), - post_ops_binary_rhs_arg_vec_dw.data()); + post_ops_binary_rhs_arg_vec_dw.data(), MB, output_compensation); }); return status::success; } @@ -113,7 +117,8 @@ void jit_uni_x8s8s32x_1x1_convolution_fwd_t::execute_forward_thr( char *dst, const int32_t *src_zero_point, const int32_t *dst_zero_point, const memory_tracking::grantor_t &scratchpad, const void *post_ops_binary_rhs_arg_vec, - const void *post_ops_binary_rhs_arg_vec_dw) const { + const void *post_ops_binary_rhs_arg_vec_dw, int MB, + const int32_t *output_compensation) const { const memory_desc_wrapper src_d(pd()->src_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); @@ -134,7 +139,7 @@ void jit_uni_x8s8s32x_1x1_convolution_fwd_t::execute_forward_thr( auto local_scales = scratchpad.get(key_conv_adjusted_scales); - const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + const int work_amount = MB * jcp.ngroups * jcp.nb_bcast; const int ndims = dst_d.ndims(); const int stride_d = (ndims == 5) ? pd()->desc()->strides[0] : 1; @@ -149,9 +154,8 @@ void jit_uni_x8s8s32x_1x1_convolution_fwd_t::execute_forward_thr( auto offset = weights_d.size() - weights_d.additional_buffer_size(); char *w = const_cast(weights); - const int32_t *compensation = (jcp.signed_input) - ? reinterpret_cast(w + offset) - : nullptr; + const int32_t *compensation = (jcp.signed_input) ? reinterpret_cast(w + offset) : + (jcp.with_input_zp) ? output_compensation : nullptr; const int32_t *zp_compensation = jcp.src_zero_point ? reinterpret_cast(&w[offset]) + (jcp.signed_input ? jcp.ngroups * jcp.oc : 0) @@ -211,7 +215,7 @@ void jit_uni_x8s8s32x_1x1_convolution_fwd_t::execute_forward_thr( int &bcast_step, int &od, int &oh, int &ow, int &id, int &ih, int &iw) { int osb {0}; - nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, nb_bcast); + nd_iterator_init(iwork, n, MB, g, jcp.ngroups, osb, nb_bcast); bcast_step = step( nb_bcast_blocking, nb_bcast - osb, nb_bcast_blocking_max); bcast_step = nstl::min(bcast_step, bcast_end - iwork); @@ -266,8 +270,7 @@ void jit_uni_x8s8s32x_1x1_convolution_fwd_t::execute_forward_thr( : weights_d.blk_off(ocb, icb); p.load_data = weights + wei_offset; p.bias_data = &bias[_ocb * jcp.oc_block * bia_dt_size]; - p.compensation = (jcp.signed_input) ? &compensation[_ocb * jcp.oc_block] - : nullptr; + p.compensation = (jcp.signed_input || jcp.with_input_zp) ? &compensation[_ocb * jcp.oc_block] : nullptr; p.zp_compensation = jcp.src_zero_point ? zp_compensation + _ocb * jcp.oc_block : nullptr; @@ -292,6 +295,7 @@ void jit_uni_x8s8s32x_1x1_convolution_fwd_t::execute_forward_thr( p.oc_l_off = g * nb_oc + ocb * jcp.oc_block; p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; p.dst_orig = jcp.with_dw_conv ? pbuf : dst; + p.oc_off = _ocb * jcp.oc_block * sizeof(float); (*kernel_)(&p); }; @@ -421,6 +425,7 @@ void jit_uni_x8s8s32x_1x1_convolution_fwd_t::execute_forward_thr( par_conv_dw.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec_dw; par_conv_dw.dst_orig = dst; + p.oc_off = ocb * jcp_dw->ch_block * sizeof(float); (*kernel_dw_)(&par_conv_dw); @@ -440,7 +445,7 @@ void jit_uni_x8s8s32x_1x1_convolution_fwd_t::execute_forward_thr( addrs.resize(jcp_dw->kh); int bcast_start {0}, bcast_end {0}, ocb_start, ocb_end; - balance2D(nthr, ithr, jcp.mb * jcp.ngroups * jcp_dw->oh, bcast_start, + balance2D(nthr, ithr, MB * jcp.ngroups * jcp_dw->oh, bcast_start, bcast_end, nb_oc, ocb_start, ocb_end, jcp.load_grp_count); while (ocb_start < ocb_end) { @@ -451,7 +456,7 @@ void jit_uni_x8s8s32x_1x1_convolution_fwd_t::execute_forward_thr( auto bcast_iter = bcast_start; while (bcast_iter < bcast_end) { int n {0}, g {0}, oh_dw {0}; - nd_iterator_init(bcast_iter, n, jcp.mb, g, jcp.ngroups, oh_dw, + nd_iterator_init(bcast_iter, n, MB, g, jcp.ngroups, oh_dw, jcp_dw->oh); if (oh_dw == 0) oh_1x1 = 0; // Reset over mb boundary const int oh_1x1_range diff --git a/src/cpu/x64/jit_uni_x8s8s32x_1x1_convolution.hpp b/src/cpu/x64/jit_uni_x8s8s32x_1x1_convolution.hpp index c372a850cf7..3d728efb4f8 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_1x1_convolution.hpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_1x1_convolution.hpp @@ -72,7 +72,10 @@ struct jit_uni_x8s8s32x_1x1_convolution_fwd_t : public primitive_t { && desc()->accum_data_type == s32 && attr()->has_default_values(smask_t::oscale | smask_t::zero_points_runtime - | smask_t::post_ops | smask_t::sum_dt, + | smask_t::post_ops + | smask_t::sum_dt + | smask_t::input_zero_points + | smask_t::output_compensations, dst_md(0)->data_type) && attr()->post_ops_.check_sum_consistent_dt( dst_md(0)->data_type) @@ -199,11 +202,10 @@ struct jit_uni_x8s8s32x_1x1_convolution_fwd_t : public primitive_t { memory_desc_init_by_tag(want_wei_md, wei_tag); if (is_src_s8) { want_wei_md.extra.flags - = 0 | compensation_conv_s8s8 | scale_adjust; + = 0 | compensation_conv_s8s8; want_wei_md.extra.compensation_mask = with_groups() ? g_mask : c_mask; - want_wei_md.extra.scale_adjust - = mayiuse(avx2_vnni) ? 1.0f : 0.5f; + want_wei_md.extra.scale_adjust = 1.0f; } if (is_src_zero_point) { want_wei_md.extra.flags |= compensation_conv_asymmetric_src; @@ -354,7 +356,8 @@ struct jit_uni_x8s8s32x_1x1_convolution_fwd_t : public primitive_t { const int32_t *dst_zero_point, const memory_tracking::grantor_t &scratchpad, const void *post_ops_binary_rhs_arg_vec, - const void *post_ops_binary_rhs_arg_vec_dw) const; + const void *post_ops_binary_rhs_arg_vec_dw, int MB, + const int32_t *output_compensation) const; const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } std::unique_ptr> kernel_; diff --git a/src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.cpp b/src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.cpp index 4d14fcd2e1e..32c9bdbf230 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.cpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.cpp @@ -44,7 +44,7 @@ void pick_loop_order(jit_conv_conf_t &jcp) { jcp.loop_order = loop_cwgn; if (jcp.ngroups > 1) { jcp.loop_order = loop_ngcw; - if (jcp.mb < jcp.nthr) + if (jcp.mb < jcp.nthr && jcp.ndims != 5) jcp.loop_order = jcp.ndims == 3 ? loop_nwcg : loop_nhwcg; } else if (jcp.mb >= jcp.nthr && jcp.ic_without_padding <= 8) { jcp.loop_order = loop_ngcw; @@ -57,7 +57,7 @@ _jit_uni_x8s8s32x_fwd_kernel::_jit_uni_x8s8s32x_fwd_kernel( const jit_conv_conf_t &ajcp, const primitive_attr_t &attr, const memory_desc_t &dst_md) : jit_generator(nullptr, MAX_CODE_SIZE, true, isa), jcp(ajcp), attr_(attr) { - if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; @@ -73,10 +73,12 @@ _jit_uni_x8s8s32x_fwd_kernel::_jit_uni_x8s8s32x_fwd_kernel( memory_desc_wrapper(dst_md), tail_size, true}; const static_params_t static_params { this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params = + {vmm_d_weights.getIdx(), vmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } } @@ -173,14 +175,28 @@ template void _jit_uni_x8s8s32x_fwd_kernel::apply_postops( const int nb_oc_block, const int ur_w, const bool last_oc_block_flag, const int oc_block, const float *p_sum_scale, const int32_t *p_sum_zp) { - if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { - if (jcp.with_sum && *p_sum_zp != 0) push(reg_ptr_sum_zp); + if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum || jcp.with_depthwise || jcp.with_quantization) { + std::map vmm_idx_off; + iterate(nb_oc_block, ur_w, + [&](const bool, const int k, const int j) { + vmm_idx_off.insert({vmm_out_idx(j, k), k * oc_block * sizeof(float)}); + }); + depthwise_injector::dynamic_params_t ddp {vmm_d_weights.getIdx(), vmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off, + this->rsp, base_post_ops_data_offset}; + quantization_injector::dynamic_params_t qdp {ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off, jcp.dst_dt, + this->rsp, base_post_ops_data_offset}; + + if (jcp.with_sum && *p_sum_zp != 0) { + base_post_ops_data_offset += reg64_size; + push(reg_ptr_sum_zp); + } apply_sum(nb_oc_block, ur_w, last_oc_block_flag, oc_block, p_sum_scale, p_sum_zp); vmm_index_set_t vmm_idxs; + binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; if (jcp.with_binary) { - binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; const bool oc_blk_is_smaller_than_vmm = oc_block < isa_simd_width_; iterate(nb_oc_block, ur_w, last_oc_block_flag, oc_blk_is_smaller_than_vmm, @@ -207,9 +223,12 @@ void _jit_uni_x8s8s32x_fwd_kernel::apply_postops( [&](const bool, const int k, const int j) { vmm_idxs.emplace(vmm_out_idx(j, k)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); + } + if (jcp.with_sum && *p_sum_zp != 0) { + base_post_ops_data_offset -= reg64_size; + pop(reg_ptr_sum_zp); } - if (jcp.with_sum && *p_sum_zp != 0) pop(reg_ptr_sum_zp); } } @@ -222,7 +241,7 @@ void _jit_uni_x8s8s32x_fwd_kernel::store_output( mov(reg_bias, ptr[param1 + GET_OFF(bias)]); mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); - if (jcp.signed_input) + if (jcp.signed_input || jcp.with_input_zp) mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]); if (jcp.src_zero_point) { @@ -258,7 +277,7 @@ void _jit_uni_x8s8s32x_fwd_kernel::store_output( if (jcp.signed_input && jcp.ver != ver_vnni) /* bias *= 0.5 */ uni_vmulps(vmm_bias, vmm_bias, vmm_bias_alpha()); } - if (jcp.signed_input) { + if (jcp.signed_input || jcp.with_input_zp) { const int comp_offset = sizeof(int32_t) * k * oc_block; load_data(data_type::s32, vmm_comp, reg_compensation, comp_offset, load_size); @@ -284,10 +303,9 @@ void _jit_uni_x8s8s32x_fwd_kernel::store_output( /* add comp in s32 to avoid loss of precision when convert s32 to f32 in integer (2^24) TODO: do the same to bias */ - if (jcp.signed_input) uni_vpaddd(vmm, vmm, vmm_comp); + if (jcp.signed_input || jcp.with_input_zp) uni_vpaddd(vmm, vmm, vmm_comp); if (jcp.src_zero_point) uni_vpaddd(vmm, vmm, vmm_zp_comp); uni_vcvtdq2ps(vmm, vmm); - if (jcp.with_bias) uni_vaddps(vmm, vmm, vmm_bias); uni_vmulps(vmm, vmm, vmm_scale); @@ -371,9 +389,10 @@ void _jit_uni_x8s8s32x_fwd_kernel::compute_ker_dw(int ur_w, int pad_l, && std::is_same::value)) assert(!"invalid group blocking for depthwise convolution"); - const bool compute_kernel = IMPLICATION(h_padded, jcp.signed_input); + const bool compute_kernel = IMPLICATION(h_padded, jcp.signed_input || jcp.with_input_zp); if (jcp.src_zero_point) { + base_post_ops_data_offset += reg64_size; push(aux_reg_ker_d); mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]); uni_vpbroadcastd(vmm_zp, ptr[reg_src_zero_point]); @@ -396,7 +415,7 @@ void _jit_uni_x8s8s32x_fwd_kernel::compute_ker_dw(int ur_w, int pad_l, }; auto kernel_offset = [=](int ci, int ki) { - return jcp.typesize_in * ((ci * jcp.kh * jcp.kw + ki) * jcp.ch_block); + return jcp.typesize_in * ((ci * jcp.kd * jcp.kh * jcp.kw + ki) * jcp.ch_block); }; auto compute = [=](Vmm vreg_acc, Vmm vreg_wei, Vmm vreg_src) { @@ -427,6 +446,10 @@ void _jit_uni_x8s8s32x_fwd_kernel::compute_ker_dw(int ur_w, int pad_l, } for (int ci = 0; ci < jcp.nb_ch_blocking; ++ci) { + if (jcp.with_input_zp && (h_padded || get_ow_start(0, pad_l) != 0 || get_ow_end(ur_w, jcp.kw-1, pad_r) != ur_w)) { + load_data(data_type::u8, vmm_shift, reg_input_zp, ci * jcp.ch_block, get_blocking_size()); + } + const bool mask_flag = last_ic_block_flag != no_last_block && ci == jcp.nb_ch_blocking - 1; if (jcp.is_resrc_depthwise && !h_padded) { @@ -451,12 +474,12 @@ void _jit_uni_x8s8s32x_fwd_kernel::compute_ker_dw(int ur_w, int pad_l, if (compute_kernel) { uni_vpmovsxbd(vmm_wei, ptr[aux_reg_ker + aux_kernel_offset]); if (h_padded) { - assert(jcp.signed_input); + assert(jcp.signed_input || jcp.with_input_zp); for (int oi = 0; oi < ur_w; ++oi) compute(vmm_out(oi, ci), vmm_wei, vmm_shift); } else { - int start = jcp.signed_input ? 0 : oi_start; - int end = jcp.signed_input ? ur_w : oi_end; + int start = (jcp.signed_input || jcp.with_input_zp) ? 0 : oi_start; + int end = (jcp.signed_input || jcp.with_input_zp) ? ur_w : oi_end; for (int oi = start; oi < end; ++oi) { if (oi >= oi_start && oi < oi_end) { if (jcp.is_resrc_depthwise) { @@ -475,7 +498,7 @@ void _jit_uni_x8s8s32x_fwd_kernel::compute_ker_dw(int ur_w, int pad_l, } compute(vmm_out(oi, ci), vmm_wei, vmm_dw_src); } else { - assert(jcp.signed_input); + assert(jcp.signed_input || jcp.with_input_zp); compute(vmm_out(oi, ci), vmm_wei, vmm_shift); } } @@ -500,7 +523,10 @@ void _jit_uni_x8s8s32x_fwd_kernel::compute_ker_dw(int ur_w, int pad_l, } } - if (jcp.src_zero_point) pop(aux_reg_ker_d); + if (jcp.src_zero_point) { + base_post_ops_data_offset -= reg64_size; + pop(aux_reg_ker_d); + } } template @@ -517,11 +543,12 @@ void _jit_uni_x8s8s32x_fwd_kernel::compute_ker(int ur_w, int pad_l, int nb_oc_block = jcp.nb_oc_blocking; - const bool compute_kernel = IMPLICATION(h_padded, jcp.signed_input); + const bool compute_kernel = IMPLICATION(h_padded, jcp.signed_input || jcp.with_input_zp); - assert(IMPLICATION(h_padded, jcp.src_zero_point || jcp.signed_input)); + assert(IMPLICATION(h_padded, jcp.src_zero_point || jcp.signed_input || jcp.with_input_zp)); if (jcp.src_zero_point) { + base_post_ops_data_offset += reg64_size; push(aux_reg_ker_d); mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]); } @@ -553,8 +580,8 @@ void _jit_uni_x8s8s32x_fwd_kernel::compute_ker(int ur_w, int pad_l, const int ow_end = get_ow_end(ur_w, ki, pad_r); const int ic_tail_size = jcp.ic_without_padding % ic_sub_step; - const int _start = jcp.signed_input ? 0 : ow_start; - const int _end = jcp.signed_input ? ur_w : ow_end; + const int _start = (jcp.signed_input || jcp.with_input_zp) ? 0 : ow_start; + const int _end = (jcp.signed_input || jcp.with_input_zp) ? ur_w : ow_end; /* Skip the last loads of input if (ic % 8) / ic_sub_step < ic_block / ic_sub_step */ @@ -565,6 +592,9 @@ void _jit_uni_x8s8s32x_fwd_kernel::compute_ker(int ur_w, int pad_l, if (compute_kernel) { for (int ic = 0; ic < icb; ++ic) { if (h_padded) { + if (jcp.with_input_zp) + uni_vpbroadcastd(vmm_shift, ptr[reg_input_zp + ic_sub_step * ic * sizeof(uint8_t)]); + // fill padded area with shifted value in first iteration if (ic == 0) { const Vmm inp = vmm_inp(0, nb_oc_block); @@ -597,7 +627,10 @@ void _jit_uni_x8s8s32x_fwd_kernel::compute_ker(int ur_w, int pad_l, } else { // fill padded area with shifted value in // first iteration - if (jcp.signed_input && ic == 0) { + if ((jcp.signed_input || jcp.with_input_zp) && ic == 0) { + if (jcp.with_input_zp) + uni_vpbroadcastd(vmm_shift, ptr[reg_input_zp + 4 * ic * sizeof(uint8_t)]); + const Vmm inp = vmm_inp(jj, nb_oc_block); uni_vmovups(inp, vmm_shift); } @@ -644,7 +677,10 @@ void _jit_uni_x8s8s32x_fwd_kernel::compute_ker(int ur_w, int pad_l, } } } - if (jcp.src_zero_point) pop(aux_reg_ker_d); + if (jcp.src_zero_point) { + base_post_ops_data_offset -= reg64_size; + pop(aux_reg_ker_d); + } } template @@ -672,7 +708,7 @@ void _jit_uni_x8s8s32x_fwd_kernel::kh_loop( if (jcp.ndims == 5) { mov(aux_reg_ker_d, reg_ker); mov(aux_reg_inp_d, reg_inp); - if (jcp.signed_input || jcp.src_zero_point) { + if (jcp.signed_input || jcp.src_zero_point || jcp.with_input_zp) { //TODO: May be avoided when f_pad=0 and dd0 //TODO: Potential optimization by precomputing, when kd <<< od? mov(reg_ki, ptr[param1 + GET_OFF(f_overflow)]); @@ -697,8 +733,8 @@ void _jit_uni_x8s8s32x_fwd_kernel::kh_loop( } mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); - if ((jcp.signed_input || jcp.src_zero_point) || (jcp.dilate_d >= jcp.id) - || (!(jcp.signed_input || jcp.src_zero_point) + if ((jcp.signed_input || jcp.src_zero_point || jcp.with_input_zp) || (jcp.dilate_d >= jcp.id) + || (!(jcp.signed_input || jcp.src_zero_point || jcp.with_input_zp) && (jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad))) { cmp(reg_ki, 0); @@ -716,7 +752,7 @@ void _jit_uni_x8s8s32x_fwd_kernel::kh_loop( mov(aux_reg_ker, reg_ker); } - if ((jcp.signed_input || jcp.src_zero_point) && jcp.ndims > 3) { + if ((jcp.signed_input || jcp.src_zero_point || jcp.with_input_zp) && jcp.ndims > 3) { mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]); cmp(reg_overflow, 0); je(no_t_overflow_label, T_NEAR); @@ -732,8 +768,8 @@ void _jit_uni_x8s8s32x_fwd_kernel::kh_loop( L(no_t_overflow_label); } mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); - if ((jcp.signed_input || jcp.src_zero_point) || (jcp.dilate_h >= jcp.ih) - || (!(jcp.signed_input || jcp.src_zero_point) + if ((jcp.signed_input || jcp.src_zero_point || jcp.with_input_zp) || (jcp.dilate_h >= jcp.ih) + || (!(jcp.signed_input || jcp.src_zero_point || jcp.with_input_zp) && (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad))) { cmp(reg_kj, 0); @@ -758,7 +794,7 @@ void _jit_uni_x8s8s32x_fwd_kernel::kh_loop( jg(kh_label, T_NEAR); } L(skip_kh_loop); - if ((jcp.signed_input || jcp.src_zero_point) && jcp.ndims > 3) { + if ((jcp.signed_input || jcp.src_zero_point || jcp.with_input_zp) && jcp.ndims > 3) { mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]); cmp(reg_overflow, 0); je(no_b_overflow_label, T_NEAR); @@ -780,7 +816,7 @@ void _jit_uni_x8s8s32x_fwd_kernel::kh_loop( jne(kd_label, T_NEAR); L(skip_kd_loop); - if (jcp.signed_input || jcp.src_zero_point) { + if (jcp.signed_input || jcp.src_zero_point || jcp.with_input_zp) { mov(reg_ki, ptr[param1 + GET_OFF(back_overflow)]); cmp(reg_ki, 0); je(no_back_overflow_label, T_NEAR); @@ -812,6 +848,9 @@ void _jit_uni_x8s8s32x_fwd_kernel::icb_loop( // IC loop Label icb_label; mov(reg_icb, jcp.nb_ic); + if (jcp.with_input_zp) + mov(reg_input_zp, ptr[param1 + GET_OFF(input_zp)]); + L(icb_label); const bool do_icb_loop = jcp.is_depthwise ? jcp.nb_ch > jcp.nb_ch_blocking : jcp.nb_ic > 1; @@ -844,6 +883,8 @@ void _jit_uni_x8s8s32x_fwd_kernel::icb_loop( * jcp.ic_block; add(reg_inp, jcp.typesize_in * inp_step); safe_add(reg_ker, jcp.typesize_in * ker_step, reg_ker_long_offt); + if (jcp.with_input_zp) + add(reg_input_zp, sizeof(uint8_t) * inp_step); dec(reg_icb); cmp(reg_icb, 0); @@ -892,16 +933,19 @@ void _jit_uni_x8s8s32x_fwd_kernel::generate() { * (jcp.ur_w * jcp.oc_without_padding * jcp.ngroups); preamble(); + if (postops_injector_) + postops_injector_->push_post_ops_data_on_stack(this->param1, GET_OFF(post_ops_binary_rhs_arg_vec), reg_inp, reg_out); + if (jcp.is_depthwise) { const bool is_zero_point = jcp.src_zero_point || jcp.dst_zero_point; int idx = ker_max_reg + 1 - jcp.max_regs_ur - 2 * is_zero_point; if (!jcp.is_resrc_depthwise) vmm_dw_src = Vmm(--idx); if (jcp.ver != ver_vnni) vmm_dw_tmp = Vmm(--idx); - if (jcp.signed_input) { + if (jcp.signed_input || jcp.with_input_zp) { --idx; // due to extra register used for compensations } assert(IMPLICATION( - !is_zero_point, idx == ker_max_reg - ker_dw_reg_base_idx)); + !(is_zero_point || jcp.with_input_zp), idx == ker_max_reg - ker_dw_reg_base_idx)); } if (!jcp.is_depthwise && jcp.ver != ver_vnni) { @@ -1217,6 +1261,9 @@ void _jit_uni_x8s8s32x_fwd_kernel::generate() { L(done_compute); assert(ow_block_jmp_table.size() == static_cast(label_cntr)); + if (postops_injector_) + postops_injector_->reset_stack_pointer(); + postamble(); if (jcp.with_eltwise) postops_injector_->prepare_table(); @@ -1307,7 +1354,22 @@ status_t jit_uni_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, if ((jcp.dst_zero_point || jcp.src_zero_point) && jcp.is_fused_conv) return status::unimplemented; - if (is_3d && jcp.is_depthwise) return status::unimplemented; + jcp.with_input_zp = !attr.input_zero_points_.has_default_values(); + jcp.with_weights_zp = !attr.weights_zero_points_.has_default_values(); + + if (jcp.with_input_zp) { + if (attr.input_zero_points_.count_ != 1 && attr.input_zero_points_.count_ != jcp.ic * jcp.ngroups) + return status::unimplemented; + + if (attr.output_compensations_.count_ != jcp.oc * jcp.ngroups) + return status::unimplemented; + } + + if (jcp.with_input_zp && jcp.is_depthwise && !utils::one_of(ndims, 3, 4)) + return status::unimplemented; + + if (jcp.with_weights_zp) + return status::unimplemented; if (jcp.is_depthwise) { jcp.ch_block = is_avx2 ? 8 : 4; @@ -1339,7 +1401,7 @@ status_t jit_uni_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, && jcp.kw < 4 && jcp.dilate_w == 0; if (jcp.is_depthwise) { - jcp.max_regs_ur = 14 - !jcp.is_resrc_depthwise - jcp.signed_input + jcp.max_regs_ur = 14 - !jcp.is_resrc_depthwise - (jcp.signed_input || jcp.with_input_zp) + (jcp.ver == ver_vnni); } else { jcp.max_regs_ur = jcp.ver == ver_vnni ? 15 - jcp.signed_input : 12; @@ -1361,7 +1423,8 @@ status_t jit_uni_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, wei_tag = with_groups ? jcp.is_depthwise ? Goihw8g : gOIhw2i8o4i : OIhw2i8o4i; } else { - wei_tag = with_groups ? gOIdhw2i8o4i : OIdhw2i8o4i; + wei_tag = with_groups ? jcp.is_depthwise ? Goidhw8g : gOIdhw2i8o4i + : OIdhw2i8o4i; } } else { if (is_avx2) { @@ -1376,7 +1439,9 @@ status_t jit_uni_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, ? jcp.is_depthwise ? Goihw4g : gOIhw4o4i : OIhw4o4i; } else { - wei_tag = with_groups ? gOIdhw4o4i : OIdhw4o4i; + wei_tag = with_groups + ? jcp.is_depthwise ? Goidhw4g : gOIdhw4o4i + : OIdhw4o4i; } } } @@ -1442,13 +1507,17 @@ status_t jit_uni_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, jcp.with_binary = binary_ind != -1; const int sum_ind = post_ops.find(primitive_kind::sum); jcp.with_sum = sum_ind != -1; - jcp.sum_dt = post_ops.get_sum_dt(jcp.dst_dt); + if (jcp.with_sum) + jcp.sum_dt = post_ops.get_sum_dt(jcp.dst_dt); + + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise) != -1; + jcp.with_quantization = post_ops.find(primitive_kind::quantization) != -1; jcp.post_ops = post_ops; using namespace injector; - const bool post_ops_ok_ = post_ops_ok({isa, {eltwise, binary, sum}, + const bool post_ops_ok_ = post_ops_ok({isa, {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, false, false, false}); if (!post_ops_ok_) return status::unimplemented; diff --git a/src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.hpp b/src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.hpp index ffcf4526899..4cb6dec6f8c 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.hpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.hpp @@ -96,6 +96,16 @@ struct _jit_uni_x8s8s32x_fwd_kernel : public jit_generator { /* binary post-ops operand */ const Xbyak::Reg64 temp_offset_reg = r12; + const Xbyak::Reg64 reg_input_zp = reg_bias_alpha; + + const Xbyak::Reg64 reg_d_weights = r15; + const Xbyak::Reg64 reg_d_bias = r13; + int base_post_ops_data_offset = 0; + constexpr static int reg64_size = 8; + + const Vmm vmm_d_weights = Vmm(0); + const Vmm vmm_d_bias = Vmm(1); + const Vmm vmm_wei = Vmm(0); /* used during bias/comp/scale section of store_output */ const Vmm vmm_bias = Vmm(0); diff --git a/src/cpu/x64/jit_uni_x8s8s32x_convolution.cpp b/src/cpu/x64/jit_uni_x8s8s32x_convolution.cpp index 854aa84df4e..84548a6687d 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_convolution.cpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_convolution.cpp @@ -53,6 +53,11 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_2d( DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); + DEFINE_INPUT_ZERO_POINTS_BUFFER(input_zp, jcp); + DEFINE_OUTPUT_COMPENSATION_BUFFER(output_compensation, jcp); + + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const memory_desc_wrapper src_d(pd()->src_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); @@ -84,16 +89,15 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_2d( size_t offset = weights_d.size() - weights_d.additional_buffer_size(); auto w = const_cast(weights); - const int32_t *compensation = (jcp.signed_input) - ? reinterpret_cast(&w[offset]) - : nullptr; + const int32_t *compensation = (jcp.signed_input) ? reinterpret_cast(&w[offset]) : + (jcp.with_input_zp) ? output_compensation : nullptr; const int32_t *zp_compensation = jcp.src_zero_point ? reinterpret_cast(&w[offset]) + (jcp.signed_input ? jcp.ngroups * jcp.oc : 0) : nullptr; int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk; int nb_groups = jcp.nb_ch; - int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow; + int work_amount = MB * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow; parallel(jcp.nthr, [&](const int ithr, const int nthr) { int start {0}, end {0}; @@ -109,14 +113,14 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_2d( switch (jcp.loop_order) { case loop_cwgn: nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, g, - nb_groups, n, jcp.mb, oh_s, jcp.oh); + nb_groups, n, MB, oh_s, jcp.oh); break; case loop_ngcw: - nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks, + nd_iterator_init(start, n, MB, g, nb_groups, occ, oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); break; case loop_nhwcg: - nd_iterator_init(start, n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, + nd_iterator_init(start, n, MB, oh_s, jcp.oh, owb, jcp.nb_ow, occ, oc_chunks, g, nb_groups); break; default: assert(!"unsupported loop order"); @@ -140,7 +144,7 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_2d( auto bias_w = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) : nullptr; const int32_t *compensation_w - = (jcp.signed_input) ? compensation + g_oc : nullptr; + = (jcp.signed_input || jcp.with_input_zp) ? compensation + g_oc : nullptr; auto dst_w = dst + dst_dt_size * dst_d.blk_off(n, g_oc, oh_s, ow_s); @@ -163,7 +167,7 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_2d( 0, jcp.kh - i_t_overflow - i_b_overflow); const size_t wei_stride - = (jcp.signed_input || jcp.src_zero_point) + = (jcp.signed_input || jcp.src_zero_point || jcp.with_input_zp) ? 0 : i_t_overflow * wht_h_stride; p.src = src_w + i_t_overflow * dilate_h * src_h_stride; @@ -189,6 +193,9 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_2d( p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); p.dst_orig = dst; + p.oc_off = g_oc * sizeof(float); + if (jcp.with_input_zp) + p.input_zp = input_zp + g_ic; (*kernel_)(&p); src_w += src_h_stride * jcp.stride_h; @@ -198,15 +205,15 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_2d( switch (jcp.loop_order) { case loop_cwgn: nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, - g, nb_groups, n, jcp.mb, oh_s, jcp.oh); + g, nb_groups, n, MB, oh_s, jcp.oh); break; case loop_ngcw: - nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ, + nd_iterator_jump(start, end, n, MB, g, nb_groups, occ, oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); break; case loop_nhwcg: ++start; - nd_iterator_step(n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, + nd_iterator_step(n, MB, oh_s, jcp.oh, owb, jcp.nb_ow, occ, oc_chunks, g, nb_groups); break; default: assert(!"unsupported loop order"); @@ -230,6 +237,11 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_1d( DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); + DEFINE_INPUT_ZERO_POINTS_BUFFER(input_zp, jcp); + DEFINE_OUTPUT_COMPENSATION_BUFFER(output_compensation, jcp); + + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const memory_desc_wrapper src_d(pd()->src_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); @@ -262,9 +274,8 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_1d( size_t ch_offset = jcp.is_depthwise ? jcp.nb_ch * jcp.ch_block : jcp.ngroups * jcp.oc; auto w = const_cast(weights); - const int32_t *compensation = (jcp.signed_input) - ? reinterpret_cast(&w[extra_data_offset]) - : nullptr; + const int32_t *compensation = (jcp.signed_input) ? reinterpret_cast(&w[extra_data_offset]) : + (jcp.with_input_zp) ? output_compensation : nullptr; const int32_t *zp_compensation = jcp.src_zero_point ? reinterpret_cast(&w[extra_data_offset]) + (jcp.signed_input ? ch_offset : 0) @@ -273,7 +284,7 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_1d( int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking; int group_block = jcp.ch_block; - int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.nb_ow; + int work_amount = MB * nb_groups * oc_chunks * jcp.nb_ow; parallel(jcp.nthr, [&](const int ithr, const int nthr) { int start {0}, end {0}; balance211(work_amount, nthr, ithr, start, end); @@ -284,18 +295,18 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_1d( switch (jcp.loop_order) { case loop_cwgn: nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg, - nb_groups, n, jcp.mb); + nb_groups, n, MB); break; case loop_gncw: - nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, + nd_iterator_init(start, gg, nb_groups, n, MB, occ, oc_chunks, owb, jcp.nb_ow); break; case loop_ngcw: - nd_iterator_init(start, n, jcp.mb, gg, nb_groups, occ, + nd_iterator_init(start, n, MB, gg, nb_groups, occ, oc_chunks, owb, jcp.nb_ow); break; case loop_nwcg: - nd_iterator_init(start, n, jcp.mb, owb, jcp.nb_ow, occ, + nd_iterator_init(start, n, MB, owb, jcp.nb_ow, occ, oc_chunks, gg, nb_groups); break; default: assert(!"unsupported loop order"); @@ -311,7 +322,7 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_1d( p.bias = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) : nullptr; - p.compensation = (jcp.signed_input) ? compensation + g_oc : nullptr; + p.compensation = (jcp.signed_input || jcp.with_input_zp) ? compensation + g_oc : nullptr; p.zp_compensation = jcp.src_zero_point ? zp_compensation + g_oc : nullptr; p.src_zero_point = jcp.src_zero_point ? src_zero_point : nullptr; @@ -329,6 +340,9 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_1d( p.oc_l_off = g_oc; p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); p.dst_orig = dst; + p.oc_off = g_oc * sizeof(float); + if (jcp.with_input_zp) + p.input_zp = input_zp + g_ic; (*kernel_)(&p); @@ -336,18 +350,18 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_1d( switch (jcp.loop_order) { case loop_cwgn: nd_iterator_step(occ, oc_chunks, owb, jcp.nb_ow, gg, - nb_groups, n, jcp.mb); + nb_groups, n, MB); break; case loop_gncw: - nd_iterator_step(gg, nb_groups, n, jcp.mb, occ, oc_chunks, + nd_iterator_step(gg, nb_groups, n, MB, occ, oc_chunks, owb, jcp.nb_ow); break; case loop_ngcw: - nd_iterator_step(n, jcp.mb, gg, nb_groups, occ, oc_chunks, + nd_iterator_step(n, MB, gg, nb_groups, occ, oc_chunks, owb, jcp.nb_ow); break; case loop_nwcg: - nd_iterator_step(n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks, + nd_iterator_step(n, MB, owb, jcp.nb_ow, occ, oc_chunks, gg, nb_groups); break; default: assert(!"unsupported loop order"); @@ -371,6 +385,11 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_2d_dw( DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); + DEFINE_INPUT_ZERO_POINTS_BUFFER(input_zp, jcp); + DEFINE_OUTPUT_COMPENSATION_BUFFER(output_compensation, jcp); + + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const memory_desc_wrapper src_d(pd()->src_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); @@ -404,9 +423,8 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_2d_dw( size_t offset = weights_d.size() - weights_d.additional_buffer_size(); auto w = const_cast(weights); - const int32_t *compensation = (jcp.signed_input) - ? reinterpret_cast(&w[offset]) - : nullptr; + const int32_t *compensation = (jcp.signed_input) ? reinterpret_cast(&w[offset]) : + (jcp.with_input_zp) ? output_compensation : nullptr; const int32_t *zp_compensation = jcp.src_zero_point ? reinterpret_cast(&w[offset]) + (jcp.signed_input ? jcp.nb_ch * jcp.ch_block : 0) @@ -414,7 +432,7 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_2d_dw( int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking; int group_block = jcp.ch_block; - parallel_nd(jcp.mb, jcp.oh, jcp.nb_ow, nb_groups, + parallel_nd(MB, jcp.oh, jcp.nb_ow, nb_groups, [&](dim_t n, dim_t oh_s, dim_t owb, dim_t gg) { auto p = jit_conv_call_s(); @@ -431,7 +449,7 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_2d_dw( auto bias_w = bias ? bias + (bias_d.blk_off(g) * bia_dt_size) : nullptr; const int32_t *compensation_w - = jcp.signed_input ? compensation + g : nullptr; + = (jcp.signed_input || jcp.with_input_zp) ? compensation + g : nullptr; auto dst_w = dst + dst_dt_size * dst_d.blk_off(n, g, oh_s, ow_s); @@ -451,7 +469,7 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_2d_dw( int kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow); - size_t wei_stride = (jcp.signed_input || jcp.src_zero_point) + size_t wei_stride = (jcp.signed_input || jcp.src_zero_point || jcp.with_input_zp) ? 0 : i_t_overflow * wht_h_stride; p.src = src_w + i_t_overflow * dilate_h * src_h_stride; @@ -476,6 +494,9 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_2d_dw( p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); p.dst_orig = dst; + p.oc_off = g * sizeof(float); + if (jcp.with_input_zp) + p.input_zp = input_zp + g; (*kernel_)(&p); }); @@ -496,6 +517,11 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_3d( DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); + DEFINE_INPUT_ZERO_POINTS_BUFFER(input_zp, jcp); + DEFINE_OUTPUT_COMPENSATION_BUFFER(output_compensation, jcp); + + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const memory_desc_wrapper src_d(pd()->src_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); @@ -527,9 +553,8 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_3d( size_t offset = weights_d.size() - weights_d.additional_buffer_size(); auto w = const_cast(weights); - const int32_t *compensation = (jcp.signed_input) - ? reinterpret_cast(&w[offset]) - : nullptr; + const int32_t *compensation = (jcp.signed_input) ? reinterpret_cast(&w[offset]) : + (jcp.with_input_zp) ? output_compensation : nullptr; const int32_t *zp_compensation = jcp.src_zero_point ? reinterpret_cast(&w[offset]) + (jcp.signed_input ? jcp.ngroups * jcp.oc : 0) @@ -537,7 +562,7 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_3d( int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk; int nb_groups = jcp.nb_ch; int work_amount - = jcp.mb * nb_groups * oc_chunks * jcp.od * jcp.oh * jcp.nb_ow; + = MB * nb_groups * oc_chunks * jcp.od * jcp.oh * jcp.nb_ow; parallel(jcp.nthr, [&](const int ithr, const int nthr) { int start {0}, end {0}; @@ -555,14 +580,14 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_3d( switch (jcp.loop_order) { case loop_cwgn: nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, g, - nb_groups, n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh); + nb_groups, n, MB, od_s, jcp.od, oh_s, jcp.oh); break; case loop_ngcw: - nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks, + nd_iterator_init(start, n, MB, g, nb_groups, occ, oc_chunks, owb, jcp.nb_ow, od_s, jcp.od, oh_s, jcp.oh); break; case loop_nhwcg: - nd_iterator_init(start, n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh, + nd_iterator_init(start, n, MB, od_s, jcp.od, oh_s, jcp.oh, owb, jcp.nb_ow, occ, oc_chunks, g, nb_groups); break; default: assert(!"unsupported loop order"); @@ -598,7 +623,7 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_3d( auto bias_w = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) : nullptr; const int32_t *compensation_w - = (jcp.signed_input) ? compensation + g_oc : nullptr; + = (jcp.signed_input || jcp.with_input_zp) ? compensation + g_oc : nullptr; p.zp_compensation = jcp.src_zero_point ? zp_compensation + g_oc : nullptr; p.src_zero_point @@ -612,7 +637,7 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_3d( auto src_w = src + src_d.blk_off(n, g_ic, id_s, ih_s, iw_s) + d_f_overflow * dilate_d * src_d_stride; auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0) - + ((jcp.signed_input || jcp.src_zero_point) + + ((jcp.signed_input || jcp.src_zero_point || jcp.with_input_zp) ? 0 : d_f_overflow) * wht_d_stride; @@ -632,7 +657,7 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_3d( int kh_padding = nstl::max( 0, jcp.kh - i_t_overflow - i_b_overflow); - size_t wei_stride = (jcp.signed_input || jcp.src_zero_point) + size_t wei_stride = (jcp.signed_input || jcp.src_zero_point || jcp.with_input_zp) ? 0 : wht_h_stride * i_t_overflow; p.src = src_w + i_t_overflow * dilate_h * src_h_stride; @@ -654,6 +679,9 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_3d( p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); p.dst_orig = dst; + p.oc_off = g_oc * sizeof(float); + if (jcp.with_input_zp) + p.input_zp = input_zp + g_ic; (*kernel_)(&p); src_w += src_h_stride * jcp.stride_h; @@ -663,17 +691,17 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_3d( switch (jcp.loop_order) { case loop_cwgn: nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, - g, nb_groups, n, jcp.mb, od_s, jcp.od, oh_s, + g, nb_groups, n, MB, od_s, jcp.od, oh_s, jcp.oh); break; case loop_ngcw: - nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ, + nd_iterator_jump(start, end, n, MB, g, nb_groups, occ, oc_chunks, owb, jcp.nb_ow, od_s, jcp.od, oh_s, jcp.oh); break; case loop_nhwcg: ++start; - nd_iterator_step(n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh, owb, + nd_iterator_step(n, MB, od_s, jcp.od, oh_s, jcp.oh, owb, jcp.nb_ow, occ, oc_chunks, g, nb_groups); break; default: assert(!"unsupported loop order"); @@ -683,6 +711,131 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_3d( return status::success; } +template +status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_3d_dw(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); + auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); + auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); + + const auto &jcp = pd()->jcp_; + const auto post_ops_binary_rhs_arg_vec + = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); + + DEFINE_INPUT_ZERO_POINTS_BUFFER(input_zp, jcp); + DEFINE_OUTPUT_COMPENSATION_BUFFER(output_compensation, jcp); + + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const size_t bia_dt_size + = pd()->with_bias() ? types::data_type_size(bias_d.data_type()) : 0; + const size_t dst_dt_size = types::data_type_size(dst_d.data_type()); + + assert(jcp.ic_block == 1); + assert(jcp.oc_block == 1); + assert(jcp.nb_ic == 1); + assert(jcp.nb_oc == 1); + assert(jcp.nb_oc_blocking == 1); + assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); + + const float *oscales = pd()->attr()->output_scales_.scales_; + if (jcp.signed_input && jcp.ver != ver_vnni) { + auto local_scales = ctx.get_scratchpad_grantor().template get( + key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, oscales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = oscales[c] * factor; + } + oscales = local_scales; + } + + size_t offset = weights_d.size() - weights_d.additional_buffer_size(); + auto w = const_cast(weights); + const int32_t* compensation = (jcp.signed_input) ? reinterpret_cast(&w[offset]) : + (jcp.with_input_zp) ? output_compensation : 0; + int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking; + int group_block = jcp.ch_block; + + parallel_nd(MB, jcp.od, jcp.oh, jcp.nb_ow, nb_groups, [&](int n, int od_s, int oh_s, int owb, int gg) { + auto p = jit_conv_call_s(); + + size_t src_d_stride = src_d.blk_off(0, 0, 1); + size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + + size_t src_h_stride = src_d.blk_off(0, 0, 0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1); + + int gb = gg * jcp.nb_ch_blocking; + int g = gb * group_block; + + int id_s = -jcp.f_pad + od_s * jcp.stride_d; + + int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + + auto bias_w = bias ? bias + (bias_d.blk_off(g) * bia_dt_size) : 0; + const int32_t *compensation_w = (jcp.signed_input || jcp.with_input_zp) ? compensation + g : 0; + + auto dst_w = dst + dst_dt_size * dst_d.blk_off(n, g, od_s, oh_s, ow_s); + auto src_w = src + src_d.blk_off(n, g, id_s, ih_s, iw_s); + auto wht_w = weights + wht_blk_off(weights_d, gb, 0); + + auto scales = &oscales[jcp.is_oc_scale * g]; + + int dilate_d = jcp.dilate_d + 1; + int i_f_overflow = nstl::min(jcp.kd, div_up(max(0, -id_s), dilate_d)); + int i_back_overflow = nstl::min(jcp.kd, + div_up(max(0, id_s - jcp.id + (jcp.kd - 1) * dilate_d + 1), + dilate_d)); + int kd_padding = nstl::max(0, jcp.kd - i_f_overflow - i_back_overflow); + + size_t wei_d_stride = (jcp.signed_input || jcp.with_input_zp) ? 0 : i_f_overflow * wht_d_stride; + + int dilate_h = jcp.dilate_h + 1; + int i_t_overflow = nstl::min(jcp.kh, div_up(max(0, -ih_s), dilate_h)); + int i_b_overflow = nstl::min(jcp.kh, + div_up(max(0, ih_s - jcp.ih + (jcp.kh - 1) * dilate_h + 1), + dilate_h)); + int kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow); + + size_t wei_h_stride = (jcp.signed_input || jcp.with_input_zp) ? 0 : i_t_overflow * wht_h_stride; + p.src = src_w + i_t_overflow * dilate_h * src_h_stride + + i_f_overflow * dilate_d * src_d_stride; + p.dst = dst_w; + p.filt = wht_w + wei_d_stride + wei_h_stride; + p.bias = bias_w; + p.compensation = compensation_w; + p.oc_blocks = gb; + p.kd_padding = kd_padding; + p.kh_padding = kh_padding; + p.scales = scales; + p.f_overflow = i_f_overflow; + p.back_overflow = i_back_overflow; + p.t_overflow = i_t_overflow; + p.b_overflow = i_b_overflow; + p.owb = owb; + p.post_ops_binary_rhs_arg_vec + = post_ops_binary_rhs_arg_vec.data(); + + p.oc_off = g * sizeof(float); + if (jcp.with_input_zp) + p.input_zp = input_zp + g; + + (*kernel_)(&p); + }); + return status::success; +} + template struct jit_uni_x8s8s32x_convolution_fwd_t; template struct jit_uni_x8s8s32x_convolution_fwd_t; diff --git a/src/cpu/x64/jit_uni_x8s8s32x_convolution.hpp b/src/cpu/x64/jit_uni_x8s8s32x_convolution.hpp index 2985feb045c..83827fa7e1e 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_convolution.hpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_convolution.hpp @@ -59,7 +59,10 @@ struct jit_uni_x8s8s32x_convolution_fwd_t : public primitive_t { && desc()->accum_data_type == s32 && attr()->has_default_values(smask_t::oscale | smask_t::zero_points_runtime - | smask_t::post_ops | smask_t::sum_dt, + | smask_t::post_ops + | smask_t::sum_dt + | smask_t::input_zero_points + | smask_t::output_compensations, dst_md(0)->data_type) && attr()->post_ops_.check_sum_consistent_dt( dst_md(0)->data_type) @@ -112,7 +115,9 @@ struct jit_uni_x8s8s32x_convolution_fwd_t : public primitive_t { case 4: if (is_dw) return execute_forward_2d_dw(ctx); return execute_forward_2d(ctx); - case 5: return execute_forward_3d(ctx); + case 5: + if (is_dw) return execute_forward_3d_dw(ctx); + return execute_forward_3d(ctx); } return status::unimplemented; } @@ -122,6 +127,7 @@ struct jit_uni_x8s8s32x_convolution_fwd_t : public primitive_t { status_t execute_forward_2d(const exec_ctx_t &ctx) const; status_t execute_forward_3d(const exec_ctx_t &ctx) const; status_t execute_forward_2d_dw(const exec_ctx_t &ctx) const; + status_t execute_forward_3d_dw(const exec_ctx_t &ctx) const; const pd_t *pd() const { return static_cast(primitive_t::pd().get()); } diff --git a/src/cpu/x64/jit_uni_x8s8s32x_deconvolution.cpp b/src/cpu/x64/jit_uni_x8s8s32x_deconvolution.cpp index 3743c41d205..694becd43ab 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_deconvolution.cpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_deconvolution.cpp @@ -147,11 +147,10 @@ status_t jit_uni_x8s8s32x_deconv_fwd_kernel::init_conf( memory_desc_init_by_tag(want_wei_md, wei_tag); if (jcp.signed_input && !jcp.is_depthwise) { want_wei_md.extra.flags = 0 - | memory_extra_flags::compensation_conv_s8s8 - | memory_extra_flags::scale_adjust; + | memory_extra_flags::compensation_conv_s8s8; want_wei_md.extra.compensation_mask = (1 << 0) + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0); - want_wei_md.extra.scale_adjust = (jcp.ver == ver_vnni) ? 1.f : 0.5f; + want_wei_md.extra.scale_adjust = 1.f; } if (jcp.src_zero_point) set_zp_src_comp_flags(want_wei_md, with_groups); @@ -250,6 +249,12 @@ status_t jit_uni_x8s8s32x_deconv_fwd_kernel::init_conf( const int sum_ind = p.find(primitive_kind::sum); jcp.with_sum = sum_ind != -1; + const int depthwise_ind = p.find(primitive_kind::depthwise); + jcp.with_depthwise = depthwise_ind != -1; + + const int quantization_ind = p.find(primitive_kind::quantization); + jcp.with_quantization = quantization_ind != -1; + const auto &oscales = attr.output_scales_; jcp.is_oc_scale = oscales.mask_ == 1 << 1; @@ -383,7 +388,7 @@ bool jit_uni_x8s8s32x_deconv_fwd_kernel::post_ops_ok(jit_conv_conf_t &jcp, const memory_desc_wrapper &dst_d, const primitive_attr_t &attr) { using namespace injector; - return injector::post_ops_ok(post_ops_ok_args_t(isa, {sum, eltwise, binary}, + return injector::post_ops_ok(post_ops_ok_args_t(isa, {sum, eltwise, binary, depthwise, quantization}, attr.post_ops_, &dst_d, false /*sum_at_pos_0_only*/, false /*sum_requires_scale_one*/, false /*sum_requires_zp_zero*/, {broadcasting_strategy_t::per_oc, @@ -398,7 +403,7 @@ _jit_uni_x8s8s32x_deconv_fwd_kernelparam1_, rhs_sp}; + const quantization_injector::static_params_t qsp {vmm_d_weights.getIdx(), vmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; + postops_injector_ = utils::make_unique< injector::jit_uni_postops_injector_t>( - this, jcp_.post_ops, bsp); + this, jcp_.post_ops, bsp, qsp); } } @@ -1041,10 +1048,23 @@ void _jit_uni_x8s8s32x_deconv_fwd_kernel::apply_postops(int ur_w, } } } + + std::map vmm_idx_off; + for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) { + for (int ur = 0; ur < ur_w; ur++) { + vmm_idx_off.insert({vmm_out(ur, ocb).getIdx(), ocb * jcp_.oc_block * sizeof(float)}); + } + } + depthwise_injector::dynamic_params_t ddp {vmm_d_weights.getIdx(), vmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off, + this->rsp, base_post_ops_data_offset}; + quantization_injector::dynamic_params_t qdp {ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off, jcp_.dst_dt, + this->rsp, base_post_ops_data_offset}; + const int nb_oc_block = jcp_.is_depthwise ? jcp_.nb_ch_blocking : jcp_.nb_oc_blocking; postops_injector_->compute_vector_range( - 16 - nb_oc_block * ur_w, 16, rhs_arg_params); + 16 - nb_oc_block * ur_w, 16, rhs_arg_params, ddp, qdp); } template @@ -1131,7 +1151,7 @@ void _jit_uni_x8s8s32x_deconv_fwd_kernel::store_output( if (p_sum_zp && *p_sum_zp != 0) { mov(reg_ptr_sum_zp_, reinterpret_cast(p_sum_zp)); } - if (jcp_.with_eltwise || jcp_.with_binary || jcp_.with_sum) + if (jcp_.with_eltwise || jcp_.with_binary || jcp_.with_sum || jcp_.with_depthwise || jcp_.with_quantization) apply_postops(ur_w, last_oc_block, p_sum_scale, p_sum_zp); if (jcp_.dst_zero_point) { mov(reg_zp_dst_, ptr[param1_ + GET_OFF(dst_zero_point)]); @@ -1283,8 +1303,13 @@ template void _jit_uni_x8s8s32x_deconv_fwd_kernel::generate() { preamble(); - if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp_)) + if (postops_injector_) + postops_injector_->push_post_ops_data_on_stack(param1, GET_OFF(post_ops_binary_rhs_arg_vec), reg_src_, reg_filt_); + + if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp_)) { sub(rsp, reserved_stack_size_); + base_post_ops_data_offset += reserved_stack_size_; + } const auto vmm_one_128 = Xbyak::Xmm(vmm_one_.getIdx()); mov(reg_scratch_, 0x10001); @@ -1350,8 +1375,13 @@ void _jit_uni_x8s8s32x_deconv_fwd_kernel::generate() { } } - if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp_)) + if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp_)) { add(rsp, reserved_stack_size_); + base_post_ops_data_offset -= reserved_stack_size_; + } + + if (postops_injector_) + postops_injector_->reset_stack_pointer(); postamble(); @@ -1528,6 +1558,8 @@ status_t jit_uni_x8s8s32x_deconvolution_fwd_t::execute_forward_1d( p.dst_zero_point = zp_dst; p.dst_orig = dst; + p.oc_off = g_oc * sizeof(float); + (*kernel_)(&p); ++start; @@ -1706,6 +1738,8 @@ status_t jit_uni_x8s8s32x_deconvolution_fwd_t::execute_forward_2d( p.dst_zero_point = zp_dst; p.dst_orig = dst; + p.oc_off = g_oc * sizeof(float); + (*kernel_)(&p); } if (jcp.loop_order == loop_ngc) @@ -1941,6 +1975,8 @@ status_t jit_uni_x8s8s32x_deconvolution_fwd_t::execute_forward_3d( p.dst_zero_point = zp_dst; p.dst_orig = dst; + p.oc_off = g_oc * sizeof(float); + (*kernel_)(&p); } diff --git a/src/cpu/x64/jit_uni_x8s8s32x_deconvolution.hpp b/src/cpu/x64/jit_uni_x8s8s32x_deconvolution.hpp index 24ee691b42f..7f3bcac93c6 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_deconvolution.hpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_deconvolution.hpp @@ -141,6 +141,14 @@ struct _jit_uni_x8s8s32x_deconv_fwd_kernel : public jit_generator { int ur_w, int l_overflow, int r_overflow, bool h_padded); void append_zp_src_pad_str_comp(int ur_w, int l_overflow, int r_overflow, bool h_padded, bool last_oc_block); + + /* depthwise and quantization post ops */ + const Xbyak::Reg64 reg_d_weights = r15; + const Xbyak::Reg64 reg_d_bias = r13; + int base_post_ops_data_offset = 0; + Vmm vmm_d_weights = Vmm(0); + Vmm vmm_d_bias = Vmm(1); + void kh_loop(int ur_w, int pad_l, int pad_r, ker_block_t last_ker_block); void icb_loop(int ur_w, int pad_l, int pad_r, bool last_block); void generate() override; diff --git a/src/cpu/x64/lrn/lrn_avx512_blocked_executor.hpp b/src/cpu/x64/lrn/lrn_avx512_blocked_executor.hpp index c6d85b9cc63..e6a9677daea 100644 --- a/src/cpu/x64/lrn/lrn_avx512_blocked_executor.hpp +++ b/src/cpu/x64/lrn/lrn_avx512_blocked_executor.hpp @@ -84,6 +84,8 @@ class lrn_avx512_blocked_executor_fwd_t : public i_lrn_executor_t { const auto ws = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_WORKSPACE, status); CHECK(status); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const auto ker = ker_.get(); const auto ker_first = ker_first_.get(); const auto ker_last = ker_last_.get(); @@ -92,12 +94,12 @@ class lrn_avx512_blocked_executor_fwd_t : public i_lrn_executor_t { size_t start {0}, end {0}; const int C16 = C_ / vsize_; const size_t work_amount - = use_h_parallelism_ ? N_ * C16 * H_ : N_ * C16; + = use_h_parallelism_ ? MB * C16 * H_ : MB * C16; balance211(work_amount, nthr, ithr, start, end); if (use_h_parallelism_) { int n {0}, c16 {0}, h {0}; - nd_iterator_init(start, n, N_, c16, C16, h, H_); + nd_iterator_init(start, n, MB, c16, C16, h, H_); for (size_t iwork = start; iwork < end; ++iwork) { const auto offset = n * C_ * H_ * W_ + c16 * H_ * W_ * vsize_ + h * W_ * vsize_; @@ -120,11 +122,11 @@ class lrn_avx512_blocked_executor_fwd_t : public i_lrn_executor_t { (*ker_last)(&args); else (*ker)(&args); - nd_iterator_step(n, N_, c16, C16, h, H_); + nd_iterator_step(n, MB, c16, C16, h, H_); } } else { int n {0}, c16 {0}; - nd_iterator_init(start, n, N_, c16, C16); + nd_iterator_init(start, n, MB, c16, C16); for (size_t iwork = start; iwork < end; ++iwork) { const auto offset = n * C_ * H_ * W_ + c16 * H_ * W_ * vsize_; @@ -148,7 +150,7 @@ class lrn_avx512_blocked_executor_fwd_t : public i_lrn_executor_t { else (*ker)(&args); - nd_iterator_step(n, N_, c16, C16); + nd_iterator_step(n, MB, c16, C16); } } }); diff --git a/src/cpu/x64/lrn/lrn_avx512_nhwc_executor.hpp b/src/cpu/x64/lrn/lrn_avx512_nhwc_executor.hpp index d4d5f519452..41a7c92335f 100644 --- a/src/cpu/x64/lrn/lrn_avx512_nhwc_executor.hpp +++ b/src/cpu/x64/lrn/lrn_avx512_nhwc_executor.hpp @@ -54,8 +54,10 @@ class lrn_avx512_nhwc_executor_fwd_t : public i_lrn_executor_t { const auto ws = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_WORKSPACE, status); CHECK(status); + auto MB = CTX_IN_BATCH(DNNL_ARG_SRC); + const auto ker = ker_.get(); - parallel_nd(N_, H_ * W_, [&](dim_t n, dim_t pixel_id) { + parallel_nd(MB, H_ * W_, [&](dim_t n, dim_t pixel_id) { typename lrn::jit_avx512_common_lrn_kernel_fwd_t< d_type>::jit_args_fwd_t args; const auto offset = n * C_ * H_ * W_ + pixel_id * C_; diff --git a/src/cpu/x64/matmul/brgemm_matmul.hpp b/src/cpu/x64/matmul/brgemm_matmul.hpp index a3d5d7615d3..86547e9825e 100644 --- a/src/cpu/x64/matmul/brgemm_matmul.hpp +++ b/src/cpu/x64/matmul/brgemm_matmul.hpp @@ -62,7 +62,7 @@ struct brgemm_matmul_t : public primitive_t { using ::dnnl::impl::cpu::matmul::cpu_matmul_pd_t::cpu_matmul_pd_t; DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("brg:", isa, ""), brgemm_matmul_t); + JIT_IMPL_NAME_HELPER("brgemm:", isa, ""), brgemm_matmul_t); status_t init(engine_t *engine); int get_brg_kernel_idx(bool do_initialization, bool is_M_tail, diff --git a/src/cpu/x64/prelu/jit_prelu_backward.cpp b/src/cpu/x64/prelu/jit_prelu_backward.cpp index c0f4718633a..ff5243a4ccd 100644 --- a/src/cpu/x64/prelu/jit_prelu_backward.cpp +++ b/src/cpu/x64/prelu/jit_prelu_backward.cpp @@ -293,7 +293,7 @@ void jit_prelu_bwd_t::fill_scratchpad_zeros(float *const scratchpad, parallel(nthr, [&](std::size_t ithr, std::size_t) { float *scratchpad_ithr = scratchpad + ithr * thread_scratchpad_size; -#if SAFE_TO_USE_OMP_SIMD +#if defined(SAFE_TO_USE_OMP_SIMD) PRAGMA_OMP_SIMD() for (int i = 0; i < thread_scratchpad_size; i++) scratchpad_ithr[i] = 0.0f; diff --git a/src/cpu/x64/xbyak/xbyak.h b/src/cpu/x64/xbyak/xbyak.h index bb8234c7698..6c88004fc8c 100644 --- a/src/cpu/x64/xbyak/xbyak.h +++ b/src/cpu/x64/xbyak/xbyak.h @@ -312,22 +312,13 @@ inline const char *ConvertErrorToString(int err) #ifdef XBYAK_NO_EXCEPTION namespace local { -inline int& GetErrorRef() { - static XBYAK_TLS int err = 0; - return err; -} - -inline void SetError(int err) { - if (local::GetErrorRef()) return; // keep the first err code - local::GetErrorRef() = err; -} +static XBYAK_TLS int l_err = 0; +inline void SetError(int err) { if (err) l_err = err; } // keep the first err code } // local -inline void ClearError() { - local::GetErrorRef() = 0; -} -inline int GetError() { return local::GetErrorRef(); } +inline void ClearError() { local::l_err = 0; } +inline int GetError() { return local::l_err; } #define XBYAK_THROW(err) { local::SetError(err); return; } #define XBYAK_THROW_RET(err, r) { local::SetError(err); return r; } diff --git a/tests/gtests/test_eltwise.cpp b/tests/gtests/test_eltwise.cpp index ebe87e9b073..edf670a2c86 100644 --- a/tests/gtests/test_eltwise.cpp +++ b/tests/gtests/test_eltwise.cpp @@ -113,7 +113,7 @@ T bounded_relu_bwd(T dd, T s, A alpha) { template T soft_relu_fwd(T s) { - return s < (T)logf(FLT_MAX) ? T(log1pf(::expf(s))) : s; + return s < (T)logf(20.f) ? T(log1pf(::expf(s))) : s; } template