Skip to content

Commit

Permalink
refactor: impl ReferenceSerialization
Browse files Browse the repository at this point in the history
  • Loading branch information
KKould committed Oct 16, 2024
1 parent 770b505 commit e589f08
Show file tree
Hide file tree
Showing 84 changed files with 1,616 additions and 1,867 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ regex = { version = "1" }
rocksdb = { version = "0.22.0" }
rust_decimal = { version = "1" }
serde = { version = "1", features = ["derive", "rc"] }
serde_macros = { path = "serde_macros" }
siphasher = { version = "1", features = ["serde"] }
sqlparser = { version = "0.34", features = ["serde"] }
strum_macros = { version = "0.26.2" }
Expand All @@ -83,4 +84,4 @@ pprof = { version = "0.13", features = ["flamegraph", "criterion"] }
members = [
"tests/sqllogictest",
"tests/macros-test"
]
, "serde_macros"]
14 changes: 14 additions & 0 deletions serde_macros/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "serde_macros"
version = "0.1.0"
edition = "2021"

[dependencies]
darling = "0.20"
proc-macro2 = "1"
quote = "1"
syn = "2"

[lib]
path = "src/lib.rs"
proc-macro = true
15 changes: 15 additions & 0 deletions serde_macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
mod reference_serialization;

use proc_macro::TokenStream;
use syn::{parse_macro_input, DeriveInput};

#[proc_macro_derive(ReferenceSerialization, attributes(reference_serialization))]
pub fn reference_serialization(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as DeriveInput);

let result = reference_serialization::handle(ast);
match result {
Ok(codegen) => codegen.into(),
Err(e) => e.to_compile_error().into(),
}
}
205 changes: 205 additions & 0 deletions serde_macros/src/reference_serialization.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
use darling::ast::Data;
use darling::{FromDeriveInput, FromField, FromVariant};
use proc_macro2::{Ident, Span, TokenStream};
use quote::quote;
use syn::{
AngleBracketedGenericArguments, DeriveInput, Error, GenericArgument, PathArguments, Type,
TypePath,
};

#[derive(Debug, FromDeriveInput)]
#[darling(attributes(record))]
struct SerializationOpts {
ident: Ident,
data: Data<SerializationVariantOpts, SerializationFieldOpt>,
}

#[derive(Debug, FromVariant)]
#[darling(attributes(record))]
struct SerializationVariantOpts {
ident: Ident,
fields: darling::ast::Fields<SerializationFieldOpt>,
}

#[derive(Debug, FromField)]
#[darling(attributes(record))]
struct SerializationFieldOpt {
ident: Option<Ident>,
ty: Type,
}

