Skip to content

Commit

Permalink
[chore] refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
botbw committed Oct 14, 2024
1 parent b3744b6 commit 3c08659
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 18 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
packaging
click
torch
colossalai>=0.4.4
# colossalai>=0.4.4
18 changes: 1 addition & 17 deletions tensornvme/async_file_io.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import ctypes
from functools import partial
import torch

from typing import Dict, List
from typing import List
from io import IOBase
from tensornvme._C import AsyncFileWriter as AsyncFileWriterC

from colossalai.utils.safetensors import prepare

class AsyncFileWriter:
def __init__(self, fp: IOBase, n_entries: int = 16, backend=None) -> None:
fd = fp.fileno()
Expand All @@ -34,19 +31,6 @@ def write_raw(self, py_ref: object, buffer: int, n_bytes: int, offset: int) -> N
self.io.write(buffer, n_bytes, offset, partial(AsyncFileWriter.gc_callback, self.buffers, len(self.buffers) - 1))
self.offset += n_bytes

def save(
self,
state_dict: Dict[str, torch.Tensor]
) -> None:
prepared_data, tensors = prepare(state_dict)
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset

self.write(n.to_bytes(8, byteorder='little'))
self.write(header_bytes)

for tensor in tensors:
self.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), self.offset)

@staticmethod
def gc_callback(listt: List, idx: int) -> None:
listt[idx] = None
Expand Down

0 comments on commit 3c08659

Please sign in to comment.