Skip to content

Commit

Permalink
-
Browse files Browse the repository at this point in the history
  • Loading branch information
jdcpni committed Nov 25, 2024
1 parent bc40c3d commit 0c78c45
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 53 deletions.
2 changes: 1 addition & 1 deletion psyneulink/library/compositions/emcomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2064,7 +2064,7 @@ def _parse_fields_dict(name, fields, num_fields)->(list,list,list,list):

if normalize_field_weights and not all([fw == 0 for fw in field_weights]): # noqa: E127
fld_wts_0s_for_Nones = [fw if fw is not None else 0 for fw in field_weights]
parsed_field_weights = fld_wts_0s_for_Nones / np.sum(fld_wts_0s_for_Nones)
parsed_field_weights = list(np.array(fld_wts_0s_for_Nones) / (np.sum(fld_wts_0s_for_Nones) or 1))
parsed_field_weights = [pfw if fw is not None else None
for pfw, fw in zip(parsed_field_weights, field_weights)]
else:
Expand Down
89 changes: 37 additions & 52 deletions tests/composition/test_emcomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,17 +434,44 @@ def test_field_args_and_map_assignments(self,
assert em._field_index_map[[k for k in em._field_index_map.keys()
if 'WEIGHT to WEIGHTED MATCH for KEY B' in k.name][0]] == 2

def test_field_weights_all_None_and_or_0(self):
with pytest.raises(EMCompositionError) as error_text:
EMComposition(memory_template=(3,1), memory_capacity=1, field_weights=[None, None, None])
assert error_text.value.error_value == (f"The entries in 'field_weights' arg for EM_Composition can't all "
f"be 'None' since that will preclude the construction of any keys.")
@pytest.mark.parametrize('field_weight_1', ([None], [0], [1]), ids=['None', '0', '1'])
@pytest.mark.parametrize('field_weight_2', ([None], [0], [1]), ids=['None', '0', '1'])
@pytest.mark.parametrize('field_weight_3', ([None], [0], [1]), ids=['None', '0', '1'])
def test_order_fields_in_memory(self, field_weight_1, field_weight_2, field_weight_3):
"""Test that order of keys and values doesn't matter"""

with pytest.warns(UserWarning) as warning:
EMComposition(memory_template=(3,1), memory_capacity=1, field_weights=[0, None, 0])
warning_msg = (f"All of the entries in the 'field_weights' arg for EM_Composition are either None or set to 0; "
f"this will result in no retrievals unless/until the 0(s) is(are) changed to a positive value.")
assert warning_msg in str(warning[0].message)
# pytest.skip(<UNECESSARY TESTS>>)

def construct_em(field_weights):
return pnl.EMComposition(memory_template=[[[5,0], [5], [5,0,3]], [[20,0], [20], [20,1,199]]],
memory_capacity=4,
field_weights=field_weights)

field_weights = field_weight_1 + field_weight_2 + field_weight_3

if all([fw is None for fw in field_weights]):
with pytest.raises(EMCompositionError) as error_text:
construct_em(field_weights)
assert ("The entries in 'field_weights' arg for EM_Composition can't all be 'None' "
"since that will preclude the construction of any keys." in str(error_text.value))

elif not any(field_weights):
with pytest.warns(UserWarning) as warning:
construct_em(field_weights)
warning_msg = ("All of the entries in the 'field_weights' arg for EM_Composition "
"are either None or set to 0; this will result in no retrievals "
"unless/until one or more of them are changed to a positive value.")
assert warning_msg in str(warning[0].message)

elif any([fw == 0 for fw in field_weights]):
with pytest.warns(UserWarning) as warning:
construct_em(field_weights)
warning_msg = ("Some of the entries in the 'field_weights' arg for EM_Composition are set to 0; those "
"fields will be ignored during retrieval unless/until they are changed to a positive value.")
assert warning_msg in str(warning[0].message)

else:
construct_em(field_weights)


@pytest.mark.pytorch
Expand Down Expand Up @@ -1057,45 +1084,3 @@ def test_backpropagation_of_error_in_learning(self):
# axes[2].set_ylabel('Correct Logit')
# plt.suptitle(f"Blocked Training")
# plt.show()

@pytest.mark.parametrize('field_weight_1', ([None], [0], [1]), ids=['None', '0', '1'])
@pytest.mark.parametrize('field_weight_2', ([None], [0], [1]), ids=['None', '0', '1'])
@pytest.mark.parametrize('field_weight_3', ([None], [0], [1]), ids=['None', '0', '1'])
@pytest.mark.composition
def test_order_fields_in_memory(self, field_weight_1, field_weight_2, field_weight_3):
"""Test that order of keys and values doesn't matter"""

# pytest.skip('All field weights are None')

def construct_em(field_weights):
return pnl.EMComposition(memory_template=[[[5,0], [5], [5,0,3]], [[20,0], [20], [20,1,199]]],
memory_capacity=4,
field_weights=field_weights)

field_weights = field_weight_1 + field_weight_2 + field_weight_3

if all([fw is None for fw in field_weights]):
with pytest.raises(EMCompositionError) as error_text:
construct_em(field_weights)
assert ("The entries in 'field_weights' arg for EM_Composition can't all be 'None' "
"since that will preclude the construction of any keys." in str(error_text.value))

elif not any(field_weights):
with pytest.warns(UserWarning) as warning:
construct_em(field_weights)
warning_msg = ("All of the entries in the 'field_weights' arg for EM_Composition "
"are either None or set to 0; this will result in no retrievals "
"unless/until one or more of them are changed to a positive value.")
assert warning_msg in str(warning[0].message)

elif any([fw == 0 for fw in field_weights]):
with pytest.warns(UserWarning) as warning:
construct_em(field_weights)
warning_msg = ("Some of the entries in the 'field_weights' arg for EM_Composition are set to 0; those "
"fields will be ignored during retrieval unless/until they are changed to a positive value.")
assert warning_msg in str(warning[0].message)

else:
construct_em(field_weights)


0 comments on commit 0c78c45

Please sign in to comment.