Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

patch/emcomposition/field_memory_indices #3127

Merged
merged 20 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -672,13 +672,14 @@ def _validate_params(self, request_set, target_set=None, context=None):
f"in its variable ({len(self.variable)}).")

# Ensure shape of learning_signals matches shapes of matrices for match nodes (i.e., either keys or concatenate)
key_indices = [i for i, field_type in enumerate(field_types) if field_type == 1]
for i, learning_signal in enumerate(learning_signals[:num_match_fields]):
learning_signal_shape = learning_signal.parameters.matrix._get(context).shape
if concatenate_queries:
memory_matrix_field_shape = np.array([np.concatenate(row, dtype=object).flatten()
for row in memory_matrix[:,0:num_keys]]).T.shape
else:
memory_matrix_field_shape = np.array(memory_matrix[:,i].tolist()).T.shape
memory_matrix_field_shape = np.array(memory_matrix[:,key_indices[i]].tolist()).T.shape
assert learning_signal_shape == memory_matrix_field_shape, \
f"The shape ({learning_signal_shape}) of the matrix for the Projection {learning_signal.name} " \
f"used to specify learning signal {i} of {self.name} does not match the shape " \
Expand Down
45 changes: 21 additions & 24 deletions psyneulink/library/compositions/emcomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,7 @@ def field_weights_setter(field_weights, owning_component=None, context=None):
raise EMCompositionError(f"The number of field_weights ({len(field_weights)}) must match the number of fields "
f"{len(owning_component.field_weights)}")
if owning_component.normalize_field_weights:
denominator = np.sum(np.where(field_weights is not None, field_weights, 0))
denominator = np.sum(np.where(field_weights is not None, field_weights, 0)) or 1
field_weights = [fw / denominator if fw is not None else None for fw in field_weights]

