Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference for all the images in one folder is added. #150

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions basicsr/demo_folder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------
import torch

# from basicsr.data import create_dataloader, create_dataset
from basicsr.models import create_model
from basicsr.train import parse_options
from basicsr.utils import FileClient, imfrombytes, img2tensor, padding, tensor2img, imwrite
import os

# from basicsr.utils import (get_env_info, get_root_logger, get_time_str,
# make_exp_dirs)
# from basicsr.utils.options import dict2str

def main():
# parse options, set distributed setting, set ramdom seed
opt = parse_options(is_train=False)
opt['num_gpu'] = torch.cuda.device_count()
input_folder = opt['img_path'].get('input_folder')
output_folder = opt['img_path'].get('output_folder')
# Get a list of all image files in the input folder
image_files = [f for f in os.listdir(input_folder) if os.path.isfile(os.path.join(input_folder, f))]
opt['dist'] = False
model = create_model(opt)
for image_file in image_files:
# Construct the input and output paths for each image
img_path = os.path.join(input_folder, image_file)
output_path = os.path.join(output_folder, image_file)
## 1. read image
file_client = FileClient('disk')
img_bytes = file_client.get(img_path, None)
try:
img = imfrombytes(img_bytes, float32=True)
except:
raise Exception("path {} not working".format(img_path))
img = img2tensor(img, bgr2rgb=True, float32=True)
## 2. run inference

model.feed_data(data={'lq': img.unsqueeze(dim=0)})
if model.opt['val'].get('grids', False):
model.grids()
model.test()
if model.opt['val'].get('grids', False):
model.grids_inverse()
visuals = model.get_current_visuals()
sr_img = tensor2img([visuals['result']])
imwrite(sr_img, output_path)
print(f'inference {img_path} .. finished. saved to {output_path}')
if __name__ == '__main__':
main()
8 changes: 7 additions & 1 deletion basicsr/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ def parse_options(is_train=True):
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)

parser.add_argument('--input_path', type=str, required=False, help='The path to the input image. For single image inference only.')
parser.add_argument('--input_folder', type=str, required=False, help='The path to the input folder. For multiple image inference.')
parser.add_argument('--output_path', type=str, required=False, help='The path to the output image. For single image inference only.')
parser.add_argument('--output_folder', type=str, required=False, help='The path to the output folder. For multiple image inference.')

args = parser.parse_args()
opt = parse(args.opt, is_train=is_train)
Expand Down Expand Up @@ -68,6 +69,11 @@ def parse_options(is_train=True):
'input_img': args.input_path,
'output_img': args.output_path
}
elif args.input_folder is not None and args.output_folder is not None:
opt['img_path'] = {
'input_folder': args.input_folder,
'output_folder': args.output_folder
}

return opt

Expand Down
4 changes: 2 additions & 2 deletions basicsr/version.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# GENERATED VERSION FILE
# TIME: Mon Apr 18 21:35:20 2022
__version__ = '1.2.0+386ca20'
# TIME: Mon Jun 17 23:20:12 2024
__version__ = '1.2.0+2b4af71'
short_version = '1.2.0'
version_info = (1, 2, 0)
Binary file modified demo/denoise_img.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.