Skip to content

Commit

Permalink
fix bug (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
GreatV authored Aug 27, 2024
1 parent 0080f9a commit 711d915
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 40 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "labelme2yolo"
version = "0.2.4"
version = "0.2.5"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
67 changes: 29 additions & 38 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
use std::collections::HashMap;
use std::fs::{self, copy, File};
use std::io::{BufWriter, Write};
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::{
atomic::{AtomicUsize, Ordering::Relaxed},
Arc, Mutex,
};

use clap::{Parser, ValueEnum};
use env_logger;
use glob::glob;
Expand All @@ -9,15 +19,6 @@ use rand::SeedableRng;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use serde_json;
use std::collections::HashMap;
use std::fs::{self, copy, File};
use std::io::{BufWriter, Write};
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
};

#[derive(Debug, Serialize, Deserialize, Clone)]
struct Shape {
Expand Down Expand Up @@ -46,45 +47,33 @@ struct ImageAnnotation {
#[command(version, about = "Convert LabelMe JSON to YOLO format", long_about = None)]
struct Args {
// Directory containing LabelMe JSON files
#[arg(
short = 'd',
long = "json_dir",
help = "Directory containing LabelMe JSON files"
)]
#[arg(short = 'd', long = "json_dir")]
json_dir: String,

// Proportion of the dataset to use for validation
#[arg(long = "val_size", default_value_t = 0.2, value_parser = validate_size, help = "Proportion of the dataset to use for validation (between 0.0 and 1.0)")]
#[arg(long = "val_size", default_value_t = 0.2, value_parser = validate_size)]
val_size: f32,

// Proportion of the dataset to use for testing
#[arg(long = "test_size", default_value_t = 0.0, value_parser = validate_size, help = "Proportion of the dataset to use for testing (between 0.0 and 1.0)")]
#[arg(long = "test_size", default_value_t = 0.0, value_parser = validate_size)]
test_size: f32,

// Output format (bbox or polygon) for YOLO annotations
// Output format for YOLO annotations: 'bbox' or 'polygon'
#[arg(
long = "output_format",
visible_alias = "format",
value_enum,
default_value = "bbox",
help = "Output format for YOLO annotations: 'bbox' or 'polygon'"
default_value = "bbox"
)]
output_format: Format,

// List of labels in the dataset
#[arg(
use_value_delimiter = true,
help = "Comma-separated list of labels in the dataset"
)]
label_list: Vec<String>,

// Seed for random shuffling
#[arg(
long = "seed",
default_value_t = 42,
help = "Seed for random shuffling"
)]
#[arg(long = "seed", default_value_t = 42)]
seed: u64,

// List of labels in the dataset
#[arg(use_value_delimiter = true)]
label_list: Vec<String>,
}

// Enumeration for the YOLO output format
Expand Down Expand Up @@ -159,9 +148,10 @@ fn create_output_directory(path: &Path) -> std::io::Result<PathBuf> {
"Directory {:?} already exists. Deleting and recreating it.",
path
);
fs::remove_dir_all(path)?;
fs::remove_dir_all(path).and_then(|_| fs::create_dir_all(path))?;
} else {
fs::create_dir_all(path)?;
}
fs::create_dir_all(path)?;
Ok(path.to_path_buf())
}

Expand Down Expand Up @@ -264,7 +254,7 @@ fn initialize_label_map(
for (id, label) in args.label_list.iter().enumerate() {
map.insert(label.clone(), id);
}
next_class_id.store(args.label_list.len(), Ordering::Relaxed);
next_class_id.store(args.label_list.len(), Relaxed);
} else {
// Otherwise, use labels found in the dataset
split_data
Expand All @@ -275,7 +265,7 @@ fn initialize_label_map(
.flat_map(|(_, annotation)| annotation.shapes.iter())
.for_each(|shape| {
if !map.contains_key(&shape.label) {
let new_id = next_class_id.fetch_add(1, Ordering::Relaxed);
let new_id = next_class_id.fetch_add(1, Relaxed);
map.insert(shape.label.clone(), new_id);
}
});
Expand Down Expand Up @@ -392,7 +382,7 @@ fn process_annotation(

let yolo_data = convert_to_yolo_format(annotation, args, label_map);

let sanitized_name = sanitize_filename::sanitize(path.file_stem().unwrap().to_str().unwrap());
let sanitized_name = sanitize_filename::sanitize(path.file_name().unwrap().to_str().unwrap());
let output_path = labels_dir.join(&sanitized_name).with_extension("txt");

let file = File::create(&output_path)?;
Expand Down Expand Up @@ -474,11 +464,12 @@ fn process_polygon_shape(yolo_data: &mut String, annotation: &ImageAnnotation, s
yolo_data.push_str(&format!(" {:.6} {:.6}", x_norm, y_norm));
}
} else if shape.shape_type == "circle" {
const CIRCLE_POINTS: usize = 12;
let (cx, cy) = shape.points[0];
let (px, py) = shape.points[1];
let radius = ((cx - px).powi(2) + (cy - py).powi(2)).sqrt();
for i in 0..12 {
let angle = 2.0 * std::f64::consts::PI * i as f64 / 12.0;
for i in 0..CIRCLE_POINTS {
let angle = 2.0 * std::f64::consts::PI * i as f64 / CIRCLE_POINTS as f64;
let x = cx + radius * angle.cos();
let y = cy + radius * angle.sin();
let x_norm = x / annotation.image_width as f64;
Expand Down

0 comments on commit 711d915

Please sign in to comment.