Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add imports #79

Merged
merged 2 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/concrete_ast/src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::{
pub struct ConstantDecl {
pub doc_string: Option<DocString>,
pub name: Ident,
pub is_pub: bool,
pub r#type: TypeSpec,
}

Expand Down
3 changes: 2 additions & 1 deletion crates/concrete_ast/src/modules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub struct Module {
pub enum ModuleDefItem {
Constant(ConstantDef),
Function(FunctionDef),
Record(StructDecl),
Struct(StructDecl),
Type(TypeDecl),
Module(Module),
}
14 changes: 7 additions & 7 deletions crates/concrete_ast/src/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ use crate::{

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct StructDecl {
doc_string: Option<DocString>,
name: Ident,
type_params: Vec<GenericParam>,
fields: Vec<Field>,
pub doc_string: Option<DocString>,
pub name: Ident,
pub type_params: Vec<GenericParam>,
pub fields: Vec<Field>,
}

#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct Field {
doc_string: Option<DocString>,
name: Ident,
r#type: TypeSpec,
pub doc_string: Option<DocString>,
pub name: Ident,
pub r#type: TypeSpec,
}
4 changes: 3 additions & 1 deletion crates/concrete_ast/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::common::{DocString, Ident};
use crate::common::{DocString, Ident, Span};

#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum TypeSpec {
Expand All @@ -8,11 +8,13 @@ pub enum TypeSpec {
Generic {
name: Ident,
type_params: Vec<TypeSpec>,
span: Span,
},
}

#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct TypeDecl {
pub doc_string: Option<DocString>,
pub name: Ident,
pub value: TypeSpec,
}
1 change: 1 addition & 0 deletions crates/concrete_codegen_mlir/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ edition = "2021"
bumpalo = { version = "3.14.0", features = ["std"] }
concrete_ast = { path = "../concrete_ast"}
concrete_session = { path = "../concrete_session"}
itertools = "0.12.0"
llvm-sys = "170.0.1"
melior = { version = "0.15.0", features = ["ods-dialects"] }
mlir-sys = "0.2.1"
Expand Down
127 changes: 127 additions & 0 deletions crates/concrete_codegen_mlir/src/ast_helper.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
use std::collections::HashMap;

use concrete_ast::{
common::Ident,
constants::ConstantDef,
functions::FunctionDef,
modules::{Module, ModuleDefItem},
structs::StructDecl,
types::TypeDecl,
Program,
};

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ModuleInfo<'p> {
pub name: String,
pub functions: HashMap<String, &'p FunctionDef>,
pub constants: HashMap<String, &'p ConstantDef>,
pub structs: HashMap<String, &'p StructDecl>,
pub types: HashMap<String, &'p TypeDecl>,
pub modules: HashMap<String, ModuleInfo<'p>>,
}

impl<'p> ModuleInfo<'p> {
pub fn get_module_from_import(&self, import: &[Ident]) -> Option<&ModuleInfo<'p>> {
let next = import.first()?;
let module = self.modules.get(&next.name)?;

if import.len() > 1 {
module.get_module_from_import(&import[1..])
} else {
Some(module)
}
}

/// Returns the symbol name from a local name.
pub fn get_symbol_name(&self, local_name: &str) -> String {
if local_name == "main" {
return local_name.to_string();
}

let mut result = self.name.clone();

result.push_str("::");
result.push_str(local_name);

result
}
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct AstHelper<'p> {
pub root: &'p Program,
pub modules: HashMap<String, ModuleInfo<'p>>,
}

impl<'p> AstHelper<'p> {
pub fn new(root: &'p Program) -> Self {
let mut modules = HashMap::default();

for module in &root.modules {
modules.insert(
module.name.name.clone(),
Self::create_module_info(module, None),
);
}

Self { root, modules }
}

pub fn get_module_from_import(&self, import: &[Ident]) -> Option<&ModuleInfo<'p>> {
let next = import.first()?;
let module = self.modules.get(&next.name)?;

if import.len() > 1 {
module.get_module_from_import(&import[1..])
} else {
Some(module)
}
}

fn create_module_info(module: &Module, parent_name: Option<String>) -> ModuleInfo<'_> {
let mut functions = HashMap::default();
let mut constants = HashMap::default();
let mut structs = HashMap::default();
let mut types = HashMap::default();
let mut child_modules = HashMap::default();
let mut name = parent_name.clone().unwrap_or_default();

if name.is_empty() {
name = module.name.name.clone();
} else {
name.push_str(&format!("::{}", module.name.name));
}

for stmt in &module.contents {
match stmt {
ModuleDefItem::Constant(info) => {
constants.insert(info.decl.name.name.clone(), info);
}
ModuleDefItem::Function(info) => {
functions.insert(info.decl.name.name.clone(), info);
}
ModuleDefItem::Struct(info) => {
structs.insert(info.name.name.clone(), info);
}
ModuleDefItem::Type(info) => {
types.insert(info.name.name.clone(), info);
}
ModuleDefItem::Module(info) => {
child_modules.insert(
info.name.name.clone(),
Self::create_module_info(info, Some(name.clone())),
);
}
}
}

ModuleInfo {
name,
functions,
structs,
constants,
types,
modules: child_modules,
}
}
}
100 changes: 78 additions & 22 deletions crates/concrete_codegen_mlir/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,21 @@ use melior::{
Context as MeliorContext,
};

use crate::ast_helper::{AstHelper, ModuleInfo};

