diff --git a/docs/finn/internals.rst b/docs/finn/internals.rst index 652c94ac24..a3d18bed77 100644 --- a/docs/finn/internals.rst +++ b/docs/finn/internals.rst @@ -311,6 +311,13 @@ Depending on the amount of parallelism requested, one of two implementation styl - 1 - default - depthwise-agnostic + * - < C + - 1 + - 1 + - 1 + - K + - parallel + - depthwise only * - C - 1 - 1 @@ -343,4 +350,4 @@ The RTL SWG is supported by the basic automatic folding algorithm in FINN (:py:m **MVAU:** Although it is recommended to unfold SIMD first, SIMD and PE can be set independently. Full (and balanced) parallelism is achieved by using the SWG in parallel window mode and setting MVAU SIMD and PE to their maximum values (SIMD = MW = C_in * K, PE = MH = C_out). -**VVAU:** While the VVAU HLS component supports SIMD unfolding independently from PE, the RTL SWG requires full unfolding across the channel dimension (SIMD of the SWG = PE of the VVAU) before enabling window-parallelism. Unlike the MVAU, the VVAU can't accept datawidth-converted input from a fully-parallel SWG in this case due to the depthwise data layout. As a result, the VVAU should be unfolded by PE first (up to PE = C), followed by SIMD (up to SIMD = K). +**VVAU:** The VVAU component supports SIMD unfolding (up to SIMD = K) independently from PE unfolding (up to PE = C), but can't accept a datawidth-converted input from a fully-parallel SWG in case PE is not fully unfolded due to the depthwise data layout. Therefore, it is required to set SIMD of the SWG = PE of the VVAU when window-parallelism is enabled. In this scenario, VVAU SIMD < K is supported via an automatically inserted DWC. diff --git a/finn-rtllib/swg/swg_common.sv b/finn-rtllib/swg/swg_common.sv index f2cdc333ca..c1d388550a 100644 --- a/finn-rtllib/swg/swg_common.sv +++ b/finn-rtllib/swg/swg_common.sv @@ -195,8 +195,7 @@ for (genvar e=0; e0; i--) - Data[i] <= Data[i-1]; + if (DEPTH > 1) Data[DEPTH-1:1] <= Data[DEPTH-2:0]; Data[0] <= shift_in; end end diff --git a/finn-rtllib/swg/swg_template_parallel.sv b/finn-rtllib/swg/swg_template_parallel.sv index 83a525ff36..b92f27b2ca 100644 --- a/finn-rtllib/swg/swg_template_parallel.sv +++ b/finn-rtllib/swg/swg_template_parallel.sv @@ -136,7 +136,6 @@ module $TOP_MODULE_NAME$_impl #( // counters/address registers logic signed [$clog2(LAST_READ_ELEM+1)+1-1:0] Newest_buffered_elem = -1; logic [$clog2(LAST_READ_ELEM+1)+1-1:0] Current_elem = FIRST_WRITE_ELEM; - logic [$clog2(LAST_READ_ELEM+1)+1-1:0] First_elem_next_window = 0; // control registers/signals logic Writing_done = 0; @@ -146,13 +145,7 @@ module $TOP_MODULE_NAME$_impl #( uwire write_blocked = write_cmd && !out_V_V_TREADY && !Write_done; uwire reading_done = Newest_buffered_elem == LAST_READ_ELEM; - uwire read_cmd = - !reading_done && ( // if there is still an input element left to read - Writing_done || ( // if writing is done (e.g. for skipped rows at FM end due to stride) - $signed(((Newest_buffered_elem - ($signed(BUF_ELEM_TOTAL) - 1)))) < $signed(First_elem_next_window) && - $signed(((Newest_buffered_elem - ($signed(BUF_ELEM_TOTAL) - 1)))) < $signed(Current_elem) - ) // (over-)write to buffer if oldest buffered element will no longer be needed - ); + uwire read_cmd = !reading_done && (Writing_done || Newest_buffered_elem <= $signed(Current_elem)); uwire read_ok = read_cmd && in0_V_V_TVALID && !write_blocked; // includes waiting on W if W-only cycle: wait only on W no R/W to wait for @@ -186,7 +179,6 @@ module $TOP_MODULE_NAME$_impl #( if(!ap_rst_n) begin Newest_buffered_elem <= -1; Current_elem <= FIRST_WRITE_ELEM; - First_elem_next_window <= 0; Writing_done <= 0; end else begin @@ -199,14 +191,11 @@ module $TOP_MODULE_NAME$_impl #( // todo: allow for read overlapping between feature maps (i.e., reading first elements from next FM while still writing last window of current FM) Newest_buffered_elem <= -1; Current_elem <= FIRST_WRITE_ELEM; - First_elem_next_window <= 0; Writing_done <= 0; end end if (write_ok) begin - First_elem_next_window <= First_elem_next_window + tail_incr; - // check if this is the last write cycle (Writing_done will be true afterwards) if (Current_elem == LAST_WRITE_ELEM) begin Writing_done <= 1; @@ -215,7 +204,6 @@ module $TOP_MODULE_NAME$_impl #( // start processing of next FM if reading is done already, or completes in the same cycle Newest_buffered_elem <= -1; Current_elem <= FIRST_WRITE_ELEM; - First_elem_next_window <= 0; Writing_done <= 0; end end diff --git a/src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py b/src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py index a55cdcc0be..734f75a973 100755 --- a/src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py +++ b/src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py @@ -237,12 +237,11 @@ def get_buffer_depth(self): mmv_in = 1 mmv_out = 1 channel_factor = int(ifm_ch / simd) - - # compute minimal buffer length (assuming it holds 1 complete window) - buffer_min_size = ((k_h - 1) * dilation_h * w + (k_w - 1) * dilation_w + 1) * channel_factor - impl_style = self.select_impl_style() if impl_style == "default": + buffer_min_size = ( + (k_h - 1) * dilation_h * w + (k_w - 1) * dilation_w + 1 + ) * channel_factor # add additional buffer space in case of stride > 1 # this minimizes cycle count as it allows an earlier pre-load of inputs buffer_depth = ( @@ -257,6 +256,9 @@ def get_buffer_depth(self): ) ) elif impl_style == "parallel": + buffer_min_size = ( + (k_h - 1) * dilation_h * w + (k_w - 1) * dilation_w + ) * channel_factor + 1 buffer_depth = buffer_min_size + 1 return buffer_depth @@ -691,7 +693,7 @@ def prepare_codegen_parallel(self): channel_factor = int(ifm_ch / simd) # compute minimal buffer length (assuming it holds 1 complete window) - buffer_min_size = ((k_h - 1) * dilation_h * w + (k_w - 1) * dilation_w + 1) * channel_factor + buffer_min_size = ((k_h - 1) * dilation_h * w + (k_w - 1) * dilation_w) * channel_factor + 1 buffer_actual_size = self.get_buffer_depth() code_gen_dict["$BUF_ELEM_TOTAL$"] = [str(buffer_actual_size)] @@ -710,38 +712,31 @@ def prepare_codegen_parallel(self): ] # re-use default controller loop structure - code_gen_dict["$IS_DEPTHWISE$"] = ["0"] loop_h_iterations = out_dim_h - loop_w_iterations = out_dim_w # now the innermost loop - loop_kh_iterations = 1 + loop_w_iterations = out_dim_w + loop_kh_iterations = channel_factor loop_kw_iterations = 1 loop_simd_iterations = 1 - if loop_w_iterations == 1: - code_gen_dict["$INNERMOST_STATE$"] = ["STATE_LOOP_H"] - loop_h_iterations -= 1 # -1 because state is initial state + if loop_kh_iterations == 1: + if loop_w_iterations == 1: + code_gen_dict["$INNERMOST_STATE$"] = ["STATE_LOOP_H"] + loop_h_iterations -= 1 # -1 because state is initial state + else: + code_gen_dict["$INNERMOST_STATE$"] = ["STATE_LOOP_W"] + loop_w_iterations -= 1 # -1 because state is initial state else: - code_gen_dict["$INNERMOST_STATE$"] = ["STATE_LOOP_W"] - loop_w_iterations -= 1 # -1 because state is initial state - - # set head and tail address increment values - addr_incr_end_window = -buffer_min_size + stride_w * channel_factor + 1 - addr_incr_end_row = ( - -buffer_min_size - + ((skip_columns + kernel_width) * channel_factor) # remaining line - + ((stride_h - 1) * w * channel_factor) # skip lines - + 1 - ) - - tail_incr_w = addr_incr_end_window + buffer_min_size - 1 - tail_incr_h = addr_incr_end_row + buffer_min_size - 1 - tail_incr_last_window = stride_w + code_gen_dict["$INNERMOST_STATE$"] = ["STATE_LOOP_KH"] + loop_kh_iterations -= 1 # -1 because state is initial state + # set head address increment values addr_incr_end_simd = 1 addr_incr_end_window_elem = 1 addr_incr_end_window_row = 1 - addr_incr_end_window = tail_incr_w - addr_incr_end_row = tail_incr_h + addr_incr_end_window = (stride_w - 1) * channel_factor + 1 + addr_incr_end_row = ((skip_columns + (kernel_width - 1)) * channel_factor + 1) + ( + (stride_h - 1) * w * channel_factor + ) # add init value for CURRENT_ELEM counter = last elem of first window code_gen_dict["$FIRST_WRITE_ELEM$"] = [str(buffer_min_size - 1)] @@ -772,9 +767,6 @@ def prepare_codegen_parallel(self): abs(addr_incr_end_window_row) + 1, abs(addr_incr_end_window) + 1, abs(addr_incr_end_row) + 1, - abs(tail_incr_w) + 1, - abs(tail_incr_h) + 1, - abs(tail_incr_last_window) + 1, ) ) ) @@ -784,9 +776,11 @@ def prepare_codegen_parallel(self): code_gen_dict["$HEAD_INCR_KH$"] = [str(addr_incr_end_window_row)] code_gen_dict["$HEAD_INCR_W$"] = [str(addr_incr_end_window)] code_gen_dict["$HEAD_INCR_H$"] = [str(addr_incr_end_row)] - code_gen_dict["$TAIL_INCR_W$"] = [str(tail_incr_w)] - code_gen_dict["$TAIL_INCR_H$"] = [str(tail_incr_h)] - code_gen_dict["$TAIL_INCR_LAST$"] = [str(tail_incr_last_window)] + # not used, set to zero: + code_gen_dict["$TAIL_INCR_W$"] = ["0"] + code_gen_dict["$TAIL_INCR_H$"] = ["0"] + code_gen_dict["$TAIL_INCR_LAST$"] = ["0"] + code_gen_dict["$IS_DEPTHWISE$"] = ["0"] code_gen_dict["$SIMD$"] = [str(simd)] code_gen_dict["$MMV_IN$"] = [str(mmv_in)] @@ -810,15 +804,21 @@ def prepare_codegen_parallel(self): for ky in range(k_h): reg_fifo = [] for kx in range(k_w): - reg_fifo.append(px_idx) - px_idx += 1 + for c in range(channel_factor): + if c < (channel_factor - 1): + if not (ky == 0 and kx == 0): + reg_fifo.append(-1) + px_idx += 1 + else: + reg_fifo.append(px_idx) + px_idx += 1 if kx < (k_w - 1): - reg_fifo.extend([-1] * (dilation_w - 1)) - px_idx += dilation_w - 1 + reg_fifo.extend([-1] * ((dilation_w - 1) * channel_factor)) + px_idx += (dilation_w - 1) * channel_factor reg_fifos.append(reg_fifo) if ky < (k_h - 1): - line_buffer_len = (w - kernel_width) + w * (dilation_h - 1) + line_buffer_len = ((w - kernel_width) + w * (dilation_h - 1)) * channel_factor bram_fifos_depth.append(line_buffer_len) px_idx += line_buffer_len @@ -926,6 +926,7 @@ def select_impl_style(self): """Selects implementation style based on folding configuration.""" simd = self.get_nodeattr("SIMD") M = self.get_nodeattr("M") + depthwise = self.get_nodeattr("depthwise") ifm_ch = self.get_nodeattr("IFMChannels") ifm_dim = self.get_nodeattr("IFMDim") stride = self.get_nodeattr("Stride") @@ -950,7 +951,6 @@ def select_impl_style(self): if self.get_nodeattr("parallel_window"): # mmv_in = M * 1 mmv_out = M * k_h * k_w - assert ifm_ch == simd, "Constraint violated: SIMD must be equal to IFMChannels" else: # mmv_in = 1 mmv_out = 1 @@ -959,7 +959,12 @@ def select_impl_style(self): # choose implementation style if mmv_out > 1 or (k_h == 1 and k_w == 1): impl_style = "parallel" - assert ifm_ch == simd, "Constraint violated: SIMD must be equal to IFMChannels" + if depthwise or (k_h == 1 and k_w == 1): + # allow SIMD < IFM_CH in depthwise mode (VVAU supports the resulting data layout) + # also allowed for 1x1 kernel since depthwise and non-depthwise are equivalent + assert ifm_ch % simd == 0, "Constraint violated: SIMD must divide IFMChannels" + else: + assert ifm_ch == simd, "Constraint violated: SIMD must be equal to IFMChannels" else: impl_style = "default" diff --git a/tests/fpgadataflow/test_fpgadataflow_convinputgenerator_rtl.py b/tests/fpgadataflow/test_fpgadataflow_convinputgenerator_rtl.py index 53d7be0ebb..62b7abe536 100755 --- a/tests/fpgadataflow/test_fpgadataflow_convinputgenerator_rtl.py +++ b/tests/fpgadataflow/test_fpgadataflow_convinputgenerator_rtl.py @@ -192,12 +192,10 @@ def test_fpgadataflow_slidingwindow_rtl( pytest.skip("Illegal convolution configuration: kernel or stride > FM dimension") if (k_h == 1 and dilation_h != 1) or (k_w == 1 and dilation_w != 1): pytest.skip("Illegal convolution configuration: dilation for unitary kernel dim") - if (stride_h > k_h) or (stride_w > k_w) and not parallel_window: + if ((stride_h > k_h) or (stride_w > k_w)) and not (parallel_window or (k_h == 1 and k_w == 1)): pytest.skip("Not all combinations for stride > k edge case supported in default mode") - if k_h == 1 and k_w == 1 and simd != ifm_ch: - pytest.skip("1x1 Kernel only supported in parallel mode (SIMD=C)") - if parallel_window and simd != ifm_ch: - pytest.skip("Parallel window requires SIMD=C") + if parallel_window and simd != ifm_ch and not (dw or (k_h == 1 and k_w == 1)): + pytest.skip("Parallel window requires SIMD=C for non-depthwise case") ofm_dim_h = compute_conv_output_dim(ifm_dim_h, k_h, stride_h, 0, dilation_h) ofm_dim_w = compute_conv_output_dim(ifm_dim_w, k_w, stride_w, 0, dilation_w)