forked from sands-lab/grace
-
Notifications
You must be signed in to change notification settings - Fork 0
/
__init__.py
58 lines (43 loc) · 1.94 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from abc import ABC, abstractmethod
class Memory(ABC):
@abstractmethod
def compensate(self, tensor, name):
"""Update the tensor with the residuals."""
raise NotImplemented("compensate was not implemented.")
def update(self, tensor, name, compressor, tensor_compressed, ctx):
"""Update the residuals."""
pass
class Compressor(ABC):
"""Interface for compressing and decompressing a given tensor."""
def __init__(self, average=True, tensors_size_are_same=True):
self.average = average
self.tensors_size_are_same = tensors_size_are_same
@abstractmethod
def compress(self, tensor, name):
"""Compresses a tensor and returns it with the context needed to decompress it."""
raise NotImplemented("compress was not implemented.")
@abstractmethod
def decompress(self, tensors, ctx):
"""Decompress the tensor with the given context."""
raise NotImplemented("decompress was not implemented.")
def aggregate(self, tensors):
"""Aggregate a list of tensors."""
return sum(tensors)
class Communicator(ABC):
@abstractmethod
def async_send(self, tensors, name):
raise NotImplemented("async_send was not implemented.")
@abstractmethod
def wait_receive(self, handles, ctx):
raise NotImplemented("wait_receive was not implemented.")
def __init__(self, compressor, memory):
self.compressor = compressor
self.memory = memory
def send_step(self, tensor, name):
tensor = self.memory.compensate(tensor, name)
tensors_compressed, ctx = self.compressor.compress(tensor, name)
self.memory.update(tensor, name, self.compressor, tensors_compressed, ctx)
handles = self.async_send(tensors_compressed, name)
return handles, ctx
def receive_step(self, handles, ctx):
return self.wait_receive(handles, ctx)