Skip to content

Commit

Permalink
add sensor component to module
Browse files Browse the repository at this point in the history
  • Loading branch information
bhaney committed Nov 14, 2024
1 parent fc27667 commit a48cb9a
Show file tree
Hide file tree
Showing 8 changed files with 314 additions and 9 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ test:
go test ./countclassifier/

lint:
golangci-lint run ./countclassifier/
go mod tidy
golangci-lint run

module.tar.gz:
go build -a -o module ./cmd/module
Expand Down
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ models that summarize information from underlying vision models

## Example Config

### for count-classifier
```
{
"count_thresholds": {
Expand All @@ -17,3 +18,21 @@ models that summarize information from underlying vision models
}
}
```

### for count-sensor
```
{
"count_thresholds": {
"high": 1000,
"none": 0,
"low": 10,
"medium": 20
},
"detector_name": "vision-1",
"camera_name": "camera-1",
"poll_frequency_hz": 0.5,
"chosen_labels": {
"person": 0.3
}
}
```
8 changes: 7 additions & 1 deletion cmd/module/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@ package main
import (
"context"

"go.viam.com/rdk/components/sensor"
"go.viam.com/rdk/services/vision"

"go.viam.com/rdk/logging"
"go.viam.com/rdk/module"
"go.viam.com/utils"

"github.com/viam-modules/vision-summary/countclassifier"
"github.com/viam-modules/vision-summary/countsensor"
)

func main() {
utils.ContextualMain(mainWithArgs, module.NewLoggerFromArgs("count-classifier"))
utils.ContextualMain(mainWithArgs, module.NewLoggerFromArgs("vision-summary"))
}

