Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: impl DataType::Float32/64 #240

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ aws = ["fusio-dispatch/aws", "fusio/aws"]
bench = ["redb", "rocksdb", "sled"]
bytes = ["dep:bytes"]
datafusion = ["dep:async-trait", "dep:datafusion"]
default = ["aws", "bytes", "tokio", "tokio-http"]
default = ["aws", "bytes", "float", "tokio", "tokio-http"]
float = ["dep:ordered-float"]
load_tbl = []
object-store = ["fusio/object_store"]
opfs = [
Expand All @@ -42,7 +43,7 @@ wasm = ["aws", "bytes", "opfs"]

[[example]]
name = "declare"
required-features = ["bytes", "tokio"]
required-features = ["bytes", "float", "tokio"]

[[example]]
name = "datafusion"
Expand Down Expand Up @@ -90,6 +91,7 @@ futures-io = "0.3"
futures-util = "0.3"
lockable = "0.1.1"
once_cell = "1"
ordered-float = { version = "4", optional = true }
parquet = { version = "53", default-features = false, features = [
"async",
"base64",
Expand Down
6 changes: 5 additions & 1 deletion examples/declare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::ops::Bound;
use bytes::Bytes;
use fusio::path::Path;
use futures_util::stream::StreamExt;
use ordered_float::OrderedFloat;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to re-export this type

use tokio::fs;
use tonbo::{executor::tokio::TokioExecutor, DbOption, Projection, Record, DB};

Expand All @@ -15,6 +16,7 @@ pub struct User {
email: Option<String>,
age: u8,
bytes: Bytes,
float: Option<OrderedFloat<f32>>,
}

#[tokio::main]
Expand All @@ -32,6 +34,7 @@ async fn main() {
email: Some("[email protected]".into()),
age: 22,
bytes: Bytes::from(vec![0, 1, 2]),
float: Some(OrderedFloat(1.1)),
})
.await
.unwrap();
Expand Down Expand Up @@ -61,7 +64,7 @@ async fn main() {
let mut scan = txn
.scan((Bound::Included(&name), Bound::Excluded(&upper)))
// tonbo supports pushing down projection
.projection(vec![1, 3])
.projection(vec![1, 3, 4])
// push down limitation
.limit(1)
.take()
Expand All @@ -75,6 +78,7 @@ async fn main() {
email: Some("[email protected]"),
age: None,
bytes: Some(&[0, 1, 2]),
float: Some(OrderedFloat(1.1)),
})
);
}
Expand Down
2 changes: 2 additions & 0 deletions src/serdes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ mod bytes;
mod list;
mod num;
pub(crate) mod option;
#[cfg(feature = "float")]
mod ordered_float;
mod string;

use std::future::Future;
Expand Down
10 changes: 10 additions & 0 deletions src/serdes/num.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ implement_encode_decode!(u8);
implement_encode_decode!(u16);
implement_encode_decode!(u32);
implement_encode_decode!(u64);
implement_encode_decode!(f32);
implement_encode_decode!(f64);

