Skip to content

Commit

Permalink
feat: impl DataType::Float32/64
Browse files Browse the repository at this point in the history
  • Loading branch information
KKould committed Nov 25, 2024
1 parent b509436 commit 8e8ffcc
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 7 deletions.
6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ aws = ["fusio-dispatch/aws", "fusio/aws"]
bench = ["redb", "rocksdb", "sled"]
bytes = ["dep:bytes"]
datafusion = ["dep:async-trait", "dep:datafusion"]
default = ["aws", "bytes", "tokio", "tokio-http"]
default = ["aws", "bytes", "float", "tokio", "tokio-http"]
float = ["dep:ordered-float"]
load_tbl = []
object-store = ["fusio/object_store"]
opfs = [
Expand All @@ -42,7 +43,7 @@ wasm = ["aws", "bytes", "opfs"]

[[example]]
name = "declare"
required-features = ["bytes", "tokio"]
required-features = ["bytes", "float", "tokio"]

[[example]]
name = "datafusion"
Expand Down Expand Up @@ -90,6 +91,7 @@ futures-io = "0.3"
futures-util = "0.3"
lockable = "0.1.1"
once_cell = "1"
ordered-float = { version = "4", optional = true }
parquet = { version = "53", default-features = false, features = [
"async",
"base64",
Expand Down
6 changes: 5 additions & 1 deletion examples/declare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::ops::Bound;
use bytes::Bytes;
use fusio::path::Path;
use futures_util::stream::StreamExt;
use ordered_float::OrderedFloat;
use tokio::fs;
use tonbo::{executor::tokio::TokioExecutor, DbOption, Projection, Record, DB};

Expand All @@ -15,6 +16,7 @@ pub struct User {
email: Option<String>,
age: u8,
bytes: Bytes,
float: Option<OrderedFloat<f32>>,
}

#[tokio::main]
Expand All @@ -32,6 +34,7 @@ async fn main() {
email: Some("[email protected]".into()),
age: 22,
bytes: Bytes::from(vec![0, 1, 2]),
float: Some(OrderedFloat(1.1)),
})
.await
.unwrap();
Expand Down Expand Up @@ -61,7 +64,7 @@ async fn main() {
let mut scan = txn
.scan((Bound::Included(&name), Bound::Excluded(&upper)))
// tonbo supports pushing down projection
.projection(vec![1, 3])
.projection(vec![1, 3, 4])
// push down limitation
.limit(1)
.take()
Expand All @@ -75,6 +78,7 @@ async fn main() {
email: Some("[email protected]"),
age: None,
bytes: Some(&[0, 1, 2]),
float: Some(OrderedFloat(1.1)),
})
);
}
Expand Down
2 changes: 2 additions & 0 deletions src/serdes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ mod bytes;
mod list;
mod num;
pub(crate) mod option;
#[cfg(feature = "float")]
mod ordered_float;
mod string;

use std::future::Future;
Expand Down
10 changes: 10 additions & 0 deletions src/serdes/num.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ implement_encode_decode!(u8);
implement_encode_decode!(u16);
implement_encode_decode!(u32);
implement_encode_decode!(u64);
implement_encode_decode!(f32);
implement_encode_decode!(f64);

