Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model.fit - class_weight broken #20542

Open
GICodeWarrior opened this issue Nov 23, 2024 · 1 comment
Open

model.fit - class_weight broken #20542

GICodeWarrior opened this issue Nov 23, 2024 · 1 comment

Comments

@GICodeWarrior
Copy link

It seems argmax is returning dtype=int64 in the true case and int32 is returned in the false case.

y_classes = tf.__internal__.smart_cond.smart_cond(
tf.shape(y)[-1] > 1,
lambda: tf.argmax(y, axis=-1),
lambda: tf.cast(tf.round(tf.squeeze(y, axis=-1)), tf.int32),
)

Stacktrace:

Traceback (most recent call last):
  File "/home/example/workspace/fir/trainer/train.py", line 122, in <module>
    model.fit(
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 113, in error_handler
    return fn(*args, **kwargs)
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 282, in fit
    epoch_iterator = TFEpochIterator(
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 664, in __init__
    super().__init__(*args, **kwargs)
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/keras/src/trainers/epoch_iterator.py", line 64, in __init__
    self.data_adapter = data_adapters.get_data_adapter(
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__init__.py", line 56, in get_data_adapter
    return TFDatasetAdapter(
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/keras/src/trainers/data_adapters/tf_dataset_adapter.py", line 30, in __init__
    dataset = dataset.map(
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 2341, in map
    return map_op._map_v2(
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/tensorflow/python/data/ops/map_op.py", line 43, in _map_v2
    return _MapDataset(
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/tensorflow/python/data/ops/map_op.py", line 157, in __init__
    self._map_func = structured_function.StructuredFunctionWrapper(
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/tensorflow/python/data/ops/structured_function.py", line 265, in __init__
    self._function = fn_factory()
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 1251, in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 1221, in _get_concrete_function_garbage_collected
    self._initialize(args, kwargs, add_initializers_to=initializers)
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 696, in _initialize
    self._concrete_variable_creation_fn = tracing_compilation.trace_function(
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 178, in trace_function
    concrete_function = _maybe_define_function(
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 283, in _maybe_define_function
    concrete_function = _create_concrete_function(
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 310, in _create_concrete_function
    traced_func_graph = func_graph_module.func_graph_from_py_func(
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py", line 1059, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 599, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/tensorflow/python/data/ops/structured_function.py", line 231, in wrapped_fn
    ret = wrapper_helper(*args)
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/tensorflow/python/data/ops/structured_function.py", line 161, in wrapper_helper
    ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 690, in wrapper
    return converted_call(f, args, kwargs, options=options)
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 377, in converted_call
    return _call_unconverted(f, args, kwargs, options)
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 459, in _call_unconverted
    return f(*args, **kwargs)
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/keras/src/trainers/data_adapters/tf_dataset_adapter.py", line 129, in class_weights_map_fn
    y_classes = tf.__internal__.smart_cond.smart_cond(
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/tensorflow/python/framework/smart_cond.py", line 57, in smart_cond
    return cond.cond(pred, true_fn=true_fn, false_fn=false_fn,
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/example/.local/share/virtualenvs/trainer-gT8lgKB3/lib/python3.10/site-packages/tensorflow/python/ops/cond_v2.py", line 880, in error
    raise TypeError(
TypeError: true_fn and false_fn arguments to tf.cond must have the same number, type, and overall structure of return values.

true_fn output: Tensor("cond/Identity:0", shape=(2048,), dtype=int64)
false_fn output: Tensor("cond/Identity:0", shape=(2048,), dtype=int32)

Error details:
Tensor("cond/Identity:0", shape=(2048,), dtype=int64) and Tensor("cond/Identity:0", shape=(2048,), dtype=int32) have different types
@fchollet
Copy link
Member

Thanks for the report, can you propose a unit test to reproduce this issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants