Skip to content

Commit

Permalink
Fixed ONNX support
Browse files Browse the repository at this point in the history
  • Loading branch information
chiefMarlin committed Aug 30, 2023
1 parent f23b688 commit ef5ad08
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 121 deletions.
144 changes: 76 additions & 68 deletions firescrew.go
Original file line number Diff line number Diff line change
Expand Up @@ -479,10 +479,16 @@ func getStreamInfo(rtspURL string) (StreamInfo, error) {

func CheckFFmpegAndFFprobe() (bool, error) {
if _, err := exec.LookPath("ffmpeg"); err != nil {
// Print PATH
path := os.Getenv("PATH")
Log("error", fmt.Sprintf("PATH: %s", path))
return false, fmt.Errorf("ffmpeg binary not found: %w", err)
}

if _, err := exec.LookPath("ffprobe"); err != nil {
// Print PATH
path := os.Getenv("PATH")
Log("error", fmt.Sprintf("PATH: %s", path))
return false, fmt.Errorf("ffprobe binary not found: %w", err)
}

Expand Down Expand Up @@ -659,7 +665,7 @@ func recodeToMP4(inputFile string) (string, error) {
var cmd *exec.Cmd
// Create the FFmpeg command
if globalConfig.Video.OnlyRemuxMp4 {
cmd = exec.Command("ffmpeg", "-i", inputFile, "-c", "copy", outputFile)
cmd = exec.Command("ffmpeg", "-i", inputFile, "-c", "copy", "-hls_segment_type", "fmp4", outputFile)
} else {
cmd = exec.Command("ffmpeg", "-i", inputFile, "-c:v", "libx264", "-c:a", "aac", outputFile)
}
Expand Down Expand Up @@ -706,7 +712,11 @@ func main() {
fmt.Fprintf(os.Stderr, ("Usage: firescrew -s [path] [addr]\n"))
return
}
firescrewServe.Serve(os.Args[2], os.Args[3])
err := firescrewServe.Serve(os.Args[2], os.Args[3])
if err != nil {
fmt.Fprintf(os.Stderr, "Error starting server: %v\n", err)
return
}
case "-v", "--version", "v":
// Print version
fmt.Println(Version)
Expand All @@ -733,7 +743,7 @@ func main() {
// Check if ffmpeg/ffprobe binaries are available
_, err := CheckFFmpegAndFFprobe()
if err != nil {
Log("error", "Unable to find ffmpeg/ffprobe binaries. Please install them")
Log("error", fmt.Sprintf("Unable to find ffmpeg/ffprobe binaries. Please install them: %s", err))
os.Exit(2)
}

Expand Down Expand Up @@ -997,80 +1007,78 @@ func main() {
runtimeConfig.MotionMutex.Unlock()
}

if globalConfig.Motion.NetworkObjectDetectServer != "" {
// Python motion detection
if predictFrameCounter%everyNthFrame == 0 {
if predictFrameCounter > 10000 {
predictFrameCounter = 0
}
// Only run this on every 5th frame
if msg.Frame != nil {

var predict []Prediction
var err error
// If globalConfig.Motion.OnnxModel is blank run this
// Send data to objectPredict
if globalConfig.Motion.OnnxModel == "" {
predict, err = objectPredict(msg.Frame)
if err != nil {
Log("error", fmt.Sprintf("Error running objectPredict: %v", err))
return
}
} else {
timer := time.Now()
objects, err := runtimeConfig.ObjectPredictClient.Predict(msg.Frame)
if err != nil {
fmt.Println("Cannot predict:", err)
return
}
// Python motion detection
if predictFrameCounter%everyNthFrame == 0 {
if predictFrameCounter > 10000 {
predictFrameCounter = 0
}
// Only run this on every 5th frame
if msg.Frame != nil {

var predict []Prediction
var err error
// If globalConfig.Motion.OnnxModel is blank run this
// Send data to objectPredict
if globalConfig.Motion.OnnxModel == "" {
predict, err = objectPredict(msg.Frame)
if err != nil {
Log("error", fmt.Sprintf("Error running objectPredict: %v", err))
return
}
} else {
timer := time.Now()
objects, err := runtimeConfig.ObjectPredictClient.Predict(msg.Frame)
if err != nil {
fmt.Println("Cannot predict:", err)
return
}

// Detect took
took := time.Since(timer).Milliseconds()

for _, object := range objects {
pred := Prediction{
Object: object.ClassID,
ClassName: object.ClassName,
Box: []float32{object.X1, object.Y1, object.X2, object.Y2},
Top: int(object.Y1),
Bottom: int(object.Y2),
Left: int(object.X1),
Right: int(object.X2),
Confidence: object.Confidence,
Took: float64(took),
}
predict = append(predict, pred)
// Detect took
took := time.Since(timer).Milliseconds()

for _, object := range objects {
pred := Prediction{
Object: object.ClassID,
ClassName: object.ClassName,
Box: []float32{object.X1, object.Y1, object.X2, object.Y2},
Top: int(object.Y1),
Bottom: int(object.Y2),
Left: int(object.X1),
Right: int(object.X2),
Confidence: object.Confidence,
Took: float64(took),
}

predict = append(predict, pred)
}
calcInferenceStats(predict) // Calculate inference stats

if len(predict) > 0 {
// Notify in realtime about detected objects
type Event struct {
Type string `json:"type"`
Timestamp time.Time `json:"timestamp"`
PredictedObjects []Prediction `json:"predicted_objects"`
}

eventRaw := Event{
Type: "objects_predicted",
Timestamp: time.Now(),
PredictedObjects: predict,
}
eventJson, err := json.Marshal(eventRaw)
if err != nil {
Log("error", fmt.Sprintf("Error marshalling object_predicted event: %v", err))
return
}
eventHandler("objects_detected", eventJson)
}
calcInferenceStats(predict) // Calculate inference stats

if len(predict) > 0 {
// Notify in realtime about detected objects
type Event struct {
Type string `json:"type"`
Timestamp time.Time `json:"timestamp"`
PredictedObjects []Prediction `json:"predicted_objects"`
}

performDetectionOnObject(rgba, predict)
eventRaw := Event{
Type: "objects_predicted",
Timestamp: time.Now(),
PredictedObjects: predict,
}
eventJson, err := json.Marshal(eventRaw)
if err != nil {
Log("error", fmt.Sprintf("Error marshalling object_predicted event: %v", err))
return
}
eventHandler("objects_detected", eventJson)
}

performDetectionOnObject(rgba, predict)
}
predictFrameCounter++
}
predictFrameCounter++

if globalConfig.EnableOutputStream {
streamImage(rgba, stream) // Stream the image to the web
Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ module github.com/8ff/firescrew
go 1.21.0

require (
github.com/8ff/gonnx v0.0.0-20230829153731-be8bf833e043
github.com/8ff/prettyTimer v0.0.0-20230829162136-99737bc17c1e
github.com/8ff/onnxruntime_go v0.0.0-20230830191505-2b14a218432e
github.com/8ff/prettyTimer v0.0.0-20230830184900-c96793faf613
github.com/8ff/tuna v0.0.0-20230811173825-52af88c52674
github.com/asticode/go-astits v1.13.0
github.com/bluenviron/gortsplib/v3 v3.10.0
Expand Down
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
github.com/8ff/gonnx v0.0.0-20230829153731-be8bf833e043 h1:vMooYcj/8AQmjoeWufnrX67GpHJadI/WX4fFjDqmQZc=
github.com/8ff/gonnx v0.0.0-20230829153731-be8bf833e043/go.mod h1:OvWCkxwg/6UPaf4WtAlfomCkjxoJKomjHwX4+qvr92k=
github.com/8ff/prettyTimer v0.0.0-20230829162136-99737bc17c1e h1:c4wqswZSJIuF35aETvcSIiRTWNb0/nKQWSi+dr5gR6Y=
github.com/8ff/prettyTimer v0.0.0-20230829162136-99737bc17c1e/go.mod h1:iQAVuoCXBrrxT875kd25GCALLf+ulTOt/mCikuQs2j8=
github.com/8ff/onnxruntime_go v0.0.0-20230830191505-2b14a218432e h1:471ee5xiyPAM4P5bUS03t462ZZk0tAKVVzJ30DBzEpA=
github.com/8ff/onnxruntime_go v0.0.0-20230830191505-2b14a218432e/go.mod h1:0OvBfqX5NGJJIMwZhLtZ9udqfyBpcXObhCbzWn4Iq6M=
github.com/8ff/prettyTimer v0.0.0-20230830184900-c96793faf613 h1:mIPSzE+OciNlYwNQs1qi7GoKRI3SKGKrVsGnap20iqQ=
github.com/8ff/prettyTimer v0.0.0-20230830184900-c96793faf613/go.mod h1:iQAVuoCXBrrxT875kd25GCALLf+ulTOt/mCikuQs2j8=
github.com/8ff/tuna v0.0.0-20230811173825-52af88c52674 h1:9L0K8szFUXJ0V71/I5YJeCmOXVhgU0+v0+9nf8mHqG0=
github.com/8ff/tuna v0.0.0-20230811173825-52af88c52674/go.mod h1:brULTDkAKe2Ut39W20RPVcc0M6MhaHLTS/oGJiC5tVs=
github.com/asticode/go-astikit v0.30.0 h1:DkBkRQRIxYcknlaU7W7ksNfn4gMFsB0tqMJflxkRsZA=
Expand Down
124 changes: 77 additions & 47 deletions pkg/objectPredict/objectPredict.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
"sort"
"sync"

"github.com/8ff/gonnx"
onnx "github.com/8ff/onnxruntime_go"
"github.com/nfnt/resize"
)

Expand All @@ -43,11 +43,17 @@ type Client struct {
ModelHeight int
LibPath string
LibExtractPath string
Session *gonnx.SessionV3
RuntimeSession ModelSession
EnableCuda bool
EnableCoreMl bool
}

type ModelSession struct {
Session *onnx.AdvancedSession
Input *onnx.Tensor[float32]
Output *onnx.Tensor[float32]
}

type Object struct {
ClassName string
ClassID int
Expand Down Expand Up @@ -144,78 +150,86 @@ func Init(opt Config) (*Client, error) {
if err != nil {
return &Client{}, err
}
client.Session = ses
client.RuntimeSession = ses
return &client, nil
}

func (c *Client) Predict(imgRaw image.Image) ([]Object, error) {
input, img_width, img_height := c.prepareInput(imgRaw)
inputShape := gonnx.NewShape(1, 3, 640, 640)
inputTensor, err := gonnx.NewTensor(inputShape, input)
if err != nil {
return nil, err
}

output, e := c.Session.Run([]*gonnx.TensorWithType{{
Tensor: inputTensor,
TensorType: "float32",
}})
if e != nil {
return nil, fmt.Errorf("error running session: %w", e)
}

var allFloat32Data []float32
inputTensor := c.RuntimeSession.Input.GetData()

for i := range output {
data := output[i].GetData()
float32Data, ok := data.([]float32)
if !ok {
continue
}
allFloat32Data = append(allFloat32Data, float32Data...)
// inTensor := modelSes.Input.GetData()
copy(inputTensor, input)
err := c.RuntimeSession.Session.Run()
if err != nil {
return nil, fmt.Errorf("error running session: %w", err)
}

objects := processOutput(allFloat32Data, img_width, img_height)

objects := processOutput(c.RuntimeSession.Output.GetData(), img_width, img_height)
return objects, nil
}

func (c *Client) initSession() (*gonnx.SessionV3, error) {
func (c *Client) initSession() (ModelSession, error) {
// Change dir to libExtractPath and then change back
cwd, err := os.Getwd()
if err != nil {
return nil, err
return ModelSession{}, err
}
err = os.Chdir(c.LibExtractPath) // Change dir to libExtractPath
if err != nil {
return nil, err
return ModelSession{}, err
}

gonnx.SetSharedLibraryPath(c.LibPath) // Set libPath
err = gonnx.InitializeEnvironment()
onnx.SetSharedLibraryPath(c.LibPath) // Set libPath
err = onnx.InitializeEnvironment()
if err != nil {
return nil, err
return ModelSession{}, err
}
err = os.Chdir(cwd) // Change back to cwd
if err != nil {
return nil, err
return ModelSession{}, err
}

var opts string
switch {
case c.EnableCuda:
opts = "cuda"
case c.EnableCoreMl:
opts = "coreml"
default:
opts = ""
options, e := onnx.NewSessionOptions()
if e != nil {
return ModelSession{}, fmt.Errorf("error creating session options: %w", e)
}
defer options.Destroy()

session, e := gonnx.NewSessionV3(c.ModelPath, opts)
if e != nil {
return nil, fmt.Errorf("error creating session: %w", e)
if c.EnableCoreMl { // If CoreML is enabled, append the CoreML execution provider
e = options.AppendExecutionProviderCoreML(0)
if e != nil {
options.Destroy()
return ModelSession{}, err
}
defer options.Destroy()
}

// Create and prepare a blank image
blankImage := CreateBlankImage(c.ModelWidth, c.ModelHeight)
input, _, _ := c.prepareInput(blankImage)

inputShape := onnx.NewShape(1, 3, 640, 640)
inputTensor, err := onnx.NewTensor(inputShape, input)
if err != nil {
return ModelSession{}, fmt.Errorf("error creating input tensor: %w", err)
}

outputShape := onnx.NewShape(1, 84, 8400)
outputTensor, err := onnx.NewEmptyTensor[float32](outputShape)
if err != nil {
return ModelSession{}, fmt.Errorf("error creating output tensor: %w", err)
}
return session, nil

session, err := onnx.NewAdvancedSession(c.ModelPath,
[]string{"images"}, []string{"output0"},
[]onnx.ArbitraryTensor{inputTensor}, []onnx.ArbitraryTensor{outputTensor}, options)

return ModelSession{
Session: session,
Input: inputTensor,
Output: outputTensor,
}, nil
}

func (c *Client) prepareInput(imageObj image.Image) ([]float32, int64, int64) {
Expand Down Expand Up @@ -543,5 +557,21 @@ func (c *Client) Close() {
}
}

c.Session.Destroy() // Cleanup session
c.RuntimeSession.Session.Destroy() // Cleanup session
c.RuntimeSession.Input.Destroy() // Cleanup input
c.RuntimeSession.Output.Destroy() // Cleanup output
}

func CreateBlankImage(width, height int) image.Image {
// Create a new blank image with the given dimensions
img := image.NewRGBA(image.Rect(0, 0, width, height))

// Fill the image with the background color
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
img.Set(x, y, color.RGBA{0, 0, 0, 0})
}
}

return img
}

0 comments on commit ef5ad08

Please sign in to comment.