#[cfg(test)]
mod tests {
Expand All @@ -63,6 +65,8 @@ mod tests {
let source_5 = 16i16;
let source_6 = 32i32;
let source_7 = 64i64;
let source_8 = 32f32;
let source_9 = 64f64;

let mut bytes = Vec::new();
let mut cursor = Cursor::new(&mut bytes);
Expand All @@ -75,6 +79,8 @@ mod tests {
source_5.encode(&mut cursor).await.unwrap();
source_6.encode(&mut cursor).await.unwrap();
source_7.encode(&mut cursor).await.unwrap();
source_8.encode(&mut cursor).await.unwrap();
source_9.encode(&mut cursor).await.unwrap();

cursor.seek(std::io::SeekFrom::Start(0)).await.unwrap();
let decoded_0 = u8::decode(&mut cursor).await.unwrap();
Expand All @@ -85,6 +91,8 @@ mod tests {
let decoded_5 = i16::decode(&mut cursor).await.unwrap();
let decoded_6 = i32::decode(&mut cursor).await.unwrap();
let decoded_7 = i64::decode(&mut cursor).await.unwrap();
let decoded_8 = f32::decode(&mut cursor).await.unwrap();
let decoded_9 = f64::decode(&mut cursor).await.unwrap();

assert_eq!(source_0, decoded_0);
assert_eq!(source_1, decoded_1);
Expand All @@ -94,5 +102,7 @@ mod tests {
assert_eq!(source_5, decoded_5);
assert_eq!(source_6, decoded_6);
assert_eq!(source_7, decoded_7);
assert_eq!(source_8, decoded_8);
assert_eq!(source_9, decoded_9);
}
}
65 changes: 65 additions & 0 deletions src/serdes/ordered_float.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use fusio::{SeqRead, Write};
use ordered_float::OrderedFloat;

use crate::serdes::{Decode, Encode};

impl<T> Decode for OrderedFloat<T>
where
T: Decode + ordered_float::FloatCore,
{
type Error = T::Error;

async fn decode<R>(reader: &mut R) -> Result<Self, Self::Error>
where
R: SeqRead,
{
Ok(OrderedFloat::from(T::decode(reader).await?))
}
}

impl<T> Encode for OrderedFloat<T>
where
T: Encode + Send + Sync,
{
type Error = T::Error;

async fn encode<W>(&self, writer: &mut W) -> Result<(), Self::Error>
where
W: Write,
{
self.0.encode(writer).await
}

fn size(&self) -> usize {
Encode::size(&self.0)
}
}

#[cfg(test)]
mod tests {
use std::io::Cursor;

use ordered_float::OrderedFloat;
use tokio::io::AsyncSeekExt;

use crate::serdes::{Decode, Encode};

#[tokio::test]
async fn test_encode_decode() {
let source_0 = OrderedFloat(32f32);
let source_1 = OrderedFloat(64f64);

let mut bytes = Vec::new();
let mut cursor = Cursor::new(&mut bytes);

source_0.encode(&mut cursor).await.unwrap();
source_1.encode(&mut cursor).await.unwrap();

cursor.seek(std::io::SeekFrom::Start(0)).await.unwrap();
let decoded_0 = OrderedFloat::<f32>::decode(&mut cursor).await.unwrap();
let decoded_1 = OrderedFloat::<f64>::decode(&mut cursor).await.unwrap();

assert_eq!(source_0, decoded_0);
assert_eq!(source_1, decoded_1);
}
}
77 changes: 77 additions & 0 deletions tonbo_macros/src/data_type.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use proc_macro2::Ident;
use quote::quote;
use syn::{GenericArgument, Type};

pub(crate) enum DataType {
UInt8,
Expand All @@ -10,6 +11,8 @@ pub(crate) enum DataType {
Int16,
Int32,
Int64,
Float32,
Float64,
String,
Boolean,
Bytes,
Expand All @@ -33,6 +36,20 @@ impl DataType {
DataType::Int32
} else if path.is_ident("i64") {
DataType::Int64
} else if path.segments[0].ident == "OrderedFloat" {
if let syn::PathArguments::AngleBracketed(ref generic_args) = path.segments[0].arguments
{
if generic_args.args.len() == 1 {
if let GenericArgument::Type(Type::Path(type_path)) = &generic_args.args[0] {
if type_path.path.is_ident("f32") {
return DataType::Float32;
} else if type_path.path.is_ident("f64") {
return DataType::Float64;
}
}
}
}
unreachable!("only f32/f64 is allowed in `OrderFloat`");
} else if path.is_ident("String") {
DataType::String
} else if path.is_ident("bool") {
Expand Down Expand Up @@ -76,6 +93,12 @@ impl DataType {
DataType::Boolean => {
quote!(bool)
}
DataType::Float32 => {
quote!(ordered_float::OrderedFloat::<f32>)
}
DataType::Float64 => {
quote!(ordered_float::OrderedFloat::<f64>)
}
DataType::Bytes => {
quote!(bytes::Bytes)
}
Expand Down Expand Up @@ -108,6 +131,12 @@ impl DataType {
DataType::Int64 => {
quote!(::tonbo::arrow::datatypes::DataType::Int64)
}
DataType::Float32 => {
quote!(::tonbo::arrow::datatypes::DataType::Float32)
}
DataType::Float64 => {
quote!(::tonbo::arrow::datatypes::DataType::Float64)
}
DataType::String => {
quote!(::tonbo::arrow::datatypes::DataType::Utf8)
}
Expand Down Expand Up @@ -146,6 +175,12 @@ impl DataType {
DataType::Int64 => {
quote!(::tonbo::arrow::array::Int64Array)
}
DataType::Float32 => {
quote!(::tonbo::arrow::array::Float32Array)
}
DataType::Float64 => {
quote!(::tonbo::arrow::array::Float64Array)
}
DataType::String => {
quote!(::tonbo::arrow::array::StringArray)
}
Expand Down Expand Up @@ -188,6 +223,12 @@ impl DataType {
DataType::Int64 => {
quote!(as_primitive::<::tonbo::arrow::datatypes::Int64Type>())
}
DataType::Float32 => {
quote!(as_primitive::<::tonbo::arrow::datatypes::Float32Type>())
}
DataType::Float64 => {
quote!(as_primitive::<::tonbo::arrow::datatypes::Float64Type>())
}
DataType::String => {
quote!(as_string::<i32>())
}
Expand Down Expand Up @@ -242,6 +283,16 @@ impl DataType {
::tonbo::arrow::datatypes::Int64Type,
>::with_capacity(capacity))
}
DataType::Float32 => {
quote!(::tonbo::arrow::array::PrimitiveBuilder::<
::tonbo::arrow::datatypes::Float32Type,
>::with_capacity(capacity))
}
DataType::Float64 => {
quote!(::tonbo::arrow::array::PrimitiveBuilder::<
::tonbo::arrow::datatypes::Float64Type,
>::with_capacity(capacity))
}
DataType::String => {
quote!(::tonbo::arrow::array::StringBuilder::with_capacity(
capacity, 0
Expand Down Expand Up @@ -318,6 +369,20 @@ impl DataType {
>
)
}
DataType::Float32 => {
quote!(
::tonbo::arrow::array::PrimitiveBuilder<
::tonbo::arrow::datatypes::Float32Type,
>
)
}
DataType::Float64 => {
quote!(
::tonbo::arrow::array::PrimitiveBuilder<
::tonbo::arrow::datatypes::Float64Type,
>
)
}
DataType::String => {
quote!(::tonbo::arrow::array::StringBuilder)
}
Expand Down Expand Up @@ -359,6 +424,12 @@ impl DataType {
DataType::Int64 => {
quote!(std::mem::size_of_val(self.#field_name.values_slice()))
}
DataType::Float32 => {
quote!(std::mem::size_of_val(self.#field_name.values_slice()))
}
DataType::Float64 => {
quote!(std::mem::size_of_val(self.#field_name.values_slice()))
}
DataType::String => {
quote!(self.#field_name.values_slice().len())
}
Expand Down Expand Up @@ -400,6 +471,12 @@ impl DataType {
DataType::Int64 => {
quote! {std::mem::size_of::<i64>()}
}
DataType::Float32 => {
quote! {std::mem::size_of::<f32>()}
}
DataType::Float64 => {
quote! {std::mem::size_of::<f64>()}
}
DataType::String => {
if is_nullable {
quote!(self.#field_name.as_ref().map(String::len).unwrap_or(0))
Expand Down
Loading

0 comments on commit 8e8ffcc

Please sign in to comment.