Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stdlib: base64 stream decoder #21348

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 154 additions & 59 deletions lib/std/base64.zig
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ pub const Base64Encoder = struct {
}
}

// dest must be compatible with std.io.Writer's writeAll interface
/// `dest` must be compatible with `std.io.Writer`'s `writeAll` interface.
pub fn encodeWriter(encoder: *const Base64Encoder, dest: anytype, source: []const u8) !void {
var chunker = window(u8, source, 3, 3);
while (chunker.next()) |chunk| {
Expand All @@ -109,19 +109,19 @@ pub const Base64Encoder = struct {
}
}

// destWriter must be compatible with std.io.Writer's writeAll interface
// sourceReader must be compatible with std.io.Reader's read interface
pub fn encodeFromReaderToWriter(encoder: *const Base64Encoder, destWriter: anytype, sourceReader: anytype) !void {
/// `dest_writer` must be compatible with `std.io.Writer`'s `writeAll` interface.
/// `source_reader` must be compatible with `std.io.Reader`'s `read` interface.
pub fn encodeFromReaderToWriter(encoder: *const Base64Encoder, dest_writer: anytype, source_reader: anytype) !void {
while (true) {
var tempSource: [3]u8 = undefined;
const bytesRead = try sourceReader.read(&tempSource);
var temp_source: [3]u8 = undefined;
const bytesRead = try source_reader.read(&temp_source);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed one: bytes_read

if (bytesRead == 0) {
break;
}

var temp: [5]u8 = undefined;
const s = encoder.encode(&temp, tempSource[0..bytesRead]);
try destWriter.writeAll(s);
const s = encoder.encode(&temp, temp_source[0..bytesRead]);
try dest_writer.writeAll(s);
}
}

Expand Down Expand Up @@ -301,6 +301,33 @@ pub const Base64Decoder = struct {
if (padding_chars != padding_len) return error.InvalidPadding;
}
}

/// `dest_writer` must be compatible with `std.io.Writer`'s `writeAll` interface.
pub fn decodeWriter(decoder: *const Base64Decoder, dest_writer: anytype, source: []const u8) !void {
var temp = [_]u8{0} ** 4;
var chunker = window(u8, source, 4, 4);
while (chunker.next()) |chunk| {
const size = try decoder.calcSizeForSlice(chunk);
try decoder.decode(&temp, chunk);
try dest_writer.writeAll(temp[0..size]);
}
}

/// `dest_writer` must be compatible with `std.io.Writer`'s `writeAll` interface.
/// `source_reader` must be compatible with `std.io.Reader`'s `read` interface.
pub fn decodeFromReaderToWriter(decoder: *const Base64Decoder, dest_writer: anytype, source_reader: anytype) !void {
var temp = [_]u8{0} ** 3;
var temp_source = [_]u8{0} ** 4;
while (true) {
const bytesRead = try source_reader.read(&temp_source);
if (bytesRead == 0) {
break;
}
const size = try decoder.calcSizeForSlice(temp_source[0..bytesRead]);
try decoder.decode(&temp, temp_source[0..bytesRead]);
try dest_writer.writeAll(temp[0..size]);
}
}
};

