Skip to content

Commit

Permalink
cgen, checker: fix generic variable resolution on generic func return…
Browse files Browse the repository at this point in the history
… assignment (#21712)
  • Loading branch information
felipensp authored Jun 23, 2024
1 parent 53d7a55 commit cc14272
Show file tree
Hide file tree
Showing 10 changed files with 222 additions and 46 deletions.
2 changes: 2 additions & 0 deletions vlib/v/ast/ast.v
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,7 @@ pub mut:
receiver_type Type // User / T, if receiver is generic, then cgen requires receiver_type to be T
receiver_concrete_type Type // if receiver_type is T, then receiver_concrete_type is concrete type, otherwise it is the same as receiver_type
return_type Type
return_type_generic Type // the original generic return type from fn def
fn_var_type Type // the fn type, when `is_fn_a_const` or `is_fn_var` is true
const_name string // the fully qualified name of the const, i.e. `main.c`, given `const c = abc`, and callexpr: `c()`
should_be_skipped bool // true for calls to `[if someflag?]` functions, when there is no `-d someflag`
Expand Down Expand Up @@ -825,6 +826,7 @@ pub enum ComptimeVarKind {
value_var // map value from `for k,v in t.$(field.name)`
field_var // comptime field var `a := t.$(field.name)`
generic_param // generic fn parameter
generic_var // generic var
smartcast // smart cast when used in `is v` (when `v` is from $for .variants)
}

Expand Down
10 changes: 10 additions & 0 deletions vlib/v/checker/assign.v
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,16 @@ fn (mut c Checker) assign_stmt(mut node ast.AssignStmt) {
&& right.expr is ast.ComptimeSelector {
left.obj.ct_type_var = .field_var
left.obj.typ = c.comptime.comptime_for_field_type
} else if mut right is ast.CallExpr {
if left.obj.ct_type_var == .no_comptime
&& c.table.cur_fn != unsafe { nil }
&& c.table.cur_fn.generic_names.len != 0
&& !right.comptime_ret_val
&& right.return_type_generic.has_flag(.generic)
&& c.is_generic_expr(right) {
// mark variable as generic var because its type changes according to fn return generic resolution type
left.obj.ct_type_var = .generic_var
}
}
}
ast.GlobalField {
Expand Down
29 changes: 28 additions & 1 deletion vlib/v/checker/fn.v
Original file line number Diff line number Diff line change
Expand Up @@ -1509,6 +1509,9 @@ fn (mut c Checker) fn_call(mut node ast.CallExpr, mut continue_check &bool) ast.
} else {
node.return_type = func.return_type
}
if func.return_type.has_flag(.generic) {
node.return_type_generic = func.return_type
}
if node.concrete_types.len > 0 && func.return_type != 0 && c.table.cur_fn != unsafe { nil }
&& c.table.cur_fn.generic_names.len == 0 {
if typ := c.table.resolve_generic_to_concrete(func.return_type, func.generic_names,
Expand Down Expand Up @@ -1582,6 +1585,27 @@ fn (mut c Checker) register_trace_call(node ast.CallExpr, func ast.Fn) {
}
}

// is_generic_expr checks if the expr relies on fn generic argument
fn (mut c Checker) is_generic_expr(node ast.Expr) bool {
return match node {
ast.Ident {
c.comptime.is_generic_param_var(node)
}
ast.IndexExpr {
c.comptime.is_generic_param_var(node.left)
}
ast.CallExpr {
node.args.any(c.comptime.is_generic_param_var(it.expr))
}
ast.SelectorExpr {
c.comptime.is_generic_param_var(node.expr)
}
else {
false
}
}
}

fn (mut c Checker) resolve_comptime_args(func ast.Fn, node_ ast.CallExpr, concrete_types []ast.Type) map[int]ast.Type {
mut comptime_args := map[int]ast.Type{}
has_dynamic_vars := (c.table.cur_fn != unsafe { nil } && c.table.cur_fn.generic_names.len > 0)
Expand All @@ -1602,7 +1626,7 @@ fn (mut c Checker) resolve_comptime_args(func ast.Fn, node_ ast.CallExpr, concre
param_typ := param.typ
if call_arg.expr is ast.Ident {
if call_arg.expr.obj is ast.Var {
if call_arg.expr.obj.ct_type_var !in [.generic_param, .no_comptime] {
if call_arg.expr.obj.ct_type_var !in [.generic_var, .generic_param, .no_comptime] {
mut ctyp := c.comptime.get_comptime_var_type(call_arg.expr)
if ctyp != ast.void_type {
arg_sym := c.table.sym(ctyp)
Expand Down Expand Up @@ -2159,6 +2183,9 @@ fn (mut c Checker) method_call(mut node ast.CallExpr) ast.Type {
node.is_noreturn = method.is_noreturn
node.is_ctor_new = method.is_ctor_new
node.return_type = method.return_type
if method.return_type.has_flag(.generic) {
node.return_type_generic = method.return_type
}
if !method.is_pub && method.mod != c.mod {
// If a private method is called outside of the module
// its receiver type is defined in, show an error.
Expand Down
62 changes: 39 additions & 23 deletions vlib/v/comptime/comptimeinfo.v
Original file line number Diff line number Diff line change
Expand Up @@ -44,33 +44,49 @@ pub fn (mut ct ComptimeInfo) get_ct_type_var(node ast.Expr) ast.ComptimeVarKind
}
}

@[inline]
pub fn (mut ct ComptimeInfo) is_generic_param_var(node ast.Expr) bool {
return node is ast.Ident && node.info is ast.IdentVar && node.obj is ast.Var
&& (node.obj as ast.Var).ct_type_var == .generic_param
}

// get_comptime_var_type retrieves the actual type from a comptime related ast node
@[inline]
pub fn (mut ct ComptimeInfo) get_comptime_var_type(node ast.Expr) ast.Type {
if node is ast.Ident && node.obj is ast.Var {
return match (node.obj as ast.Var).ct_type_var {
.generic_param {
// generic parameter from current function
node.obj.typ
}
.smartcast {
ctyp := ct.type_map['${ct.comptime_for_variant_var}.typ'] or { node.obj.typ }
return if (node.obj as ast.Var).is_unwrapped {
ctyp.clear_flag(.option)
} else {
ctyp
if node is ast.Ident {
if node.obj is ast.Var {
return match node.obj.ct_type_var {
.generic_param {
// generic parameter from current function
node.obj.typ
}
.generic_var {
// generic var used on fn call assignment
if node.obj.smartcasts.len > 0 {
node.obj.smartcasts.last()
} else {
ct.type_map['g.${node.name}.${node.obj.pos.pos}'] or { node.obj.typ }
}
}
.smartcast {
ctyp := ct.type_map['${ct.comptime_for_variant_var}.typ'] or { node.obj.typ }
return if (node.obj as ast.Var).is_unwrapped {
ctyp.clear_flag(.option)
} else {
ctyp
}
}
.key_var, .value_var {
// key and value variables from normal for stmt
ct.type_map[node.name] or { ast.void_type }
}
.field_var {
// field var from $for loop
ct.comptime_for_field_type
}
else {
ast.void_type
}
}
.key_var, .value_var {
// key and value variables from normal for stmt
ct.type_map[node.name] or { ast.void_type }
}
.field_var {
// field var from $for loop
ct.comptime_for_field_type
}
else {
ast.void_type
}
}
} else if node is ast.ComptimeSelector {
Expand Down
13 changes: 12 additions & 1 deletion vlib/v/gen/c/assign.v
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ fn (mut g Gen) assign_stmt(node_ ast.AssignStmt) {
}
g.assign_ct_type = var_type
} else if val is ast.IndexExpr {
if val.left is ast.Ident && g.is_generic_param_var(val.left) {
if val.left is ast.Ident && g.comptime.is_generic_param_var(val.left) {
ctyp := g.unwrap_generic(g.get_gn_var_type(val.left))
if ctyp != ast.void_type {
var_type = ctyp
Expand All @@ -293,6 +293,17 @@ fn (mut g Gen) assign_stmt(node_ ast.AssignStmt) {
g.assign_ct_type = var_type
}
}
} else if left.obj.ct_type_var == .generic_var && val is ast.CallExpr {
if val.return_type_generic != 0 && val.return_type_generic.has_flag(.generic) {
fn_ret_type := g.resolve_fn_return_type(val)
if fn_ret_type != ast.void_type {
var_type = fn_ret_type
val_type = var_type
left.obj.typ = var_type
g.comptime.type_map['g.${left.name}.${left.obj.pos.pos}'] = var_type
// eprintln('>> ${func.name} > resolve ${left.name}.${left.obj.pos.pos}.generic to ${g.table.type_to_str(var_type)}')
}
}
}
is_auto_heap = left.obj.is_auto_heap
}
Expand Down
6 changes: 0 additions & 6 deletions vlib/v/gen/c/cgen.v
Original file line number Diff line number Diff line change
Expand Up @@ -4721,12 +4721,6 @@ fn (mut g Gen) select_expr(node ast.SelectExpr) {
}
}

@[inline]
pub fn (mut g Gen) is_generic_param_var(node ast.Expr) bool {
return node is ast.Ident && node.info is ast.IdentVar && node.obj is ast.Var
&& (node.obj as ast.Var).ct_type_var == .generic_param
}

fn (mut g Gen) get_const_name(node ast.Ident) string {
if g.pref.translated && !g.is_builtin_mod
&& !util.module_is_builtin(node.name.all_before_last('.')) {
Expand Down
91 changes: 81 additions & 10 deletions vlib/v/gen/c/fn.v
Original file line number Diff line number Diff line change
Expand Up @@ -1105,10 +1105,6 @@ fn (mut g Gen) gen_to_str_method_call(node ast.CallExpr) bool {
rec_type = g.comptime.get_comptime_var_type(left_node)
g.gen_expr_to_string(left_node, rec_type)
return true
} else if g.comptime.type_map.len > 0 {
rec_type = left_node.obj.typ
g.gen_expr_to_string(left_node, rec_type)
return true
} else if left_node.obj.smartcasts.len > 0 {
rec_type = g.unwrap_generic(left_node.obj.smartcasts.last())
cast_sym := g.table.sym(rec_type)
Expand Down Expand Up @@ -1154,6 +1150,79 @@ fn (mut g Gen) get_gn_var_type(var ast.Ident) ast.Type {
return ast.void_type
}

// resolve_fn_return_type resolves the generic return type of fn
fn (mut g Gen) resolve_fn_return_type(node ast.CallExpr) ast.Type {
if node.is_method {
if func := g.table.find_method(g.table.sym(node.left_type), node.name) {
if func.generic_names.len > 0 {
mut concrete_types := node.concrete_types.map(g.unwrap_generic(it))
mut rec_len := 0
if node.left_type.has_flag(.generic) {
rec_sym := g.table.final_sym(g.unwrap_generic(node.left_type))
match rec_sym.info {
ast.Struct, ast.Interface, ast.SumType {
rec_len += rec_sym.info.generic_types.len
}
else {}
}
}

mut call_ := unsafe { node }
comptime_args := g.resolve_comptime_args(func, mut call_, concrete_types)
if concrete_types.len > 0 {
for k, v in comptime_args {
if (rec_len + k) < concrete_types.len {
if !node.concrete_types[k].has_flag(.generic) {
concrete_types[rec_len + k] = g.unwrap_generic(v)
}
}
}
}
if gen_type := g.table.resolve_generic_to_concrete(node.return_type_generic,
func.generic_names, concrete_types)
{
if !gen_type.has_flag(.generic) {
return if node.or_block.kind == .absent {
gen_type
} else {
gen_type.clear_option_and_result()
}
}
}
}
}
} else {
if func := g.table.find_fn(node.name) {
if func.generic_names.len > 0 {
mut concrete_types := node.concrete_types.map(g.unwrap_generic(it))
mut call_ := unsafe { node }
comptime_args := g.resolve_comptime_args(func, mut call_, concrete_types)
if concrete_types.len > 0 {
for k, v in comptime_args {
if k < concrete_types.len {
if !node.concrete_types[k].has_flag(.generic) {
concrete_types[k] = g.unwrap_generic(v)
}
}
}
}
if gen_type := g.table.resolve_generic_to_concrete(node.return_type_generic,
func.generic_names, concrete_types)
{
if !gen_type.has_flag(.generic) {
return if node.or_block.kind == .absent {
gen_type
} else {
gen_type.clear_option_and_result()
}
}
}
}
}
}
return ast.void_type
}

fn (g Gen) get_generic_array_element_type(array ast.Array) ast.Type {
mut cparam_elem_info := array as ast.Array
mut cparam_elem_sym := g.table.sym(cparam_elem_info.elem_type)
Expand Down Expand Up @@ -1194,7 +1263,7 @@ fn (mut g Gen) resolve_comptime_args(func ast.Fn, mut node_ ast.CallExpr, concre
if mut call_arg.expr is ast.Ident {
if mut call_arg.expr.obj is ast.Var {
node_.args[i].typ = call_arg.expr.obj.typ
if call_arg.expr.obj.ct_type_var !in [.generic_param, .no_comptime] {
if call_arg.expr.obj.ct_type_var !in [.generic_var, .generic_param, .no_comptime] {
mut ctyp := g.comptime.get_comptime_var_type(call_arg.expr)
if ctyp != ast.void_type {
arg_sym := g.table.sym(ctyp)
Expand Down Expand Up @@ -1293,11 +1362,13 @@ fn (mut g Gen) resolve_comptime_args(func ast.Fn, mut node_ ast.CallExpr, concre
comptime_args[k] = comptime_args[k].set_nr_muls(0)
}
} else if mut call_arg.expr.right is ast.Ident {
mut ctyp := g.comptime.get_comptime_var_type(call_arg.expr.right)
if ctyp != ast.void_type {
comptime_args[k] = ctyp
if param_typ.nr_muls() > 0 && comptime_args[k].nr_muls() > 0 {
comptime_args[k] = comptime_args[k].set_nr_muls(0)
if g.comptime.get_ct_type_var(call_arg.expr.right) != .generic_var {
mut ctyp := g.comptime.get_comptime_var_type(call_arg.expr.right)
if ctyp != ast.void_type {
comptime_args[k] = ctyp
if param_typ.nr_muls() > 0 && comptime_args[k].nr_muls() > 0 {
comptime_args[k] = comptime_args[k].set_nr_muls(0)
}
}
}
}
Expand Down
8 changes: 6 additions & 2 deletions vlib/v/gen/c/for.v
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,13 @@ fn (mut g Gen) for_in_stmt(node_ ast.ForInStmt) {
g.writeln('${t_expr});')
g.writeln('\tif (${t_var}.state != 0) break;')
val := if node.val_var in ['', '_'] { g.new_tmp_var() } else { node.val_var }
val_styp := g.typ(node.val_type)
val_styp := g.typ(ret_typ.clear_option_and_result())
if node.val_is_mut {
g.writeln('\t${val_styp} ${val} = (${val_styp})${t_var}.data;')
if ret_typ.has_flag(.option) {
g.writeln('\t${val_styp}* ${val} = (${val_styp}*)${t_var}.data;')
} else {
g.writeln('\t${val_styp} ${val} = (${val_styp})${t_var}.data;')
}
} else {
g.writeln('\t${val_styp} ${val} = *(${val_styp}*)${t_var}.data;')
}
Expand Down
4 changes: 1 addition & 3 deletions vlib/v/gen/c/str_intp.v
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,7 @@ fn (mut g Gen) str_val(node ast.StringInterLiteral, i int, fmts []u8) {
mut exp_typ := typ
if expr is ast.Ident {
if expr.obj is ast.Var {
if g.comptime.type_map.len > 0 || g.comptime.comptime_for_method.len > 0 {
exp_typ = expr.obj.typ
} else if expr.obj.smartcasts.len > 0 {
if expr.obj.smartcasts.len > 0 {
exp_typ = g.unwrap_generic(expr.obj.smartcasts.last())
cast_sym := g.table.sym(exp_typ)
if cast_sym.info is ast.Aggregate {
Expand Down
43 changes: 43 additions & 0 deletions vlib/v/tests/generic_return_test.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
fn mkey[K, V](m map[K]V) K {
return K{}
}

fn mvalue[K, V](m map[K]V) V {
return V{}
}

fn aelem[E](a []E) E {
return E{}
}

fn g[T](x T) {
$if T is $map {
dk := mkey(x)
dv := mvalue(x)
eprintln('default k: `${dk}` | typeof dk: ${typeof(dk).name}')
eprintln('default v: `${dv}` | typeof dv: ${typeof(dv).name}')
for k, v in x {
eprintln('> k: ${k} | v: ${v}')
}
}
$if T is $array {
de := aelem(x)
eprintln('default e: `${de}` | typeof de: ${typeof(de).name}')
for idx, e in x {
eprintln('> idx: ${idx} | e: ${e}')
}
}
}

fn test_main() {
g({
'abc': 123
'def': 456
})
g([1, 2, 3])
g({
123: 'ggg'
456: 'hhh'
})
g(['xyz', 'zzz'])
}

0 comments on commit cc14272

Please sign in to comment.