Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix for auto transforms #604

Merged
merged 11 commits into from
Nov 7, 2024
2 changes: 1 addition & 1 deletion scripts/metavar.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,7 +1634,7 @@ def add_variable(self, newvar, run_env, exists_ok=False, gen_unique=False,
# end if
if cvar is not None:
compat = cvar.compatible(newvar, run_env)
if compat:
if compat.compat:
# Check for intent mismatch
vintent = cvar.get_prop_value('intent')
dintent = newvar.get_prop_value('intent')
Expand Down
132 changes: 102 additions & 30 deletions scripts/suite_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,14 @@ def add_variable(self, newvar, run_env, exists_ok=False, gen_unique=False,
super().add_variable(newvar, run_env, exists_ok=exists_ok,
gen_unique=gen_unique, adjust_intent=adjust_intent)

def call_string(self, cldicts=None, is_func_call=False, subname=None):
def call_string(self, cldicts=None, is_func_call=False, subname=None, sub_lname_list=None):
"""Return a dummy argument string for this call list.
<cldict> may be a list of VarDictionary objects to search for
local_names (default is to use self).
<is_func_call> should be set to True to construct a call statement.
If <is_func_call> is False, construct a subroutine dummy argument
list.
<sub_lname_list> may be a list of local_name substitutions.
"""
arg_str = ""
arg_sep = ""
Expand Down Expand Up @@ -157,6 +158,17 @@ def call_string(self, cldicts=None, is_func_call=False, subname=None):
lname = dummy
# end if
# end if
# Modify Scheme call_list to handle local_name change for this var.
dustinswales marked this conversation as resolved.
Show resolved Hide resolved
# Are there any variable transforms for this scheme?
# If so, change Var's local_name need to local dummy array containing
dustinswales marked this conversation as resolved.
Show resolved Hide resolved
# transformed argument, var_trans_local.
if sub_lname_list:
for (var_trans_local, var_lname, sname, rindices, lindices, compat_obj) in sub_lname_list:
if (sname == stdname):
lname = var_trans_local
# end if
# end for
# end if
if is_func_call:
if cldicts is not None:
use_dicts = cldicts
Expand Down Expand Up @@ -898,9 +910,6 @@ def match_variable(self, var, run_env):
new_dict_dims = dict_dims
match = True
# end if
# Create compatability object, containing any necessary forward/reverse
# transforms from <var> and <dict_var>
compat_obj = var.compatible(dict_var, run_env)
# If variable is defined as "inactive" by the host, ensure that
# this variable is declared as "optional" by the scheme. If
# not satisfied, return error.
Expand Down Expand Up @@ -933,6 +942,14 @@ def match_variable(self, var, run_env):
# end if
# end if
# end if
# We have a match!
# Are the Scheme's <var> and Host's <dict_var> compatible?
dustinswales marked this conversation as resolved.
Show resolved Hide resolved
# If so, create compatibility object, containing any necessary
# forward/reverse transforms to/from <var> and <dict_var>.
if dict_var is not None:
dict_var = self.parent.find_variable(source_var=var, any_scope=True)
compat_obj = var.compatible(dict_var, run_env)
# end if
return found_var, dict_var, var_vdim, new_vdims, missing_vert, compat_obj

def in_process_split(self):
Expand Down Expand Up @@ -1262,7 +1279,7 @@ def analyze(self, phase, group, scheme_library, suite_vars, level):
# end if

# Is this a conditionally allocated variable?
# If so, declare localpointer varaible. This is needed to
# If so, declare localpointer variable. This is needed to
# pass inactive (not present) status through the caps.
if var.get_prop_value('optional'):
newvar_ptr = var.clone(var.get_prop_value('local_name')+'_ptr')
Expand Down Expand Up @@ -1611,9 +1628,25 @@ def add_var_transform(self, var, compat_obj, vert_dim):
from <var> to perform the transformation. Determine the indices needed
for the transform and save for use during write stage"""

# Add dummy variable (<var>_local) needed for transformation.
dummy = var.clone(var.get_prop_value('local_name')+'_local')
self.__group.manage_variable(dummy)
# Add local variable (<var(local_name)>_local) needed for transformation.
# Do not let the Group manage this variable. Handle local var
dustinswales marked this conversation as resolved.
Show resolved Hide resolved
# when writing Group.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# when writing Group.
# when writing group.

prop_dict = var.copy_prop_dict()
prop_dict['local_name'] = var.get_prop_value('local_name')+'_local'
# This is a local variable.
if 'intent' in prop_dict:
del prop_dict['intent']
# end if
local_trans_var = Var(prop_dict,
ParseSource(_API_SOURCE_NAME,
_API_LOCAL_VAR_NAME, var.context),
self.run_env)
found = self.__group.find_variable(source_var=local_trans_var, any_scope=False)
if not found:
lmsg = "Adding new local variable, '{}', for variable transform"
self.run_env.logger.info(lmsg.format(local_trans_var.get_prop_value('local_name')))
self.__group.transform_locals.append(local_trans_var)
# end if

# Create indices (default) for transform.
lindices = [':']*var.get_rank()
Expand Down Expand Up @@ -1652,26 +1685,35 @@ def add_var_transform(self, var, compat_obj, vert_dim):
#hdim = find_horizontal_dimension(var.get_dimensions())
#if compat_obj.has_dim_transforms:

#
# Register any reverse (pre-Scheme) transforms.
#
# Register any reverse (pre-Scheme) transforms. Also, save local_name used in
dustinswales marked this conversation as resolved.
Show resolved Hide resolved
# transform (used in write stage).
if (var.get_prop_value('intent') != 'out'):
self.__reverse_transforms.append([dummy.get_prop_value('local_name'),
lmsg = "Automatic unit conversion from '{}' to '{}' for '{}' before entering '{}'"
self.run_env.logger.info(lmsg.format(compat_obj.v2_units,
compat_obj.v1_units,
compat_obj.v2_stdname,
compat_obj.v1_stdname))
self.__reverse_transforms.append([local_trans_var.get_prop_value('local_name'),
var.get_prop_value('local_name'),
var.get_prop_value('standard_name'),
rindices, lindices, compat_obj])

#
# end if
# Register any forward (post-Scheme) transforms.
#
if (var.get_prop_value('intent') != 'in'):
lmsg = "Automatic unit conversion from '{}' to '{}' for '{}' after returning '{}'"
self.run_env.logger.info(lmsg.format(compat_obj.v1_units,
compat_obj.v2_units,
compat_obj.v1_stdname,
compat_obj.v2_stdname))
self.__forward_transforms.append([var.get_prop_value('local_name'),
dummy.get_prop_value('local_name'),
var.get_prop_value('standard_name'),
local_trans_var.get_prop_value('local_name'),
lindices, rindices, compat_obj])

# end if
def write_var_transform(self, var, dummy, rindices, lindices, compat_obj,
outfile, indent, forward):
"""Write variable transformation needed to call this Scheme in <outfile>.
<var> is the varaible that needs transformation before and after calling Scheme.
<var> is the variable that needs transformation before and after calling Scheme.
<dummy> is the local variable needed for the transformation..
<lindices> are the LHS indices of <dummy> for reverse transforms (before Scheme).
<rindices> are the RHS indices of <var> for reverse transforms (before Scheme).
Expand Down Expand Up @@ -1709,7 +1751,8 @@ def write(self, outfile, errcode, errmsg, indent):
cldicts.extend(self.__group.suite_dicts())
my_args = self.call_list.call_string(cldicts=cldicts,
is_func_call=True,
subname=self.subroutine_name)
subname=self.subroutine_name,
sub_lname_list = self.__reverse_transforms)
#
outfile.write('', indent)
outfile.write('if ({} == 0) then'.format(errcode), indent)
Expand Down Expand Up @@ -1737,8 +1780,15 @@ def write(self, outfile, errcode, errmsg, indent):
if len(self.__reverse_transforms) > 0:
outfile.comment('Compute reverse (pre-scheme) transforms', indent+1)
# end if
for (dummy, var, rindices, lindices, compat_obj) in self.__reverse_transforms:
tstmt = self.write_var_transform(var, dummy, rindices, lindices, compat_obj, outfile, indent+1, False)
for rcnt, (dummy, var_lname, var_sname, rindices, lindices, compat_obj) in enumerate(self.__reverse_transforms):
# Any transform(s) were added during the Group's analyze phase, but
dustinswales marked this conversation as resolved.
Show resolved Hide resolved
# the local_name(s) of the <var> assoicated with the transform(s)
# may have since changed. Here we need to use the standard_name
# from <var> and replace its local_name with the local_name from the
# Group's call_list.
climbfuji marked this conversation as resolved.
Show resolved Hide resolved
lvar = self.__group.call_list.find_variable(standard_name=var_sname)
lvar_lname = lvar.get_prop_value('local_name')
tstmt = self.write_var_transform(lvar_lname, dummy, rindices, lindices, compat_obj, outfile, indent+1, False)
# end for
outfile.write('',indent+1)
#
Expand Down Expand Up @@ -1778,8 +1828,15 @@ def write(self, outfile, errcode, errmsg, indent):
if len(self.__forward_transforms) > 0:
outfile.comment('Compute forward (post-scheme) transforms', indent+1)
# end if
for (var, dummy, lindices, rindices, compat_obj) in self.__forward_transforms:
tstmt = self.write_var_transform(var, dummy, rindices, lindices, compat_obj, outfile, indent+1, True)
for fcnt, (var_lname, var_sname, dummy, lindices, rindices, compat_obj) in enumerate(self.__forward_transforms):
# Any transform(s) were added during the Group's analyze phase, but
dustinswales marked this conversation as resolved.
Show resolved Hide resolved
# the local_name(s) of the <var> assoicated with the transform(s)
# may have since changed. Here we need to use the standard_name
# from <var> and replace its local_name with the local_name from the
# Group's call_list.
lvar = self.__group.call_list.find_variable(standard_name=var_sname)
lvar_lname = lvar.get_prop_value('local_name')
tstmt = self.write_var_transform(lvar_lname, dummy, rindices, lindices, compat_obj, outfile, indent+1, True)
# end for
outfile.write('', indent)
outfile.write('end if', indent)
Expand Down Expand Up @@ -2125,6 +2182,7 @@ def __init__(self, group_xml, transition, parent, context, run_env):
self._phase_check_stmts = list()
self._set_state = None
self._ddt_library = None
self.transform_locals = list()

def phase_match(self, scheme_name):
"""If scheme_name matches the group phase, return the group and
Expand Down Expand Up @@ -2391,6 +2449,19 @@ def write(self, outfile, host_arglist, indent, const_mod,
# end if
pointer_var_set.append([name,kind,dimstr,vtype])
# end for
# Any arguments used in variable transforms before or after the
# Scheme call? If so, declare local copy for reuse in the Group cap.
dustinswales marked this conversation as resolved.
Show resolved Hide resolved
for ivar in self.transform_locals:
lname = ivar.get_prop_value('local_name')
climbfuji marked this conversation as resolved.
Show resolved Hide resolved
opt_var = ivar.get_prop_value('optional')
dims = ivar.get_dimensions()
if (dims is not None) and dims:
subpart_allocate_vars[lname] = (ivar, item, opt_var)
allocatable_var_set.add(lname)
else:
subpart_scalar_vars[lname] = (ivar, item, opt_var)
# end if
# end for

# end for
# First, write out the subroutine header
Expand Down Expand Up @@ -2499,6 +2570,14 @@ def write(self, outfile, host_arglist, indent, const_mod,
self._phase_check_stmts.write(outfile, indent,
{'errcode' : errcode, 'errmsg' : errmsg,
'funcname' : self.name})
# Write any loop match calculations
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was a bug. This needed to be need to BEFORE the allocation.

outfile.write("! Set horizontal loop extent",indent+1)
for vmatch in self._loop_var_matches:
action = vmatch.write_action(self, dict2=self.call_list)
if action:
outfile.write(action, indent+1)
# end if
# end for
# Allocate local arrays
outfile.write('\n! Allocate local arrays', indent+1)
alloc_stmt = "allocate({}({}))"
Expand Down Expand Up @@ -2530,13 +2609,6 @@ def write(self, outfile, host_arglist, indent, const_mod,
# end if dims (do not allocate scalars)
# end for
# end if
# Write any loop match calculations
for vmatch in self._loop_var_matches:
action = vmatch.write_action(self, dict2=self.call_list)
if action:
outfile.write(action, indent+1)
# end if
# end for
# Write the scheme and subcycle calls
for item in self.parts:
item.write(outfile, errcode, errmsg, indent + 1)
Expand Down
6 changes: 5 additions & 1 deletion scripts/var_props.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,10 @@ def __init__(self, var1_stdname, var1_type, var1_kind, var1_units,
self.__v2_context = v2_context
self.__v1_kind = var1_kind
self.__v2_kind = var2_kind
self.v1_units = var1_units
climbfuji marked this conversation as resolved.
Show resolved Hide resolved
self.v2_units = var2_units
self.v1_stdname = var1_stdname
self.v2_stdname = var2_stdname
# Default (null) transform information
self.__dim_transforms = None
self.__kind_transforms = None
Expand Down Expand Up @@ -966,8 +970,8 @@ def __init__(self, var1_stdname, var1_type, var1_kind, var1_units,
incompat_reason.append(emsg)
# end if
elif var1_units != var2_units:
self.__equiv = False
climbfuji marked this conversation as resolved.
Show resolved Hide resolved
# Try to find a set of unit conversions
self.__equiv = False
self.__unit_transforms = self._get_unit_convstrs(var1_units,
var2_units)
# end if
Expand Down
4 changes: 2 additions & 2 deletions test/unit_tests/test_var_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_valid_unit_change(self):
compat = real_scalar1.compatible(real_scalar2, self.__run_env)
self.assertIsInstance(compat, VarCompatObj,
msg=self.__inst_emsg.format(type(compat)))
self.assertFalse(compat)
self.assertFalse(compat.equiv)
self.assertTrue(compat.compat)
self.assertEqual(compat.incompat_reason, '')
self.assertFalse(compat.has_kind_transforms)
Expand All @@ -150,7 +150,7 @@ def test_valid_unit_change(self):
compat = real_scalar1.compatible(real_scalar2, self.__run_env)
self.assertIsInstance(compat, VarCompatObj,
msg=self.__inst_emsg.format(type(compat)))
self.assertFalse(compat)
self.assertFalse(compat.equiv)
self.assertTrue(compat.compat)
self.assertEqual(compat.incompat_reason, '')
self.assertFalse(compat.has_kind_transforms)
Expand Down
42 changes: 42 additions & 0 deletions test/var_compatibility_test/effr_diag.F90
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
!Test unit conversions for intent in, inout, out variables
!

module effr_diag

use ccpp_kinds, only: kind_phys

implicit none
private

public :: effr_diag_run

contains

!> \section arg_table_effr_diag_run Argument Table
!! \htmlinclude arg_table_effr_diag_run.html
!!
subroutine effr_diag_run( effrr_in, errmsg, errflg)

real(kind_phys), intent(in) :: effrr_in(:,:)
character(len=512), intent(out) :: errmsg
integer, intent(out) :: errflg
!----------------------------------------------------------------
real(kind_phys) :: effrr_min, effrr_max

errmsg = ''
errflg = 0

call cmp_effr_diag(effrr_in, effrr_min, effrr_max)

end subroutine effr_diag_run

subroutine cmp_effr_diag(effr, effr_min, effr_max)
real(kind_phys), intent(in) :: effr(:,:)
real(kind_phys), intent(out) :: effr_min, effr_max

! Do some diagnostic calcualtions...
effr_min = minval(effr)
effr_max = maxval(effr)

end subroutine cmp_effr_diag
end module effr_diag
31 changes: 31 additions & 0 deletions test/var_compatibility_test/effr_diag.meta
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
[ccpp-table-properties]
name = effr_diag
type = scheme
dependencies =
[ccpp-arg-table]
name = effr_diag_run
type = scheme
[effrr_in]
standard_name = effective_radius_of_stratiform_cloud_rain_particle
long_name = effective radius of cloud rain particle in micrometer
units = um
dimensions = (horizontal_loop_extent,vertical_layer_dimension)
type = real
kind = kind_phys
intent = in
top_at_one = True
[ errmsg ]
standard_name = ccpp_error_message
long_name = Error message for error handling in CCPP
units = none
dimensions = ()
type = character
kind = len=512
intent = out
[ errflg ]
standard_name = ccpp_error_code
long_name = Error flag for error handling in CCPP
units = 1
dimensions = ()
type = integer
intent = out
Loading
Loading