diff --git a/src/main/java/uk/ac/ebi/eva/pipeline/configuration/jobs/steps/processors/VariantStatsProcessorConfiguration.java b/src/main/java/uk/ac/ebi/eva/pipeline/configuration/jobs/steps/processors/VariantStatsProcessorConfiguration.java index 23b1b92f..997d67b5 100644 --- a/src/main/java/uk/ac/ebi/eva/pipeline/configuration/jobs/steps/processors/VariantStatsProcessorConfiguration.java +++ b/src/main/java/uk/ac/ebi/eva/pipeline/configuration/jobs/steps/processors/VariantStatsProcessorConfiguration.java @@ -21,6 +21,7 @@ import org.springframework.context.annotation.Configuration; import uk.ac.ebi.eva.commons.models.mongo.entity.VariantDocument; import uk.ac.ebi.eva.pipeline.io.processors.VariantStatsProcessor; +import uk.ac.ebi.eva.pipeline.parameters.InputParameters; import static uk.ac.ebi.eva.pipeline.configuration.BeanNames.VARIANT_STATS_PROCESSOR; @@ -29,7 +30,7 @@ public class VariantStatsProcessorConfiguration { @Bean(VARIANT_STATS_PROCESSOR) @StepScope - public ItemProcessor variantStatsProcessor() { - return new VariantStatsProcessor(); + public ItemProcessor variantStatsProcessor(InputParameters inputParameters) { + return new VariantStatsProcessor(inputParameters.getStudyId()); } } diff --git a/src/main/java/uk/ac/ebi/eva/pipeline/io/processors/VariantStatsProcessor.java b/src/main/java/uk/ac/ebi/eva/pipeline/io/processors/VariantStatsProcessor.java index 6e510345..a12e9b35 100644 --- a/src/main/java/uk/ac/ebi/eva/pipeline/io/processors/VariantStatsProcessor.java +++ b/src/main/java/uk/ac/ebi/eva/pipeline/io/processors/VariantStatsProcessor.java @@ -14,7 +14,6 @@ import java.util.Arrays; import java.util.Comparator; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -30,18 +29,30 @@ public class VariantStatsProcessor implements ItemProcessor MISSING_GENOTYPE_ALLELE_REPRESENTATIONS = Arrays.asList(".", "-1"); - public VariantStatsProcessor() { + private String studyId; + + public VariantStatsProcessor(String studyId) { + this.studyId = studyId; } @Override public VariantDocument process(VariantDocument variant) { Map filesIdNumberOfSamplesMap = VariantStatsReader.getFilesIdAndNumberOfSamplesMap(); + Set fidSet = filesIdNumberOfSamplesMap.keySet(); String variantRef = variant.getReference(); String variantAlt = variant.getAlternate(); - Set variantStatsSet = new HashSet<>(); - Set variantSourceEntrySet = variant.getVariantSources(); + // copy the stats that should not be changed/updated and will be copied as it is + Set variantStatsSet = variant.getVariantStatsMongo().stream() + .filter(st -> !st.getStudyId().equals(studyId) || !fidSet.contains(st.getFileId())) + .collect(Collectors.toSet()); + + // get only the ones for which we can calculate the stats + Set variantSourceEntrySet = variant.getVariantSources().stream() + .filter(vse -> vse.getStudyId().equals(studyId) && fidSet.contains(vse.getFileId())) + .collect(Collectors.toSet()); + for (VariantSourceEntryMongo variantSourceEntry : variantSourceEntrySet) { String studyId = variantSourceEntry.getStudyId(); String fileId = variantSourceEntry.getFileId(); diff --git a/src/main/java/uk/ac/ebi/eva/pipeline/io/readers/VariantStatsReader.java b/src/main/java/uk/ac/ebi/eva/pipeline/io/readers/VariantStatsReader.java index eda7561a..35871ae9 100644 --- a/src/main/java/uk/ac/ebi/eva/pipeline/io/readers/VariantStatsReader.java +++ b/src/main/java/uk/ac/ebi/eva/pipeline/io/readers/VariantStatsReader.java @@ -94,10 +94,11 @@ private void populateFilesIdAndNumberOfSamplesMap() { computed("fid", "$fid"), computed("numOfSamples", new Document("$size", new Document("$objectToArray", "$samp"))) )); - Bson groupStage = group("$fid", sum("totalNumOfSamples", "$numOfSamples")); + Bson groupStage = group("$fid", sum("totalNumOfSamples", "$numOfSamples"), sum("count", 1)); + Bson filterStage = match(Filters.eq("count", 1)); filesIdNumberOfSamplesMap = mongoTemplate.getCollection(databaseParameters.getCollectionFilesName()) - .aggregate(asList(matchStage, projectStage, groupStage)) + .aggregate(asList(matchStage, projectStage, groupStage, filterStage)) .into(new ArrayList<>()) .stream() .collect(Collectors.toMap(doc -> doc.getString("_id"), doc -> doc.getInteger("totalNumOfSamples"))); diff --git a/src/test/java/uk/ac/ebi/eva/pipeline/configuration/jobs/steps/VariantStatsStepTest.java b/src/test/java/uk/ac/ebi/eva/pipeline/configuration/jobs/steps/VariantStatsStepTest.java index fae03763..1a13faed 100644 --- a/src/test/java/uk/ac/ebi/eva/pipeline/configuration/jobs/steps/VariantStatsStepTest.java +++ b/src/test/java/uk/ac/ebi/eva/pipeline/configuration/jobs/steps/VariantStatsStepTest.java @@ -15,6 +15,7 @@ */ package uk.ac.ebi.eva.pipeline.configuration.jobs.steps; +import com.mongodb.client.MongoCollection; import org.bson.Document; import org.junit.After; import org.junit.Assert; @@ -39,6 +40,7 @@ import uk.ac.ebi.eva.utils.EvaJobParameterBuilder; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import static org.junit.Assert.assertEquals; @@ -76,7 +78,6 @@ public class VariantStatsStepTest { @Before public void setUp() throws Exception { mongoRule.getTemporaryDatabase(DATABASE_NAME).drop(); - mongoRule.restoreDump(getResourceUrl(MONGO_DUMP), DATABASE_NAME); } @After @@ -85,7 +86,10 @@ public void cleanUp() { } @Test - public void variantStatsStepShouldCalculateAndLoadStats() { + public void variantStatsStepShouldCalculateAndLoadStats() throws Exception { + mongoRule.restoreDump(getResourceUrl(MONGO_DUMP), DATABASE_NAME); + + JobParameters jobParameters = new EvaJobParameterBuilder() .collectionFilesName(COLLECTION_FILES_NAME) .collectionVariantsName(COLLECTION_VARIANTS_NAME) @@ -122,4 +126,103 @@ public void variantStatsStepShouldCalculateAndLoadStats() { assertEquals(0, variantStats.get("missGt")); } + @Test + public void variantStatsStepShouldCalculateAndLoadStats_WhereFidHasMoreThanOneFile() { + MongoCollection filesCollection = mongoRule.getCollection(DATABASE_NAME, COLLECTION_FILES_NAME); + MongoCollection variantsCollection = mongoRule.getCollection(DATABASE_NAME, COLLECTION_VARIANTS_NAME); + + filesCollection.insertMany(Arrays.asList( + new Document("sid", "sid1").append("fid", "fid1").append("fname", "fname11") + .append("samp", new Document("samp11", 0).append("samp12", 1).append("samp13", 2)), + // multiple entries for fid2 in the files collection + new Document("sid", "sid1").append("fid", "fid2").append("fname", "fname21") + .append("samp", new Document("samp21", 0).append("samp22", 1)).append("samp23", 2), + new Document("sid", "sid1").append("fid", "fid2").append("fname", "fname22") + .append("samp", new Document("samp31", 0).append("samp32", 1)).append("samp33", 2) + )); + + variantsCollection.insertMany(Arrays.asList( + new Document("_id", "chr1_11111111_A_G").append("ref", "A").append("alt", "G") + .append("files", Arrays.asList( + // stats should be calculated + new Document("sid", "sid1").append("fid", "fid1") + .append("samp", new Document("def", "0|0").append("0|1", Arrays.asList(1))), + // should not calculate stats - fid2 has more than one entry in the files collection + new Document("sid", "sid1").append("fid", "fid2") + .append("samp", new Document("def", "0|0").append("0|1", Arrays.asList(1))), + // should not calculate stats - no entry for fid3 in files collection + new Document("sid", "sid1").append("fid", "fid3") + .append("samp", new Document("def", "0|0").append("0|1", Arrays.asList(1))), + // should not calculate stats - different study + new Document("sid", "sid2").append("fid", "fid1") + .append("samp", new Document("def", "0|0").append("0|1", Arrays.asList(1))) + )) + .append("st", Arrays.asList( + // should be updated with new values + new Document("sid", "sid1").append("fid", "fid1") + .append("maf", 0.1).append("mgf", 0.1) + .append("mafAl", "A").append("mgfGt", "0|0"), + // should not change as it belongs to different study id + new Document("sid", "sid2").append("fid", "fid1") + .append("maf", 0.20000000298023224).append("mgf", 0.20000000298023224) + .append("mafAl", "A").append("mgfGt", "0|0"), + // should not change as it belongs to different fid + new Document("sid", "sid1").append("fid", "fid3") + .append("maf", 0.30000001192092896).append("mgf", 0.30000001192092896) + .append("mafAl", "A").append("mgfGt", "0|0") + + )) + )); + + JobParameters jobParameters = new EvaJobParameterBuilder() + .collectionFilesName(COLLECTION_FILES_NAME) + .collectionVariantsName(COLLECTION_VARIANTS_NAME) + .databaseName(DATABASE_NAME) + .inputStudyId("sid1") + .chunkSize("100") + .toJobParameters(); + + JobExecution jobExecution = jobLauncherTestUtils.launchStep(BeanNames.VARIANT_STATS_STEP, jobParameters); + + // check job completed successfully + assertCompleted(jobExecution); + List documents = mongoRule.getTemporaryDatabase(DATABASE_NAME).getCollection(COLLECTION_VARIANTS_NAME) + .find().into(new ArrayList<>()); + Assert.assertTrue(documents.size() == 1); + + // assert data + ArrayList variantStatsList = documents.stream().filter(doc -> doc.get("_id").equals("chr1_11111111_A_G")) + .findFirst().get().get("st", ArrayList.class); + assertEquals(3, variantStatsList.size()); + + // assert remained unchanged + Document variantStatsForSid2Fid1 = variantStatsList.stream() + .filter(st -> st.get("sid").equals("sid2") && st.get("fid").equals("fid1")).findFirst().get(); + assertEquals(0.20000000298023224, variantStatsForSid2Fid1.get("maf")); + assertEquals(0.20000000298023224, variantStatsForSid2Fid1.get("mgf")); + assertEquals("A", variantStatsForSid2Fid1.get("mafAl")); + assertEquals("0|0", variantStatsForSid2Fid1.get("mgfGt")); + + // assert remained unchanged + Document variantStatsForSid1Fid3 = variantStatsList.stream() + .filter(st -> st.get("sid").equals("sid1") && st.get("fid").equals("fid3")).findFirst().get(); + assertEquals(0.30000001192092896, variantStatsForSid1Fid3.get("maf")); + assertEquals(0.30000001192092896, variantStatsForSid1Fid3.get("mgf")); + assertEquals("A", variantStatsForSid1Fid3.get("mafAl")); + assertEquals("0|0", variantStatsForSid1Fid3.get("mgfGt")); + + // assert updated with new stats + Document variantStatsForSid1Fid1 = variantStatsList.stream() + .filter(st -> st.get("sid").equals("sid1") && st.get("fid").equals("fid1")).findFirst().get(); + Document numOfGT = (Document) variantStatsForSid1Fid1.get("numGt"); + assertEquals(2, numOfGT.get("0|0")); + assertEquals(1, numOfGT.get("0|1")); + assertEquals(0.1666666716337204, variantStatsForSid1Fid1.get("maf")); + assertEquals(0.3333333432674408, variantStatsForSid1Fid1.get("mgf")); + assertEquals("G", variantStatsForSid1Fid1.get("mafAl")); + assertEquals("0|1", variantStatsForSid1Fid1.get("mgfGt")); + assertEquals(0, variantStatsForSid1Fid1.get("missAl")); + assertEquals(0, variantStatsForSid1Fid1.get("missGt")); + } + }