From a48cb9aa54bbc18deb2a56f5a3ab9477f72d6a4e Mon Sep 17 00:00:00 2001 From: Bijan Haney Date: Thu, 14 Nov 2024 14:34:43 -0500 Subject: [PATCH] add sensor component to module --- Makefile | 3 +- README.md | 19 +++ cmd/module/main.go | 8 +- countclassifier/countclassifier.go | 15 +- countsensor/countsensor.go | 258 +++++++++++++++++++++++++++++ countsensor/countsensor_test.go | 14 ++ go.mod | 2 +- meta.json | 4 + 8 files changed, 314 insertions(+), 9 deletions(-) create mode 100644 countsensor/countsensor.go create mode 100644 countsensor/countsensor_test.go diff --git a/Makefile b/Makefile index 6583ed0..849a819 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,8 @@ test: go test ./countclassifier/ lint: - golangci-lint run ./countclassifier/ + go mod tidy + golangci-lint run module.tar.gz: go build -a -o module ./cmd/module diff --git a/README.md b/README.md index 76a6d64..5e250fe 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ models that summarize information from underlying vision models ## Example Config +### for count-classifier ``` { "count_thresholds": { @@ -17,3 +18,21 @@ models that summarize information from underlying vision models } } ``` + +### for count-sensor +``` +{ + "count_thresholds": { + "high": 1000, + "none": 0, + "low": 10, + "medium": 20 + }, + "detector_name": "vision-1", + "camera_name": "camera-1", + "poll_frequency_hz": 0.5, + "chosen_labels": { + "person": 0.3 + } +} +``` diff --git a/cmd/module/main.go b/cmd/module/main.go index 56c3280..aa7029b 100644 --- a/cmd/module/main.go +++ b/cmd/module/main.go @@ -4,6 +4,7 @@ package main import ( "context" + "go.viam.com/rdk/components/sensor" "go.viam.com/rdk/services/vision" "go.viam.com/rdk/logging" @@ -11,10 +12,11 @@ import ( "go.viam.com/utils" "github.com/viam-modules/vision-summary/countclassifier" + "github.com/viam-modules/vision-summary/countsensor" ) func main() { - utils.ContextualMain(mainWithArgs, module.NewLoggerFromArgs("count-classifier")) + utils.ContextualMain(mainWithArgs, module.NewLoggerFromArgs("vision-summary")) } func mainWithArgs(ctx context.Context, args []string, logger logging.Logger) (err error) { @@ -27,6 +29,10 @@ func mainWithArgs(ctx context.Context, args []string, logger logging.Logger) (er if err != nil { return err } + err = myMod.AddModelFromRegistry(ctx, sensor.API, countsensor.Model) + if err != nil { + return err + } err = myMod.Start(ctx) defer myMod.Close(ctx) diff --git a/countclassifier/countclassifier.go b/countclassifier/countclassifier.go index 73329e2..bb19ca9 100644 --- a/countclassifier/countclassifier.go +++ b/countclassifier/countclassifier.go @@ -140,24 +140,26 @@ func (cc *countcls) Reconfigure(ctx context.Context, deps resource.Dependencies, return nil } -func (cc *countcls) count(dets []objdet.Detection) string { +func (cc *countcls) count(dets []objdet.Detection) (string, []objdet.Detection) { // get the number of boxes with the right label and confidences count := 0 + outDets := []objdet.Detection{} for _, d := range dets { label := strings.ToLower(d.Label()) if conf, ok := cc.labels[label]; ok { if d.Score() >= conf { count++ + outDets = append(outDets, d) } } } // associated the number with the right label for _, thresh := range cc.thresholds { if count <= thresh.UpperBound { - return thresh.Label + return thresh.Label, outDets } } - return OverflowLabel + return OverflowLabel, outDets } // Detections just calls the underlying detector @@ -186,7 +188,7 @@ func (cc *countcls) ClassificationsFromCamera( if err != nil { return nil, errors.Wrapf(err, "error from underlying detector %s", cc.detName) } - label := cc.count(dets) + label, _ := cc.count(dets) c := classification.NewClassification(1.0, label) cls = append(cls, c) return classification.Classifications(cls), nil @@ -201,7 +203,7 @@ func (cc *countcls) Classifications(ctx context.Context, img image.Image, if err != nil { return nil, errors.Wrapf(err, "error from underlying vision model %s", cc.detName) } - label := cc.count(dets) + label, _ := cc.count(dets) c := classification.NewClassification(1.0, label) cls = append(cls, c) return classification.Classifications(cls), nil @@ -232,11 +234,12 @@ func (cc *countcls) CaptureAllFromCamera( if err != nil { return visCapture, errors.Wrapf(err, "error from underlying detector %s", cc.detName) } - label := cc.count(visCapture.Detections) + label, dets := cc.count(visCapture.Detections) cls := []classification.Classification{} c := classification.NewClassification(1.0, label) cls = append(cls, c) visCapture.Classifications = classification.Classifications(cls) + visCapture.Detections = dets return visCapture, nil } diff --git a/countsensor/countsensor.go b/countsensor/countsensor.go new file mode 100644 index 0000000..510f1ba --- /dev/null +++ b/countsensor/countsensor.go @@ -0,0 +1,258 @@ +package countsensor + +import ( + "context" + "sort" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pkg/errors" + + "go.viam.com/rdk/components/sensor" + "go.viam.com/rdk/logging" + "go.viam.com/rdk/resource" + "go.viam.com/rdk/services/vision" + objdet "go.viam.com/rdk/vision/objectdetection" + viamutils "go.viam.com/utils" +) + +const ( + // ModelName is the name of the model + ModelName = "count-sensor" + // OverflowLabel is the label if the counts exceed what was specified by the user + OverflowLabel = "Overflow" + // DefaulMaxFrequency is how often the vision service will poll the camera for a new image + DefaultPollFrequency = 1.0 +) + +var ( + // Model is the resource + Model = resource.NewModel("viam", "vision-summary", ModelName) +) + +func init() { + resource.RegisterComponent(sensor.API, Model, resource.Registration[sensor.Sensor, *Config]{ + Constructor: newCountSensor, + }) +} + +// Config contains names for necessary resources +type Config struct { + DetectorName string `json:"detector_name"` + CameraName string `json:"camera_name"` + ChosenLabels map[string]float64 `json:"chosen_labels"` + CountThresholds map[string]int `json:"count_thresholds"` + PollFrequency float64 `json:"poll_frequency_hz"` +} + +// Validate validates the config and returns implicit dependencies, +// this Validate checks if the camera and detector exist for the module's vision model. +func (cfg *Config) Validate(path string) ([]string, error) { + if cfg.DetectorName == "" { + return nil, errors.New("attribute detector_name cannot be left blank") + } + if cfg.CameraName == "" { + return nil, errors.New("attribute camera_name cannot be left blank") + } + if len(cfg.CountThresholds) == 0 { + return nil, errors.New("attribute count_thresholds is required") + } + if cfg.PollFrequency < 0 { + return nil, errors.New("attribute poll_frequency_hz cannot be negative") + } + testMap := map[int]string{} + for label, v := range cfg.CountThresholds { + if _, ok := testMap[v]; ok { + return nil, errors.Errorf("cannot have two labels for the same threshold in count_thresholds. Threshold value %v appears more than once", v) + } + testMap[v] = label + } + return []string{cfg.DetectorName, cfg.CameraName}, nil +} + +// Bin stores the thresholds that turns counts into labels +type Bin struct { + UpperBound int + Label string +} + +// NewThresholds creates a list of thresholds for labeling counts +func NewThresholds(t map[string]int) []Bin { + // first invert the map, Validate ensures a 1-1 mapping + thresholds := map[int]string{} + for label, val := range t { + thresholds[val] = label + } + out := []Bin{} + keys := []int{} + for k := range thresholds { + keys = append(keys, int(k)) + } + sort.Ints(keys) + for _, key := range keys { + b := Bin{key, thresholds[key]} + out = append(out, b) + } + return out +} + +type counter struct { + resource.Named + cancelFunc context.CancelFunc + cancelContext context.Context + activeBackgroundWorkers sync.WaitGroup + logger logging.Logger + detName string + camName string + detector vision.Service + labels map[string]float64 + thresholds []Bin + frequency float64 + num atomic.Int64 + class atomic.Value +} + +func newCountSensor( + ctx context.Context, + deps resource.Dependencies, + conf resource.Config, + logger logging.Logger) (sensor.Sensor, error) { + cs := &counter{ + Named: conf.ResourceName().AsNamed(), + logger: logger, + } + + if err := cs.Reconfigure(ctx, deps, conf); err != nil { + return nil, err + } + return cs, nil +} + +// Reconfigure resets the underlying detector as well as the thresholds and labels for the count +func (cs *counter) Reconfigure(ctx context.Context, deps resource.Dependencies, conf resource.Config) error { + if cs.cancelFunc != nil { + cs.cancelFunc() + cs.activeBackgroundWorkers.Wait() + } + cancelableCtx, cancel := context.WithCancel(context.Background()) + cs.cancelFunc = cancel + cs.cancelContext = cancelableCtx + + countConf, err := resource.NativeConfig[*Config](conf) + if err != nil { + return errors.Errorf("Could not assert proper config for %s", ModelName) + } + cs.frequency = DefaultPollFrequency + if countConf.PollFrequency > 0 { + cs.frequency = countConf.PollFrequency + } + cs.camName = countConf.CameraName + cs.detName = countConf.DetectorName + cs.detector, err = vision.FromDependencies(deps, countConf.DetectorName) + if err != nil { + return errors.Wrapf(err, "unable to get vision service %v for count classifier", countConf.DetectorName) + } + // put everything in lower case + labels := map[string]float64{} + for l, c := range countConf.ChosenLabels { + labels[strings.ToLower(l)] = c + } + cs.labels = labels + cs.thresholds = NewThresholds(countConf.CountThresholds) + // now start the background thread + cs.activeBackgroundWorkers.Add(1) + viamutils.ManagedGo(func() { + // if you get an error while running just keep trying forever + for { + runErr := cs.run(cs.cancelContext) + if runErr != nil { + cs.logger.Errorw("background thread exited with error", "error", runErr) + continue // keep trying to run, forever + } + return + } + }, func() { + cs.activeBackgroundWorkers.Done() + }) + return nil +} + +func (cs *counter) count(dets []objdet.Detection) (string, int) { + // get the number of boxes with the right label and confidences + count := 0 + for _, d := range dets { + label := strings.ToLower(d.Label()) + if conf, ok := cs.labels[label]; ok { + if d.Score() >= conf { + count++ + } + } + } + // associated the number with the right label + for _, thresh := range cs.thresholds { + if count <= thresh.UpperBound { + return thresh.Label, count + } + } + return OverflowLabel, count +} + +func (cs *counter) run(ctx context.Context) error { + freq := cs.frequency + for { + select { + case <-ctx.Done(): + return nil + default: + start := time.Now() + dets, err := cs.detector.DetectionsFromCamera(ctx, cs.camName, nil) + if err != nil { + return errors.Errorf("vision service error in background thread: %q", err) + } + class, num := cs.count(dets) + cs.class.Store(class) + cs.num.Store(int64(num)) + took := time.Since(start) + waitFor := time.Duration((1/freq)*float64(time.Second)) - took // only poll according to set freq + if waitFor > time.Microsecond { + select { + case <-ctx.Done(): + return nil + case <-time.After(waitFor): + } + } + } + } +} + +// Readings contains both the label and the count of the underlying detector +func (cs *counter) Readings(ctx context.Context, extra map[string]interface{}) (map[string]interface{}, error) { + select { + case <-ctx.Done(): + return nil, errors.Wrap(ctx.Err(), "module might be configuring") + case <-cs.cancelContext.Done(): + return nil, errors.Wrap(cs.cancelContext.Err(), "lost connection with background vision service loop") + default: + className, ok := cs.class.Load().(string) + if !ok { + return nil, errors.Errorf("class string was not a string, but %T", className) + } + countNumber := cs.num.Load() + return map[string]interface{}{ + "label": className, + "count": countNumber, + }, nil + } +} + +// Close does nothing +func (cs *counter) Close(ctx context.Context) error { + return nil +} + +// DoCommand implements nothing +func (cs *counter) DoCommand(ctx context.Context, cmd map[string]interface{}) (map[string]interface{}, error) { + return nil, nil +} diff --git a/countsensor/countsensor_test.go b/countsensor/countsensor_test.go new file mode 100644 index 0000000..bd5231e --- /dev/null +++ b/countsensor/countsensor_test.go @@ -0,0 +1,14 @@ +package countsensor + +import ( + "testing" + + "go.viam.com/test" +) + +func TestValidate(t *testing.T) { + cfg := Config{} + _, err := cfg.Validate("") + test.That(t, err, test.ShouldNotBeNil) + test.That(t, err.Error(), test.ShouldContainSubstring, "detector_name") +} diff --git a/go.mod b/go.mod index 61945a9..e7381eb 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.23 require ( github.com/pkg/errors v0.9.1 go.viam.com/rdk v0.50.0 + go.viam.com/test v1.2.3 go.viam.com/utils v0.1.113 ) @@ -131,7 +132,6 @@ require ( go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect go.viam.com/api v0.1.357 // indirect - go.viam.com/test v1.2.3 // indirect golang.org/x/crypto v0.28.0 // indirect golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e // indirect golang.org/x/image v0.19.0 // indirect diff --git a/meta.json b/meta.json index 3c5c581..a9dc411 100644 --- a/meta.json +++ b/meta.json @@ -12,6 +12,10 @@ { "api": "rdk:service:vision", "model": "viam:vision-summary:count-classifier" + }, + { + "api": "rdk:component:sensor", + "model": "viam:vision-summary:count-sensor" } ], "entrypoint": "module"