Skip to content

Commit

Permalink
Merge pull request #71 from pytorch-labs/wrapper
Browse files Browse the repository at this point in the history
consistent wrapper for ET & AOTI
  • Loading branch information
mikekgfb authored Apr 7, 2024
2 parents 301bf31 + c78662f commit 06c9bb6
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 24 deletions.
23 changes: 22 additions & 1 deletion export.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,26 @@ def device_sync(device):
else:
print(f"device={device} is not yet suppported")


class model_wrapper(nn.Module):
def __init__(self, model, device):
super().__init__()

max_seq_length = 350
with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)

self.model = model
# init model here if necessary

def forward(self, idx, input_pos):
# input_pos: [B, 1]
# assert failed on symbolic shape during aot_compile?!
# but not for ET?
# assert input_pos.shape[-1] == 1
logits = self.model(idx, input_pos)
return logits # sample(logits, **sampling_kwargs)


def main(checkpoint_path, device, quantize = "{ }", args = None):
assert checkpoint_path.is_file(), checkpoint_path
Expand All @@ -53,7 +73,8 @@ def main(checkpoint_path, device, quantize = "{ }", args = None):
print(f"Time to load model: {time.time() - t0:.02f} seconds")

quantize_model(model, args.quantize)

model = model_wrapper(model, device=device)

output_pte_path = args.output_pte_path
output_dso_path = args.output_dso_path

Expand Down
4 changes: 2 additions & 2 deletions export_aoti.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def device_sync(device):

def export_model(model: nn.Module, device, output_path, args=None):
max_seq_length = 350
with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
# with torch.device(device):
# model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)

input = (
torch.tensor([[1, 9038, 2501, 263, 931]], dtype=torch.int, device=device),
Expand Down
25 changes: 4 additions & 21 deletions export_et.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,32 +70,15 @@ def materialze_broadcast_of_rope_freq_cis(
return module


class model_wrapper(nn.Module):
def __init__(self, model, device):
super().__init__()

max_seq_length = 350
with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)

self.model = model
# init model here if necessary

def forward(self, x, input_pos):
# input_pos: [B, 1]
assert input_pos.shape[-1] == 1
logits = self.model(x, input_pos)
return logits # sample(logits, **sampling_kwargs)


def canonical_path(path):
return path

## align AOTI and ET export
# def export_model(model: nn.Module, device, output_path):

def export_model(model, device, output_path, args=None) -> str: # noqa: C901

export_model = model_wrapper(model, device=device)
# applied wrapper already in export.
# export_model = model_wrapper(model, device=device)
export_model = model
print(export_model)

input = (
Expand Down

0 comments on commit 06c9bb6

Please sign in to comment.