diff --git a/src/main/java/htsjdk/samtools/cram/compression/range/ByteModel.java b/src/main/java/htsjdk/samtools/cram/compression/range/ByteModel.java index 86b891ee2e..f2f71c4e2a 100644 --- a/src/main/java/htsjdk/samtools/cram/compression/range/ByteModel.java +++ b/src/main/java/htsjdk/samtools/cram/compression/range/ByteModel.java @@ -8,7 +8,7 @@ public class ByteModel { // the cumulative frequencies of all symbols prior to this symbol, // and the total of all frequencies. public int totalFrequency; - public int maxSymbol; + public final int maxSymbol; public final int[] symbols; public final int[] frequencies; @@ -24,17 +24,7 @@ public ByteModel(final int numSymbols) { } } - // TODO: use this method to reset - public void reset() { - totalFrequency = 0; - for (int i = 0; i <= maxSymbol; i++) { - symbols[i] = 0; - frequencies[i] = 0; - } - // maxSymbol = 0; // TODO: ??? - } - - public int modelDecode(ByteBuffer inBuffer, RangeCoder rangeCoder){ + public int modelDecode(final ByteBuffer inBuffer, final RangeCoder rangeCoder){ // decodes one symbol final int freq = rangeCoder.rangeGetFrequency(totalFrequency); @@ -45,7 +35,7 @@ public int modelDecode(ByteBuffer inBuffer, RangeCoder rangeCoder){ } // update rangecoder - rangeCoder.rangeDecode(inBuffer,cumulativeFrequency,frequencies[x],totalFrequency); + rangeCoder.rangeDecode(inBuffer,cumulativeFrequency,frequencies[x]); // update model frequencies frequencies[x] += Constants.STEP; @@ -57,7 +47,7 @@ public int modelDecode(ByteBuffer inBuffer, RangeCoder rangeCoder){ } // keep symbols approximately frequency sorted - int symbol = symbols[x]; + final int symbol = symbols[x]; if (x > 0 && frequencies[x] > frequencies[x-1]){ // Swap frequencies[x], frequencies[x-1] int tmp = frequencies[x]; @@ -81,7 +71,7 @@ public void modelRenormalize(){ } } - public void modelEncode(final ByteBuffer outBuffer, RangeCoder rangeCoder, int symbol){ + public void modelEncode(final ByteBuffer outBuffer, final RangeCoder rangeCoder, final int symbol){ // encodes one input symbol int cumulativeFrequency = 0; diff --git a/src/main/java/htsjdk/samtools/cram/compression/range/RangeCoder.java b/src/main/java/htsjdk/samtools/cram/compression/range/RangeCoder.java index 763a65a3b2..05b1a0f33c 100644 --- a/src/main/java/htsjdk/samtools/cram/compression/range/RangeCoder.java +++ b/src/main/java/htsjdk/samtools/cram/compression/range/RangeCoder.java @@ -21,7 +21,7 @@ protected RangeCoder() { this.cache = 0; } - protected void rangeDecodeStart(ByteBuffer inBuffer){ + protected void rangeDecodeStart(final ByteBuffer inBuffer){ for (int i = 0; i < 5; i++){ // Get next 5 bytes. Ensure it is +ve @@ -29,9 +29,9 @@ protected void rangeDecodeStart(ByteBuffer inBuffer){ } } - protected void rangeDecode(ByteBuffer inBuffer, int sym_low, int sym_freq, int tot_freq){ - code -= sym_low * range; - range *= sym_freq; + protected void rangeDecode(final ByteBuffer inBuffer, final int cumulativeFrequency, final int symbolFrequency){ + code -= cumulativeFrequency * range; + range *= symbolFrequency; while (range < (1<<24)) { range <<= 8; @@ -39,17 +39,21 @@ protected void rangeDecode(ByteBuffer inBuffer, int sym_low, int sym_freq, int t } } - protected int rangeGetFrequency(final int tot_freq){ - range = (long) Math.floor(range / tot_freq); + protected int rangeGetFrequency(final int totalFrequency){ + range = (long) Math.floor(range / totalFrequency); return (int) Math.floor(code / range); } - protected void rangeEncode(final ByteBuffer outBuffer, final int sym_low, final int sym_freq, final int tot_freq){ - long old_low = low; - range = (long) Math.floor(range/tot_freq); - low += sym_low * range; + protected void rangeEncode( + final ByteBuffer outBuffer, + final int cumulativeFrequency, + final int symbolFrequency, + final int totalFrequency){ + final long old_low = low; + range = (long) Math.floor(range/totalFrequency); + low += cumulativeFrequency * range; low &= 0xFFFFFFFFL; // keep bottom 4 bytes, shift the top byte out of low - range *= sym_freq; + range *= symbolFrequency; if (low < old_low) { carry = true; @@ -70,7 +74,7 @@ protected void rangeEncodeEnd(final ByteBuffer outBuffer){ } } - private void rangeShiftLow(ByteBuffer outBuffer) { + private void rangeShiftLow(final ByteBuffer outBuffer) { // rangeShiftLow tracks the total number of extra bytes to emit and // carry indicates whether they are a string of 0xFF or 0x00 values diff --git a/src/main/java/htsjdk/samtools/cram/compression/range/RangeDecode.java b/src/main/java/htsjdk/samtools/cram/compression/range/RangeDecode.java index 6f892ae667..4a6f367106 100644 --- a/src/main/java/htsjdk/samtools/cram/compression/range/RangeDecode.java +++ b/src/main/java/htsjdk/samtools/cram/compression/range/RangeDecode.java @@ -2,7 +2,6 @@ import htsjdk.samtools.cram.CRAMException; import htsjdk.samtools.cram.compression.BZIP2ExternalCompressor; -import htsjdk.samtools.cram.compression.rans.Utils; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -39,16 +38,16 @@ private ByteBuffer uncompress(final ByteBuffer inBuffer, int outSize) { // if pack, get pack metadata, which will be used later to decode packed data int packDataLength = 0; int numSymbols = 0; - int[] packMappingTable = new int[0]; + byte[] packMappingTable = null; if (rangeParams.isPack()){ packDataLength = outSize; numSymbols = inBuffer.get() & 0xFF; // if (numSymbols > 16 or numSymbols==0), raise exception if (numSymbols <= 16 && numSymbols!=0) { - packMappingTable = new int[numSymbols]; + packMappingTable = new byte[numSymbols]; for (int i = 0; i < numSymbols; i++) { - packMappingTable[i] = inBuffer.get() & 0xFF; + packMappingTable[i] = inBuffer.get(); } outSize = Utils.readUint7(inBuffer); } else { @@ -92,8 +91,8 @@ private ByteBuffer uncompress(final ByteBuffer inBuffer, int outSize) { } // if pack, then decodePack - if (rangeParams.isPack() && packMappingTable.length > 0) { - outBuffer = decodePack(outBuffer, packMappingTable, numSymbols, packDataLength); + if (rangeParams.isPack()) { + outBuffer = Utils.decodePack(outBuffer, packMappingTable, numSymbols, packDataLength); } outBuffer.rewind(); return outBuffer; @@ -227,55 +226,6 @@ private ByteBuffer uncompressEXT( return outBuffer; } - private ByteBuffer decodePack(ByteBuffer inBuffer, final int[] packMappingTable, int numSymbols, int uncompressedPackOutputLength) { - ByteBuffer outBufferPack = ByteBuffer.allocate(uncompressedPackOutputLength); - int j = 0; - - if (numSymbols <= 1) { - for (int i=0; i < uncompressedPackOutputLength; i++){ - outBufferPack.put(i, (byte) packMappingTable[0]); - } - } - - // 1 bit per value - else if (numSymbols <= 2) { - int v = 0; - for (int i=0; i < uncompressedPackOutputLength; i++){ - if (i % 8 == 0){ - v = inBuffer.get(j++); - } - outBufferPack.put(i, (byte) packMappingTable[v & 1]); - v >>=1; - } - } - - // 2 bits per value - else if (numSymbols <= 4){ - int v = 0; - for(int i=0; i < uncompressedPackOutputLength; i++){ - if (i % 4 == 0){ - v = inBuffer.get(j++); - } - outBufferPack.put(i, (byte) packMappingTable[v & 3]); - v >>=2; - } - } - - // 4 bits per value - else if (numSymbols <= 16){ - int v = 0; - for(int i=0; i < uncompressedPackOutputLength; i++){ - if (i % 2 == 0){ - v = inBuffer.get(j++); - } - outBufferPack.put(i, (byte) packMappingTable[v & 15]); - v >>=4; - } - } - inBuffer = outBufferPack; - return inBuffer; - } - private ByteBuffer decodeStripe(ByteBuffer inBuffer, final int outSize){ final int numInterleaveStreams = inBuffer.get() & 0xFF; diff --git a/src/main/java/htsjdk/samtools/cram/compression/range/Utils.java b/src/main/java/htsjdk/samtools/cram/compression/range/Utils.java index 0f6b1507dd..abb0969320 100644 --- a/src/main/java/htsjdk/samtools/cram/compression/range/Utils.java +++ b/src/main/java/htsjdk/samtools/cram/compression/range/Utils.java @@ -28,4 +28,56 @@ public static int readUint7(ByteBuffer cp) { } while ((c & 0x80) != 0); return i; } + + public static ByteBuffer decodePack( + final ByteBuffer inBuffer, + final byte[] packMappingTable, + final int numSymbols, + final int uncompressedPackOutputLength) { + ByteBuffer outBufferPack = ByteBuffer.allocate(uncompressedPackOutputLength); + int j = 0; + + if (numSymbols <= 1) { + for (int i=0; i < uncompressedPackOutputLength; i++){ + outBufferPack.put(i, packMappingTable[0]); + } + } + + // 1 bit per value + else if (numSymbols <= 2) { + int v = 0; + for (int i=0; i < uncompressedPackOutputLength; i++){ + if (i % 8 == 0){ + v = inBuffer.get(j++); + } + outBufferPack.put(i, packMappingTable[v & 1]); + v >>=1; + } + } + + // 2 bits per value + else if (numSymbols <= 4){ + int v = 0; + for(int i=0; i < uncompressedPackOutputLength; i++){ + if (i % 4 == 0){ + v = inBuffer.get(j++); + } + outBufferPack.put(i, packMappingTable[v & 3]); + v >>=2; + } + } + + // 4 bits per value + else if (numSymbols <= 16){ + int v = 0; + for(int i=0; i < uncompressedPackOutputLength; i++){ + if (i % 2 == 0){ + v = inBuffer.get(j++); + } + outBufferPack.put(i, packMappingTable[v & 15]); + v >>=4; + } + } + return outBufferPack; + } } \ No newline at end of file diff --git a/src/test/java/htsjdk/samtools/cram/RangeInteropTest.java b/src/test/java/htsjdk/samtools/cram/RangeInteropTest.java index c00e568ad0..1f9547e744 100644 --- a/src/test/java/htsjdk/samtools/cram/RangeInteropTest.java +++ b/src/test/java/htsjdk/samtools/cram/RangeInteropTest.java @@ -19,11 +19,6 @@ import java.util.ArrayList; import java.util.List; -import static htsjdk.samtools.cram.CRAMInteropTestUtils.filterEmbeddedNewlines; -import static htsjdk.samtools.cram.CRAMInteropTestUtils.getInteropCompressedFilePaths; -import static htsjdk.samtools.cram.CRAMInteropTestUtils.getParamsFormatFlags; -import static htsjdk.samtools.cram.CRAMInteropTestUtils.getUnCompressedFilePath; - public class RangeInteropTest extends HtsjdkTest { public static final String COMPRESSED_RANGE_DIR = "arith"; @@ -34,13 +29,13 @@ public Object[][] getRoundTripTestCases() throws IOException { // compressed testfile path, uncompressed testfile path, // Range encoder, Range decoder, Range params final List testCases = new ArrayList<>(); - for (Path path : getInteropCompressedFilePaths(COMPRESSED_RANGE_DIR)) { + for (Path path : CRAMInteropTestUtils.getInteropCompressedFilePaths(COMPRESSED_RANGE_DIR)) { Object[] objects = new Object[]{ path, - getUnCompressedFilePath(path), + CRAMInteropTestUtils.getUnCompressedFilePath(path), new RangeEncode(), new RangeDecode(), - new RangeParams(getParamsFormatFlags(path)) + new RangeParams(CRAMInteropTestUtils.getParamsFormatFlags(path)) }; testCases.add(objects); } @@ -70,7 +65,7 @@ public void testRangeRoundTrip( // preprocess the uncompressed data (to match what the htscodecs-library test harness does) // by filtering out the embedded newlines, and then round trip through Range codec and compare the // results - final ByteBuffer uncompressedInteropBytes = ByteBuffer.wrap(filterEmbeddedNewlines(IOUtils.toByteArray(uncompressedInteropStream))); + final ByteBuffer uncompressedInteropBytes = ByteBuffer.wrap(CRAMInteropTestUtils.filterEmbeddedNewlines(IOUtils.toByteArray(uncompressedInteropStream))); if (params.isStripe()) { Assert.assertThrows(CRAMException.class, () -> rangeEncode.compress(uncompressedInteropBytes, params)); @@ -95,11 +90,16 @@ public void testDecodeOnly( try (final InputStream uncompressedInteropStream = Files.newInputStream(uncompressedInteropPath); final InputStream preCompressedInteropStream = Files.newInputStream(compressedFilePath) ) { - // preprocess the uncompressed data (to match what the htscodecs-library test harness does) - // by filtering out the embedded newlines, and then round trip through Range codec and compare the - // results - final ByteBuffer uncompressedInteropBytes = ByteBuffer.wrap(filterEmbeddedNewlines(IOUtils.toByteArray(uncompressedInteropStream))); + // by filtering out the embedded newlines, and then round trip through Range codec + // and compare the results + + final ByteBuffer uncompressedInteropBytes; + if (uncompressedInteropPath.toString().contains("htscodecs/tests/dat/u")) { + uncompressedInteropBytes = ByteBuffer.wrap(IOUtils.toByteArray(uncompressedInteropStream)); + } else { + uncompressedInteropBytes = ByteBuffer.wrap(CRAMInteropTestUtils.filterEmbeddedNewlines(IOUtils.toByteArray(uncompressedInteropStream))); + } final ByteBuffer preCompressedInteropBytes = ByteBuffer.wrap(IOUtils.toByteArray(preCompressedInteropStream)); // Use htsjdk to uncompress the precompressed file from htscodecs repo