-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add the cuda.core.experiemental.Linker class #229
Open
ksimpson-work
wants to merge
5
commits into
main
Choose a base branch
from
ksimpson/cuda_core_linker_155
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
81086e0
merge main and update the options class
ksimpson-work c467550
address some review comments
ksimpson-work e982657
address some review comments
ksimpson-work 275eb71
remove the add_code_object public entry and update tests
ksimpson-work fc11337
improve the comment on the LinkerOptions class
ksimpson-work File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,282 @@ | ||
from cuda.core.experimental._module import ObjectCode | ||
from cuda.core.experimental._utils import check_or_create_options | ||
from dataclasses import dataclass | ||
from typing import Optional, List | ||
from cuda.bindings import nvjitlink | ||
|
||
|
||
@dataclass | ||
class LinkerOptions: | ||
"""Customizable :obj:`LinkerOptions` for nvJitLink. | ||
|
||
Attributes | ||
---------- | ||
arch : str | ||
Pass SM architecture value. Can use compute_<N> value instead if only generating PTX. | ||
This is a required option. | ||
Acceptable value type: str | ||
Maps to: -arch=sm_<N> | ||
max_register_count : int, optional | ||
Maximum register count. | ||
Default: None | ||
Acceptable value type: int | ||
Maps to: -maxrregcount=<N> | ||
time : bool, optional | ||
Print timing information to InfoLog. | ||
Default: False | ||
Acceptable value type: bool | ||
Maps to: -time | ||
verbose : bool, optional | ||
Print verbose messages to InfoLog. | ||
Default: False | ||
Acceptable value type: bool | ||
Maps to: -verbose | ||
link_time_optimization : bool, optional | ||
Perform link time optimization. | ||
Default: False | ||
Acceptable value type: bool | ||
Maps to: -lto | ||
ptx : bool, optional | ||
Emit PTX after linking instead of CUBIN; only supported with -lto. | ||
Default: False | ||
Acceptable value type: bool | ||
Maps to: -ptx | ||
optimization_level : int, optional | ||
Set optimization level. Only 0 and 3 are accepted. | ||
Default: None | ||
Acceptable value type: int | ||
Maps to: -O<N> | ||
debug : bool, optional | ||
Generate debug information. | ||
Default: False | ||
Acceptable value type: bool | ||
Maps to: -g | ||
lineinfo : bool, optional | ||
Generate line information. | ||
Default: False | ||
Acceptable value type: bool | ||
Maps to: -lineinfo | ||
ftz : bool, optional | ||
Flush denormal values to zero. | ||
Default: False | ||
Acceptable value type: bool | ||
Maps to: -ftz=<n> | ||
prec_div : bool, optional | ||
Use precise division. | ||
Default: True | ||
Acceptable value type: bool | ||
Maps to: -prec-div=<n> | ||
prec_sqrt : bool, optional | ||
Use precise square root. | ||
Default: True | ||
Acceptable value type: bool | ||
Maps to: -prec-sqrt=<n> | ||
fma : bool, optional | ||
Use fast multiply-add. | ||
Default: True | ||
Acceptable value type: bool | ||
Maps to: -fma=<n> | ||
kernels_used : List[str], optional | ||
Pass list of kernels that are used; any not in the list can be removed. This option can be specified multiple times. | ||
Default: None | ||
Acceptable value type: list of str | ||
Maps to: -kernels-used=<name> | ||
variables_used : List[str], optional | ||
Pass list of variables that are used; any not in the list can be removed. This option can be specified multiple times. | ||
Default: None | ||
Acceptable value type: list of str | ||
Maps to: -variables-used=<name> | ||
optimize_unused_variables : bool, optional | ||
Assume that if a variable is not referenced in device code, it can be removed. | ||
Default: False | ||
Acceptable value type: bool | ||
Maps to: -optimize-unused-variables | ||
xptxas : List[str], optional | ||
Pass options to PTXAS. This option can be called multiple times. | ||
Default: None | ||
Acceptable value type: list of str | ||
Maps to: -Xptxas=<opt> | ||
split_compile : int, optional | ||
Split compilation maximum thread count. Use 0 to use all available processors. Value of 1 disables split compilation (default). | ||
Default: 1 | ||
Acceptable value type: int | ||
Maps to: -split-compile=<N> | ||
split_compile_extended : int, optional | ||
A more aggressive form of split compilation available in LTO mode only. Accepts a maximum thread count value. Use 0 to use all available processors. Value of 1 disables extended split compilation (default). Note: This option can potentially impact performance of the compiled binary. | ||
Default: 1 | ||
Acceptable value type: int | ||
Maps to: -split-compile-extended=<N> | ||
jump_table_density : int, optional | ||
When doing LTO, specify the case density percentage in switch statements, and use it as a minimal threshold to determine whether jump table (brx.idx instruction) will be used to implement a switch statement. Default value is 101. The percentage ranges from 0 to 101 inclusively. | ||
Default: 101 | ||
Acceptable value type: int | ||
Maps to: -jump-table-density=<N> | ||
no_cache : bool, optional | ||
Do not cache the intermediate steps of nvJitLink. | ||
Default: False | ||
Acceptable value type: bool | ||
Maps to: -no-cache | ||
device_stack_protector : bool, optional | ||
Enable stack canaries in device code. Stack canaries make it more difficult to exploit certain types of memory safety bugs involving stack-local variables. The compiler uses heuristics to assess the risk of such a bug in each function. Only those functions which are deemed high-risk make use of a stack canary. | ||
Default: False | ||
Acceptable value type: bool | ||
Maps to: -device-stack-protector | ||
""" | ||
arch: str | ||
max_register_count: Optional[int] = None | ||
time: Optional[bool] = None | ||
verbose: Optional[bool] = None | ||
link_time_optimization: Optional[bool] = None | ||
ptx: Optional[bool] = None | ||
optimization_level: Optional[int] = None | ||
debug: Optional[bool] = None | ||
lineinfo: Optional[bool] = None | ||
ftz: Optional[bool] = None | ||
prec_div: Optional[bool] = None | ||
prec_sqrt: Optional[bool] = None | ||
fma: Optional[bool] = None | ||
kernels_used: Optional[List[str]] = None | ||
variables_used: Optional[List[str]] = None | ||
optimize_unused_variables: Optional[bool] = None | ||
xptxas: Optional[List[str]] = None | ||
split_compile: Optional[int] = None | ||
split_compile_extended: Optional[int] = None | ||
jump_table_density: Optional[int] = None | ||
no_cache: Optional[bool] = None | ||
device_stack_protector: Optional[bool] = None | ||
|
||
def __post_init__(self): | ||
self.formatted_options = [] | ||
if self.arch is not None: | ||
self.formatted_options.append(f"-arch={self.arch}") | ||
if self.max_register_count is not None: | ||
self.formatted_options.append(f"-maxrregcount={self.max_register_count}") | ||
if self.time is not None: | ||
self.formatted_options.append("-time") | ||
if self.verbose is not None: | ||
self.formatted_options.append("-verbose") | ||
if self.link_time_optimization is not None: | ||
self.formatted_options.append("-lto") | ||
if self.ptx is not None: | ||
self.formatted_options.append("-ptx") | ||
if self.optimization_level is not None: | ||
self.formatted_options.append(f"-O{self.optimization_level}") | ||
if self.debug is not None: | ||
self.formatted_options.append("-g") | ||
if self.lineinfo is not None: | ||
self.formatted_options.append("-lineinfo") | ||
if self.ftz is not None: | ||
self.formatted_options.append(f"-ftz={'true' if self.ftz else 'false'}") | ||
if self.prec_div is not None: | ||
self.formatted_options.append(f"-prec-div={'true' if self.prec_div else 'false'}") | ||
if self.prec_sqrt is not None: | ||
self.formatted_options.append(f"-prec-sqrt={'true' if self.prec_sqrt else 'false'}") | ||
if self.fma is not None: | ||
self.formatted_options.append(f"-fma={'true' if self.fma else 'false'}") | ||
if self.kernels_used is not None: | ||
for kernel in self.kernels_used: | ||
self.formatted_options.append(f"-kernels-used={kernel}") | ||
if self.variables_used is not None: | ||
for variable in self.variables_used: | ||
self.formatted_options.append(f"-variables-used={variable}") | ||
if self.optimize_unused_variables is not None: | ||
self.formatted_options.append("-optimize-unused-variables") | ||
if self.xptxas is not None: | ||
for opt in self.xptxas: | ||
self.formatted_options.append(f"-Xptxas={opt}") | ||
if self.split_compile is not None: | ||
self.formatted_options.append(f"-split-compile={self.split_compile}") | ||
if self.split_compile_extended is not None: | ||
self.formatted_options.append(f"-split-compile-extended={self.split_compile_extended}") | ||
if self.jump_table_density is not None: | ||
self.formatted_options.append(f"-jump-table-density={self.jump_table_density}") | ||
if self.no_cache is not None: | ||
self.formatted_options.append("-no-cache") | ||
if self.device_stack_protector is not None: | ||
self.formatted_options.append("-device-stack-protector") | ||
|
||
|
||
class Linker: | ||
|
||
__slots__ = "_handle" | ||
|
||
def __init__(self, *object_codes : ObjectCode, options: LinkerOptions = None): | ||
self._handle = None | ||
options = check_or_create_options(LinkerOptions, options, "Linker options") | ||
self._handle = nvjitlink.create(len(options.formatted_options), options.formatted_options) | ||
|
||
if object_codes is not None: | ||
for code in object_codes: | ||
assert isinstance(code, ObjectCode) | ||
self._add_code_object(code) | ||
|
||
|
||
def _add_code_object(self, object_code: ObjectCode): | ||
data = object_code._module | ||
assert isinstance(data, bytes) | ||
nvjitlink.add_data( | ||
self._handle, | ||
self._input_type_from_code_type(object_code._code_type), | ||
data, | ||
len(data), | ||
f"{object_code._handle}_{object_code._code_type}", | ||
) | ||
|
||
|
||
def link(self, target_type) -> ObjectCode: | ||
nvjitlink.complete(self._handle) | ||
if target_type not in ["cubin", "ptx"]: | ||
raise ValueError(f"Unsupported target type: {target_type}") | ||
code = None | ||
if target_type == "cubin": | ||
cubin_size = nvjitlink.get_linked_cubin_size(self._handle) | ||
code = bytearray(cubin_size) | ||
nvjitlink.get_linked_cubin(self._handle, code) | ||
else: | ||
ptx_size = nvjitlink.get_linked_ptx_size(self._handle) | ||
code = bytearray(ptx_size) | ||
nvjitlink.get_linked_ptx(self._handle, code) | ||
|
||
return ObjectCode(bytes(code), target_type) | ||
|
||
|
||
def get_error_log(self) -> str: | ||
log_size = nvjitlink.get_error_log_size(self._handle) | ||
log = bytearray(log_size) | ||
nvjitlink.get_error_log(self._handle, log) | ||
return log.decode() | ||
|
||
|
||
def get_info_log(self) -> str: | ||
log_size = nvjitlink.get_info_log_size(self._handle) | ||
log = bytearray(log_size) | ||
nvjitlink.get_info_log(self._handle, log) | ||
return log.decode() | ||
|
||
|
||
def _input_type_from_code_type(self, code_type: str) -> nvjitlink.InputType: | ||
# this list is based on the supported values for code_type in the ObjectCode class definition. nvjitlink supports other options for input type | ||
if code_type == "ptx": | ||
return nvjitlink.InputType.PTX | ||
elif code_type == "cubin": | ||
return nvjitlink.InputType.CUBIN | ||
elif code_type == "fatbin": | ||
return nvjitlink.InputType.FATBIN | ||
elif code_type == "ltoir": | ||
return nvjitlink.InputType.LTOIR | ||
elif code_type == "object": | ||
return nvjitlink.InputType.OBJECT | ||
else: | ||
raise ValueError( | ||
f"Unknown code_type associated with ObjectCode: {code_type}" | ||
) | ||
|
||
|
||
@property | ||
def handle(self) -> int: | ||
return self._handle | ||
|
||
def __del__(self): | ||
if self._handle is not None: | ||
nvjitlink.destroy(self._handle) | ||
self._handle = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import pytest | ||
from cuda.core.experimental._linker import Linker, LinkerOptions | ||
from cuda.core.experimental._module import ObjectCode | ||
from cuda.core.experimental._program import Program | ||
|
||
ARCH = "sm_80" # use sm_80 for testing the oop nvJitLink wrapper | ||
empty_entrypoint_kernel = "__global__ void A() {}" | ||
empty_kernel = "__device__ void B() {}" | ||
addition_kernel = "__device__ int C(int a, int b) { return a + b; }" | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def compile_ptx_functions(init_cuda): | ||
|
||
object_code_a_ptx = Program(empty_entrypoint_kernel, "c++").compile("ptx") | ||
object_code_b_ptx = Program(empty_kernel, "c++").compile("ptx") | ||
object_code_c_ptx = Program(addition_kernel, "c++").compile("ptx") | ||
|
||
return object_code_a_ptx, object_code_b_ptx, object_code_c_ptx | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def compile_ltoir_functions(init_cuda): | ||
object_code_a_ltoir = Program(empty_entrypoint_kernel, "c++").compile("ltoir", options=("-dlto",)) | ||
object_code_b_ltoir = Program(empty_kernel, "c++").compile("ltoir", options=("-dlto",)) | ||
object_code_c_ltoir = Program(addition_kernel, "c++").compile("ltoir", options=("-dlto",)) | ||
|
||
return object_code_a_ltoir, object_code_b_ltoir, object_code_c_ltoir | ||
|
||
|
||
@pytest.mark.parametrize("options", [ | ||
LinkerOptions(arch=ARCH), | ||
LinkerOptions(arch=ARCH, max_register_count=32), | ||
LinkerOptions(arch=ARCH, time=True), | ||
LinkerOptions(arch=ARCH, verbose=True), | ||
LinkerOptions(arch=ARCH, optimization_level=3), | ||
LinkerOptions(arch=ARCH, debug=True), | ||
LinkerOptions(arch=ARCH, lineinfo=True), | ||
LinkerOptions(arch=ARCH, ftz=True), | ||
LinkerOptions(arch=ARCH, prec_div=True), | ||
LinkerOptions(arch=ARCH, prec_sqrt=True), | ||
LinkerOptions(arch=ARCH, fma=True), | ||
LinkerOptions(arch=ARCH, kernels_used=["kernel1"]), | ||
LinkerOptions(arch=ARCH, variables_used=["var1"]), | ||
LinkerOptions(arch=ARCH, optimize_unused_variables=True), | ||
LinkerOptions(arch=ARCH, xptxas=["-v"]), | ||
LinkerOptions(arch=ARCH, split_compile=0), | ||
LinkerOptions(arch=ARCH, split_compile_extended=1), | ||
LinkerOptions(arch=ARCH, jump_table_density=100), | ||
LinkerOptions(arch=ARCH, no_cache=True) | ||
]) | ||
def test_linker_init(compile_ptx_functions, options): | ||
linker = Linker(*compile_ptx_functions, options=options) | ||
object_code = linker.link("cubin") | ||
assert isinstance(object_code, ObjectCode) | ||
|
||
|
||
def test_linker_init_invalid_arch(): | ||
options = LinkerOptions(arch=None) | ||
with pytest.raises(TypeError): | ||
Linker(options) | ||
|
||
|
||
def test_linker_link_ptx(compile_ltoir_functions): | ||
options = LinkerOptions(arch=ARCH, link_time_optimization=True, ptx=True) | ||
linker = Linker(*compile_ltoir_functions, options = options) | ||
linked_code = linker.link("ptx") | ||
assert isinstance(linked_code, ObjectCode) | ||
|
||
|
||
def test_linker_link_cubin(compile_ptx_functions): | ||
options = LinkerOptions(arch=ARCH) | ||
linker = Linker(*compile_ptx_functions, options=options) | ||
linked_code = linker.link("cubin") | ||
assert isinstance(linked_code, ObjectCode) | ||
|
||
|
||
def test_linker_link_invalid_target_type(compile_ptx_functions): | ||
options = LinkerOptions(arch=ARCH) | ||
linker = Linker(*compile_ptx_functions, options=options) | ||
with pytest.raises(ValueError): | ||
linker.link("invalid_target") | ||
|
||
|
||
def test_linker_get_error_log(compile_ptx_functions): | ||
options = LinkerOptions(arch=ARCH) | ||
linker = Linker(*compile_ptx_functions, options=options) | ||
linker.link("cubin") | ||
log = linker.get_error_log() | ||
assert isinstance(log, str) | ||
|
||
|
||
def test_linker_get_info_log(compile_ptx_functions): | ||
options = LinkerOptions(arch=ARCH) | ||
linker = Linker(*compile_ptx_functions, options=options) | ||
linker.link("cubin") | ||
log = linker.get_info_log() | ||
assert isinstance(log, str) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm unclear about what combinations of functions {
ptx
,ltoir
} andtarget_type
s {ptx
,cubin
} are valid.To reduce the code duplication, could this work as a general idea?
I'm not sure if/how that works with fixtures though.
I'd also add these right here:
I think that could consolidate four test functions into one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the idea of using parameterize, so I added it to the linker init test. In general I write tests to be as minimal as possible, and then to build on each other. ie I like to test add_code_object() in isolation before calling it in another test such as test_linker_link_ptx. This is just something I picked up when working on my last project, as there were a lot of fragile components, and it made it a faster process to determine where the issue was. That said, you have a lot more experience than me, so I am tempted to go with your suggestion of consolidating tests. I've left it as is for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, considering the removal of add_codeObject as a public entrypoint that test can go.