Skip to content

Commit

Permalink
Merge pull request #922 from fpjentzsch/feature/swg_reordering
Browse files Browse the repository at this point in the history
[RTL SWG] Support SIMD < C in window-parallel mode
  • Loading branch information
auphelia authored Jan 8, 2024
2 parents e9985e6 + eddbd27 commit e3cb226
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 62 deletions.
9 changes: 8 additions & 1 deletion docs/finn/internals.rst
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,13 @@ Depending on the amount of parallelism requested, one of two implementation styl
- 1
- default
- depthwise-agnostic
* - < C
- 1
- 1
- 1
- K
- parallel
- depthwise only
* - C
- 1
- 1
Expand Down Expand Up @@ -343,4 +350,4 @@ The RTL SWG is supported by the basic automatic folding algorithm in FINN (:py:m

**MVAU:** Although it is recommended to unfold SIMD first, SIMD and PE can be set independently. Full (and balanced) parallelism is achieved by using the SWG in parallel window mode and setting MVAU SIMD and PE to their maximum values (SIMD = MW = C_in * K, PE = MH = C_out).

**VVAU:** While the VVAU HLS component supports SIMD unfolding independently from PE, the RTL SWG requires full unfolding across the channel dimension (SIMD of the SWG = PE of the VVAU) before enabling window-parallelism. Unlike the MVAU, the VVAU can't accept datawidth-converted input from a fully-parallel SWG in this case due to the depthwise data layout. As a result, the VVAU should be unfolded by PE first (up to PE = C), followed by SIMD (up to SIMD = K).
**VVAU:** The VVAU component supports SIMD unfolding (up to SIMD = K) independently from PE unfolding (up to PE = C), but can't accept a datawidth-converted input from a fully-parallel SWG in case PE is not fully unfolded due to the depthwise data layout. Therefore, it is required to set SIMD of the SWG = PE of the VVAU when window-parallelism is enabled. In this scenario, VVAU SIMD < K is supported via an automatically inserted DWC.
3 changes: 1 addition & 2 deletions finn-rtllib/swg/swg_common.sv
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,7 @@ for (genvar e=0; e<DEPTH; e++)

always @ (posedge clk) begin
if (shift_enable) begin
for (int i=DEPTH-1; i>0; i--)
Data[i] <= Data[i-1];
if (DEPTH > 1) Data[DEPTH-1:1] <= Data[DEPTH-2:0];
Data[0] <= shift_in;
end
end
Expand Down
14 changes: 1 addition & 13 deletions finn-rtllib/swg/swg_template_parallel.sv
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ module $TOP_MODULE_NAME$_impl #(
// counters/address registers
logic signed [$clog2(LAST_READ_ELEM+1)+1-1:0] Newest_buffered_elem = -1;
logic [$clog2(LAST_READ_ELEM+1)+1-1:0] Current_elem = FIRST_WRITE_ELEM;
logic [$clog2(LAST_READ_ELEM+1)+1-1:0] First_elem_next_window = 0;

// control registers/signals
logic Writing_done = 0;
Expand All @@ -146,13 +145,7 @@ module $TOP_MODULE_NAME$_impl #(
uwire write_blocked = write_cmd && !out_V_V_TREADY && !Write_done;

uwire reading_done = Newest_buffered_elem == LAST_READ_ELEM;
uwire read_cmd =
!reading_done && ( // if there is still an input element left to read
Writing_done || ( // if writing is done (e.g. for skipped rows at FM end due to stride)
$signed(((Newest_buffered_elem - ($signed(BUF_ELEM_TOTAL) - 1)))) < $signed(First_elem_next_window) &&
$signed(((Newest_buffered_elem - ($signed(BUF_ELEM_TOTAL) - 1)))) < $signed(Current_elem)
) // (over-)write to buffer if oldest buffered element will no longer be needed
);
uwire read_cmd = !reading_done && (Writing_done || Newest_buffered_elem <= $signed(Current_elem));
uwire read_ok = read_cmd && in0_V_V_TVALID && !write_blocked;

// includes waiting on W if W-only cycle: wait only on W no R/W to wait for
Expand Down Expand Up @@ -186,7 +179,6 @@ module $TOP_MODULE_NAME$_impl #(
if(!ap_rst_n) begin
Newest_buffered_elem <= -1;
Current_elem <= FIRST_WRITE_ELEM;
First_elem_next_window <= 0;
Writing_done <= 0;
end
else begin
Expand All @@ -199,14 +191,11 @@ module $TOP_MODULE_NAME$_impl #(
// todo: allow for read overlapping between feature maps (i.e., reading first elements from next FM while still writing last window of current FM)
Newest_buffered_elem <= -1;
Current_elem <= FIRST_WRITE_ELEM;
First_elem_next_window <= 0;
Writing_done <= 0;
end
end

if (write_ok) begin
First_elem_next_window <= First_elem_next_window + tail_incr;

// check if this is the last write cycle (Writing_done will be true afterwards)
if (Current_elem == LAST_WRITE_ELEM) begin
Writing_done <= 1;
Expand All @@ -215,7 +204,6 @@ module $TOP_MODULE_NAME$_impl #(
// start processing of next FM if reading is done already, or completes in the same cycle
Newest_buffered_elem <= -1;
Current_elem <= FIRST_WRITE_ELEM;
First_elem_next_window <= 0;
Writing_done <= 0;
end
end
Expand Down
87 changes: 46 additions & 41 deletions src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,11 @@ def get_buffer_depth(self):
mmv_in = 1
mmv_out = 1
channel_factor = int(ifm_ch / simd)

# compute minimal buffer length (assuming it holds 1 complete window)
buffer_min_size = ((k_h - 1) * dilation_h * w + (k_w - 1) * dilation_w + 1) * channel_factor

impl_style = self.select_impl_style()
if impl_style == "default":
buffer_min_size = (
(k_h - 1) * dilation_h * w + (k_w - 1) * dilation_w + 1
) * channel_factor
# add additional buffer space in case of stride > 1
# this minimizes cycle count as it allows an earlier pre-load of inputs
buffer_depth = (
Expand All @@ -257,6 +256,9 @@ def get_buffer_depth(self):
)
)
elif impl_style == "parallel":
buffer_min_size = (
(k_h - 1) * dilation_h * w + (k_w - 1) * dilation_w
) * channel_factor + 1
buffer_depth = buffer_min_size + 1
return buffer_depth

