Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cgen, checker: fix generic variable resolution on generic func return assignment #21712

Merged
merged 11 commits into from
Jun 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vlib/v/ast/ast.v
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,7 @@ pub mut:
receiver_type Type // User / T, if receiver is generic, then cgen requires receiver_type to be T
receiver_concrete_type Type // if receiver_type is T, then receiver_concrete_type is concrete type, otherwise it is the same as receiver_type
return_type Type
return_type_generic Type // the original generic return type from fn def
fn_var_type Type // the fn type, when `is_fn_a_const` or `is_fn_var` is true
const_name string // the fully qualified name of the const, i.e. `main.c`, given `const c = abc`, and callexpr: `c()`
should_be_skipped bool // true for calls to `[if someflag?]` functions, when there is no `-d someflag`
Expand Down Expand Up @@ -825,6 +826,7 @@ pub enum ComptimeVarKind {
value_var // map value from `for k,v in t.$(field.name)`
field_var // comptime field var `a := t.$(field.name)`
generic_param // generic fn parameter
generic_var // generic var
smartcast // smart cast when used in `is v` (when `v` is from $for .variants)
}

Expand Down
10 changes: 10 additions & 0 deletions vlib/v/checker/assign.v
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,16 @@ fn (mut c Checker) assign_stmt(mut node ast.AssignStmt) {
&& right.expr is ast.ComptimeSelector {
left.obj.ct_type_var = .field_var
left.obj.typ = c.comptime.comptime_for_field_type
} else if mut right is ast.CallExpr {
if left.obj.ct_type_var == .no_comptime
&& c.table.cur_fn != unsafe { nil }
&& c.table.cur_fn.generic_names.len != 0
&& !right.comptime_ret_val
&& right.return_type_generic.has_flag(.generic)
&& c.is_generic_expr(right) {
// mark variable as generic var because its type changes according to fn return generic resolution type
left.obj.ct_type_var = .generic_var
}
}
}
ast.GlobalField {
Expand Down
29 changes: 28 additions & 1 deletion vlib/v/checker/fn.v
Original file line number Diff line number Diff line change
Expand Up @@ -1509,6 +1509,9 @@ fn (mut c Checker) fn_call(mut node ast.CallExpr, mut continue_check &bool) ast.
} else {
node.return_type = func.return_type
}
if func.return_type.has_flag(.generic) {
node.return_type_generic = func.return_type
}
if node.concrete_types.len > 0 && func.return_type != 0 && c.table.cur_fn != unsafe { nil }
&& c.table.cur_fn.generic_names.len == 0 {
if typ := c.table.resolve_generic_to_concrete(func.return_type, func.generic_names,
Expand Down Expand Up @@ -1582,6 +1585,27 @@ fn (mut c Checker) register_trace_call(node ast.CallExpr, func ast.Fn) {
}
}

// is_generic_expr checks if the expr relies on fn generic argument
fn (mut c Checker) is_generic_expr(node ast.Expr) bool {
return match node {
ast.Ident {
c.comptime.is_generic_param_var(node)
}
ast.IndexExpr {
c.comptime.is_generic_param_var(node.left)
}
ast.CallExpr {
node.args.any(c.comptime.is_generic_param_var(it.expr))
}
ast.SelectorExpr {
c.comptime.is_generic_param_var(node.expr)
}
else {
false
}
}
}

fn (mut c Checker) resolve_comptime_args(func ast.Fn, node_ ast.CallExpr, concrete_types []ast.Type) map[int]ast.Type {
mut comptime_args := map[int]ast.Type{}
has_dynamic_vars := (c.table.cur_fn != unsafe { nil } && c.table.cur_fn.generic_names.len > 0)
Expand All @@ -1602,7 +1626,7 @@ fn (mut c Checker) resolve_comptime_args(func ast.Fn, node_ ast.CallExpr, concre
param_typ := param.typ
if call_arg.expr is ast.Ident {
if call_arg.expr.obj is ast.Var {
if call_arg.expr.obj.ct_type_var !in [.generic_param, .no_comptime] {
if call_arg.expr.obj.ct_type_var !in [.generic_var, .generic_param, .no_comptime] {
mut ctyp := c.comptime.get_comptime_var_type(call_arg.expr)
if ctyp != ast.void_type {
arg_sym := c.table.sym(ctyp)
Expand Down Expand Up @@ -2159,6 +2183,9 @@ fn (mut c Checker) method_call(mut node ast.CallExpr) ast.Type {
node.is_noreturn = method.is_noreturn
node.is_ctor_new = method.is_ctor_new
node.return_type = method.return_type
if method.return_type.has_flag(.generic) {
node.return_type_generic = method.return_type
}
if !method.is_pub && method.mod != c.mod {
// If a private method is called outside of the module
// its receiver type is defined in, show an error.
Expand Down
62 changes: 39 additions & 23 deletions vlib/v/comptime/comptimeinfo.v
Original file line number Diff line number Diff line change
Expand Up @@ -44,33 +44,49 @@ pub fn (mut ct ComptimeInfo) get_ct_type_var(node ast.Expr) ast.ComptimeVarKind
}
}

@[inline]
pub fn (mut ct ComptimeInfo) is_generic_param_var(node ast.Expr) bool {
return node is ast.Ident && node.info is ast.IdentVar && node.obj is ast.Var
&& (node.obj as ast.Var).ct_type_var == .generic_param
}

// get_comptime_var_type retrieves the actual type from a comptime related ast node
@[inline]
pub fn (mut ct ComptimeInfo) get_comptime_var_type(node ast.Expr) ast.Type {
if node is ast.Ident && node.obj is ast.Var {
return match (node.obj as ast.Var).ct_type_var {
.generic_param {
// generic parameter from current function
node.obj.typ
}
.smartcast {
ctyp := ct.type_map['${ct.comptime_for_variant_var}.typ'] or { node.obj.typ }
return if (node.obj as ast.Var).is_unwrapped {
ctyp.clear_flag(.option)
} else {
ctyp
if node is ast.Ident {
if node.obj is ast.Var {
return match node.obj.ct_type_var {
.generic_param {
// generic parameter from current function
node.obj.typ
}
.generic_var {
// generic var used on fn call assignment
if node.obj.smartcasts.len > 0 {
node.obj.smartcasts.last()
} else {
ct.type_map['g.${node.name}.${node.obj.pos.pos}'] or { node.obj.typ }
}
}
.smartcast {
ctyp := ct.type_map['${ct.comptime_for_variant_var}.typ'] or { node.obj.typ }
return if (node.obj as ast.Var).is_unwrapped {
ctyp.clear_flag(.option)
} else {
ctyp
}
}
.key_var, .value_var {
// key and value variables from normal for stmt
ct.type_map[node.name] or { ast.void_type }
}
.field_var {
// field var from $for loop
ct.comptime_for_field_type
}
else {
ast.void_type
}
}
.key_var, .value_var {
// key and value variables from normal for stmt
ct.type_map[node.name] or { ast.void_type }
}
.field_var {
// field var from $for loop
ct.comptime_for_field_type
}
else {
ast.void_type
}
}
} else if node is ast.ComptimeSelector {
Expand Down
13 changes: 12 additions & 1 deletion vlib/v/gen/c/assign.v
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ fn (mut g Gen) assign_stmt(node_ ast.AssignStmt) {
}
g.assign_ct_type = var_type
} else if val is ast.IndexExpr {
if val.left is ast.Ident && g.is_generic_param_var(val.left) {
if val.left is ast.Ident && g.comptime.is_generic_param_var(val.left) {
ctyp := g.unwrap_generic(g.get_gn_var_type(val.left))
if ctyp != ast.void_type {
var_type = ctyp
Expand All @@ -293,6 +293,17 @@ fn (mut g Gen) assign_stmt(node_ ast.AssignStmt) {
g.assign_ct_type = var_type
}
}
} else if left.obj.ct_type_var == .generic_var && val is ast.CallExpr {
if val.return_type_generic != 0 && val.return_type_generic.has_flag(.generic) {
fn_ret_type := g.resolve_fn_return_type(val)
if fn_ret_type != ast.void_type {
var_type = fn_ret_type
val_type = var_type
left.obj.typ = var_type
g.comptime.type_map['g.${left.name}.${left.obj.pos.pos}'] = var_type
// eprintln('>> ${func.name} > resolve ${left.name}.${left.obj.pos.pos}.generic to ${g.table.type_to_str(var_type)}')
}
}
}
is_auto_heap = left.obj.is_auto_heap
}
Expand Down
6 changes: 0 additions & 6 deletions vlib/v/gen/c/cgen.v
Original file line number Diff line number Diff line change
Expand Up @@ -4721,12 +4721,6 @@ fn (mut g Gen) select_expr(node ast.SelectExpr) {
}
}

@[inline]
pub fn (mut g Gen) is_generic_param_var(node ast.Expr) bool {
return node is ast.Ident && node.info is ast.IdentVar && node.obj is ast.Var
&& (node.obj as ast.Var).ct_type_var == .generic_param
}

fn (mut g Gen) get_const_name(node ast.Ident) string {
if g.pref.translated && !g.is_builtin_mod
&& !util.module_is_builtin(node.name.all_before_last('.')) {
Expand Down
91 changes: 81 additions & 10 deletions vlib/v/gen/c/fn.v
Original file line number Diff line number Diff line change
Expand Up @@ -1105,10 +1105,6 @@ fn (mut g Gen) gen_to_str_method_call(node ast.CallExpr) bool {
rec_type = g.comptime.get_comptime_var_type(left_node)
g.gen_expr_to_string(left_node, rec_type)
return true
} else if g.comptime.type_map.len > 0 {
rec_type = left_node.obj.typ
g.gen_expr_to_string(left_node, rec_type)
return true
} else if left_node.obj.smartcasts.len > 0 {
rec_type = g.unwrap_generic(left_node.obj.smartcasts.last())
cast_sym := g.table.sym(rec_type)
Expand Down Expand Up @@ -1154,6 +1150,79 @@ fn (mut g Gen) get_gn_var_type(var ast.Ident) ast.Type {
return ast.void_type
}

// resolve_fn_return_type resolves the generic return type of fn
fn (mut g Gen) resolve_fn_return_type(node ast.CallExpr) ast.Type {
if node.is_method {
if func := g.table.find_method(g.table.sym(node.left_type), node.name) {
if func.generic_names.len > 0 {
mut concrete_types := node.concrete_types.map(g.unwrap_generic(it))
mut rec_len := 0
if node.left_type.has_flag(.generic) {
rec_sym := g.table.final_sym(g.unwrap_generic(node.left_type))
match rec_sym.info {
ast.Struct, ast.Interface, ast.SumType {
rec_len += rec_sym.info.generic_types.len
}
else {}
}
}

mut call_ := unsafe { node }
comptime_args := g.resolve_comptime_args(func, mut call_, concrete_types)
if concrete_types.len > 0 {
for k, v in comptime_args {
if (rec_len + k) < concrete_types.len {
if !node.concrete_types[k].has_flag(.generic) {
concrete_types[rec_len + k] = g.unwrap_generic(v)
}
}
}
}
if gen_type := g.table.resolve_generic_to_concrete(node.return_type_generic,
func.generic_names, concrete_types)
{
if !gen_type.has_flag(.generic) {
return if node.or_block.kind == .absent {
gen_type
} else {
gen_type.clear_option_and_result()
}
}
}
}
}
} else {
if func := g.table.find_fn(node.name) {
if func.generic_names.len > 0 {
mut concrete_types := node.concrete_types.map(g.unwrap_generic(it))
mut call_ := unsafe { node }
comptime_args := g.resolve_comptime_args(func, mut call_, concrete_types)
if concrete_types.len > 0 {
for k, v in comptime_args {
if k < concrete_types.len {
if !node.concrete_types[k].has_flag(.generic) {
concrete_types[k] = g.unwrap_generic(v)
}
}
}
}
if gen_type := g.table.resolve_generic_to_concrete(node.return_type_generic,
func.generic_names, concrete_types)
{
if !gen_type.has_flag(.generic) {
return if node.or_block.kind == .absent {
gen_type
} else {
gen_type.clear_option_and_result()
}
}
}
}
}
}
return ast.void_type
}

fn (g Gen) get_generic_array_element_type(array ast.Array) ast.Type {
mut cparam_elem_info := array as ast.Array
mut cparam_elem_sym := g.table.sym(cparam_elem_info.elem_type)
Expand Down Expand Up @@ -1194,7 +1263,7 @@ fn (mut g Gen) resolve_comptime_args(func ast.Fn, mut node_ ast.CallExpr, concre
if mut call_arg.expr is ast.Ident {
if mut call_arg.expr.obj is ast.Var {
node_.args[i].typ = call_arg.expr.obj.typ
if call_arg.expr.obj.ct_type_var !in [.generic_param, .no_comptime] {
if call_arg.expr.obj.ct_type_var !in [.generic_var, .generic_param, .no_comptime] {
mut ctyp := g.comptime.get_comptime_var_type(call_arg.expr)
if ctyp != ast.void_type {
arg_sym := g.table.sym(ctyp)
Expand Down Expand Up @@ -1293,11 +1362,13 @@ fn (mut g Gen) resolve_comptime_args(func ast.Fn, mut node_ ast.CallExpr, concre
comptime_args[k] = comptime_args[k].set_nr_muls(0)
}
} else if mut call_arg.expr.right is ast.Ident {
mut ctyp := g.comptime.get_comptime_var_type(call_arg.expr.right)
if ctyp != ast.void_type {
comptime_args[k] = ctyp
if param_typ.nr_muls() > 0 && comptime_args[k].nr_muls() > 0 {
comptime_args[k] = comptime_args[k].set_nr_muls(0)
if g.comptime.get_ct_type_var(call_arg.expr.right) != .generic_var {
mut ctyp := g.comptime.get_comptime_var_type(call_arg.expr.right)
if ctyp != ast.void_type {
comptime_args[k] = ctyp
if param_typ.nr_muls() > 0 && comptime_args[k].nr_muls() > 0 {
comptime_args[k] = comptime_args[k].set_nr_muls(0)
}
}
}
}
Expand Down
8 changes: 6 additions & 2 deletions vlib/v/gen/c/for.v
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,13 @@ fn (mut g Gen) for_in_stmt(node_ ast.ForInStmt) {
g.writeln('${t_expr});')
g.writeln('\tif (${t_var}.state != 0) break;')
val := if node.val_var in ['', '_'] { g.new_tmp_var() } else { node.val_var }
val_styp := g.typ(node.val_type)
val_styp := g.typ(ret_typ.clear_option_and_result())
if node.val_is_mut {
g.writeln('\t${val_styp} ${val} = (${val_styp})${t_var}.data;')
if ret_typ.has_flag(.option) {
g.writeln('\t${val_styp}* ${val} = (${val_styp}*)${t_var}.data;')
} else {
g.writeln('\t${val_styp} ${val} = (${val_styp})${t_var}.data;')
}
} else {
g.writeln('\t${val_styp} ${val} = *(${val_styp}*)${t_var}.data;')
}
Expand Down
4 changes: 1 addition & 3 deletions vlib/v/gen/c/str_intp.v
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,7 @@ fn (mut g Gen) str_val(node ast.StringInterLiteral, i int, fmts []u8) {
mut exp_typ := typ
if expr is ast.Ident {
if expr.obj is ast.Var {
if g.comptime.type_map.len > 0 || g.comptime.comptime_for_method.len > 0 {
exp_typ = expr.obj.typ
} else if expr.obj.smartcasts.len > 0 {
if expr.obj.smartcasts.len > 0 {
exp_typ = g.unwrap_generic(expr.obj.smartcasts.last())
cast_sym := g.table.sym(exp_typ)
if cast_sym.info is ast.Aggregate {
Expand Down
43 changes: 43 additions & 0 deletions vlib/v/tests/generic_return_test.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
fn mkey[K, V](m map[K]V) K {
return K{}
}

fn mvalue[K, V](m map[K]V) V {
return V{}
}

fn aelem[E](a []E) E {
return E{}
}

fn g[T](x T) {
$if T is $map {
dk := mkey(x)
dv := mvalue(x)
eprintln('default k: `${dk}` | typeof dk: ${typeof(dk).name}')
eprintln('default v: `${dv}` | typeof dv: ${typeof(dv).name}')
for k, v in x {
eprintln('> k: ${k} | v: ${v}')
}
}
$if T is $array {
de := aelem(x)
eprintln('default e: `${de}` | typeof de: ${typeof(de).name}')
for idx, e in x {
eprintln('> idx: ${idx} | e: ${e}')
}
}
}

fn test_main() {
g({
'abc': 123
'def': 456
})
g([1, 2, 3])
g({
123: 'ggg'
456: 'hhh'
})
g(['xyz', 'zzz'])
}
Loading