Skip to content

Commit

Permalink
[rtl thresh] Rename top module and fix intf names
Browse files Browse the repository at this point in the history
  • Loading branch information
auphelia committed Nov 20, 2024
1 parent a44237b commit c75678b
Showing 1 changed file with 31 additions and 43 deletions.
74 changes: 31 additions & 43 deletions src/finn/custom_op/fpgadataflow/rtl/thresholding_rtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,12 @@
import numpy as np
import os
import shutil
from pyverilator.util.axi_utils import reset_rtlsim, rtlsim_multi_io
from qonnx.core.datatype import DataType
from qonnx.util.basic import roundup_to_integer_multiple

from finn.custom_op.fpgadataflow.rtlbackend import RTLBackend
from finn.custom_op.fpgadataflow.thresholding import Thresholding
from finn.util.basic import (
get_memutil_alternatives,
mem_primitives_versal,
pyverilate_get_liveness_threshold_cycles,
)
from finn.util.basic import get_memutil_alternatives, mem_primitives_versal
from finn.util.data_packing import (
npy_to_rtlsim_input,
pack_innermost_dim_as_hex_string,
Expand Down Expand Up @@ -243,9 +238,7 @@ def prepare_codegen_rtl_values(self, model):
code_gen_dict["$THRESHOLDS_PATH$"] = ['"./%s_"' % self.onnx_node.name]

# Identify the module name
code_gen_dict["$MODULE_NAME_AXI_WRAPPER$"] = [
self.get_verilog_top_module_name() + "_axi_wrapper"
]
code_gen_dict["$MODULE_NAME_AXI_WRAPPER$"] = [self.get_verilog_top_module_name()]
# Set the top module name - AXI wrapper
code_gen_dict["$TOP_MODULE$"] = code_gen_dict["$MODULE_NAME_AXI_WRAPPER$"]

Expand Down Expand Up @@ -287,14 +280,22 @@ def prepare_codegen_rtl_values(self, model):
code_gen_dict["$DEEP_PIPELINE$"] = [str(deep_pipeline)]
return code_gen_dict

def get_rtl_file_list(self):
def get_rtl_file_list(self, abspath=False):
"""Thresholding binary search RTL file list"""
return [
"axilite_if.v",
"thresholding.sv",
"thresholding_axi.sv",
self.get_nodeattr("gen_top_module") + ".v",
if abspath:
code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + "/"
rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/thresholding/hdl/")
else:
code_gen_dir = ""
rtllib_dir = ""

verilog_files = [
rtllib_dir + "axilite_if.v",
rtllib_dir + "thresholding.sv",
rtllib_dir + "thresholding_axi.sv",
code_gen_dir + self.get_nodeattr("gen_top_module") + ".v",
]
return verilog_files

def generate_hdl(self, model, fpgapart, clk):
"""Prepare HDL files from templates for synthesis"""
Expand Down Expand Up @@ -373,38 +374,23 @@ def execute_node(self, context, graph):
# Create a PyVerilator wrapper of the RTLSim .so
sim = self.get_rtlsim()
nbits = self.get_instream_width()
inp = npy_to_rtlsim_input("{}/input_0.npy".format(code_gen_dir), export_idt, nbits)
io_names = self.get_verilog_top_module_intf_names()
istream_name = io_names["s_axis"][0][0]
ostream_name = io_names["m_axis"][0][0]
rtlsim_inp = npy_to_rtlsim_input(
"{}/input_0.npy".format(code_gen_dir), export_idt, nbits
)
io_dict = {
"inputs": {istream_name: inp},
"outputs": {ostream_name: []},
"inputs": {"in0": rtlsim_inp},
"outputs": {"out": []},
}

trace_file = self.get_nodeattr("rtlsim_trace")
if trace_file == "default":
trace_file = self.onnx_node.name + ".vcd"
sname = "_"

# Change into so directory to ensure threshold files can be found
rtlsim_so = self.get_nodeattr("rtlsim_so")
so_dir = os.path.dirname(os.path.realpath(rtlsim_so))
olcwd = os.getcwd()
os.chdir(so_dir)
num_out_values = self.get_number_output_values()
reset_rtlsim(sim)
total_cycle_count = rtlsim_multi_io(
sim,
io_dict,
num_out_values,
trace_file=trace_file,
sname=sname,
liveness_threshold=pyverilate_get_liveness_threshold_cycles(),
)
self.set_nodeattr("cycles_rtlsim", total_cycle_count)
os.chdir(olcwd)
output = io_dict["outputs"][ostream_name]

super().reset_rtlsim(sim)
if self.get_nodeattr("rtlsim_backend") == "pyverilator":
super().toggle_clk(sim)
self.rtlsim_multi_io(sim, io_dict)
super().close_rtlsim(sim)
rtlsim_output = io_dict["outputs"]["out"]

# Manage output data
odt = self.get_output_datatype()
Expand All @@ -413,7 +399,9 @@ def execute_node(self, context, graph):
out_npy_path = "{}/output.npy".format(code_gen_dir)
out_shape = self.get_folded_output_shape()

rtlsim_output_to_npy(output, out_npy_path, odt, out_shape, packed_bits, target_bits)
rtlsim_output_to_npy(
rtlsim_output, out_npy_path, odt, out_shape, packed_bits, target_bits
)

# load and reshape output
output = np.load(out_npy_path)
Expand Down

0 comments on commit c75678b

Please sign in to comment.