Skip to content

Commit

Permalink
Merge pull request #128 from m3dev/rand-dev
Browse files Browse the repository at this point in the history
Fix random seed
  • Loading branch information
vaaaaanquish authored Mar 12, 2020
2 parents 96c078d + 068c508 commit ba1dc49
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 28 deletions.
42 changes: 42 additions & 0 deletions examples/sample_fix_random_seed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import random

import gokart
import luigi
import numpy as np


class SampleTask(gokart.TaskOnKart):
task_namespace = 'sample_fix_random_seed'
sample_param = luigi.Parameter()

def run(self):
x = [random.randint(0, 100) for _ in range(0, 10)]
y = [np.random.randint(0, 100) for _ in range(0, 10)]
try:
import torch
z = [torch.randn(1).tolist()[0] for _ in range(0, 5)]
except ImportError:
z = []
self.dump({'random': x, 'numpy': y, 'torch': z})


if __name__ == '__main__':
# //---------------------------------------------------------------------
# Please set fix_random_seed_methods parameter.
# Change seed if you change sample_param.
#
# //--- The output is as follows every time (with pytorch installed). ---
# {'random': [65, 41, 61, 37, 55, 81, 48, 2, 94, 21],
# 'numpy': [79, 86, 5, 22, 79, 98, 56, 40, 81, 37], 'torch': []}
# 'torch': [0.14460121095180511, -0.11649507284164429,
# 0.6928958296775818, -0.916053831577301, 0.7317505478858948]}
#
# //------------------------- without pytorch ---------------------------
# {'random': [65, 41, 61, 37, 55, 81, 48, 2, 94, 21],
# 'numpy': [79, 86, 5, 22, 79, 98, 56, 40, 81, 37], 'torch': []}
#
# //---------------------------------------------------------------------
gokart.run([
'sample_fix_random_seed.SampleTask', '--local-scheduler', '--rerun', '--sample-param=a',
'--fix-random-seed-methods=["random.seed","numpy.random.seed","torch.random.manual_seed"]', '--fix-random-seed-value=57'
])
90 changes: 62 additions & 28 deletions gokart/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,24 @@ class TaskOnKart(luigi.Task):
* :py:meth:`dump` - this save a object as output of this task.
"""

workspace_directory = luigi.Parameter(
default='./resources/', description='A directory to set outputs on. Please use a path starts with s3:// when you use s3.',
significant=False) # type: str
workspace_directory = luigi.Parameter(default='./resources/',
description='A directory to set outputs on. Please use a path starts with s3:// when you use s3.',
significant=False) # type: str
local_temporary_directory = luigi.Parameter(default='./resources/tmp/', description='A directory to save temporary files.', significant=False) # type: str
rerun = luigi.BoolParameter(default=False, description='If this is true, this task will run even if all output files exist.', significant=False)
strict_check = luigi.BoolParameter(
default=False, description='If this is true, this task will not run only if all input and output files exist.', significant=False)
modification_time_check = luigi.BoolParameter(
default=False,
description='If this is true, this task will not run only if all input and output files exist,'
' and all input files are modified before output file are modified.',
significant=False)
strict_check = luigi.BoolParameter(default=False,
description='If this is true, this task will not run only if all input and output files exist.',
significant=False)
modification_time_check = luigi.BoolParameter(default=False,
description='If this is true, this task will not run only if all input and output files exist,'
' and all input files are modified before output file are modified.',
significant=False)
delete_unnecessary_output_files = luigi.BoolParameter(default=False, description='If this is true, delete unnecessary output files.', significant=False)
significant = luigi.BoolParameter(
default=True,
description='If this is false, this task is not treated as a part of dependent tasks for the unique id.',
significant=False)
significant = luigi.BoolParameter(default=True,
description='If this is false, this task is not treated as a part of dependent tasks for the unique id.',
significant=False)
fix_random_seed_methods = luigi.ListParameter(default=['random.seed', 'numpy.random.seed'], description='Fix random seed method list.', significant=False)
fix_random_seed_value = luigi.IntParameter(default=None, description='Fix random seed method value.', significant=False)

def __init__(self, *args, **kwargs):
self._add_configuration(kwargs, self.get_task_family())
Expand Down Expand Up @@ -132,15 +133,14 @@ def make_target(self, relative_file_path: str, use_unique_id: bool = True, proce
unique_id = self.make_unique_id() if use_unique_id else None
return gokart.target.make_target(file_path=file_path, unique_id=unique_id, processor=processor)

def make_large_data_frame_target(self, relative_file_path: str, use_unique_id: bool = True, max_byte=int(2 ** 26)) -> TargetOnKart:
def make_large_data_frame_target(self, relative_file_path: str, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart:
file_path = os.path.join(self.workspace_directory, relative_file_path)
unique_id = self.make_unique_id() if use_unique_id else None
return gokart.target.make_model_target(
file_path=file_path,
temporary_directory=self.local_temporary_directory,
unique_id=unique_id,
save_function=gokart.target.LargeDataFrameProcessor(max_byte=max_byte).save,
load_function=gokart.target.LargeDataFrameProcessor.load)
return gokart.target.make_model_target(file_path=file_path,
temporary_directory=self.local_temporary_directory,
unique_id=unique_id,
save_function=gokart.target.LargeDataFrameProcessor(max_byte=max_byte).save,
load_function=gokart.target.LargeDataFrameProcessor.load)

def make_model_target(self,
relative_file_path: str,
Expand All @@ -158,12 +158,11 @@ def make_model_target(self,
file_path = os.path.join(self.workspace_directory, relative_file_path)
assert relative_file_path[-3:] == 'zip', f'extension must be zip, but {relative_file_path} is passed.'
unique_id = self.make_unique_id() if use_unique_id else None
return gokart.target.make_model_target(
file_path=file_path,
temporary_directory=self.local_temporary_directory,
unique_id=unique_id,
save_function=save_function,
load_function=load_function)
return gokart.target.make_model_target(file_path=file_path,
temporary_directory=self.local_temporary_directory,
unique_id=unique_id,
save_function=save_function,
load_function=load_function)

def load(self, target: Union[None, str, TargetOnKart] = None) -> Any:
def _load(targets):
Expand All @@ -188,10 +187,13 @@ def _load(targets):

return _load(self._get_input_targets(target))

def load_data_frame(self, target: Union[None, str, TargetOnKart] = None, required_columns: Optional[Set[str]] = None,
def load_data_frame(self,
target: Union[None, str, TargetOnKart] = None,
required_columns: Optional[Set[str]] = None,
drop_columns: bool = False) -> pd.DataFrame:
data = self.load(target=target)
if isinstance(data, list):

def _pd_concat(dfs):
if isinstance(dfs, list):
return pd.concat([_pd_concat(df) for df in dfs])
Expand Down Expand Up @@ -279,6 +281,38 @@ def get_task_params(self) -> Dict:
return self.load(target)
return dict()

@luigi.Task.event_handler(luigi.Event.START)
def _set_random_seed(self):
random_seed = self._get_random_seed()
seed_methods = self.try_set_seed(self.fix_random_seed_methods, random_seed)
self.dump({'seed': random_seed, 'seed_methods': seed_methods}, self._get_random_seeds_target())

def _get_random_seeds_target(self):
return self.make_target(f'log/random_seed/{type(self).__name__}.pkl')

@staticmethod
def try_set_seed(methods: List[str], random_seed: int) -> List[str]:
success_methods = []
for method_name in methods:
try:
for i, x in enumerate(method_name.split('.')):
if i == 0:
m = import_module(x)
else:
m = getattr(m, x)
m(random_seed)
success_methods.append(method_name)
except ModuleNotFoundError:
pass
except AttributeError:
pass
return success_methods

def _get_random_seed(self):
if self.fix_random_seed_value:
return self.fix_random_seed_value
return int(self.make_unique_id(), 16) % (2**32 - 1) # maximum numpy.random.seed

@luigi.Task.event_handler(luigi.Event.START)
def _dump_task_params(self):
self.dump(self.to_str_params(only_significant=True), self._get_task_params_target())
Expand Down

0 comments on commit ba1dc49

Please sign in to comment.