Skip to content

Commit

Permalink
Merge pull request #79 from lambdaclass/add_imports
Browse files Browse the repository at this point in the history
Add imports
  • Loading branch information
igaray authored Jan 22, 2024
2 parents 4f68ec5 + 68423e9 commit 4d5f4a8
Show file tree
Hide file tree
Showing 12 changed files with 277 additions and 35 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/concrete_ast/src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::{
pub struct ConstantDecl {
pub doc_string: Option<DocString>,
pub name: Ident,
pub is_pub: bool,
pub r#type: TypeSpec,
}

Expand Down
3 changes: 2 additions & 1 deletion crates/concrete_ast/src/modules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub struct Module {
pub enum ModuleDefItem {
Constant(ConstantDef),
Function(FunctionDef),
Record(StructDecl),
Struct(StructDecl),
Type(TypeDecl),
Module(Module),
}
14 changes: 7 additions & 7 deletions crates/concrete_ast/src/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ use crate::{

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct StructDecl {
doc_string: Option<DocString>,
name: Ident,
type_params: Vec<GenericParam>,
fields: Vec<Field>,
pub doc_string: Option<DocString>,
pub name: Ident,
pub type_params: Vec<GenericParam>,
pub fields: Vec<Field>,
}

#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct Field {
doc_string: Option<DocString>,
name: Ident,
r#type: TypeSpec,
pub doc_string: Option<DocString>,
pub name: Ident,
pub r#type: TypeSpec,
}
4 changes: 3 additions & 1 deletion crates/concrete_ast/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::common::{DocString, Ident};
use crate::common::{DocString, Ident, Span};

#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum TypeSpec {
Expand All @@ -8,11 +8,13 @@ pub enum TypeSpec {
Generic {
name: Ident,
type_params: Vec<TypeSpec>,
span: Span,
},
}

#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct TypeDecl {
pub doc_string: Option<DocString>,
pub name: Ident,
pub value: TypeSpec,
}
1 change: 1 addition & 0 deletions crates/concrete_codegen_mlir/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ edition = "2021"
bumpalo = { version = "3.14.0", features = ["std"] }
concrete_ast = { path = "../concrete_ast"}
concrete_session = { path = "../concrete_session"}
itertools = "0.12.0"
llvm-sys = "170.0.1"
melior = { version = "0.15.0", features = ["ods-dialects"] }
mlir-sys = "0.2.1"
Expand Down
127 changes: 127 additions & 0 deletions crates/concrete_codegen_mlir/src/ast_helper.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
use std::collections::HashMap;

use concrete_ast::{
common::Ident,
constants::ConstantDef,
functions::FunctionDef,
modules::{Module, ModuleDefItem},
structs::StructDecl,
types::TypeDecl,
Program,
};

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ModuleInfo<'p> {
pub name: String,
pub functions: HashMap<String, &'p FunctionDef>,
pub constants: HashMap<String, &'p ConstantDef>,
pub structs: HashMap<String, &'p StructDecl>,
pub types: HashMap<String, &'p TypeDecl>,
pub modules: HashMap<String, ModuleInfo<'p>>,
}

impl<'p> ModuleInfo<'p> {
pub fn get_module_from_import(&self, import: &[Ident]) -> Option<&ModuleInfo<'p>> {
let next = import.first()?;
let module = self.modules.get(&next.name)?;

if import.len() > 1 {
module.get_module_from_import(&import[1..])
} else {
Some(module)
}
}

/// Returns the symbol name from a local name.
pub fn get_symbol_name(&self, local_name: &str) -> String {
if local_name == "main" {
return local_name.to_string();
}

let mut result = self.name.clone();

result.push_str("::");
result.push_str(local_name);

result
}
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct AstHelper<'p> {
pub root: &'p Program,
pub modules: HashMap<String, ModuleInfo<'p>>,
}

