Skip to content

Commit

Permalink
add seed to algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
levongh committed Aug 17, 2023
1 parent 5165909 commit 85f1eaf
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 11 deletions.
3 changes: 3 additions & 0 deletions deeplake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .core.dataset import Dataset
from .core.transform import compute, compose
from .core.tensor import Tensor
from .core.seed import DeeplakeRandom
from .util.bugout_reporter import deeplake_reporter
from .compression import SUPPORTED_COMPRESSIONS
from .htype import HTYPE_CONFIGURATIONS
Expand All @@ -50,6 +51,7 @@
ingest_huggingface = huggingface.ingest_huggingface
dataset = api_dataset.init # type: ignore
tensor = Tensor
random = DeeplakeRandom()

__all__ = [
"tensor",
Expand All @@ -76,6 +78,7 @@
"delete",
"copy",
"rename",
"random",
]


Expand Down
8 changes: 2 additions & 6 deletions deeplake/client/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import deeplake
import requests
from typing import Any, Optional, Dict
from deeplake.core.seed import DeeplakeRandom
from deeplake.util.exceptions import (
AgreementNotAcceptedError,
AuthorizationException,
Expand Down Expand Up @@ -514,9 +515,4 @@ def get_seed(self) -> Optional[int]:
"""
Get the seed used in library
"""
import numpy as np

try:
return np.random.get_state()[1][0]
except Exception:
return None
return DeeplakeRandom().get_seed()
20 changes: 20 additions & 0 deletions deeplake/core/seed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import numpy as np
from typing import Optional

class DeeplakeRandom(object):
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(DeeplakeRandom, cls).__new__(cls)
cls.instance.internal_seed = None
return cls.instance

def seed(self, seed: Optional[int] = None):
if isinstance(seed, Optional[int]):
self.internal_seed = seed
else:
raise TypeError(f"provided seed type `{type(seed)}` is increect seed must be an integer")

def get_seed(self) -> Optional[int]:
return self.internal_seed


8 changes: 5 additions & 3 deletions deeplake/core/tests/test_deeplake_indra_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ def test_query_tensors_polygon_htype_consistency(local_auth_ds_generator):
@requires_libdeeplake
def test_random_split_with_seed(local_auth_ds_generator):
deeplake_ds = local_auth_ds_generator()
from deeplake.core.seed import DeeplakeRandom

with deeplake_ds:
deeplake_ds.create_tensor("label", htype="generic", dtype=np.int32)
for i in range(1000):
Expand All @@ -304,17 +306,17 @@ def test_random_split_with_seed(local_auth_ds_generator):
deeplake_indra_ds = deeplake_ds.query("SELECT * GROUP BY label")

initial_state = np.random.get_state()
np.random.seed(100)
DeeplakeRandom().seed(100)
split1 = deeplake_indra_ds.random_split([0.2, 0.2, 0.6])
assert len(split1) == 3
assert len(split1[0]) == 20

np.random.seed(101)
DeeplakeRandom().seed(101)
split2 = deeplake_indra_ds.random_split([0.2, 0.2, 0.6])
assert len(split2) == 3
assert len(split2[0]) == 20

np.random.seed(100)
DeeplakeRandom().seed(100)
split3 = deeplake_indra_ds.random_split([0.2, 0.2, 0.6])
assert len(split3) == 3
assert len(split3[0]) == 20
Expand Down
2 changes: 1 addition & 1 deletion deeplake/integrations/pytorch/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def validate_decode_method(
for tensor_name, decode_method in decode_method.items():
if tensor_name not in all_tensor_keys:
raise ValueError(
"tensor {tensor_name} specified in decode_method not found in tensors."
f"tensor {tensor_name} specified in decode_method not found in tensors."
)
if tensor_name in jpeg_png_compressed_tensors_set:
if decode_method not in jpeg_png_supported_decode_methods:
Expand Down
4 changes: 3 additions & 1 deletion deeplake/integrations/pytorch/shuffle_buffer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Any, Sequence
import random
from random import randrange
from functools import reduce
from operator import mul
Expand Down Expand Up @@ -30,7 +31,8 @@ class ShuffleBuffer:
def __init__(self, size: int) -> None:
if size <= 0:
raise ValueError("Buffer size should be positive value more than zero")

from deeplake.core.seed import DeeplakeRandom
random.seed(DeeplakeRandom().get_seed())
self.size = size
self.buffer: List[Any] = list()
self.buffer_used = 0
Expand Down
4 changes: 4 additions & 0 deletions deeplake/util/scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from collections import defaultdict
from deeplake.core.meta.encode.chunk_id import ChunkIdEncoder
from deeplake.core.seed import DeeplakeRandom


def find_primary_tensor(dataset):
Expand Down Expand Up @@ -37,6 +38,8 @@ def create_fetching_schedule(dataset, primary_tensor_name, shuffle_within_chunks
enc_array = chunk_id_encoder.array
num_chunks = chunk_id_encoder.num_chunks
# pick chunks randomly, one by one
prev_state = np.random.get_state()
np.random.seed(DeeplakeRandom().get_seed())
chunk_order = np.random.choice(num_chunks, num_chunks, replace=False)
schedule = []
for chunk_idx in chunk_order:
Expand All @@ -52,6 +55,7 @@ def create_fetching_schedule(dataset, primary_tensor_name, shuffle_within_chunks
elif isinstance(index_struct, dict):
idxs = filter(lambda idx: idx in index_struct, schedule)
schedule = [int(idx) for idx in idxs for _ in range(index_struct[idx])]
np.random.set_state(prev_state)
return schedule


Expand Down
5 changes: 5 additions & 0 deletions deeplake/util/shuffle.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import numpy as np
from deeplake.core.seed import DeeplakeRandom



def shuffle(ds):
"""Returns a shuffled wrapper of a given Dataset."""
prev_state = np.random.get_state()
np.random.seed(DeeplakeRandom().get_seed())
idxs = np.arange(len(ds))
np.random.shuffle(idxs)
np.random.set_state(prev_state)
return ds[idxs.tolist()]

0 comments on commit 85f1eaf

Please sign in to comment.