Skip to content

Commit

Permalink
Merge pull request #49 from Photoroom/ben/fix_sources_ne
Browse files Browse the repository at this point in the history
[DB API] Respect Sources NE
  • Loading branch information
blefaudeux authored Nov 26, 2024
2 parents a971595 + 22ca422 commit 78b26bf
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 55 deletions.
57 changes: 57 additions & 0 deletions pkg/generator_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"log"
"net/http"
"net/url"
"os"
"time"
)
Expand Down Expand Up @@ -107,6 +108,7 @@ func (c *SourceDBConfig) setDefaults() {
c.WorldSize = -1

c.Sources = ""
c.SourcesNE = ""
c.RequireImages = true
c.RequireEmbeddings = false
c.Tags = ""
Expand Down Expand Up @@ -176,6 +178,7 @@ func (c *SourceDBConfig) getDbRequest() dbRequest {
return dbRequest{
fields: fields,
sources: c.Sources,
sourcesNE: c.SourcesNE,
pageSize: fmt.Sprintf("%d", c.PageSize),
tags: c.Tags,
tagsNE: c.TagsNE,
Expand Down Expand Up @@ -318,3 +321,57 @@ func (f datagoGeneratorDB) generatePages(ctx context.Context, chanPages chan Pag
}
}
}

func getHTTPRequest(api_url string, api_key string, request dbRequest) *http.Request {
if request.randomSampling {
api_url += "images/random/"
} else {
api_url += "images/"
}
request_url, _ := http.NewRequest("GET", api_url, nil)
request_url.Header.Add("Authorization", "Token "+api_key)
req := request_url.URL.Query()

maybeAddField := func(req *url.Values, field string, value string) {
if value != "" {
req.Add(field, value)
}
}

// Limit the returned latents to the ones we asked for
return_latents := request.hasLatents
if request.hasMasks != "" {
return_latents += "," + request.hasMasks
}

maybeAddField(&req, "fields", request.fields)
maybeAddField(&req, "sources", request.sources)
maybeAddField(&req, "sources__ne", request.sourcesNE)
maybeAddField(&req, "page_size", request.pageSize)

maybeAddField(&req, "tags", request.tags)
maybeAddField(&req, "tags__ne", request.tagsNE)

maybeAddField(&req, "has_attributes", request.hasAttributes)
maybeAddField(&req, "lacks_attributes", request.lacksAttributes)

maybeAddField(&req, "has_masks", request.hasMasks)
maybeAddField(&req, "lacks_masks", request.lacksMasks)

maybeAddField(&req, "has_latents", request.hasLatents)
maybeAddField(&req, "lacks_latents", request.lacksLatents)
maybeAddField(&req, "return_latents", return_latents)

maybeAddField(&req, "short_edge__gte", request.minShortEdge)
maybeAddField(&req, "short_edge__lte", request.maxShortEdge)
maybeAddField(&req, "pixel_count__gte", request.minPixelCount)
maybeAddField(&req, "pixel_count__lte", request.maxPixelCount)

maybeAddField(&req, "partitions_count", request.partitionsCount)
maybeAddField(&req, "partition", request.partition)

request_url.URL.RawQuery = req.Encode()
fmt.Println("Request URL:", request_url.URL.String())
fmt.Println()
return request_url
}
55 changes: 0 additions & 55 deletions pkg/serdes.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"

Expand Down Expand Up @@ -262,57 +261,3 @@ func fetchSample(config *SourceDBConfig, http_client *http.Client, sample_result
Tags: sample_result.Tags,
CocaEmbedding: cocaEmbedding}
}