Expand Down Expand Up @@ -691,7 +693,7 @@ def prepare_codegen_parallel(self):
channel_factor = int(ifm_ch / simd)

# compute minimal buffer length (assuming it holds 1 complete window)
buffer_min_size = ((k_h - 1) * dilation_h * w + (k_w - 1) * dilation_w + 1) * channel_factor
buffer_min_size = ((k_h - 1) * dilation_h * w + (k_w - 1) * dilation_w) * channel_factor + 1

buffer_actual_size = self.get_buffer_depth()
code_gen_dict["$BUF_ELEM_TOTAL$"] = [str(buffer_actual_size)]
Expand All @@ -710,38 +712,31 @@ def prepare_codegen_parallel(self):
]

# re-use default controller loop structure
code_gen_dict["$IS_DEPTHWISE$"] = ["0"]
loop_h_iterations = out_dim_h
loop_w_iterations = out_dim_w # now the innermost loop
loop_kh_iterations = 1
loop_w_iterations = out_dim_w
loop_kh_iterations = channel_factor
loop_kw_iterations = 1
loop_simd_iterations = 1

if loop_w_iterations == 1:
code_gen_dict["$INNERMOST_STATE$"] = ["STATE_LOOP_H"]
loop_h_iterations -= 1 # -1 because state is initial state
if loop_kh_iterations == 1:
if loop_w_iterations == 1:
code_gen_dict["$INNERMOST_STATE$"] = ["STATE_LOOP_H"]
loop_h_iterations -= 1 # -1 because state is initial state
else:
code_gen_dict["$INNERMOST_STATE$"] = ["STATE_LOOP_W"]
loop_w_iterations -= 1 # -1 because state is initial state
else:
code_gen_dict["$INNERMOST_STATE$"] = ["STATE_LOOP_W"]
loop_w_iterations -= 1 # -1 because state is initial state

# set head and tail address increment values
addr_incr_end_window = -buffer_min_size + stride_w * channel_factor + 1
addr_incr_end_row = (
-buffer_min_size
+ ((skip_columns + kernel_width) * channel_factor) # remaining line
+ ((stride_h - 1) * w * channel_factor) # skip lines
+ 1
)

tail_incr_w = addr_incr_end_window + buffer_min_size - 1
tail_incr_h = addr_incr_end_row + buffer_min_size - 1
tail_incr_last_window = stride_w
code_gen_dict["$INNERMOST_STATE$"] = ["STATE_LOOP_KH"]
loop_kh_iterations -= 1 # -1 because state is initial state

# set head address increment values
addr_incr_end_simd = 1
addr_incr_end_window_elem = 1
addr_incr_end_window_row = 1
addr_incr_end_window = tail_incr_w
addr_incr_end_row = tail_incr_h
addr_incr_end_window = (stride_w - 1) * channel_factor + 1
addr_incr_end_row = ((skip_columns + (kernel_width - 1)) * channel_factor + 1) + (
(stride_h - 1) * w * channel_factor
)

