Skip to content

Commit

Permalink
feat: port ClassificationResult and Landmark
Browse files Browse the repository at this point in the history
  • Loading branch information
homuler committed Aug 12, 2023
1 parent 915f22d commit fbecbba
Show file tree
Hide file tree
Showing 15 changed files with 441 additions and 16 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Copyright (c) 2023 homuler
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

using System.Collections.Generic;

namespace Mediapipe.Tasks.Components.Containers
{
/// <summary>
/// Defines classification results for a given classifier head.
/// </summary>
public readonly struct Classifications
{
/// <summary>
/// The array of predicted categories, usually sorted by descending scores,
/// e.g. from high to low probability.
/// </summary>
public readonly IReadOnlyList<Category> categories;
/// <summary>
/// The index of the classifier head (i.e. output tensor) these categories
/// refer to. This is useful for multi-head models.
/// </summary>
public readonly int headIndex;
/// <summary>
/// The optional name of the classifier head, as provided in the TFLite Model
/// Metadata [1] if present. This is useful for multi-head models.
///
/// [1]: https://www.tensorflow.org/lite/convert/metadata
/// </summary>
public readonly string headName;

internal Classifications(IReadOnlyList<Category> categories, int headIndex, string headName)
{
this.categories = categories;
this.headIndex = headIndex;
this.headName = headName;
}

public static Classifications CreateFrom(Proto.Classifications proto)
{
var categories = new List<Category>(proto.ClassificationList.Classification.Count);
foreach (var classification in proto.ClassificationList.Classification)
{
categories.Add(Category.CreateFrom(classification));
}
return new Classifications(categories, proto.HeadIndex, proto.HasHeadName ? proto.HeadName : null);
}

public static Classifications CreateFrom(ClassificationList proto, int headIndex = 0, string headName = null)
{
var categories = new List<Category>(proto.Classification.Count);
foreach (var classification in proto.Classification)
{
categories.Add(Category.CreateFrom(classification));
}
return new Classifications(categories, headIndex, headName);
}

public override string ToString()
=> $"{{ \"categories\": {Util.Format(categories)}, \"headIndex\": {headIndex}, \"headName\": {Util.Format(headName)} }}";
}

/// <summary>
/// Defines classification results of a model.
/// </summary>
public readonly struct ClassificationResult
{
/// <summary>
/// The classification results for each head of the model.
/// </summary>
public readonly IReadOnlyList<Classifications> classifications;

/// <summary>
/// The optional timestamp (in milliseconds) of the start of the chunk of data
/// corresponding to these results.
///
/// This is only used for classification on time series (e.g. audio
/// classification). In these use cases, the amount of data to process might
/// exceed the maximum size that the model can process: to solve this, the
/// input data is split into multiple chunks starting at different timestamps.
/// </summary>
public readonly long? timestampMs;

internal ClassificationResult(IReadOnlyList<Classifications> classifications, long? timestampMs)
{
this.classifications = classifications;
this.timestampMs = timestampMs;
}

public static ClassificationResult CreateFrom(Proto.ClassificationResult proto)
{
var classifications = new List<Classifications>(proto.Classifications.Count);
foreach (var classification in proto.Classifications)
{
classifications.Add(Classifications.CreateFrom(classification));
}
return new ClassificationResult(classifications, proto.HasTimestampMs ? proto.TimestampMs : null);
}

public override string ToString() => $"{{ \"classifications\": {Util.Format(classifications)}, \"timestampMs\": {Util.Format(timestampMs)} }}";
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
// https://opensource.org/licenses/MIT.

using System.Collections.Generic;
using System.Linq;

namespace Mediapipe.Tasks.Components.Containers
{
Expand Down Expand Up @@ -81,11 +80,7 @@ public static Detection CreateFrom(Mediapipe.Detection proto)
}

public override string ToString()
{
var categoriesStr = $"[{string.Join(", ", categories.Select(category => category.ToString()))}]";
var keypointsStr = keypoints == null ? "null" : $"[{string.Join(", ", keypoints.Select(keypoint => keypoint.ToString()))}]";
return $"{{\"categories\": {categoriesStr}, \"boundingBox\": {boundingBox}, \"keypoints\": {keypointsStr}}}";
}
=> $"{{ \"categories\": {Util.Format(categories)}, \"boundingBox\": {boundingBox}, \"keypoints\": {Util.Format(keypoints)} }}";
}

/// <summary>
Expand Down Expand Up @@ -113,10 +108,6 @@ public static DetectionResult CreateFrom(IReadOnlyList<Mediapipe.Detection> dete
return new DetectionResult(detections);
}

public override string ToString()
{
var detectionsStr = string.Join(", ", detections.Select(detection => detection.ToString()));
return $"{{ \"detections\": [{detectionsStr}] }}";
}
public override string ToString() => $"{{ \"detections\": {Util.Format(detections)} }}";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,6 @@ internal NormalizedKeypoint(float x, float y, string label, float? score)
this.score = score;
}