func getHTTPRequest(api_url string, api_key string, request dbRequest) *http.Request {
if request.randomSampling {
api_url += "images/random/"
} else {
api_url += "images/"
}
request_url, _ := http.NewRequest("GET", api_url, nil)
request_url.Header.Add("Authorization", "Token "+api_key)
req := request_url.URL.Query()

maybeAddField := func(req *url.Values, field string, value string) {
if value != "" {
req.Add(field, value)
}
}

// Limit the returned latents to the ones we asked for
return_latents := request.hasLatents
if request.hasMasks != "" {
return_latents += "," + request.hasMasks
}

maybeAddField(&req, "fields", request.fields)
maybeAddField(&req, "sources", request.sources)
maybeAddField(&req, "sources__ne", request.sourcesNE)
maybeAddField(&req, "page_size", request.pageSize)

maybeAddField(&req, "tags", request.tags)
maybeAddField(&req, "tags__ne", request.tagsNE)

maybeAddField(&req, "has_attributes", request.hasAttributes)
maybeAddField(&req, "lacks_attributes", request.lacksAttributes)

maybeAddField(&req, "has_masks", request.hasMasks)
maybeAddField(&req, "lacks_masks", request.lacksMasks)

maybeAddField(&req, "has_latents", request.hasLatents)
maybeAddField(&req, "lacks_latents", request.lacksLatents)
maybeAddField(&req, "return_latents", return_latents)

maybeAddField(&req, "short_edge__gte", request.minShortEdge)
maybeAddField(&req, "short_edge__lte", request.maxShortEdge)
maybeAddField(&req, "pixel_count__gte", request.minPixelCount)
maybeAddField(&req, "pixel_count__lte", request.maxPixelCount)

maybeAddField(&req, "partitions_count", request.partitionsCount)
maybeAddField(&req, "partition", request.partition)

request_url.URL.RawQuery = req.Encode()
fmt.Println("Request URL:", request_url.URL.String())
fmt.Println()
return request_url
}
27 changes: 27 additions & 0 deletions tests/client_db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ func TestRanks(t *testing.T) {
dbConfig := clientConfig.SourceConfig.(datago.SourceDBConfig)
dbConfig.WorldSize = 2
dbConfig.Rank = 0
dbConfig.RequireImages = false
clientConfig.SourceConfig = dbConfig

client_0 := datago.GetClient(clientConfig)
Expand Down Expand Up @@ -295,6 +296,7 @@ func TestTags(t *testing.T) {

dbConfig := clientConfig.SourceConfig.(datago.SourceDBConfig)
dbConfig.Tags = "v4_trainset_hq"
dbConfig.RequireImages = false
clientConfig.SourceConfig = dbConfig

client := datago.GetClient(clientConfig)
Expand Down Expand Up @@ -322,6 +324,8 @@ func TestTags(t *testing.T) {
dbConfig := clientConfig.SourceConfig.(datago.SourceDBConfig)
dbConfig.Tags = ""
dbConfig.TagsNE = "v4_trainset_hq"
dbConfig.RequireImages = false

clientConfig.SourceConfig = dbConfig

client := datago.GetClient(clientConfig)
Expand Down Expand Up @@ -351,6 +355,7 @@ func TestMultipleSources(t *testing.T) {

dbConfig := clientConfig.SourceConfig.(datago.SourceDBConfig)
dbConfig.Sources = "LAION_ART,LAION_AESTHETICS"
dbConfig.RequireImages = false
clientConfig.SourceConfig = dbConfig

client := datago.GetClient(clientConfig)
Expand Down Expand Up @@ -379,6 +384,28 @@ func TestMultipleSources(t *testing.T) {
client.Stop()
}

func TestSourcesNE(t *testing.T) {
clientConfig := get_default_test_config()
clientConfig.SamplesBufferSize = 1

dbConfig := clientConfig.SourceConfig.(datago.SourceDBConfig)
dbConfig.Sources = "LAION_ART,LAION_AESTHETICS"
dbConfig.SourcesNE = "LAION_ART"
dbConfig.RequireImages = false
clientConfig.SourceConfig = dbConfig

client := datago.GetClient(clientConfig)

// Pull samples from the client, collect the sources
for i := 0; i < 100; i++ {
sample := client.GetSample()
if sample.Source == "LAION_ART" {
t.Error("We're not supposed to get samples from LAION_ART")
}
}
client.Stop()
}

func TestRandomSampling(t *testing.T) {
clientConfig := get_default_test_config()
clientConfig.SamplesBufferSize = 1
Expand Down

0 comments on commit 78b26bf

Please sign in to comment.