#[cfg(test)]
mod tests {
Expand All @@ -63,6 +65,8 @@ mod tests {
let source_5 = 16i16;
let source_6 = 32i32;
let source_7 = 64i64;
let source_8 = 32f32;
let source_9 = 64f64;

let mut bytes = Vec::new();
let mut cursor = Cursor::new(&mut bytes);
Expand All @@ -75,6 +79,8 @@ mod tests {
source_5.encode(&mut cursor).await.unwrap();
source_6.encode(&mut cursor).await.unwrap();
source_7.encode(&mut cursor).await.unwrap();
source_8.encode(&mut cursor).await.unwrap();
source_9.encode(&mut cursor).await.unwrap();

cursor.seek(std::io::SeekFrom::Start(0)).await.unwrap();
let decoded_0 = u8::decode(&mut cursor).await.unwrap();
Expand All @@ -85,6 +91,8 @@ mod tests {
let decoded_5 = i16::decode(&mut cursor).await.unwrap();
let decoded_6 = i32::decode(&mut cursor).await.unwrap();
let decoded_7 = i64::decode(&mut cursor).await.unwrap();
let decoded_8 = f32::decode(&mut cursor).await.unwrap();
let decoded_9 = f64::decode(&mut cursor).await.unwrap();

assert_eq!(source_0, decoded_0);
assert_eq!(source_1, decoded_1);
Expand All @@ -94,5 +102,7 @@ mod tests {
assert_eq!(source_5, decoded_5);
assert_eq!(source_6, decoded_6);
assert_eq!(source_7, decoded_7);
assert_eq!(source_8, decoded_8);
assert_eq!(source_9, decoded_9);
}
}
65 changes: 65 additions & 0 deletions src/serdes/ordered_float.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use fusio::{SeqRead, Write};
use ordered_float::OrderedFloat;

use crate::serdes::{Decode, Encode};

impl<T> Decode for OrderedFloat<T>
where
T: Decode + ordered_float::FloatCore,
{
type Error = T::Error;

async fn decode<R>(reader: &mut R) -> Result<Self, Self::Error>
where
R: SeqRead,
{
Ok(OrderedFloat::from(T::decode(reader).await?))
}
}

impl<T> Encode for OrderedFloat<T>
where
T: Encode + Send + Sync,
{
type Error = T::Error;

async fn encode<W>(&self, writer: &mut W) -> Result<(), Self::Error>
where
W: Write,
{
self.0.encode(writer).await
}

fn size(&self) -> usize {
Encode::size(&self.0)
}
}

#[cfg(test)]
mod tests {
use std::io::Cursor;

use ordered_float::OrderedFloat;
use tokio::io::AsyncSeekExt;

use crate::serdes::{Decode, Encode};

#[tokio::test]
async fn test_encode_decode() {
let source_0 = OrderedFloat(32f32);
let source_1 = OrderedFloat(64f64);

let mut bytes = Vec::new();
let mut cursor = Cursor::new(&mut bytes);

source_0.encode(&mut cursor).await.unwrap();
source_1.encode(&mut cursor).await.unwrap();

cursor.seek(std::io::SeekFrom::Start(0)).await.unwrap();
let decoded_0 = OrderedFloat::<f32>::decode(&mut cursor).await.unwrap();
let decoded_1 = OrderedFloat::<f64>::decode(&mut cursor).await.unwrap();

assert_eq!(source_0, decoded_0);
assert_eq!(source_1, decoded_1);
}
}
77 changes: 77 additions & 0 deletions tonbo_macros/src/data_type.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use proc_macro2::Ident;
use quote::quote;
use syn::{GenericArgument, Type};

pub(crate) enum DataType {
UInt8,
Expand All @@ -10,6 +11,8 @@ pub(crate) enum DataType {
Int16,
Int32,
Int64,
Float32,
Float64,
String,
Boolean,
Bytes,
Expand All @@ -33,6 +36,20 @@ impl DataType {
DataType::Int32
} else if path.is_ident("i64") {
DataType::Int64
} else if path.segments[0].ident == "OrderedFloat" {
if let syn::PathArguments::AngleBracketed(ref generic_args) = path.segments[0].arguments
{
if generic_args.args.len() == 1 {
if let GenericArgument::Type(Type::Path(type_path)) = &generic_args.args[0] {
if type_path.path.is_ident("f32") {
return DataType::Float32;
} else if type_path.path.is_ident("f64") {
return DataType::Float64;
}
}
}
}
unreachable!("only f32/f64 is allowed in `OrderFloat`");
} else if path.is_ident("String") {
DataType::String
} else if path.is_ident("bool") {
Expand Down Expand Up @@ -76,6 +93,12 @@ impl DataType {
DataType::Boolean => {
quote!(bool)
}
DataType::Float32 => {
quote!(ordered_float::OrderedFloat::<f32>)
}
DataType::Float64 => {
quote!(ordered_float::OrderedFloat::<f64>)
}
DataType::Bytes => {
quote!(bytes::Bytes)
}
Expand Down Expand Up @@ -108,6 +131,12 @@ impl DataType {
DataType::Int64 => {
quote!(::tonbo::arrow::datatypes::DataType::Int64)
}
DataType::Float32 => {
quote!(::tonbo::arrow::datatypes::DataType::Float32)
}
DataType::Float64 => {
quote!(::tonbo::arrow::datatypes::DataType::Float64)
}
DataType::String => {
quote!(::tonbo::arrow::datatypes::DataType::Utf8)
}
Expand Down Expand Up @@ -146,6 +175,12 @@ impl DataType {
DataType::Int64 => {
quote!(::tonbo::arrow::array::Int64Array)
}
DataType::Float32 => {
quote!(::tonbo::arrow::array::Float32Array)
}
DataType::Float64 => {
quote!(::tonbo::arrow::array::Float64Array)
}
DataType::String => {
quote!(::tonbo::arrow::array::StringArray)
}
Expand Down Expand Up @@ -188,6 +223,12 @@ impl DataType {
DataType::Int64 => {
quote!(as_primitive::<::tonbo::arrow::datatypes::Int64Type>())
}
DataType::Float32 => {
quote!(as_primitive::<::tonbo::arrow::datatypes::Float32Type>())
}
DataType::Float64 => {
quote!(as_primitive::<::tonbo::arrow::datatypes::Float64Type>())
}
DataType::String => {
quote!(as_string::<i32>())
}
Expand Down Expand Up @@ -242,6 +283,16 @@ impl DataType {
::tonbo::arrow::datatypes::Int64Type,
>::with_capacity(capacity))
}
DataType::Float32 => {
quote!(::tonbo::arrow::array::PrimitiveBuilder::<
::tonbo::arrow::datatypes::Float32Type,
>::with_capacity(capacity))
}
DataType::Float64 => {
quote!(::tonbo::arrow::array::PrimitiveBuilder::<
::tonbo::arrow::datatypes::Float64Type,
>::with_capacity(capacity))
}
DataType::String => {
quote!(::tonbo::arrow::array::StringBuilder::with_capacity(
capacity, 0
Expand Down Expand Up @@ -318,6 +369,20 @@ impl DataType {
>
)
}
DataType::Float32 => {
quote!(
::tonbo::arrow::array::PrimitiveBuilder<
::tonbo::arrow::datatypes::Float32Type,
>
)
}
DataType::Float64 => {
quote!(
::tonbo::arrow::array::PrimitiveBuilder<
::tonbo::arrow::datatypes::Float64Type,
>
)
}
DataType::String => {
quote!(::tonbo::arrow::array::StringBuilder)
}
Expand Down Expand Up @@ -359,6 +424,12 @@ impl DataType {
DataType::Int64 => {
quote!(std::mem::size_of_val(self.#field_name.values_slice()))
}
DataType::Float32 => {
quote!(std::mem::size_of_val(self.#field_name.values_slice()))
}
DataType::Float64 => {
quote!(std::mem::size_of_val(self.#field_name.values_slice()))
}
DataType::String => {
quote!(self.#field_name.values_slice().len())
}
Expand Down Expand Up @@ -400,6 +471,12 @@ impl DataType {
DataType::Int64 => {
quote! {std::mem::size_of::<i64>()}
}
DataType::Float32 => {
quote! {std::mem::size_of::<f32>()}
}
DataType::Float64 => {
quote! {std::mem::size_of::<f64>()}
}
DataType::String => {
if is_nullable {
quote!(self.#field_name.as_ref().map(String::len).unwrap_or(0))
Expand Down
Loading
Loading