Skip to content

Commit

Permalink
Addressing feedback from nov 21 - part 1
Browse files Browse the repository at this point in the history
  • Loading branch information
yash-puligundla committed Jan 10, 2024
1 parent 58ace69 commit f9b066c
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ public class ByteModel {
// the cumulative frequencies of all symbols prior to this symbol,
// and the total of all frequencies.
public int totalFrequency;
public int maxSymbol;
public final int maxSymbol;
public final int[] symbols;
public final int[] frequencies;

Expand All @@ -24,17 +24,7 @@ public ByteModel(final int numSymbols) {
}
}

// TODO: use this method to reset
public void reset() {
totalFrequency = 0;
for (int i = 0; i <= maxSymbol; i++) {
symbols[i] = 0;
frequencies[i] = 0;
}
// maxSymbol = 0; // TODO: ???
}

public int modelDecode(ByteBuffer inBuffer, RangeCoder rangeCoder){
public int modelDecode(final ByteBuffer inBuffer, final RangeCoder rangeCoder){

// decodes one symbol
final int freq = rangeCoder.rangeGetFrequency(totalFrequency);
Expand All @@ -45,7 +35,7 @@ public int modelDecode(ByteBuffer inBuffer, RangeCoder rangeCoder){
}

// update rangecoder
rangeCoder.rangeDecode(inBuffer,cumulativeFrequency,frequencies[x],totalFrequency);
rangeCoder.rangeDecode(inBuffer,cumulativeFrequency,frequencies[x]);

// update model frequencies
frequencies[x] += Constants.STEP;
Expand All @@ -57,7 +47,7 @@ public int modelDecode(ByteBuffer inBuffer, RangeCoder rangeCoder){
}

// keep symbols approximately frequency sorted
int symbol = symbols[x];
final int symbol = symbols[x];
if (x > 0 && frequencies[x] > frequencies[x-1]){
// Swap frequencies[x], frequencies[x-1]
int tmp = frequencies[x];
Expand All @@ -81,7 +71,7 @@ public void modelRenormalize(){
}
}

public void modelEncode(final ByteBuffer outBuffer, RangeCoder rangeCoder, int symbol){
public void modelEncode(final ByteBuffer outBuffer, final RangeCoder rangeCoder, final int symbol){

// encodes one input symbol
int cumulativeFrequency = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,35 +21,39 @@ protected RangeCoder() {
this.cache = 0;
}

protected void rangeDecodeStart(ByteBuffer inBuffer){
protected void rangeDecodeStart(final ByteBuffer inBuffer){
for (int i = 0; i < 5; i++){

// Get next 5 bytes. Ensure it is +ve
code = (code << 8) + (inBuffer.get() & 0xFF);
}
}

protected void rangeDecode(ByteBuffer inBuffer, int sym_low, int sym_freq, int tot_freq){
code -= sym_low * range;
range *= sym_freq;
protected void rangeDecode(final ByteBuffer inBuffer, final int cumulativeFrequency, final int symbolFrequency){
code -= cumulativeFrequency * range;
range *= symbolFrequency;

while (range < (1<<24)) {
range <<= 8;
code = (code << 8) + (inBuffer.get() & 0xFF); // Ensure code is positive
}
}

protected int rangeGetFrequency(final int tot_freq){
range = (long) Math.floor(range / tot_freq);
protected int rangeGetFrequency(final int totalFrequency){
range = (long) Math.floor(range / totalFrequency);
return (int) Math.floor(code / range);
}

protected void rangeEncode(final ByteBuffer outBuffer, final int sym_low, final int sym_freq, final int tot_freq){
long old_low = low;
range = (long) Math.floor(range/tot_freq);
low += sym_low * range;
protected void rangeEncode(
final ByteBuffer outBuffer,
final int cumulativeFrequency,
final int symbolFrequency,
final int totalFrequency){
final long old_low = low;
range = (long) Math.floor(range/totalFrequency);
low += cumulativeFrequency * range;
low &= 0xFFFFFFFFL; // keep bottom 4 bytes, shift the top byte out of low
range *= sym_freq;
range *= symbolFrequency;

if (low < old_low) {
carry = true;
Expand All @@ -70,7 +74,7 @@ protected void rangeEncodeEnd(final ByteBuffer outBuffer){
}
}

private void rangeShiftLow(ByteBuffer outBuffer) {
private void rangeShiftLow(final ByteBuffer outBuffer) {
// rangeShiftLow tracks the total number of extra bytes to emit and
// carry indicates whether they are a string of 0xFF or 0x00 values

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import htsjdk.samtools.cram.CRAMException;
import htsjdk.samtools.cram.compression.BZIP2ExternalCompressor;
import htsjdk.samtools.cram.compression.rans.Utils;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
Expand Down Expand Up @@ -39,16 +38,16 @@ private ByteBuffer uncompress(final ByteBuffer inBuffer, int outSize) {
// if pack, get pack metadata, which will be used later to decode packed data
int packDataLength = 0;
int numSymbols = 0;
int[] packMappingTable = new int[0];
byte[] packMappingTable = null;
if (rangeParams.isPack()){
packDataLength = outSize;
numSymbols = inBuffer.get() & 0xFF;

// if (numSymbols > 16 or numSymbols==0), raise exception
if (numSymbols <= 16 && numSymbols!=0) {
packMappingTable = new int[numSymbols];
packMappingTable = new byte[numSymbols];
for (int i = 0; i < numSymbols; i++) {
packMappingTable[i] = inBuffer.get() & 0xFF;
packMappingTable[i] = inBuffer.get();
}
outSize = Utils.readUint7(inBuffer);
} else {
Expand Down Expand Up @@ -92,8 +91,8 @@ private ByteBuffer uncompress(final ByteBuffer inBuffer, int outSize) {
}

// if pack, then decodePack
if (rangeParams.isPack() && packMappingTable.length > 0) {
outBuffer = decodePack(outBuffer, packMappingTable, numSymbols, packDataLength);
if (rangeParams.isPack()) {
outBuffer = Utils.decodePack(outBuffer, packMappingTable, numSymbols, packDataLength);
}
outBuffer.rewind();
return outBuffer;
Expand Down Expand Up @@ -227,55 +226,6 @@ private ByteBuffer uncompressEXT(
return outBuffer;
}

private ByteBuffer decodePack(ByteBuffer inBuffer, final int[] packMappingTable, int numSymbols, int uncompressedPackOutputLength) {
ByteBuffer outBufferPack = ByteBuffer.allocate(uncompressedPackOutputLength);
int j = 0;

if (numSymbols <= 1) {
for (int i=0; i < uncompressedPackOutputLength; i++){
outBufferPack.put(i, (byte) packMappingTable[0]);
}
}

// 1 bit per value
else if (numSymbols <= 2) {
int v = 0;
for (int i=0; i < uncompressedPackOutputLength; i++){
if (i % 8 == 0){
v = inBuffer.get(j++);
}
outBufferPack.put(i, (byte) packMappingTable[v & 1]);
v >>=1;
}
}

// 2 bits per value
else if (numSymbols <= 4){
int v = 0;
for(int i=0; i < uncompressedPackOutputLength; i++){
if (i % 4 == 0){
v = inBuffer.get(j++);
}
outBufferPack.put(i, (byte) packMappingTable[v & 3]);
v >>=2;
}
}

// 4 bits per value
else if (numSymbols <= 16){
int v = 0;
for(int i=0; i < uncompressedPackOutputLength; i++){
if (i % 2 == 0){
v = inBuffer.get(j++);
}
outBufferPack.put(i, (byte) packMappingTable[v & 15]);
v >>=4;
}
}
inBuffer = outBufferPack;
return inBuffer;
}

private ByteBuffer decodeStripe(ByteBuffer inBuffer, final int outSize){

final int numInterleaveStreams = inBuffer.get() & 0xFF;
Expand Down
52 changes: 52 additions & 0 deletions src/main/java/htsjdk/samtools/cram/compression/range/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,56 @@ public static int readUint7(ByteBuffer cp) {
} while ((c & 0x80) != 0);
return i;
}

public static ByteBuffer decodePack(
final ByteBuffer inBuffer,
final byte[] packMappingTable,
final int numSymbols,
final int uncompressedPackOutputLength) {
ByteBuffer outBufferPack = ByteBuffer.allocate(uncompressedPackOutputLength);
int j = 0;

if (numSymbols <= 1) {
for (int i=0; i < uncompressedPackOutputLength; i++){
outBufferPack.put(i, packMappingTable[0]);
}
}

// 1 bit per value
else if (numSymbols <= 2) {
int v = 0;
for (int i=0; i < uncompressedPackOutputLength; i++){
if (i % 8 == 0){
v = inBuffer.get(j++);
}
outBufferPack.put(i, packMappingTable[v & 1]);
v >>=1;
}
}

// 2 bits per value
else if (numSymbols <= 4){
int v = 0;
for(int i=0; i < uncompressedPackOutputLength; i++){
if (i % 4 == 0){
v = inBuffer.get(j++);
}
outBufferPack.put(i, packMappingTable[v & 3]);
v >>=2;
}
}

// 4 bits per value
else if (numSymbols <= 16){
int v = 0;
for(int i=0; i < uncompressedPackOutputLength; i++){
if (i % 2 == 0){
v = inBuffer.get(j++);
}
outBufferPack.put(i, packMappingTable[v & 15]);
v >>=4;
}
}
return outBufferPack;
}
}
26 changes: 13 additions & 13 deletions src/test/java/htsjdk/samtools/cram/RangeInteropTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@
import java.util.ArrayList;
import java.util.List;

import static htsjdk.samtools.cram.CRAMInteropTestUtils.filterEmbeddedNewlines;
import static htsjdk.samtools.cram.CRAMInteropTestUtils.getInteropCompressedFilePaths;
import static htsjdk.samtools.cram.CRAMInteropTestUtils.getParamsFormatFlags;
import static htsjdk.samtools.cram.CRAMInteropTestUtils.getUnCompressedFilePath;

public class RangeInteropTest extends HtsjdkTest {
public static final String COMPRESSED_RANGE_DIR = "arith";

Expand All @@ -34,13 +29,13 @@ public Object[][] getRoundTripTestCases() throws IOException {
// compressed testfile path, uncompressed testfile path,
// Range encoder, Range decoder, Range params
final List<Object[]> testCases = new ArrayList<>();
for (Path path : getInteropCompressedFilePaths(COMPRESSED_RANGE_DIR)) {
for (Path path : CRAMInteropTestUtils.getInteropCompressedFilePaths(COMPRESSED_RANGE_DIR)) {
Object[] objects = new Object[]{
path,
getUnCompressedFilePath(path),
CRAMInteropTestUtils.getUnCompressedFilePath(path),
new RangeEncode(),
new RangeDecode(),
new RangeParams(getParamsFormatFlags(path))
new RangeParams(CRAMInteropTestUtils.getParamsFormatFlags(path))
};
testCases.add(objects);
}
Expand Down Expand Up @@ -70,7 +65,7 @@ public void testRangeRoundTrip(
// preprocess the uncompressed data (to match what the htscodecs-library test harness does)
// by filtering out the embedded newlines, and then round trip through Range codec and compare the
// results
final ByteBuffer uncompressedInteropBytes = ByteBuffer.wrap(filterEmbeddedNewlines(IOUtils.toByteArray(uncompressedInteropStream)));
final ByteBuffer uncompressedInteropBytes = ByteBuffer.wrap(CRAMInteropTestUtils.filterEmbeddedNewlines(IOUtils.toByteArray(uncompressedInteropStream)));

if (params.isStripe()) {
Assert.assertThrows(CRAMException.class, () -> rangeEncode.compress(uncompressedInteropBytes, params));
Expand All @@ -95,11 +90,16 @@ public void testDecodeOnly(
try (final InputStream uncompressedInteropStream = Files.newInputStream(uncompressedInteropPath);
final InputStream preCompressedInteropStream = Files.newInputStream(compressedFilePath)
) {

// preprocess the uncompressed data (to match what the htscodecs-library test harness does)
// by filtering out the embedded newlines, and then round trip through Range codec and compare the
// results
final ByteBuffer uncompressedInteropBytes = ByteBuffer.wrap(filterEmbeddedNewlines(IOUtils.toByteArray(uncompressedInteropStream)));
// by filtering out the embedded newlines, and then round trip through Range codec
// and compare the results

final ByteBuffer uncompressedInteropBytes;
if (uncompressedInteropPath.toString().contains("htscodecs/tests/dat/u")) {
uncompressedInteropBytes = ByteBuffer.wrap(IOUtils.toByteArray(uncompressedInteropStream));
} else {
uncompressedInteropBytes = ByteBuffer.wrap(CRAMInteropTestUtils.filterEmbeddedNewlines(IOUtils.toByteArray(uncompressedInteropStream)));
}
final ByteBuffer preCompressedInteropBytes = ByteBuffer.wrap(IOUtils.toByteArray(preCompressedInteropStream));

// Use htsjdk to uncompress the precompressed file from htscodecs repo
Expand Down

0 comments on commit f9b066c

Please sign in to comment.