pub fn compile_program(
session: &Session,
ctx: &MeliorContext,
mlir_module: &MeliorModule,
program: &Program,
) -> Result<(), Box<dyn Error>> {
let ast_helper = AstHelper::new(program);
for module in &program.modules {
compile_module(session, ctx, mlir_module, module)?;
let module_info = ast_helper
.modules
.get(&module.name.name)
.unwrap_or_else(|| panic!("module info not found for {}", module.name.name));
compile_module(session, ctx, mlir_module, &ast_helper, module_info, module)?;
}
Ok(())
}
Expand Down Expand Up @@ -67,8 +74,9 @@ impl<'ctx, 'parent: 'ctx> LocalVar<'ctx, 'parent> {
#[derive(Debug, Clone)]
struct ScopeContext<'ctx, 'parent: 'ctx> {
pub locals: HashMap<String, LocalVar<'ctx, 'parent>>,
pub functions: HashMap<String, FunctionDef>,
pub function: Option<FunctionDef>,
pub imports: HashMap<String, &'parent ModuleInfo<'parent>>,
pub module_info: &'parent ModuleInfo<'parent>,
}

struct BlockHelper<'ctx, 'region: 'ctx> {
Expand All @@ -87,6 +95,34 @@ impl<'ctx, 'region> BlockHelper<'ctx, 'region> {
}

impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> {
/// Returns the symbol name from a local name.
pub fn get_symbol_name(&self, local_name: &str) -> String {
if local_name == "main" {
return local_name.to_string();
}

if let Some(module) = self.imports.get(local_name) {
// a import
module.get_symbol_name(local_name)
} else {
let mut result = self.module_info.name.clone();

result.push_str("::");
result.push_str(local_name);

result
}
}

pub fn get_function(&self, local_name: &str) -> Option<&FunctionDef> {
if let Some(module) = self.imports.get(local_name) {
// a import
module.functions.get(local_name).copied()
} else {
self.module_info.functions.get(local_name).copied()
}
}

fn resolve_type(
&self,
context: &'ctx MeliorContext,
Expand All @@ -111,10 +147,7 @@ impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> {
) -> Result<Type<'ctx>, Box<dyn Error>> {
Ok(match spec {
TypeSpec::Simple { name } => self.resolve_type(context, &name.name)?,
TypeSpec::Generic {
name,
type_params: _,
} => self.resolve_type(context, &name.name)?,
TypeSpec::Generic { name, .. } => self.resolve_type(context, &name.name)?,
})
}

Expand All @@ -139,27 +172,38 @@ fn compile_module(
session: &Session,
context: &MeliorContext,
mlir_module: &MeliorModule,
ast_helper: &AstHelper<'_>,
module_info: &ModuleInfo<'_>,
module: &Module,
) -> Result<(), Box<dyn Error>> {
// todo: handle imports

let body = mlir_module.body();

let mut scope_ctx: ScopeContext = ScopeContext {
functions: Default::default(),
locals: Default::default(),
function: None,
};
let mut imports = HashMap::new();

// save all function signatures
for statement in &module.contents {
if let ModuleDefItem::Function(info) = statement {
scope_ctx
.functions
.insert(info.decl.name.name.clone(), info.clone());
for import in &module.imports {
let target_module = ast_helper
.get_module_from_import(&import.module)
.unwrap_or_else(|| {
panic!(
"failed to find import {:?} in module {}",
import, module.name.name
)
});

for symbol in &import.symbols {
imports.insert(symbol.name.clone(), target_module);
}
}

let scope_ctx: ScopeContext = ScopeContext {
locals: Default::default(),
function: None,
module_info,
imports,
};

for statement in &module.contents {
match statement {
ModuleDefItem::Constant(_) => todo!(),
Expand All @@ -169,8 +213,17 @@ fn compile_module(
let op = compile_function_def(session, context, &scope_ctx, info)?;
body.append_operation(op);
}
ModuleDefItem::Record(_) => todo!(),
ModuleDefItem::Struct(_) => todo!(),
ModuleDefItem::Type(_) => todo!(),
ModuleDefItem::Module(info) => {
let module_info = module_info.modules.get(&info.name.name).unwrap_or_else(|| {
panic!(
"submodule {} not found while compiling module {}",
info.name.name, module.name.name
)
});
compile_module(session, context, mlir_module, ast_helper, module_info, info)?;
}
}
}

Expand Down Expand Up @@ -251,9 +304,11 @@ fn compile_function_def<'ctx, 'parent: 'ctx>(
}
}

let fn_name = scope_ctx.get_symbol_name(&info.decl.name.name);

Ok(func::func(
context,
StringAttribute::new(context, &info.decl.name.name),
StringAttribute::new(context, &fn_name),
func_type,
region,
&[],
Expand Down Expand Up @@ -755,8 +810,7 @@ fn compile_fn_call<'ctx, 'parent: 'ctx>(
let location = get_location(context, session, info.target.span.from);

let target_fn = scope_ctx
.functions
.get(&info.target.name)
.get_function(&info.target.name)
.expect("function not found")
.clone();

Expand Down Expand Up @@ -785,10 +839,12 @@ fn compile_fn_call<'ctx, 'parent: 'ctx>(
vec![]
};

let fn_name = scope_ctx.get_symbol_name(&info.target.name);

Ok(block
.append_operation(func::call(
context,
FlatSymbolRefAttribute::new(context, &info.target.name),
FlatSymbolRefAttribute::new(context, &fn_name),
&args,
&return_type,
location,
Expand Down
1 change: 1 addition & 0 deletions crates/concrete_codegen_mlir/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use llvm_sys::{
};
use module::MLIRModule;

mod ast_helper;
mod codegen;
mod context;
mod error;
Expand Down
Loading