# add init value for CURRENT_ELEM counter = last elem of first window
code_gen_dict["$FIRST_WRITE_ELEM$"] = [str(buffer_min_size - 1)]
Expand Down Expand Up @@ -772,9 +767,6 @@ def prepare_codegen_parallel(self):
abs(addr_incr_end_window_row) + 1,
abs(addr_incr_end_window) + 1,
abs(addr_incr_end_row) + 1,
abs(tail_incr_w) + 1,
abs(tail_incr_h) + 1,
abs(tail_incr_last_window) + 1,
)
)
)
Expand All @@ -784,9 +776,11 @@ def prepare_codegen_parallel(self):
code_gen_dict["$HEAD_INCR_KH$"] = [str(addr_incr_end_window_row)]
code_gen_dict["$HEAD_INCR_W$"] = [str(addr_incr_end_window)]
code_gen_dict["$HEAD_INCR_H$"] = [str(addr_incr_end_row)]
code_gen_dict["$TAIL_INCR_W$"] = [str(tail_incr_w)]
code_gen_dict["$TAIL_INCR_H$"] = [str(tail_incr_h)]
code_gen_dict["$TAIL_INCR_LAST$"] = [str(tail_incr_last_window)]
# not used, set to zero:
code_gen_dict["$TAIL_INCR_W$"] = ["0"]
code_gen_dict["$TAIL_INCR_H$"] = ["0"]
code_gen_dict["$TAIL_INCR_LAST$"] = ["0"]
code_gen_dict["$IS_DEPTHWISE$"] = ["0"]

code_gen_dict["$SIMD$"] = [str(simd)]
code_gen_dict["$MMV_IN$"] = [str(mmv_in)]
Expand All @@ -810,15 +804,21 @@ def prepare_codegen_parallel(self):
for ky in range(k_h):
reg_fifo = []
for kx in range(k_w):
reg_fifo.append(px_idx)
px_idx += 1
for c in range(channel_factor):
if c < (channel_factor - 1):
if not (ky == 0 and kx == 0):
reg_fifo.append(-1)
px_idx += 1
else:
reg_fifo.append(px_idx)
px_idx += 1
if kx < (k_w - 1):
reg_fifo.extend([-1] * (dilation_w - 1))
px_idx += dilation_w - 1
reg_fifo.extend([-1] * ((dilation_w - 1) * channel_factor))
px_idx += (dilation_w - 1) * channel_factor
reg_fifos.append(reg_fifo)

if ky < (k_h - 1):
line_buffer_len = (w - kernel_width) + w * (dilation_h - 1)
line_buffer_len = ((w - kernel_width) + w * (dilation_h - 1)) * channel_factor
bram_fifos_depth.append(line_buffer_len)
px_idx += line_buffer_len

Expand Down Expand Up @@ -926,6 +926,7 @@ def select_impl_style(self):
"""Selects implementation style based on folding configuration."""
simd = self.get_nodeattr("SIMD")
M = self.get_nodeattr("M")
depthwise = self.get_nodeattr("depthwise")
ifm_ch = self.get_nodeattr("IFMChannels")
ifm_dim = self.get_nodeattr("IFMDim")
stride = self.get_nodeattr("Stride")
Expand All @@ -950,7 +951,6 @@ def select_impl_style(self):
if self.get_nodeattr("parallel_window"):
# mmv_in = M * 1
mmv_out = M * k_h * k_w
assert ifm_ch == simd, "Constraint violated: SIMD must be equal to IFMChannels"
else:
# mmv_in = 1
mmv_out = 1
Expand All @@ -959,7 +959,12 @@ def select_impl_style(self):
# choose implementation style
if mmv_out > 1 or (k_h == 1 and k_w == 1):
impl_style = "parallel"
assert ifm_ch == simd, "Constraint violated: SIMD must be equal to IFMChannels"
if depthwise or (k_h == 1 and k_w == 1):
# allow SIMD < IFM_CH in depthwise mode (VVAU supports the resulting data layout)
# also allowed for 1x1 kernel since depthwise and non-depthwise are equivalent
assert ifm_ch % simd == 0, "Constraint violated: SIMD must divide IFMChannels"
else:
assert ifm_ch == simd, "Constraint violated: SIMD must be equal to IFMChannels"
else:
impl_style = "default"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,10 @@ def test_fpgadataflow_slidingwindow_rtl(
pytest.skip("Illegal convolution configuration: kernel or stride > FM dimension")
if (k_h == 1 and dilation_h != 1) or (k_w == 1 and dilation_w != 1):
pytest.skip("Illegal convolution configuration: dilation for unitary kernel dim")
if (stride_h > k_h) or (stride_w > k_w) and not parallel_window:
if ((stride_h > k_h) or (stride_w > k_w)) and not (parallel_window or (k_h == 1 and k_w == 1)):
pytest.skip("Not all combinations for stride > k edge case supported in default mode")
if k_h == 1 and k_w == 1 and simd != ifm_ch:
pytest.skip("1x1 Kernel only supported in parallel mode (SIMD=C)")
if parallel_window and simd != ifm_ch:
pytest.skip("Parallel window requires SIMD=C")
if parallel_window and simd != ifm_ch and not (dw or (k_h == 1 and k_w == 1)):
pytest.skip("Parallel window requires SIMD=C for non-depthwise case")

ofm_dim_h = compute_conv_output_dim(ifm_dim_h, k_h, stride_h, 0, dilation_h)
ofm_dim_w = compute_conv_output_dim(ifm_dim_w, k_w, stride_w, 0, dilation_w)
Expand Down

0 comments on commit e3cb226

Please sign in to comment.