diff --git a/vlib/v/ast/ast.v b/vlib/v/ast/ast.v index 6f0b6da3b8a11b..47b0c6a865f7e9 100644 --- a/vlib/v/ast/ast.v +++ b/vlib/v/ast/ast.v @@ -774,6 +774,7 @@ pub mut: receiver_type Type // User / T, if receiver is generic, then cgen requires receiver_type to be T receiver_concrete_type Type // if receiver_type is T, then receiver_concrete_type is concrete type, otherwise it is the same as receiver_type return_type Type + return_type_generic Type // the original generic return type from fn def fn_var_type Type // the fn type, when `is_fn_a_const` or `is_fn_var` is true const_name string // the fully qualified name of the const, i.e. `main.c`, given `const c = abc`, and callexpr: `c()` should_be_skipped bool // true for calls to `[if someflag?]` functions, when there is no `-d someflag` @@ -825,6 +826,7 @@ pub enum ComptimeVarKind { value_var // map value from `for k,v in t.$(field.name)` field_var // comptime field var `a := t.$(field.name)` generic_param // generic fn parameter + generic_var // generic var smartcast // smart cast when used in `is v` (when `v` is from $for .variants) } diff --git a/vlib/v/checker/assign.v b/vlib/v/checker/assign.v index 4d3fee67d17813..eb4c7347d175d9 100644 --- a/vlib/v/checker/assign.v +++ b/vlib/v/checker/assign.v @@ -397,6 +397,16 @@ fn (mut c Checker) assign_stmt(mut node ast.AssignStmt) { && right.expr is ast.ComptimeSelector { left.obj.ct_type_var = .field_var left.obj.typ = c.comptime.comptime_for_field_type + } else if mut right is ast.CallExpr { + if left.obj.ct_type_var == .no_comptime + && c.table.cur_fn != unsafe { nil } + && c.table.cur_fn.generic_names.len != 0 + && !right.comptime_ret_val + && right.return_type_generic.has_flag(.generic) + && c.is_generic_expr(right) { + // mark variable as generic var because its type changes according to fn return generic resolution type + left.obj.ct_type_var = .generic_var + } } } ast.GlobalField { diff --git a/vlib/v/checker/fn.v b/vlib/v/checker/fn.v index 20eff3b01d7bf0..f810342c2ff41a 100644 --- a/vlib/v/checker/fn.v +++ b/vlib/v/checker/fn.v @@ -1509,6 +1509,9 @@ fn (mut c Checker) fn_call(mut node ast.CallExpr, mut continue_check &bool) ast. } else { node.return_type = func.return_type } + if func.return_type.has_flag(.generic) { + node.return_type_generic = func.return_type + } if node.concrete_types.len > 0 && func.return_type != 0 && c.table.cur_fn != unsafe { nil } && c.table.cur_fn.generic_names.len == 0 { if typ := c.table.resolve_generic_to_concrete(func.return_type, func.generic_names, @@ -1582,6 +1585,27 @@ fn (mut c Checker) register_trace_call(node ast.CallExpr, func ast.Fn) { } } +// is_generic_expr checks if the expr relies on fn generic argument +fn (mut c Checker) is_generic_expr(node ast.Expr) bool { + return match node { + ast.Ident { + c.comptime.is_generic_param_var(node) + } + ast.IndexExpr { + c.comptime.is_generic_param_var(node.left) + } + ast.CallExpr { + node.args.any(c.comptime.is_generic_param_var(it.expr)) + } + ast.SelectorExpr { + c.comptime.is_generic_param_var(node.expr) + } + else { + false + } + } +} + fn (mut c Checker) resolve_comptime_args(func ast.Fn, node_ ast.CallExpr, concrete_types []ast.Type) map[int]ast.Type { mut comptime_args := map[int]ast.Type{} has_dynamic_vars := (c.table.cur_fn != unsafe { nil } && c.table.cur_fn.generic_names.len > 0) @@ -1602,7 +1626,7 @@ fn (mut c Checker) resolve_comptime_args(func ast.Fn, node_ ast.CallExpr, concre param_typ := param.typ if call_arg.expr is ast.Ident { if call_arg.expr.obj is ast.Var { - if call_arg.expr.obj.ct_type_var !in [.generic_param, .no_comptime] { + if call_arg.expr.obj.ct_type_var !in [.generic_var, .generic_param, .no_comptime] { mut ctyp := c.comptime.get_comptime_var_type(call_arg.expr) if ctyp != ast.void_type { arg_sym := c.table.sym(ctyp) @@ -2159,6 +2183,9 @@ fn (mut c Checker) method_call(mut node ast.CallExpr) ast.Type { node.is_noreturn = method.is_noreturn node.is_ctor_new = method.is_ctor_new node.return_type = method.return_type + if method.return_type.has_flag(.generic) { + node.return_type_generic = method.return_type + } if !method.is_pub && method.mod != c.mod { // If a private method is called outside of the module // its receiver type is defined in, show an error. diff --git a/vlib/v/comptime/comptimeinfo.v b/vlib/v/comptime/comptimeinfo.v index 0e77d42a578102..86388a513e6be4 100644 --- a/vlib/v/comptime/comptimeinfo.v +++ b/vlib/v/comptime/comptimeinfo.v @@ -44,33 +44,49 @@ pub fn (mut ct ComptimeInfo) get_ct_type_var(node ast.Expr) ast.ComptimeVarKind } } +@[inline] +pub fn (mut ct ComptimeInfo) is_generic_param_var(node ast.Expr) bool { + return node is ast.Ident && node.info is ast.IdentVar && node.obj is ast.Var + && (node.obj as ast.Var).ct_type_var == .generic_param +} + // get_comptime_var_type retrieves the actual type from a comptime related ast node @[inline] pub fn (mut ct ComptimeInfo) get_comptime_var_type(node ast.Expr) ast.Type { - if node is ast.Ident && node.obj is ast.Var { - return match (node.obj as ast.Var).ct_type_var { - .generic_param { - // generic parameter from current function - node.obj.typ - } - .smartcast { - ctyp := ct.type_map['${ct.comptime_for_variant_var}.typ'] or { node.obj.typ } - return if (node.obj as ast.Var).is_unwrapped { - ctyp.clear_flag(.option) - } else { - ctyp + if node is ast.Ident { + if node.obj is ast.Var { + return match node.obj.ct_type_var { + .generic_param { + // generic parameter from current function + node.obj.typ + } + .generic_var { + // generic var used on fn call assignment + if node.obj.smartcasts.len > 0 { + node.obj.smartcasts.last() + } else { + ct.type_map['g.${node.name}.${node.obj.pos.pos}'] or { node.obj.typ } + } + } + .smartcast { + ctyp := ct.type_map['${ct.comptime_for_variant_var}.typ'] or { node.obj.typ } + return if (node.obj as ast.Var).is_unwrapped { + ctyp.clear_flag(.option) + } else { + ctyp + } + } + .key_var, .value_var { + // key and value variables from normal for stmt + ct.type_map[node.name] or { ast.void_type } + } + .field_var { + // field var from $for loop + ct.comptime_for_field_type + } + else { + ast.void_type } - } - .key_var, .value_var { - // key and value variables from normal for stmt - ct.type_map[node.name] or { ast.void_type } - } - .field_var { - // field var from $for loop - ct.comptime_for_field_type - } - else { - ast.void_type } } } else if node is ast.ComptimeSelector { diff --git a/vlib/v/gen/c/assign.v b/vlib/v/gen/c/assign.v index e4ddd4d1fbacfb..5b8753cf48e180 100644 --- a/vlib/v/gen/c/assign.v +++ b/vlib/v/gen/c/assign.v @@ -284,7 +284,7 @@ fn (mut g Gen) assign_stmt(node_ ast.AssignStmt) { } g.assign_ct_type = var_type } else if val is ast.IndexExpr { - if val.left is ast.Ident && g.is_generic_param_var(val.left) { + if val.left is ast.Ident && g.comptime.is_generic_param_var(val.left) { ctyp := g.unwrap_generic(g.get_gn_var_type(val.left)) if ctyp != ast.void_type { var_type = ctyp @@ -293,6 +293,17 @@ fn (mut g Gen) assign_stmt(node_ ast.AssignStmt) { g.assign_ct_type = var_type } } + } else if left.obj.ct_type_var == .generic_var && val is ast.CallExpr { + if val.return_type_generic != 0 && val.return_type_generic.has_flag(.generic) { + fn_ret_type := g.resolve_fn_return_type(val) + if fn_ret_type != ast.void_type { + var_type = fn_ret_type + val_type = var_type + left.obj.typ = var_type + g.comptime.type_map['g.${left.name}.${left.obj.pos.pos}'] = var_type + // eprintln('>> ${func.name} > resolve ${left.name}.${left.obj.pos.pos}.generic to ${g.table.type_to_str(var_type)}') + } + } } is_auto_heap = left.obj.is_auto_heap } diff --git a/vlib/v/gen/c/cgen.v b/vlib/v/gen/c/cgen.v index 36eeb364172f38..21739f5b851a76 100644 --- a/vlib/v/gen/c/cgen.v +++ b/vlib/v/gen/c/cgen.v @@ -4721,12 +4721,6 @@ fn (mut g Gen) select_expr(node ast.SelectExpr) { } } -@[inline] -pub fn (mut g Gen) is_generic_param_var(node ast.Expr) bool { - return node is ast.Ident && node.info is ast.IdentVar && node.obj is ast.Var - && (node.obj as ast.Var).ct_type_var == .generic_param -} - fn (mut g Gen) get_const_name(node ast.Ident) string { if g.pref.translated && !g.is_builtin_mod && !util.module_is_builtin(node.name.all_before_last('.')) { diff --git a/vlib/v/gen/c/fn.v b/vlib/v/gen/c/fn.v index 886a7e517e53b0..00475ab98d416f 100644 --- a/vlib/v/gen/c/fn.v +++ b/vlib/v/gen/c/fn.v @@ -1105,10 +1105,6 @@ fn (mut g Gen) gen_to_str_method_call(node ast.CallExpr) bool { rec_type = g.comptime.get_comptime_var_type(left_node) g.gen_expr_to_string(left_node, rec_type) return true - } else if g.comptime.type_map.len > 0 { - rec_type = left_node.obj.typ - g.gen_expr_to_string(left_node, rec_type) - return true } else if left_node.obj.smartcasts.len > 0 { rec_type = g.unwrap_generic(left_node.obj.smartcasts.last()) cast_sym := g.table.sym(rec_type) @@ -1154,6 +1150,79 @@ fn (mut g Gen) get_gn_var_type(var ast.Ident) ast.Type { return ast.void_type } +// resolve_fn_return_type resolves the generic return type of fn +fn (mut g Gen) resolve_fn_return_type(node ast.CallExpr) ast.Type { + if node.is_method { + if func := g.table.find_method(g.table.sym(node.left_type), node.name) { + if func.generic_names.len > 0 { + mut concrete_types := node.concrete_types.map(g.unwrap_generic(it)) + mut rec_len := 0 + if node.left_type.has_flag(.generic) { + rec_sym := g.table.final_sym(g.unwrap_generic(node.left_type)) + match rec_sym.info { + ast.Struct, ast.Interface, ast.SumType { + rec_len += rec_sym.info.generic_types.len + } + else {} + } + } + + mut call_ := unsafe { node } + comptime_args := g.resolve_comptime_args(func, mut call_, concrete_types) + if concrete_types.len > 0 { + for k, v in comptime_args { + if (rec_len + k) < concrete_types.len { + if !node.concrete_types[k].has_flag(.generic) { + concrete_types[rec_len + k] = g.unwrap_generic(v) + } + } + } + } + if gen_type := g.table.resolve_generic_to_concrete(node.return_type_generic, + func.generic_names, concrete_types) + { + if !gen_type.has_flag(.generic) { + return if node.or_block.kind == .absent { + gen_type + } else { + gen_type.clear_option_and_result() + } + } + } + } + } + } else { + if func := g.table.find_fn(node.name) { + if func.generic_names.len > 0 { + mut concrete_types := node.concrete_types.map(g.unwrap_generic(it)) + mut call_ := unsafe { node } + comptime_args := g.resolve_comptime_args(func, mut call_, concrete_types) + if concrete_types.len > 0 { + for k, v in comptime_args { + if k < concrete_types.len { + if !node.concrete_types[k].has_flag(.generic) { + concrete_types[k] = g.unwrap_generic(v) + } + } + } + } + if gen_type := g.table.resolve_generic_to_concrete(node.return_type_generic, + func.generic_names, concrete_types) + { + if !gen_type.has_flag(.generic) { + return if node.or_block.kind == .absent { + gen_type + } else { + gen_type.clear_option_and_result() + } + } + } + } + } + } + return ast.void_type +} + fn (g Gen) get_generic_array_element_type(array ast.Array) ast.Type { mut cparam_elem_info := array as ast.Array mut cparam_elem_sym := g.table.sym(cparam_elem_info.elem_type) @@ -1194,7 +1263,7 @@ fn (mut g Gen) resolve_comptime_args(func ast.Fn, mut node_ ast.CallExpr, concre if mut call_arg.expr is ast.Ident { if mut call_arg.expr.obj is ast.Var { node_.args[i].typ = call_arg.expr.obj.typ - if call_arg.expr.obj.ct_type_var !in [.generic_param, .no_comptime] { + if call_arg.expr.obj.ct_type_var !in [.generic_var, .generic_param, .no_comptime] { mut ctyp := g.comptime.get_comptime_var_type(call_arg.expr) if ctyp != ast.void_type { arg_sym := g.table.sym(ctyp) @@ -1293,11 +1362,13 @@ fn (mut g Gen) resolve_comptime_args(func ast.Fn, mut node_ ast.CallExpr, concre comptime_args[k] = comptime_args[k].set_nr_muls(0) } } else if mut call_arg.expr.right is ast.Ident { - mut ctyp := g.comptime.get_comptime_var_type(call_arg.expr.right) - if ctyp != ast.void_type { - comptime_args[k] = ctyp - if param_typ.nr_muls() > 0 && comptime_args[k].nr_muls() > 0 { - comptime_args[k] = comptime_args[k].set_nr_muls(0) + if g.comptime.get_ct_type_var(call_arg.expr.right) != .generic_var { + mut ctyp := g.comptime.get_comptime_var_type(call_arg.expr.right) + if ctyp != ast.void_type { + comptime_args[k] = ctyp + if param_typ.nr_muls() > 0 && comptime_args[k].nr_muls() > 0 { + comptime_args[k] = comptime_args[k].set_nr_muls(0) + } } } } diff --git a/vlib/v/gen/c/for.v b/vlib/v/gen/c/for.v index 312d7344ca8abd..ee520b523efe5c 100644 --- a/vlib/v/gen/c/for.v +++ b/vlib/v/gen/c/for.v @@ -450,9 +450,13 @@ fn (mut g Gen) for_in_stmt(node_ ast.ForInStmt) { g.writeln('${t_expr});') g.writeln('\tif (${t_var}.state != 0) break;') val := if node.val_var in ['', '_'] { g.new_tmp_var() } else { node.val_var } - val_styp := g.typ(node.val_type) + val_styp := g.typ(ret_typ.clear_option_and_result()) if node.val_is_mut { - g.writeln('\t${val_styp} ${val} = (${val_styp})${t_var}.data;') + if ret_typ.has_flag(.option) { + g.writeln('\t${val_styp}* ${val} = (${val_styp}*)${t_var}.data;') + } else { + g.writeln('\t${val_styp} ${val} = (${val_styp})${t_var}.data;') + } } else { g.writeln('\t${val_styp} ${val} = *(${val_styp}*)${t_var}.data;') } diff --git a/vlib/v/gen/c/str_intp.v b/vlib/v/gen/c/str_intp.v index bf74ef2be66fae..1a0829b3f4280a 100644 --- a/vlib/v/gen/c/str_intp.v +++ b/vlib/v/gen/c/str_intp.v @@ -192,9 +192,7 @@ fn (mut g Gen) str_val(node ast.StringInterLiteral, i int, fmts []u8) { mut exp_typ := typ if expr is ast.Ident { if expr.obj is ast.Var { - if g.comptime.type_map.len > 0 || g.comptime.comptime_for_method.len > 0 { - exp_typ = expr.obj.typ - } else if expr.obj.smartcasts.len > 0 { + if expr.obj.smartcasts.len > 0 { exp_typ = g.unwrap_generic(expr.obj.smartcasts.last()) cast_sym := g.table.sym(exp_typ) if cast_sym.info is ast.Aggregate { diff --git a/vlib/v/tests/generic_return_test.v b/vlib/v/tests/generic_return_test.v new file mode 100644 index 00000000000000..f2610d388ed34f --- /dev/null +++ b/vlib/v/tests/generic_return_test.v @@ -0,0 +1,43 @@ +fn mkey[K, V](m map[K]V) K { + return K{} +} + +fn mvalue[K, V](m map[K]V) V { + return V{} +} + +fn aelem[E](a []E) E { + return E{} +} + +fn g[T](x T) { + $if T is $map { + dk := mkey(x) + dv := mvalue(x) + eprintln('default k: `${dk}` | typeof dk: ${typeof(dk).name}') + eprintln('default v: `${dv}` | typeof dv: ${typeof(dv).name}') + for k, v in x { + eprintln('> k: ${k} | v: ${v}') + } + } + $if T is $array { + de := aelem(x) + eprintln('default e: `${de}` | typeof de: ${typeof(de).name}') + for idx, e in x { + eprintln('> idx: ${idx} | e: ${e}') + } + } +} + +fn test_main() { + g({ + 'abc': 123 + 'def': 456 + }) + g([1, 2, 3]) + g({ + 123: 'ggg' + 456: 'hhh' + }) + g(['xyz', 'zzz']) +}