Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed functions iterating over tensors from torch compilation process #224

Open
wants to merge 1 commit into
base: habana-main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions server/text_generation_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def biggest_single_chunk(offset):
return 0


@torch_compile_for_eager

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't use torch._dynamo.graph_break() in place of mark_step instead of removing compilation of graph?
same question applies to all below cases

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the change cause the ops within the function to be executed eagerly?

Copy link
Author

@jczaja jczaja Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the change cause the ops within the function to be executed eagerly?

I assume(I have not run logging of fallback to eager events) that functions excluded from torch compile regions (as done in this PR) are now running eager e.g. pytorch ops from code that got torch compile decorator discarded are running eager .

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some of internal testing it was revealed that excluding those functions(as in this PR) from torch compile region did not have an impact on performance or accuracy.

def grouped_pad(tensor_groups, dims, values):
grouped_result = []
for tensors, dim, value in zip(tensor_groups, dims, values):
Expand Down Expand Up @@ -125,7 +124,6 @@ def grouped_roll(tensor_groups, chunk, dims, merge_graphs):
return tensor_groups


@torch_compile_for_eager
def grouped_shift(tensor_groups, dims, offset, merge_graphs):
chunks = calculate_chunks(offset)
for c in chunks:
Expand Down Expand Up @@ -155,7 +153,6 @@ def extend_tensor(tensor, padding, dim):
return result


@torch_compile_for_eager
def extend_batch(tensors, target_bs, dim):
diff = target_bs - tensors[0].size(dim)
# TODO: add support for shrinking bs
Expand All @@ -173,14 +170,12 @@ def grouped_extend_batch(tensor_groups, target_bs, bs_dims):
return tensor_groups


@torch_compile_for_eager
def merge(tensor_group):
tensor_group = [torch.stack(tensor_group)]
htorch.core.mark_step()
return tensor_group


@torch_compile_for_eager
def split(tensor_group, clone_data):
tensor_group = [t.squeeze(0) for t in torch.split(tensor_group[0], 1)]
if clone_data:
Expand Down