Skip to content

Commit

Permalink
Merge pull request #45 from Photoroom/ben/fix_source_sources
Browse files Browse the repository at this point in the history
[nitfix] Not handling multiple sources correctly
  • Loading branch information
blefaudeux authored Nov 25, 2024
2 parents 9b11d87 + 9c3b04f commit be6c978
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 5 deletions.
3 changes: 2 additions & 1 deletion pkg/generator_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type urlLatent struct {

type dbSampleMetadata struct {
Id string `json:"id"`
Source string `json:"source"`
Attributes map[string]interface{} `json:"attributes"`
DuplicateState int `json:"duplicate_state"`
ImageDirectURL string `json:"image_direct_url"`
Expand Down Expand Up @@ -129,7 +130,7 @@ func (c *SourceDBConfig) setDefaults() {

func (c *SourceDBConfig) getDbRequest() dbRequest {

fields := "attributes,image_direct_url"
fields := "attributes,image_direct_url,source"
if len(c.HasLatents) > 0 || len(c.HasMasks) > 0 {
fields += ",latents"
fmt.Println("Including some latents:", c.HasLatents, c.HasMasks)
Expand Down
6 changes: 3 additions & 3 deletions pkg/serdes.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ func fetchSample(config *SourceDBConfig, http_client *http.Client, sample_result
}

return &Sample{ID: sample_result.Id,
Source: config.Sources,
Source: sample_result.Source,
Attributes: sample_result.Attributes,
Image: *img_payload,
Latents: latents,
Expand Down Expand Up @@ -281,8 +281,8 @@ func getHTTPRequest(api_url string, api_key string, request dbRequest) *http.Req
}

maybeAddField(&req, "fields", request.fields)
maybeAddField(&req, "source", request.sources)
maybeAddField(&req, "source__ne", request.sourcesNE)
maybeAddField(&req, "sources", request.sources)
maybeAddField(&req, "sources__ne", request.sourcesNE)
maybeAddField(&req, "page_size", request.pageSize)

maybeAddField(&req, "tags", request.tags)
Expand Down
34 changes: 33 additions & 1 deletion tests/client_db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,4 +346,36 @@ func TestTags(t *testing.T) {
}
}

// FIXME: Could do with a lot of tests on the filesystem side
func TestMultipleSources(t *testing.T) {
clientConfig := get_default_test_config()
clientConfig.SamplesBufferSize = 1

dbConfig := clientConfig.SourceConfig.(datago.SourceDBConfig)
dbConfig.Sources = "LAION_ART,LAION_AESTHETICS"
clientConfig.SourceConfig = dbConfig

client := datago.GetClient(clientConfig)

// Pull samples from the client, collect the sources
test_set := make(map[string]interface{})
for i := 0; i < 100; i++ {
sample := client.GetSample()
if _, exists := test_set[sample.Source]; !exists {
test_set[sample.Source] = nil
if len(test_set) == 2 {
break
}
}
}

isin := func(dict map[string]interface{}, element string) bool {
_, exists := dict[element]
return exists
}

if len(test_set) != 2 || !isin(test_set, "LAION_ART") || !isin(test_set, "LAION_AESTHETICS") {
t.Error("Missing the required sources")
fmt.Println(test_set)
}
client.Stop()
}

0 comments on commit be6c978

Please sign in to comment.