From 3316e352c8f537c5d7b374e8ea5e00ded6fe853b Mon Sep 17 00:00:00 2001 From: PSchroedl Date: Wed, 4 Sep 2024 03:28:02 -0700 Subject: [PATCH] feat(ai): add segment anything 2 pipeline image version(#3131) This commit adds support for the new [segment anything 2](https://ai.meta.com/sam2/) pipeline (SAM2) that was added to the AI-worker in [this pull request](https://github.com/livepeer/ai-worker/pull/185). While the new SAM pipeline can also do video segmentation this will be done in a subsequent pull request. Co-authored-by: John | Elite Encoder Co-authored-by: Peter Schroedl Co-authored-by: Rick Staa --- cmd/livepeer/starter/starter.go | 19 +++++- core/ai.go | 1 + core/capabilities.go | 3 + core/orchestrator.go | 8 +++ go.mod | 16 ++--- go.sum | 32 +++++----- server/ai_http.go | 45 +++++++++++++ server/ai_mediaserver.go | 54 ++++++++++++++++ server/ai_process.go | 108 ++++++++++++++++++++++++++++++++ server/rpc.go | 1 + server/rpc_test.go | 6 ++ 11 files changed, 267 insertions(+), 26 deletions(-) diff --git a/cmd/livepeer/starter/starter.go b/cmd/livepeer/starter/starter.go index 5a8bb5bb27..00c23ec3db 100755 --- a/cmd/livepeer/starter/starter.go +++ b/cmd/livepeer/starter/starter.go @@ -1227,10 +1227,11 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) { } } - // If the config contains a URL we call Warm() anyway because AIWorker will just register - // the endpoint for an external container if config.Warm || config.URL != "" { + // Register external container endpoint if URL is provided. endpoint := worker.RunnerEndpoint{URL: config.URL, Token: config.Token} + + // Warm the AI worker container or register the endpoint. if err := n.AIWorker.Warm(ctx, config.Pipeline, config.ModelID, endpoint, config.OptimizationFlags); err != nil { glog.Errorf("Error AI worker warming %v container: %v", config.Pipeline, err) return @@ -1313,6 +1314,20 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) { if *cfg.Network != "offchain" { n.SetBasePriceForCap("default", core.Capability_AudioToText, config.ModelID, autoPrice) } + case "segment-anything-2": + _, ok := capabilityConstraints[core.Capability_SegmentAnything2] + if !ok { + aiCaps = append(aiCaps, core.Capability_SegmentAnything2) + capabilityConstraints[core.Capability_SegmentAnything2] = &core.CapabilityConstraints{ + Models: make(map[string]*core.ModelConstraint), + } + } + + capabilityConstraints[core.Capability_SegmentAnything2].Models[config.ModelID] = modelConstraint + + if *cfg.Network != "offchain" { + n.SetBasePriceForCap("default", core.Capability_SegmentAnything2, config.ModelID, autoPrice) + } } if len(aiCaps) > 0 { diff --git a/core/ai.go b/core/ai.go index e36260fa6b..887e1f8c8c 100644 --- a/core/ai.go +++ b/core/ai.go @@ -22,6 +22,7 @@ type AI interface { ImageToVideo(context.Context, worker.ImageToVideoMultipartRequestBody) (*worker.VideoResponse, error) Upscale(context.Context, worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) AudioToText(context.Context, worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) + SegmentAnything2(context.Context, worker.SegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error Stop(context.Context) error HasCapacity(pipeline, modelID string) bool diff --git a/core/capabilities.go b/core/capabilities.go index b02e137801..734404ca83 100644 --- a/core/capabilities.go +++ b/core/capabilities.go @@ -78,6 +78,7 @@ const ( Capability_ImageToVideo Capability_Upscale Capability_AudioToText + Capability_SegmentAnything2 ) var CapabilityNameLookup = map[Capability]string{ @@ -114,6 +115,7 @@ var CapabilityNameLookup = map[Capability]string{ Capability_ImageToVideo: "Image to video", Capability_Upscale: "Upscale", Capability_AudioToText: "Audio to text", + Capability_SegmentAnything2: "Segment Anything 2", } var CapabilityTestLookup = map[Capability]CapabilityTest{ @@ -204,6 +206,7 @@ func OptionalCapabilities() []Capability { Capability_ImageToVideo, Capability_Upscale, Capability_AudioToText, + Capability_SegmentAnything2, } } diff --git a/core/orchestrator.go b/core/orchestrator.go index 51fb2206e4..ba9f470bd1 100644 --- a/core/orchestrator.go +++ b/core/orchestrator.go @@ -130,6 +130,10 @@ func (orch *orchestrator) AudioToText(ctx context.Context, req worker.AudioToTex return orch.node.AudioToText(ctx, req) } +func (orch *orchestrator) SegmentAnything2(ctx context.Context, req worker.SegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) { + return orch.node.SegmentAnything2(ctx, req) +} + func (orch *orchestrator) ProcessPayment(ctx context.Context, payment net.Payment, manifestID ManifestID) error { if orch.node == nil || orch.node.Recipient == nil { return nil @@ -963,6 +967,10 @@ func (n *LivepeerNode) AudioToText(ctx context.Context, req worker.AudioToTextMu return n.AIWorker.AudioToText(ctx, req) } +func (n *LivepeerNode) SegmentAnything2(ctx context.Context, req worker.SegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) { + return n.AIWorker.SegmentAnything2(ctx, req) +} + func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) { // We might support generating more than one video in the future (i.e. multiple input images/prompts) numVideos := 1 diff --git a/go.mod b/go.mod index bfbe2c4941..b0fd01bd98 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/golang/protobuf v1.5.4 github.com/jaypipes/ghw v0.10.0 github.com/jaypipes/pcidb v1.0.0 - github.com/livepeer/ai-worker v0.1.2 + github.com/livepeer/ai-worker v0.2.0 github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b github.com/livepeer/livepeer-data v0.7.5-0.20231004073737-06f1f383fb18 github.com/livepeer/lpms v0.0.0-20240819180416-f87352959b85 @@ -31,7 +31,7 @@ require ( github.com/urfave/cli v1.22.12 go.opencensus.io v0.24.0 go.uber.org/goleak v1.3.0 - golang.org/x/net v0.25.0 + golang.org/x/net v0.28.0 google.golang.org/grpc v1.65.0 google.golang.org/protobuf v1.34.1 pgregory.net/rapid v1.1.0 @@ -215,15 +215,15 @@ require ( go.uber.org/atomic v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.24.0 // indirect - golang.org/x/crypto v0.23.0 // indirect + golang.org/x/crypto v0.26.0 // indirect golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect - golang.org/x/mod v0.17.0 // indirect + golang.org/x/mod v0.20.0 // indirect golang.org/x/oauth2 v0.20.0 // indirect - golang.org/x/sync v0.7.0 // indirect - golang.org/x/sys v0.20.0 // indirect - golang.org/x/text v0.15.0 // indirect + golang.org/x/sync v0.8.0 // indirect + golang.org/x/sys v0.23.0 // indirect + golang.org/x/text v0.17.0 // indirect golang.org/x/time v0.5.0 // indirect - golang.org/x/tools v0.21.0 // indirect + golang.org/x/tools v0.24.0 // indirect golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect google.golang.org/api v0.125.0 // indirect google.golang.org/appengine v1.6.7 // indirect diff --git a/go.sum b/go.sum index 6bcdec2fed..3299db083c 100644 --- a/go.sum +++ b/go.sum @@ -623,8 +623,8 @@ github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4n github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI= github.com/libp2p/go-openssl v0.1.0 h1:LBkKEcUv6vtZIQLVTegAil8jbNpJErQ9AnT+bWV+Ooo= github.com/libp2p/go-openssl v0.1.0/go.mod h1:OiOxwPpL3n4xlenjx2h7AwSGaFSC/KZvf6gNdOBQMtc= -github.com/livepeer/ai-worker v0.1.2 h1:I73J4zJYad95QE1JFSrqrjKKCTqLHypDcoPq/zZM5aw= -github.com/livepeer/ai-worker v0.1.2/go.mod h1:Xlnb0nFG2VsGeMG9hZmReVQXeFt0Dv28ODiUT2ooyLE= +github.com/livepeer/ai-worker v0.2.0 h1:u6m3nVQnisqWn2nMhgTgKLQby7VDAEp5xmeHt7Res/Y= +github.com/livepeer/ai-worker v0.2.0/go.mod h1:91lMzkzVuwR9kZ0EzXwf+7yVhLaNVmYAfmBtn7t3cQA= github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b h1:VQcnrqtCA2UROp7q8ljkh2XA/u0KRgVv0S1xoUvOweE= github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b/go.mod h1:hwJ5DKhl+pTanFWl+EUpw1H7ukPO/H+MFpgA7jjshzw= github.com/livepeer/joy4 v0.1.2-0.20191121080656-b2fea45cbded h1:ZQlvR5RB4nfT+cOQee+WqmaDOgGtP2oDMhcVvR4L0yA= @@ -1024,8 +1024,8 @@ golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad/go.mod h1:jdWPYTVW3xRLrWP golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220314234659-1baeb1ce4c0b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= +golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -1060,8 +1060,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= -golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0= +golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -1108,8 +1108,8 @@ golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= -golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= -golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= +golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1132,8 +1132,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -1210,8 +1210,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM= +golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -1226,8 +1226,8 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= +golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1288,8 +1288,8 @@ golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw= -golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= +golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/server/ai_http.go b/server/ai_http.go index 6b007e0c13..a11164bc9b 100644 --- a/server/ai_http.go +++ b/server/ai_http.go @@ -44,6 +44,7 @@ func startAIServer(lp lphttp) error { lp.transRPC.Handle("/image-to-video", oapiReqValidator(lp.ImageToVideo())) lp.transRPC.Handle("/upscale", oapiReqValidator(lp.Upscale())) lp.transRPC.Handle("/audio-to-text", oapiReqValidator(lp.AudioToText())) + lp.transRPC.Handle("/segment-anything-2", oapiReqValidator(lp.SegmentAnything2())) return nil } @@ -157,6 +158,29 @@ func (h *lphttp) AudioToText() http.Handler { }) } +func (h *lphttp) SegmentAnything2() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + orch := h.orchestrator + + remoteAddr := getRemoteAddr(r) + ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr) + + multiRdr, err := r.MultipartReader() + if err != nil { + respondWithError(w, err.Error(), http.StatusBadRequest) + return + } + + var req worker.SegmentAnything2MultipartRequestBody + if err := runtime.BindMultipart(&req, *multiRdr); err != nil { + respondWithError(w, err.Error(), http.StatusInternalServerError) + return + } + + handleAIRequest(ctx, w, r, orch, req) + }) +} + func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, orch Orchestrator, req interface{}) { payment, err := getPayment(r.Header.Get(paymentHeader)) if err != nil { @@ -281,6 +305,25 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request return } outPixels *= 1000 // Convert to milliseconds + case worker.SegmentAnything2MultipartRequestBody: + pipeline = "segment-anything-2" + cap = core.Capability_SegmentAnything2 + modelID = *v.ModelId + submitFn = func(ctx context.Context) (interface{}, error) { + return orch.SegmentAnything2(ctx, v) + } + + imageRdr, err := v.Image.Reader() + if err != nil { + respondWithError(w, err.Error(), http.StatusBadRequest) + return + } + config, _, err := image.DecodeConfig(imageRdr) + if err != nil { + respondWithError(w, err.Error(), http.StatusBadRequest) + return + } + outPixels = int64(config.Height) * int64(config.Width) default: respondWithError(w, "Unknown request type", http.StatusBadRequest) return @@ -352,6 +395,8 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request if err == nil { latencyScore = CalculateAudioToTextLatencyScore(took, durationSeconds) } + case worker.SegmentAnything2MultipartRequestBody: + latencyScore = CalculateSegmentAnything2LatencyScore(took, outPixels) } var pricePerAIUnit float64 diff --git a/server/ai_mediaserver.go b/server/ai_mediaserver.go index 29d33b9fed..a600f3b176 100644 --- a/server/ai_mediaserver.go +++ b/server/ai_mediaserver.go @@ -69,6 +69,7 @@ func startAIMediaServer(ls *LivepeerServer) error { ls.HTTPMux.Handle("/image-to-video", oapiReqValidator(ls.ImageToVideo())) ls.HTTPMux.Handle("/image-to-video/result", ls.ImageToVideoResult()) ls.HTTPMux.Handle("/audio-to-text", oapiReqValidator(ls.AudioToText())) + ls.HTTPMux.Handle("/segment-anything-2", oapiReqValidator(ls.SegmentAnything2())) return nil } @@ -374,6 +375,59 @@ func (ls *LivepeerServer) AudioToText() http.Handler { }) } +func (ls *LivepeerServer) SegmentAnything2() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + remoteAddr := getRemoteAddr(r) + ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr) + requestID := string(core.RandomManifestID()) + ctx = clog.AddVal(ctx, "request_id", requestID) + + multiRdr, err := r.MultipartReader() + if err != nil { + respondJsonError(ctx, w, err, http.StatusBadRequest) + return + } + + var req worker.SegmentAnything2MultipartRequestBody + if err := runtime.BindMultipart(&req, *multiRdr); err != nil { + respondJsonError(ctx, w, err, http.StatusBadRequest) + return + } + + clog.V(common.VERBOSE).Infof(ctx, "Received SegmentAnything2 request; image_size=%v model_id=%v", req.Image.FileSize(), *req.ModelId) + + params := aiRequestParams{ + node: ls.LivepeerNode, + os: drivers.NodeStorage.NewSession(requestID), + sessManager: ls.AISessionManager, + } + + start := time.Now() + resp, err := processSegmentAnything2(ctx, params, req) + if err != nil { + var serviceUnavailableErr *ServiceUnavailableError + var badRequestErr *BadRequestError + if errors.As(err, &serviceUnavailableErr) { + respondJsonError(ctx, w, err, http.StatusServiceUnavailable) + return + } + if errors.As(err, &badRequestErr) { + respondJsonError(ctx, w, err, http.StatusBadRequest) + return + } + respondJsonError(ctx, w, err, http.StatusInternalServerError) + return + } + + took := time.Since(start) + clog.V(common.VERBOSE).Infof(ctx, "Processed SegmentAnything2 request model_id=%v took=%v", *req.ModelId, took) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(resp) + }) +} + func (ls *LivepeerServer) ImageToVideoResult() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { remoteAddr := getRemoteAddr(r) diff --git a/server/ai_process.go b/server/ai_process.go index a8e884f68d..d69beb0610 100644 --- a/server/ai_process.go +++ b/server/ai_process.go @@ -31,6 +31,7 @@ const defaultImageToImageModelID = "stabilityai/sdxl-turbo" const defaultImageToVideoModelID = "stabilityai/stable-video-diffusion-img2vid-xt" const defaultUpscaleModelID = "stabilityai/stable-diffusion-x4-upscaler" const defaultAudioToTextModelID = "openai/whisper-large-v3" +const defaultSegmentAnything2ModelID = "facebook/sam2-hiera-large" type ServiceUnavailableError struct { err error @@ -586,6 +587,104 @@ func submitUpscale(ctx context.Context, params aiRequestParams, sess *AISession, return resp.JSON200, nil } +// CalculateSegmentAnything2LatencyScore computes the time taken per pixel for a segment-anything-2 request. +func CalculateSegmentAnything2LatencyScore(took time.Duration, outPixels int64) float64 { + if outPixels <= 0 { + return 0 + } + + return took.Seconds() / float64(outPixels) +} + +func processSegmentAnything2(ctx context.Context, params aiRequestParams, req worker.SegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) { + resp, err := processAIRequest(ctx, params, req) + if err != nil { + return nil, err + } + + txtResp := resp.(*worker.MasksResponse) + + return txtResp, nil +} + +func submitSegmentAnything2(ctx context.Context, params aiRequestParams, sess *AISession, req worker.BodySegmentAnything2SegmentAnything2Post) (*worker.MasksResponse, error) { + var buf bytes.Buffer + mw, err := worker.NewSegmentAnything2MultipartWriter(&buf, req) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "segment anything 2", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + client, err := worker.NewClientWithResponses(sess.Transcoder(), worker.WithHTTPClient(httpClient)) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "segment anything 2", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + imageRdr, err := req.Image.Reader() + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "segment anything 2", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + config, _, err := image.DecodeConfig(imageRdr) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "segment anything 2", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + outPixels := int64(config.Height) * int64(config.Width) + + setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, outPixels) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "segment anything 2", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + defer completeBalanceUpdate(sess.BroadcastSession, balUpdate) + + start := time.Now() + resp, err := client.SegmentAnything2WithBodyWithResponse(ctx, mw.FormDataContentType(), &buf, setHeaders) + took := time.Since(start) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "segment anything 2", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + if resp.JSON200 == nil { + // TODO: Replace trim newline with better error spec from O + return nil, errors.New(strings.TrimSuffix(string(resp.Body), "\n")) + } + + // We treat a response as "receiving change" where the change is the difference between the credit and debit for the update + if balUpdate != nil { + balUpdate.Status = ReceivedChange + } + + // TODO: Refine this rough estimate in future iterations + sess.LatencyScore = CalculateSegmentAnything2LatencyScore(took, outPixels) + + if monitor.Enabled { + var pricePerAIUnit float64 + if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 { + pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit) + } + + monitor.AIRequestFinished(ctx, "segment anything 2", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo) + } + + return resp.JSON200, nil +} + // CalculateAudioToTextLatencyScore computes the time taken per second of audio for an audio-to-text request. func CalculateAudioToTextLatencyScore(took time.Duration, durationSeconds int64) float64 { if durationSeconds <= 0 { @@ -744,6 +843,15 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitAudioToText(ctx, params, sess, v) } + case worker.SegmentAnything2MultipartRequestBody: + cap = core.Capability_SegmentAnything2 + modelID = defaultSegmentAnything2ModelID + if v.ModelId != nil { + modelID = *v.ModelId + } + submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { + return submitSegmentAnything2(ctx, params, sess, v) + } default: return nil, fmt.Errorf("unsupported request type %T", req) } diff --git a/server/rpc.go b/server/rpc.go index 0c8b9f8066..db797b9eba 100644 --- a/server/rpc.go +++ b/server/rpc.go @@ -68,6 +68,7 @@ type Orchestrator interface { ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) Upscale(ctx context.Context, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) AudioToText(ctx context.Context, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) + SegmentAnything2(ctx context.Context, req worker.SegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) } // Balance describes methods for a session's balance maintenance diff --git a/server/rpc_test.go b/server/rpc_test.go index 0ede88b261..197e2baa6b 100644 --- a/server/rpc_test.go +++ b/server/rpc_test.go @@ -202,6 +202,9 @@ func (r *stubOrchestrator) Upscale(ctx context.Context, req worker.UpscaleMultip func (r *stubOrchestrator) AudioToText(ctx context.Context, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) { return nil, nil } +func (r *stubOrchestrator) SegmentAnything2(ctx context.Context, req worker.SegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) { + return nil, nil +} func (r *stubOrchestrator) CheckAICapacity(pipeline, modelID string) bool { return true } @@ -1385,6 +1388,9 @@ func (r *mockOrchestrator) Upscale(ctx context.Context, req worker.UpscaleMultip func (r *mockOrchestrator) AudioToText(ctx context.Context, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) { return nil, nil } +func (r *mockOrchestrator) SegmentAnything2(ctx context.Context, req worker.SegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) { + return nil, nil +} func (r *mockOrchestrator) CheckAICapacity(pipeline, modelID string) bool { return true }