public override string ToString()
{
var scoreStr = score == null ? "null" : $"{score}";
return $"{{ \"x\": {x}, \"y\": {y}, \"label\": \"{label}\", \"score\": {scoreStr} }}";
}
public override string ToString() => $"{{ \"x\": {x}, \"y\": {y}, \"label\": \"{label}\", \"score\": {Util.Format(score)} }}";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
// Copyright (c) 2023 homuler
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

using System;
using System.Collections.Generic;

// TODO: use System.MathF
using Mathf = UnityEngine.Mathf;

namespace Mediapipe.Tasks.Components.Containers
{
/// <summary>
/// Landmark represents a point in 3D space with x, y, z coordinates. The
/// landmark coordinates are in meters. z represents the landmark depth, and the
/// smaller the value the closer the world landmark is to the camera.
/// </summary>
public readonly struct Landmark : IEquatable<Landmark>
{
private const float _LandmarkTolerance = 1e-6f;

public readonly float x;
public readonly float y;
public readonly float z;
/// <summary>
/// Landmark visibility. Should stay unset if not supported.
/// Float score of whether landmark is visible or occluded by other objects.
/// Landmark considered as invisible also if it is not present on the screen
/// (out of scene bounds). Depending on the model, visibility value is either a
/// sigmoid or an argument of sigmoid.
/// </summary>
public readonly float? visibility;
/// <summary>
/// Landmark presence. Should stay unset if not supported.
/// Float score of whether landmark is present on the scene (located within
/// scene bounds). Depending on the model, presence value is either a result of
/// sigmoid or an argument of sigmoid function to get landmark presence
/// probability.
/// </summary>
public readonly float? presence;
/// <summary>
/// Landmark name. Should stay unset if not supported.
/// </summary>
public readonly string name;

internal Landmark(float x, float y, float z, float? visibility, float? presence) : this(x, y, z, visibility, presence, null)
{
}

internal Landmark(float x, float y, float z, float? visibility, float? presence, string name)
{
this.x = x;
this.y = y;
this.z = z;
this.visibility = visibility;
this.presence = presence;
this.name = name;
}

#nullable enable
public override bool Equals(object? obj) => obj is Landmark other && Equals(other);
#nullable disable

bool IEquatable<Landmark>.Equals(Landmark other)
{
return Mathf.Abs(x - other.x) < _LandmarkTolerance &&
Mathf.Abs(y - other.y) < _LandmarkTolerance &&
Mathf.Abs(z - other.z) < _LandmarkTolerance;
}

// TODO: use HashCode.Combine
public override int GetHashCode() => Tuple.Create(x, y, z).GetHashCode();
public static bool operator ==(in Landmark lhs, in Landmark rhs) => lhs.Equals(rhs);
public static bool operator !=(in Landmark lhs, in Landmark rhs) => !(lhs == rhs);

public static Landmark CreateFrom(Mediapipe.Landmark proto)
{
return new Landmark(
proto.X, proto.Y, proto.Z,
proto.HasVisibility ? proto.Visibility : null,
proto.HasPresence ? proto.Presence : null);
}

public override string ToString()
=> $"{{ \"x\": {x}, \"y\": {y}, \"z\": {z}, \"visibility\": {Util.Format(visibility)}, \"presence\": {Util.Format(presence)}, \"name\": \"{name}\" }}";
}

/// <summary>
/// A normalized version of above Landmark struct. All coordinates should be
/// within [0, 1].
/// </summary>
public readonly struct NormalizedLandmark : IEquatable<NormalizedLandmark>
{
private const float _LandmarkTolerance = 1e-6f;

public readonly float x;
public readonly float y;
public readonly float z;
public readonly float? visibility;
public readonly float? presence;
public readonly string name;

internal NormalizedLandmark(float x, float y, float z, float? visibility, float? presence) : this(x, y, z, visibility, presence, null)
{
}

internal NormalizedLandmark(float x, float y, float z, float? visibility, float? presence, string name)
{
this.x = x;
this.y = y;
this.z = z;
this.visibility = visibility;
this.presence = presence;
this.name = name;
}

#nullable enable
public override bool Equals(object? obj) => obj is NormalizedLandmark other && Equals(other);
#nullable disable

bool IEquatable<NormalizedLandmark>.Equals(NormalizedLandmark other)
{
return Mathf.Abs(x - other.x) < _LandmarkTolerance &&
Mathf.Abs(y - other.y) < _LandmarkTolerance &&
Mathf.Abs(z - other.z) < _LandmarkTolerance;
}

// TODO: use HashCode.Combine
public override int GetHashCode() => Tuple.Create(x, y, z).GetHashCode();
public static bool operator ==(in NormalizedLandmark lhs, in NormalizedLandmark rhs) => lhs.Equals(rhs);
public static bool operator !=(in NormalizedLandmark lhs, in NormalizedLandmark rhs) => !(lhs == rhs);

public static NormalizedLandmark CreateFrom(Mediapipe.NormalizedLandmark proto)
{
return new NormalizedLandmark(
proto.X, proto.Y, proto.Z,
proto.HasVisibility ? proto.Visibility : null,
proto.HasPresence ? proto.Presence : null);
}

public override string ToString()
=> $"{{ \"x\": {x}, \"y\": {y}, \"z\": {z}, \"visibility\": {Util.Format(visibility)}, \"presence\": {Util.Format(presence)}, \"name\": \"{name}\" }}";
}

/// <summary>
/// A list of Landmarks.
/// </summary>
public readonly struct Landmarks
{
public readonly IReadOnlyList<Landmark> landmarks;

internal Landmarks(IReadOnlyList<Landmark> landmarks)
{
this.landmarks = landmarks;
}

public static Landmarks CreateFrom(LandmarkList proto)
{
var landmarks = new List<Landmark>(proto.Landmark.Count);
foreach (var landmark in proto.Landmark)
{
landmarks.Add(Landmark.CreateFrom(landmark));
}
return new Landmarks(landmarks);
}

public override string ToString() => $"{{ \"landmarks\": {Util.Format(landmarks)} }}";
}

/// <summary>
/// A list of NormalizedLandmarks.
/// </summary>
public readonly struct NormalizedLandmarks
{
public readonly IReadOnlyList<NormalizedLandmark> landmarks;

internal NormalizedLandmarks(IReadOnlyList<NormalizedLandmark> landmarks)
{
this.landmarks = landmarks;
}

public static NormalizedLandmarks CreateFrom(NormalizedLandmarkList proto)
{
var landmarks = new List<NormalizedLandmark>(proto.Landmark.Count);
foreach (var landmark in proto.Landmark)
{
landmarks.Add(NormalizedLandmark.CreateFrom(landmark));
}
return new NormalizedLandmarks(landmarks);
}

public override string ToString() => $"{{ \"landmarks\": {Util.Format(landmarks)} }}";
}
}
Loading

0 comments on commit fbecbba

Please sign in to comment.