Skip to content

Commit

Permalink
[Feature] Adding test & train API to be used directly in code (#1138)
Browse files Browse the repository at this point in the history
* [Feature] Adding test & train API

This changes added capability for test & training to be directly invoked
in code, e.g., inside a Jupyter notebook cell.

The change also ensures the original command-line usage remains the
same.

* fix linting issue

* [bug fix] fix polygon points ordering

bug fix for 4-point polygon, sort the points in clock-wise order

* revert changes for PR #1138

Co-authored-by: gaotongxiao <[email protected]>
  • Loading branch information
wybryan and gaotongxiao authored Sep 29, 2022
1 parent d70e1b6 commit b422ded
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 10 deletions.
23 changes: 18 additions & 5 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,20 @@
from mmocr.utils import revert_sync_batchnorm, setup_multi_processes


def parse_args():
class TestArg:

def __init__(self, config=None, checkpoint=None):
self.arg_list = None
if config is not None and checkpoint is not None:
self.arg_list = [config, checkpoint]

def add_arg(self, key, value=None):
self.arg_list.append(key)
if value is not None:
self.arg_list.append(value)


def parse_args(arg_list=None):
parser = argparse.ArgumentParser(
description='MMOCR test (and eval) a model.')
parser.add_argument('config', help='Test config file path.')
Expand Down Expand Up @@ -96,7 +109,7 @@ def parse_args():
default='none',
help='Options for job launcher.')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
args = parser.parse_args(arg_list)
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)

Expand All @@ -110,8 +123,7 @@ def parse_args():
return args


def main():
args = parse_args()
def run_test_cmd(args):

assert (
args.out or args.eval or args.format_only or args.show
Expand Down Expand Up @@ -232,4 +244,5 @@ def main():


if __name__ == '__main__':
main()
args = parse_args(TestArg().arg_list)
run_test_cmd(args)
23 changes: 18 additions & 5 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,20 @@
setup_multi_processes)


def parse_args():
class TrainArg:

def __init__(self, config=None):
self.arg_list = None
if config is not None:
self.arg_list = [config]

def add_arg(self, key, value=None):
self.arg_list.append(key)
if value is not None:
self.arg_list.append(value)


def parse_args(arg_list=None):
parser = argparse.ArgumentParser(description='Train a detector.')
parser.add_argument('config', help='Train config file path.')
parser.add_argument('--work-dir', help='The dir to save logs and models.')
Expand Down Expand Up @@ -85,7 +98,7 @@ def parse_args():
help='Options for job launcher.')
parser.add_argument('--local_rank', type=int, default=0)

args = parser.parse_args()
args = parser.parse_args(arg_list)
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)

Expand All @@ -100,8 +113,7 @@ def parse_args():
return args


def main():
args = parse_args()
def run_train_cmd(args):

cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
Expand Down Expand Up @@ -227,4 +239,5 @@ def main():


if __name__ == '__main__':
main()
args = parse_args(TrainArg().arg_list)
run_train_cmd(args)

0 comments on commit b422ded

Please sign in to comment.