impl<'p> AstHelper<'p> {
pub fn new(root: &'p Program) -> Self {
let mut modules = HashMap::default();

for module in &root.modules {
modules.insert(
module.name.name.clone(),
Self::create_module_info(module, None),
);
}

Self { root, modules }
}

pub fn get_module_from_import(&self, import: &[Ident]) -> Option<&ModuleInfo<'p>> {
let next = import.first()?;
let module = self.modules.get(&next.name)?;

if import.len() > 1 {
module.get_module_from_import(&import[1..])
} else {
Some(module)
}
}

fn create_module_info(module: &Module, parent_name: Option<String>) -> ModuleInfo<'_> {
let mut functions = HashMap::default();
let mut constants = HashMap::default();
let mut structs = HashMap::default();
let mut types = HashMap::default();
let mut child_modules = HashMap::default();
let mut name = parent_name.clone().unwrap_or_default();

if name.is_empty() {
name = module.name.name.clone();
} else {
name.push_str(&format!("::{}", module.name.name));
}

for stmt in &module.contents {
match stmt {
ModuleDefItem::Constant(info) => {
constants.insert(info.decl.name.name.clone(), info);
}
ModuleDefItem::Function(info) => {
functions.insert(info.decl.name.name.clone(), info);
}
ModuleDefItem::Struct(info) => {
structs.insert(info.name.name.clone(), info);
}
ModuleDefItem::Type(info) => {
types.insert(info.name.name.clone(), info);
}
ModuleDefItem::Module(info) => {
child_modules.insert(
info.name.name.clone(),
Self::create_module_info(info, Some(name.clone())),
);
}
}
}

ModuleInfo {
name,
functions,
structs,
constants,
types,
modules: child_modules,
}
}
}
100 changes: 78 additions & 22 deletions crates/concrete_codegen_mlir/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,21 @@ use melior::{
Context as MeliorContext,
};

use crate::ast_helper::{AstHelper, ModuleInfo};

pub fn compile_program(
session: &Session,
ctx: &MeliorContext,
mlir_module: &MeliorModule,
program: &Program,
) -> Result<(), Box<dyn Error>> {
let ast_helper = AstHelper::new(program);
for module in &program.modules {
compile_module(session, ctx, mlir_module, module)?;
let module_info = ast_helper
.modules
.get(&module.name.name)
.unwrap_or_else(|| panic!("module info not found for {}", module.name.name));
compile_module(session, ctx, mlir_module, &ast_helper, module_info, module)?;
}
Ok(())
}
Expand Down Expand Up @@ -67,8 +74,9 @@ impl<'ctx, 'parent: 'ctx> LocalVar<'ctx, 'parent> {
#[derive(Debug, Clone)]
struct ScopeContext<'ctx, 'parent: 'ctx> {
pub locals: HashMap<String, LocalVar<'ctx, 'parent>>,
pub functions: HashMap<String, FunctionDef>,
pub function: Option<FunctionDef>,
pub imports: HashMap<String, &'parent ModuleInfo<'parent>>,
pub module_info: &'parent ModuleInfo<'parent>,
}

struct BlockHelper<'ctx, 'region: 'ctx> {
Expand All @@ -87,6 +95,34 @@ impl<'ctx, 'region> BlockHelper<'ctx, 'region> {
}

impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> {
/// Returns the symbol name from a local name.
pub fn get_symbol_name(&self, local_name: &str) -> String {
if local_name == "main" {
return local_name.to_string();
}

if let Some(module) = self.imports.get(local_name) {
// a import
module.get_symbol_name(local_name)
} else {
let mut result = self.module_info.name.clone();

result.push_str("::");
result.push_str(local_name);

result
}
}

pub fn get_function(&self, local_name: &str) -> Option<&FunctionDef> {
if let Some(module) = self.imports.get(local_name) {
// a import
module.functions.get(local_name).copied()
} else {
self.module_info.functions.get(local_name).copied()
}
}

fn resolve_type(
&self,
context: &'ctx MeliorContext,
Expand All @@ -111,10 +147,7 @@ impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> {
) -> Result<Type<'ctx>, Box<dyn Error>> {
Ok(match spec {
TypeSpec::Simple { name } => self.resolve_type(context, &name.name)?,
TypeSpec::Generic {
name,
type_params: _,
} => self.resolve_type(context, &name.name)?,
TypeSpec::Generic { name, .. } => self.resolve_type(context, &name.name)?,
})
}

Expand All @@ -139,27 +172,38 @@ fn compile_module(
session: &Session,
context: &MeliorContext,
mlir_module: &MeliorModule,
ast_helper: &AstHelper<'_>,
module_info: &ModuleInfo<'_>,
module: &Module,
) -> Result<(), Box<dyn Error>> {
// todo: handle imports

let body = mlir_module.body();

let mut scope_ctx: ScopeContext = ScopeContext {
functions: Default::default(),
locals: Default::default(),
function: None,
};
let mut imports = HashMap::new();

// save all function signatures
for statement in &module.contents {
if let ModuleDefItem::Function(info) = statement {
scope_ctx
.functions
.insert(info.decl.name.name.clone(), info.clone());
for import in &module.imports {
let target_module = ast_helper
.get_module_from_import(&import.module)
.unwrap_or_else(|| {
panic!(
"failed to find import {:?} in module {}",
import, module.name.name
)
});

for symbol in &import.symbols {
imports.insert(symbol.name.clone(), target_module);
}
}

let scope_ctx: ScopeContext = ScopeContext {
locals: Default::default(),
function: None,
module_info,
imports,
};

for statement in &module.contents {
match statement {
ModuleDefItem::Constant(_) => todo!(),
Expand All @@ -169,8 +213,17 @@ fn compile_module(
let op = compile_function_def(session, context, &scope_ctx, info)?;
body.append_operation(op);
}
ModuleDefItem::Record(_) => todo!(),
ModuleDefItem::Struct(_) => todo!(),
ModuleDefItem::Type(_) => todo!(),
ModuleDefItem::Module(info) => {
let module_info = module_info.modules.get(&info.name.name).unwrap_or_else(|| {
panic!(
"submodule {} not found while compiling module {}",
info.name.name, module.name.name
)
});
compile_module(session, context, mlir_module, ast_helper, module_info, info)?;
}
}
}

Expand Down Expand Up @@ -251,9 +304,11 @@ fn compile_function_def<'ctx, 'parent: 'ctx>(
}
}

let fn_name = scope_ctx.get_symbol_name(&info.decl.name.name);

Ok(func::func(
context,
StringAttribute::new(context, &info.decl.name.name),
StringAttribute::new(context, &fn_name),
func_type,
region,
&[],
Expand Down Expand Up @@ -755,8 +810,7 @@ fn compile_fn_call<'ctx, 'parent: 'ctx>(
let location = get_location(context, session, info.target.span.from);

let target_fn = scope_ctx
.functions
.get(&info.target.name)
.get_function(&info.target.name)
.expect("function not found")
.clone();

Expand Down Expand Up @@ -785,10 +839,12 @@ fn compile_fn_call<'ctx, 'parent: 'ctx>(
vec![]
};

let fn_name = scope_ctx.get_symbol_name(&info.target.name);

Ok(block
.append_operation(func::call(
context,
FlatSymbolRefAttribute::new(context, &info.target.name),
FlatSymbolRefAttribute::new(context, &fn_name),
&args,
&return_type,
location,
Expand Down
1 change: 1 addition & 0 deletions crates/concrete_codegen_mlir/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use llvm_sys::{
};
use module::MLIRModule;

mod ast_helper;
mod codegen;
mod context;
mod error;
Expand Down
Loading

0 comments on commit 4d5f4a8

Please sign in to comment.