fn process_type(ty: &Type) -> TokenStream {
if let Type::Path(TypePath { path, .. }) = ty {
let ident = &path.segments.last().unwrap().ident;

match ident.to_string().as_str() {
"Vec" | "Option" | "Arc" | "Box" | "PhantomData" | "Bound" | "CountMinSketch" => {
if let PathArguments::AngleBracketed(AngleBracketedGenericArguments {
args, ..
}) = &path.segments.last().unwrap().arguments
{
if let Some(GenericArgument::Type(inner_ty)) = args.first() {
let inner_processed = process_type(inner_ty);

return quote! {
#ident::<#inner_processed>
};
}
}
}
_ => {}
}

quote! { #ty }
} else {
quote! { #ty }
}
}

pub(crate) fn handle(ast: DeriveInput) -> Result<TokenStream, Error> {
let record_opts: SerializationOpts = SerializationOpts::from_derive_input(&ast)?;
let struct_name = &record_opts.ident;

Ok(match record_opts.data {
Data::Struct(data_struct) => {
let mut encode_fields: Vec<TokenStream> = Vec::new();
let mut decode_fields: Vec<TokenStream> = Vec::new();
let mut init_fields: Vec<TokenStream> = Vec::new();
let mut is_tuple = false;

for (i, field_opts) in data_struct.fields.into_iter().enumerate() {
is_tuple = is_tuple || field_opts.ident.is_none();

let field_name = field_opts
.ident
.unwrap_or_else(|| Ident::new(&format!("filed_{}", i), Span::call_site()));
let ty = process_type(&field_opts.ty);

encode_fields.push(quote! {
#field_name.encode(writer, is_direct, reference_tables)?;
});
decode_fields.push(quote! {
let #field_name = #ty::decode(reader, drive, reference_tables)?;
});
init_fields.push(quote! {
#field_name,
})
}
let init_stream = if is_tuple {
quote! { #struct_name ( #(#init_fields)* ) }
} else {
quote! { #struct_name { #(#init_fields)* } }
};

quote! {
impl crate::serdes::ReferenceSerialization for #struct_name {
fn encode<W: std::io::Write>(
&self,
writer: &mut W,
is_direct: bool,
reference_tables: &mut crate::serdes::ReferenceTables,
) -> Result<(), crate::errors::DatabaseError> {
let #init_stream = self;

#(#encode_fields)*

Ok(())
}

fn decode<T: crate::storage::Transaction, R: std::io::Read>(
reader: &mut R,
drive: Option<(&T, &crate::storage::TableCache)>,
reference_tables: &crate::serdes::ReferenceTables,
) -> Result<Self, crate::errors::DatabaseError> {
#(#decode_fields)*

Ok(#init_stream)
}
}
}
}
Data::Enum(data_enum) => {
let mut variant_encode_fields: Vec<TokenStream> = Vec::new();
let mut variant_decode_fields: Vec<TokenStream> = Vec::new();

for (i, variant_opts) in data_enum.into_iter().enumerate() {
let i = i as u8;
let mut encode_fields: Vec<TokenStream> = Vec::new();
let mut decode_fields: Vec<TokenStream> = Vec::new();
let mut init_fields: Vec<TokenStream> = Vec::new();
let enum_name = variant_opts.ident;
let mut is_tuple = false;

for (i, field_opts) in variant_opts.fields.into_iter().enumerate() {
is_tuple = is_tuple || field_opts.ident.is_none();

let field_name = field_opts
.ident
.unwrap_or_else(|| Ident::new(&format!("filed_{}", i), Span::call_site()));
let ty = process_type(&field_opts.ty);

encode_fields.push(quote! {
#field_name.encode(writer, is_direct, reference_tables)?;
});
decode_fields.push(quote! {
let #field_name = #ty::decode(reader, drive, reference_tables)?;
});
init_fields.push(quote! {
#field_name,
})
}

let init_stream = if is_tuple {
quote! { #struct_name::#enum_name ( #(#init_fields)* ) }
} else {
quote! { #struct_name::#enum_name { #(#init_fields)* } }
};
variant_encode_fields.push(quote! {
#init_stream => {
std::io::Write::write_all(writer, &[#i])?;

#(#encode_fields)*
}
});
variant_decode_fields.push(quote! {
#i => {
#(#decode_fields)*

#init_stream
}
});
}

quote! {
impl crate::serdes::ReferenceSerialization for #struct_name {
fn encode<W: std::io::Write>(
&self,
writer: &mut W,
is_direct: bool,
reference_tables: &mut crate::serdes::ReferenceTables,
) -> Result<(), crate::errors::DatabaseError> {
match self {
#(#variant_encode_fields)*
}

Ok(())
}

fn decode<T: crate::storage::Transaction, R: std::io::Read>(
reader: &mut R,
drive: Option<(&T, &crate::storage::TableCache)>,
reference_tables: &crate::serdes::ReferenceTables,
) -> Result<Self, crate::errors::DatabaseError> {
let mut type_bytes = [0u8; 1];
std::io::Read::read_exact(reader, &mut type_bytes)?;

Ok(match type_bytes[0] {
#(#variant_decode_fields)*
_ => unreachable!(),
})
}
}
}
}
})
}
30 changes: 26 additions & 4 deletions src/binder/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,45 @@ use std::path::PathBuf;
use std::str::FromStr;
use std::sync::Arc;

use super::*;
use crate::errors::DatabaseError;
use crate::planner::operator::copy_from_file::CopyFromFileOperator;
use crate::planner::operator::copy_to_file::CopyToFileOperator;
use crate::planner::operator::Operator;
use serde::{Deserialize, Serialize};
use serde_macros::ReferenceSerialization;
use sqlparser::ast::{CopyOption, CopySource, CopyTarget};

use super::*;

#[derive(Debug, PartialEq, PartialOrd, Ord, Hash, Eq, Clone, Serialize, Deserialize)]
#[derive(
Debug,
PartialEq,
PartialOrd,
Ord,
Hash,
Eq,
Clone,
Serialize,
Deserialize,
ReferenceSerialization,
)]
pub struct ExtSource {
pub path: PathBuf,
pub format: FileFormat,
}

/// File format.
#[derive(Debug, PartialEq, PartialOrd, Ord, Hash, Eq, Clone, Serialize, Deserialize)]
#[derive(
Debug,
PartialEq,
PartialOrd,
Ord,
Hash,
Eq,
Clone,
Serialize,
Deserialize,
ReferenceSerialization,
)]
pub enum FileFormat {
Csv {
/// Delimiter to parse.
Expand Down
12 changes: 6 additions & 6 deletions src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ use std::slice;
use std::sync::Arc;

use super::{lower_ident, Binder, BinderContext, QueryBindStep, SubQueryType};
use crate::expression::function::scala::ScalarFunction;
use crate::expression::function::table::TableFunction;
use crate::expression::function::scala::{ArcScalarFunctionImpl, ScalarFunction};
use crate::expression::function::table::{ArcTableFunctionImpl, TableFunction};
use crate::expression::function::FunctionSummary;
use crate::expression::{AliasType, ScalarExpression};
use crate::planner::LogicalPlan;
Expand Down Expand Up @@ -235,7 +235,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {

let alias_expr = ScalarExpression::Alias {
expr: Box::new(ScalarExpression::ColumnRef(column)),
alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(Arc::new(
alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(ColumnRef::from(
alias_column,
)))),
};
Expand All @@ -246,7 +246,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
fn bind_subquery(
&mut self,
subquery: &Query,
) -> Result<(LogicalPlan, Arc<ColumnCatalog>), DatabaseError> {
) -> Result<(LogicalPlan, ColumnRef), DatabaseError> {
let BinderContext {
table_cache,
transaction,
Expand Down Expand Up @@ -601,13 +601,13 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
if let Some(function) = self.context.scala_functions.get(&summary) {
return Ok(ScalarExpression::ScalaFunction(ScalarFunction {
args,
inner: function.clone(),
inner: ArcScalarFunctionImpl(function.clone()),
}));
}
if let Some(function) = self.context.table_functions.get(&summary) {
return Ok(ScalarExpression::TableFunction(TableFunction {
args,
inner: function.clone(),
inner: ArcTableFunctionImpl(function.clone()),
}));
}

Expand Down
6 changes: 3 additions & 3 deletions src/binder/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::{

use super::{lower_case_name, lower_ident, Binder, BinderContext, QueryBindStep, SubQueryType};

use crate::catalog::{ColumnCatalog, ColumnSummary, TableName};
use crate::catalog::{ColumnCatalog, ColumnRef, ColumnSummary, TableName};
use crate::errors::DatabaseError;
use crate::execution::dql::join::joins_nullable;
use crate::expression::{AliasType, BinaryOperator};
Expand Down Expand Up @@ -356,7 +356,7 @@ impl<'a: 'b, 'b, T: Transaction> Binder<'a, 'b, T> {

let alias_column_expr = ScalarExpression::Alias {
expr: Box::new(ScalarExpression::ColumnRef(column)),
alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(Arc::new(
alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(ColumnRef::from(
alias_column,
)))),
};
Expand Down Expand Up @@ -736,7 +736,7 @@ impl<'a: 'b, 'b, T: Transaction> Binder<'a, 'b, T> {
let mut new_col = ColumnCatalog::clone(col);
new_col.nullable = *nullable;

*col = Arc::new(new_col);
*col = ColumnRef::from(new_col);
});
}
}
Expand Down
Loading

0 comments on commit e589f08

Please sign in to comment.