func mainWithArgs(ctx context.Context, args []string, logger logging.Logger) (err error) {
Expand All @@ -27,6 +29,10 @@ func mainWithArgs(ctx context.Context, args []string, logger logging.Logger) (er
if err != nil {
return err
}
err = myMod.AddModelFromRegistry(ctx, sensor.API, countsensor.Model)
if err != nil {
return err
}

err = myMod.Start(ctx)
defer myMod.Close(ctx)
Expand Down
15 changes: 9 additions & 6 deletions countclassifier/countclassifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,24 +140,26 @@ func (cc *countcls) Reconfigure(ctx context.Context, deps resource.Dependencies,
return nil
}

func (cc *countcls) count(dets []objdet.Detection) string {
func (cc *countcls) count(dets []objdet.Detection) (string, []objdet.Detection) {
// get the number of boxes with the right label and confidences
count := 0
outDets := []objdet.Detection{}
for _, d := range dets {
label := strings.ToLower(d.Label())
if conf, ok := cc.labels[label]; ok {
if d.Score() >= conf {
count++
outDets = append(outDets, d)
}
}
}
// associated the number with the right label
for _, thresh := range cc.thresholds {
if count <= thresh.UpperBound {
return thresh.Label
return thresh.Label, outDets
}
}
return OverflowLabel
return OverflowLabel, outDets
}

// Detections just calls the underlying detector
Expand Down Expand Up @@ -186,7 +188,7 @@ func (cc *countcls) ClassificationsFromCamera(
if err != nil {
return nil, errors.Wrapf(err, "error from underlying detector %s", cc.detName)
}
label := cc.count(dets)
label, _ := cc.count(dets)
c := classification.NewClassification(1.0, label)
cls = append(cls, c)
return classification.Classifications(cls), nil
Expand All @@ -201,7 +203,7 @@ func (cc *countcls) Classifications(ctx context.Context, img image.Image,
if err != nil {
return nil, errors.Wrapf(err, "error from underlying vision model %s", cc.detName)
}
label := cc.count(dets)
label, _ := cc.count(dets)
c := classification.NewClassification(1.0, label)
cls = append(cls, c)
return classification.Classifications(cls), nil
Expand Down Expand Up @@ -232,11 +234,12 @@ func (cc *countcls) CaptureAllFromCamera(
if err != nil {
return visCapture, errors.Wrapf(err, "error from underlying detector %s", cc.detName)
}
label := cc.count(visCapture.Detections)
label, dets := cc.count(visCapture.Detections)
cls := []classification.Classification{}
c := classification.NewClassification(1.0, label)
cls = append(cls, c)
visCapture.Classifications = classification.Classifications(cls)
visCapture.Detections = dets
return visCapture, nil
}

Expand Down
258 changes: 258 additions & 0 deletions countsensor/countsensor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
package countsensor

import (
"context"
"sort"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/pkg/errors"

"go.viam.com/rdk/components/sensor"
"go.viam.com/rdk/logging"
"go.viam.com/rdk/resource"
"go.viam.com/rdk/services/vision"
objdet "go.viam.com/rdk/vision/objectdetection"
viamutils "go.viam.com/utils"
)

const (
// ModelName is the name of the model
ModelName = "count-sensor"
// OverflowLabel is the label if the counts exceed what was specified by the user
OverflowLabel = "Overflow"
// DefaulMaxFrequency is how often the vision service will poll the camera for a new image
DefaultPollFrequency = 1.0
)

var (
// Model is the resource
Model = resource.NewModel("viam", "vision-summary", ModelName)
)

func init() {
resource.RegisterComponent(sensor.API, Model, resource.Registration[sensor.Sensor, *Config]{
Constructor: newCountSensor,
})
}

// Config contains names for necessary resources
type Config struct {
DetectorName string `json:"detector_name"`
CameraName string `json:"camera_name"`
ChosenLabels map[string]float64 `json:"chosen_labels"`
CountThresholds map[string]int `json:"count_thresholds"`
PollFrequency float64 `json:"poll_frequency_hz"`
}

// Validate validates the config and returns implicit dependencies,
// this Validate checks if the camera and detector exist for the module's vision model.
func (cfg *Config) Validate(path string) ([]string, error) {
if cfg.DetectorName == "" {
return nil, errors.New("attribute detector_name cannot be left blank")
}
if cfg.CameraName == "" {
return nil, errors.New("attribute camera_name cannot be left blank")
}
if len(cfg.CountThresholds) == 0 {
return nil, errors.New("attribute count_thresholds is required")
}
if cfg.PollFrequency < 0 {
return nil, errors.New("attribute poll_frequency_hz cannot be negative")
}
testMap := map[int]string{}
for label, v := range cfg.CountThresholds {
if _, ok := testMap[v]; ok {
return nil, errors.Errorf("cannot have two labels for the same threshold in count_thresholds. Threshold value %v appears more than once", v)
}
testMap[v] = label
}
return []string{cfg.DetectorName, cfg.CameraName}, nil
}

// Bin stores the thresholds that turns counts into labels
type Bin struct {
UpperBound int
Label string
}

// NewThresholds creates a list of thresholds for labeling counts
func NewThresholds(t map[string]int) []Bin {
// first invert the map, Validate ensures a 1-1 mapping
thresholds := map[int]string{}
for label, val := range t {
thresholds[val] = label
}
out := []Bin{}
keys := []int{}
for k := range thresholds {
keys = append(keys, int(k))
}
sort.Ints(keys)
for _, key := range keys {
b := Bin{key, thresholds[key]}
out = append(out, b)
}
return out
}

type counter struct {
resource.Named
cancelFunc context.CancelFunc
cancelContext context.Context
activeBackgroundWorkers sync.WaitGroup
logger logging.Logger
detName string
camName string
detector vision.Service
labels map[string]float64
thresholds []Bin
frequency float64
num atomic.Int64
class atomic.Value
}

func newCountSensor(
ctx context.Context,
deps resource.Dependencies,
conf resource.Config,
logger logging.Logger) (sensor.Sensor, error) {
cs := &counter{
Named: conf.ResourceName().AsNamed(),
logger: logger,
}

if err := cs.Reconfigure(ctx, deps, conf); err != nil {
return nil, err
}
return cs, nil
}

// Reconfigure resets the underlying detector as well as the thresholds and labels for the count
func (cs *counter) Reconfigure(ctx context.Context, deps resource.Dependencies, conf resource.Config) error {
if cs.cancelFunc != nil {
cs.cancelFunc()
cs.activeBackgroundWorkers.Wait()
}
cancelableCtx, cancel := context.WithCancel(context.Background())
cs.cancelFunc = cancel
cs.cancelContext = cancelableCtx

countConf, err := resource.NativeConfig[*Config](conf)
if err != nil {
return errors.Errorf("Could not assert proper config for %s", ModelName)
}
cs.frequency = DefaultPollFrequency
if countConf.PollFrequency > 0 {
cs.frequency = countConf.PollFrequency
}
cs.camName = countConf.CameraName
cs.detName = countConf.DetectorName
cs.detector, err = vision.FromDependencies(deps, countConf.DetectorName)
if err != nil {
return errors.Wrapf(err, "unable to get vision service %v for count classifier", countConf.DetectorName)
}
// put everything in lower case
labels := map[string]float64{}
for l, c := range countConf.ChosenLabels {
labels[strings.ToLower(l)] = c
}
cs.labels = labels
cs.thresholds = NewThresholds(countConf.CountThresholds)
// now start the background thread
cs.activeBackgroundWorkers.Add(1)
viamutils.ManagedGo(func() {
// if you get an error while running just keep trying forever
for {
runErr := cs.run(cs.cancelContext)
if runErr != nil {
cs.logger.Errorw("background thread exited with error", "error", runErr)
continue // keep trying to run, forever
}
return
}
}, func() {
cs.activeBackgroundWorkers.Done()
})
return nil
}

func (cs *counter) count(dets []objdet.Detection) (string, int) {
// get the number of boxes with the right label and confidences
count := 0
for _, d := range dets {
label := strings.ToLower(d.Label())
if conf, ok := cs.labels[label]; ok {
if d.Score() >= conf {
count++
}
}
}
// associated the number with the right label
for _, thresh := range cs.thresholds {
if count <= thresh.UpperBound {
return thresh.Label, count
}
}
return OverflowLabel, count
}

func (cs *counter) run(ctx context.Context) error {
freq := cs.frequency
for {
select {
case <-ctx.Done():
return nil
default:
start := time.Now()
dets, err := cs.detector.DetectionsFromCamera(ctx, cs.camName, nil)
if err != nil {
return errors.Errorf("vision service error in background thread: %q", err)
}
class, num := cs.count(dets)
cs.class.Store(class)
cs.num.Store(int64(num))
took := time.Since(start)
waitFor := time.Duration((1/freq)*float64(time.Second)) - took // only poll according to set freq
if waitFor > time.Microsecond {
select {
case <-ctx.Done():
return nil
case <-time.After(waitFor):
}
}
}
}
}

// Readings contains both the label and the count of the underlying detector
func (cs *counter) Readings(ctx context.Context, extra map[string]interface{}) (map[string]interface{}, error) {
select {
case <-ctx.Done():
return nil, errors.Wrap(ctx.Err(), "module might be configuring")
case <-cs.cancelContext.Done():
return nil, errors.Wrap(cs.cancelContext.Err(), "lost connection with background vision service loop")
default:
className, ok := cs.class.Load().(string)
if !ok {
return nil, errors.Errorf("class string was not a string, but %T", className)
}
countNumber := cs.num.Load()
return map[string]interface{}{
"label": className,
"count": countNumber,
}, nil
}
}

// Close does nothing
func (cs *counter) Close(ctx context.Context) error {
return nil
}

// DoCommand implements nothing
func (cs *counter) DoCommand(ctx context.Context, cmd map[string]interface{}) (map[string]interface{}, error) {
return nil, nil
}
Loading

0 comments on commit a48cb9a

Please sign in to comment.