Skip to content

Commit

Permalink
Added --data-path to densenet and alexnet
Browse files Browse the repository at this point in the history
  • Loading branch information
graham63 committed Mar 6, 2023
1 parent 7392542 commit 5fbe41e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
6 changes: 5 additions & 1 deletion applications/vision/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
parser.add_argument(
'--num-classes', action='store', default=1000, type=int,
help='number of ImageNet classes (default: 1000)', metavar='NUM')
parser.add_argument(
'--data-path', action='store', default=None, type=str,
help='Path to top-level imagenet directory. default: None')
lbann.contrib.args.add_optimizer_arguments(parser)
args = parser.parse_args()

Expand Down Expand Up @@ -64,7 +67,8 @@
opt = lbann.contrib.args.create_optimizer(args)

# Setup data reader
data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes)
data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes,
data_path=args.data_path)

# Setup trainer
trainer = lbann.Trainer(mini_batch_size=args.mini_batch_size)
Expand Down
8 changes: 6 additions & 2 deletions applications/vision/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,8 @@ def get_args():
parser.add_argument("--print-matrix-summary", dest="print_matrix_summary",
action="store_const",
const=True, default=False)
parser.add_argument('--data-path', action='store', default=None, type=str,
help='Path to top-level imagenet directory. default: None')
args = parser.parse_args()
return args

Expand All @@ -438,7 +440,7 @@ def set_up_experiment(args,
labels):
algo = lbann.BatchedIterativeOptimizer("sgd", epoch_count=args.num_epochs)


# Set up objective function
cross_entropy = lbann.CrossEntropy([probs, labels])
layers = list(lbann.traverse_layer_graph(input_))
Expand Down Expand Up @@ -472,7 +474,9 @@ def set_up_experiment(args,
callbacks=callbacks)

# Set up data reader
data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes, small_testing=True)
data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes,
small_testing=True,
data_path=args.data_path)

percentage = 0.001 * 2 * (args.mini_batch_size / 16) * 2

Expand Down

0 comments on commit 5fbe41e

Please sign in to comment.