pub const Base64DecoderWithIgnore = struct {
Expand Down Expand Up @@ -332,54 +359,82 @@ pub const Base64DecoderWithIgnore = struct {
return result;
}

fn WindowWithIgnore(comptime ReaderType: type) type {
return struct {
const Self = @This();
const Err = ReaderType.NoEofError;

reader: ReaderType,
decoder: *const Base64DecoderWithIgnore,

pub fn init(reader: ReaderType, decoder: *const Base64DecoderWithIgnore) Self {
return .{ .reader = reader, .decoder = decoder };
}

pub fn next(self: *Self, buffer: []u8) Err![]u8 {
var size: usize = 0;
while (true) {
const byte = self.reader.readByte() catch |err| switch (err) {
Self.Err.EndOfStream => {
break;
},
else => return err,
};

if (self.decoder.char_is_ignored[byte]) {
continue;
}
buffer[size] = byte;
size += 1;
if (size == 4) {
break;
}
}
if (size == 0) {
return Self.Err.EndOfStream;
}
return buffer[0..size];
}
};
}

/// Invalid characters that are not ignored result in error.InvalidCharacter.
/// Invalid padding results in error.InvalidPadding.
/// Decoding more data than can fit in dest results in error.NoSpaceLeft. See also ::calcSizeUpperBound.
/// Returns the number of bytes written to dest.
pub fn decode(decoder_with_ignore: *const Base64DecoderWithIgnore, dest: []u8, source: []const u8) Error!usize {
const decoder = &decoder_with_ignore.decoder;
var acc: u12 = 0;
var acc_len: u4 = 0;
var dest_idx: usize = 0;
var leftover_idx: ?usize = null;
for (source, 0..) |c, src_idx| {
if (decoder_with_ignore.char_is_ignored[c]) continue;
const d = decoder.char_to_index[c];
if (d == Base64Decoder.invalid_char) {
if (decoder.pad_char == null or c != decoder.pad_char.?) return error.InvalidCharacter;
leftover_idx = src_idx;
break;
}
acc = (acc << 6) + d;
acc_len += 6;
if (acc_len >= 8) {
if (dest_idx == dest.len) return error.NoSpaceLeft;
acc_len -= 8;
dest[dest_idx] = @as(u8, @truncate(acc >> acc_len));
dest_idx += 1;
}
}
if (acc_len > 4 or (acc & (@as(u12, 1) << acc_len) - 1) != 0) {
return error.InvalidPadding;
}
const padding_len = acc_len / 2;
if (leftover_idx == null) {
if (decoder.pad_char != null and padding_len != 0) return error.InvalidPadding;
return dest_idx;
}
const leftover = source[leftover_idx.?..];
if (decoder.pad_char) |pad_char| {
var padding_chars: usize = 0;
for (leftover) |c| {
if (decoder_with_ignore.char_is_ignored[c]) continue;
if (c != pad_char) {
return if (c == Base64Decoder.invalid_char) error.InvalidCharacter else error.InvalidPadding;
}
padding_chars += 1;
}
if (padding_chars != padding_len) return error.InvalidPadding;
var sourceStream = std.io.fixedBufferStream(source);
const source_reader = sourceStream.reader();
var dest_stream = std.io.fixedBufferStream(dest);
const DestStreamType = @TypeOf(dest_stream);
const dest_writer = dest_stream.writer();
decoder_with_ignore.decodeFromReaderToWriter(dest_writer, source_reader) catch |err| switch (err) {
DestStreamType.WriteError.NoSpaceLeft => return error.NoSpaceLeft,
WindowWithIgnore(@TypeOf(source_reader)).Err.EndOfStream => unreachable,
error.InvalidCharacter, error.InvalidPadding => |e| return e,
};
return dest_stream.pos;
}

/// `dest_writer` must be compatible with `std.io.Writer`'s `writeAll` interface.
pub fn decodeWriter(decoder_with_ignore: *const Base64DecoderWithIgnore, dest_writer: anytype, source: []const u8) !void {
var stream = std.io.fixedBufferStream(source);
const reader = stream.reader();
return decoder_with_ignore.decodeFromReaderToWriter(dest_writer, reader);
}

/// `dest_writer` must be compatible with `std.io.Writer`'s `writeAll` interface.
/// `source_reader` must be compatible with `std.io.Reader`'s `readByte` interface.
pub fn decodeFromReaderToWriter(decoder_with_ignore: *const Base64DecoderWithIgnore, dest_writer: anytype, source_reader: anytype) !void {
var buffer = [_]u8{0} ** 4;
const WindowType = WindowWithIgnore(@TypeOf(source_reader));
var chunker = WindowType.init(source_reader, decoder_with_ignore);
while (chunker.next(&buffer)) |chunk| {
try decoder_with_ignore.decoder.decodeWriter(dest_writer, chunk);
} else |err| switch (err) {
WindowType.Err.EndOfStream => return,
else => return err,
}
return dest_idx;
}
};

Expand Down Expand Up @@ -523,20 +578,55 @@ fn testAllApis(codecs: Codecs, expected_decoded: []const u8, expected_encoded: [

// Base64Decoder
{
var buffer: [0x100]u8 = undefined;
const decoded = buffer[0..try codecs.Decoder.calcSizeForSlice(expected_encoded)];
try codecs.Decoder.decode(decoded, expected_encoded);
try testing.expectEqualSlices(u8, expected_decoded, decoded);
{
var buffer: [0x100]u8 = undefined;
const decoded = buffer[0..try codecs.Decoder.calcSizeForSlice(expected_encoded)];
try codecs.Decoder.decode(decoded, expected_encoded);
try testing.expectEqualSlices(u8, expected_decoded, decoded);
}

//stream version
{
var list = try std.BoundedArray(u8, 0x100).init(0);
try codecs.Decoder.decodeWriter(list.writer(), expected_encoded);
try testing.expectEqualSlices(u8, expected_decoded, list.slice());
}

// from reader to writer version
{
var list = try std.BoundedArray(u8, 0x100).init(0);
var stream = std.io.fixedBufferStream(expected_encoded);
try codecs.Decoder.decodeFromReaderToWriter(list.writer(), stream.reader());
try testing.expectEqualSlices(u8, expected_decoded, list.slice());
}
}

// Base64DecoderWithIgnore
{
const decoder_ignore_nothing = codecs.decoderWithIgnore("");
var buffer: [0x100]u8 = undefined;
const decoded = buffer[0..try decoder_ignore_nothing.calcSizeUpperBound(expected_encoded.len)];
const written = try decoder_ignore_nothing.decode(decoded, expected_encoded);
try testing.expect(written <= decoded.len);
try testing.expectEqualSlices(u8, expected_decoded, decoded[0..written]);

{
var buffer: [0x100]u8 = undefined;
const decoded = buffer[0..try decoder_ignore_nothing.calcSizeUpperBound(expected_encoded.len)];
const written = try decoder_ignore_nothing.decode(decoded, expected_encoded);
try testing.expect(written <= decoded.len);
try testing.expectEqualSlices(u8, expected_decoded, decoded[0..written]);
}

//stream version
{
var list = try std.BoundedArray(u8, 0x100).init(0);
try decoder_ignore_nothing.decodeWriter(list.writer(), expected_encoded);
try testing.expectEqualSlices(u8, expected_decoded, list.slice());
}

// from reader to writer
{
var list = try std.BoundedArray(u8, 0x100).init(0);
var stream = std.io.fixedBufferStream(expected_encoded);
try decoder_ignore_nothing.decodeFromReaderToWriter(list.writer(), stream.reader());
try testing.expectEqualSlices(u8, expected_decoded, list.slice());
}
}
}

Expand All @@ -546,6 +636,11 @@ fn testDecodeIgnoreSpace(codecs: Codecs, expected_decoded: []const u8, encoded:
const decoded = buffer[0..try decoder_ignore_space.calcSizeUpperBound(encoded.len)];
const written = try decoder_ignore_space.decode(decoded, encoded);
try testing.expectEqualSlices(u8, expected_decoded, decoded[0..written]);

//stream version
var list = try std.BoundedArray(u8, 0x100).init(0);
try decoder_ignore_space.decodeWriter(list.writer(), encoded);
try testing.expectEqualSlices(u8, expected_decoded, list.slice());
}

fn testError(codecs: Codecs, encoded: []const u8, expected_err: anyerror) !void {
Expand Down
Loading