# Assign new fields_weights to default_variable of field_weight_nodes
Expand Down Expand Up @@ -1898,10 +1898,16 @@ def _validate_memory_specs(self, memory_template, memory_capacity, memory_fill,
if all([fw is None for fw in _field_wts]):
raise EMCompositionError(f"The entries in 'field_weights' arg for {name} can't all be 'None' "
f"since that will preclude the construction of any keys.")
if all([fw in {0, None} for fw in _field_wts]):
warnings.warn(f"All of the entries in the 'field_weights' arg for {name} are either None or "
f"set to 0; this will result in no retrievals unless/until the 0(s) is(are) changed "
f"to a positive value.")

if not any(_field_wts):
warnings.warn(f"All of the entries in the 'field_weights' arg for {name} "
f"are either None or set to 0; this will result in no retrievals "
f"unless/until one or more of them are changed to a positive value.")

elif any([fw == 0 for fw in _field_wts if fw is not None]):
warnings.warn(f"Some of the entries in the 'field_weights' arg for {name} "
f"are set to 0; those fields will be ignored during retrieval "
f"unless/until they are changed to a positive value.")

# If field_names has more than one value it must match the first dimension (axis 0) of memory_template:
if field_names and len(field_names) != num_fields:
Expand Down Expand Up @@ -2058,7 +2064,7 @@ def _parse_fields_dict(name, fields, num_fields)->(list,list,list,list):

if normalize_field_weights and not all([fw == 0 for fw in field_weights]): # noqa: E127
fld_wts_0s_for_Nones = [fw if fw is not None else 0 for fw in field_weights]
parsed_field_weights = fld_wts_0s_for_Nones / np.sum(fld_wts_0s_for_Nones)
parsed_field_weights = list(np.array(fld_wts_0s_for_Nones) / (np.sum(fld_wts_0s_for_Nones) or 1))
parsed_field_weights = [pfw if fw is not None else None
for pfw, fw in zip(parsed_field_weights, field_weights)]
else:
Expand Down Expand Up @@ -2380,13 +2386,14 @@ def _construct_match_nodes(self, memory_template, memory_capacity, concatenate_q
else:
# Assign each key Field its own match_node and "memory" Projection to it
for i in range(self.num_keys):
field = self.fields[self.key_indices[i]]
memory_projection = MappingProjection(name=f'MEMORY for {self.key_names[i]} [KEY]',
sender=self.query_input_nodes[i].output_port,
matrix = np.array(
memory_template[:,i].tolist()).transpose().astype(float),
function=MatrixTransform(operation=args[i][OPERATION],
normalize=args[i][NORMALIZE]))
key_idx = self.key_indices[i]
field = self.fields[key_idx]
memory_projection = (
MappingProjection(name=f'MEMORY for {self.key_names[i]} [KEY]',
sender=self.query_input_nodes[i].output_port,
matrix = np.array(memory_template[:,key_idx].tolist()).transpose().astype(float),
function=MatrixTransform(operation=args[key_idx][OPERATION],
normalize=args[key_idx][NORMALIZE])))
field.match_node = (ProcessingMechanism(name=self.key_names[i] + MATCH_TO_KEYS_AFFIX,
input_ports= {INPUT_SHAPES:memory_capacity,
PROJECTIONS: memory_projection}))
Expand Down Expand Up @@ -2506,24 +2513,14 @@ def _construct_softmax_node(self, memory_capacity, softmax_gain, softmax_thresho
def _construct_retrieved_nodes(self, memory_template)->list:
"""Create nodes that report the value field(s) for the item(s) matched in memory.
"""
key_idx = 0
value_idx = 0
for field in self.fields:
# FIX: 11/24/24 - REFACTOR TO USE memory_template[:,self.index] ONCE MEMORY IS REFACTORED BASED ON FIELDS
if field.type == FieldType.KEY:
matrix = memory_template[:,key_idx]
key_idx += 1
else:
matrix = memory_template[:,self.num_keys + value_idx]
key_idx += 1

field.retrieved_node = (
ProcessingMechanism(name=field.name + RETRIEVED_AFFIX,
input_ports={INPUT_SHAPES: len(field.input_node.variable[0]),
PROJECTIONS:
MappingProjection(
sender=self.softmax_node,
matrix=matrix,
matrix=memory_template[:,field.index],
name=f'MEMORY FOR {field.name} '
f'[RETRIEVE {field.type.name}]')}))
field.retrieve_projection = field.retrieved_node.path_afferents[0]
Expand Down
47 changes: 37 additions & 10 deletions tests/composition/test_emcomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,17 +434,44 @@ def test_field_args_and_map_assignments(self,
assert em._field_index_map[[k for k in em._field_index_map.keys()
if 'WEIGHT to WEIGHTED MATCH for KEY B' in k.name][0]] == 2

def test_field_weights_all_None_and_or_0(self):
with pytest.raises(EMCompositionError) as error_text:
EMComposition(memory_template=(3,1), memory_capacity=1, field_weights=[None, None, None])
assert error_text.value.error_value == (f"The entries in 'field_weights' arg for EM_Composition can't all "
f"be 'None' since that will preclude the construction of any keys.")
@pytest.mark.parametrize('field_weight_1', ([None], [0], [1]), ids=['None', '0', '1'])
@pytest.mark.parametrize('field_weight_2', ([None], [0], [1]), ids=['None', '0', '1'])
@pytest.mark.parametrize('field_weight_3', ([None], [0], [1]), ids=['None', '0', '1'])
def test_order_fields_in_memory(self, field_weight_1, field_weight_2, field_weight_3):
"""Test that order of keys and values doesn't matter"""

with pytest.warns(UserWarning) as warning:
EMComposition(memory_template=(3,1), memory_capacity=1, field_weights=[0, None, 0])
warning_msg = (f"All of the entries in the 'field_weights' arg for EM_Composition are either None or set to 0; "
f"this will result in no retrievals unless/until the 0(s) is(are) changed to a positive value.")
assert warning_msg in str(warning[0].message)
# pytest.skip(<UNECESSARY TESTS>>)

def construct_em(field_weights):
return pnl.EMComposition(memory_template=[[[5,0], [5], [5,0,3]], [[20,0], [20], [20,1,199]]],
memory_capacity=4,
field_weights=field_weights)

field_weights = field_weight_1 + field_weight_2 + field_weight_3

if all([fw is None for fw in field_weights]):
with pytest.raises(EMCompositionError) as error_text:
construct_em(field_weights)
assert ("The entries in 'field_weights' arg for EM_Composition can't all be 'None' "
"since that will preclude the construction of any keys." in str(error_text.value))

elif not any(field_weights):
with pytest.warns(UserWarning) as warning:
construct_em(field_weights)
warning_msg = ("All of the entries in the 'field_weights' arg for EM_Composition "
"are either None or set to 0; this will result in no retrievals "
"unless/until one or more of them are changed to a positive value.")
assert warning_msg in str(warning[0].message)

elif any([fw == 0 for fw in field_weights]):
with pytest.warns(UserWarning) as warning:
construct_em(field_weights)
warning_msg = ("Some of the entries in the 'field_weights' arg for EM_Composition are set to 0; those "
"fields will be ignored during retrieval unless/until they are changed to a positive value.")
assert warning_msg in str(warning[0].message)

else:
construct_em(field_weights)


@pytest.mark.pytorch
Expand Down
Loading