Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CTC Decoding for JAX and Tensorflow #19366

Merged
merged 29 commits into from
Apr 20, 2024

Conversation

MaanasArora
Copy link
Contributor

@MaanasArora MaanasArora commented Mar 24, 2024

Hello,

This is WIP, but I wanted to create a PR for the changes so far as I continue to work. The PR currently contains:

  • A Keras op for decoding using greedy and beam search strategies
  • Implementation for both greedy and beam search strategies in TF
  • Implementation for greedy strategy in JAX (beam search is WIP)

I will add torch support and unit testing. Any feedback is appreciated.

Thank you!

@abhaskumarsinha
Copy link
Contributor

Hello, can you add these codes in keras-nlp? I believe it'd be more helpful to have token search strategies in Keras NLP because all the LLMs to test these features are there only.
Src: https://github.com/keras-team/keras-nlp

@mattdangerw Please check!

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR!

I believe it'd be more helpful to have token search strategies in Keras NLP because all the LLMs to test these features are there only.

This is an interesting suggestion -- the Sampler functionality in KerasNLP is similar. + @mattdangerw to provide a judgement call here.

merge_repeated=True,
mask_index=None,
):
if strategy == "greedy":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would need the fallback case else:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, added, thanks.

)
else:
raise ValueError(
"Invalid strategy. Supported values are 'greedy' and 'beam_search'."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Print the strategy that was passed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed and done, thanks!


sequence_length = tf.convert_to_tensor(sequence_length)
if strategy == "greedy":
decoded, probs = tf.nn.ctc_greedy_decoder(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just do return tf.nn.ctc_greedy_decoder(...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done, thanks for noticing.

blank_index=mask_index,
)
elif strategy == "beam_search":
decoded, probs = tf.nn.ctc_beam_search_decoder(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done as well.

keras/ops/nn.py Outdated

if any_symbolic_tensors((inputs, sequence_lengths)):
raise NotImplementedError(
"CTC decoding is not supported in graph mode. "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit of a misleading message. You could simply "is not supported with KerasTensors. Use it inside the call() method of a Layer or the predict_step method of a model."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's definitely a much better and more accurate message! Just replaced it. Thank you!

@codecov-commenter
Copy link

codecov-commenter commented Mar 25, 2024

Codecov Report

Attention: Patch coverage is 93.75000% with 8 lines in your changes are missing coverage. Please review.

Project coverage is 76.36%. Comparing base (7386704) to head (25b67f5).
Report is 3 commits behind head on master.

Files Patch % Lines
keras/src/backend/jax/nn.py 96.33% 1 Missing and 3 partials ⚠️
keras/src/backend/tensorflow/nn.py 77.77% 1 Missing and 1 partial ⚠️
keras/api/_tf_keras/keras/ops/__init__.py 0.00% 1 Missing ⚠️
keras/api/_tf_keras/keras/ops/nn/__init__.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19366      +/-   ##
==========================================
+ Coverage   76.31%   76.36%   +0.04%     
==========================================
  Files         496      496              
  Lines       44077    44325     +248     
  Branches     8170     8214      +44     
==========================================
+ Hits        33637    33847     +210     
- Misses       8745     8764      +19     
- Partials     1695     1714      +19     
Flag Coverage Δ
keras 76.20% <93.75%> (+0.04%) ⬆️
keras-jax 61.18% <88.28%> (-0.03%) ⬇️
keras-numpy 55.03% <7.81%> (-0.22%) ⬇️
keras-tensorflow 62.12% <13.28%> (-0.25%) ⬇️
keras-torch 61.02% <7.03%> (-0.27%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@gbaned gbaned added this to Assigned Reviewer in PR Queue via automation Mar 25, 2024
@mattdangerw
Copy link
Member

I won't claim to be an expert on CTC classification (so may be missing lots here). But I would probably keep this separate from sampling strategies for open ended generation in KerasNLP. Seems like there's too many differences that might rear their head given that end tasks the two are solving are different.

E.g. I see jax.lax.scan here, but I don't think that would work for generation with early stopping on a "end token". For early stopping with a condition we need a jax.lax.while_loop (or a python loop with an inner compiled step).

So given that we are adding this here in part for backward compat with keras 2 (at least I think?), probably can continue to land as is. We can always stitch KerasNLP and these Keras utils together in an example.

@fchollet
Copy link
Member

So given that we are adding this here in part for backward compat with keras 2 (at least I think?), probably can continue to land as is. We can always stitch KerasNLP and these Keras utils together in an example.

Sounds good -- let's keep them separate then.

@MaanasArora we'll be able to merge this soon! Please add unit tests.

@MaanasArora
Copy link
Contributor Author

Thank you @fchollet! I have added a unit test for correctness. I omitted the shape test to maintain consistency because ctc_decode does not support KerasTensors.

@abhaskumarsinha
Copy link
Contributor

I'm not an expert either tbh.

E.g. I see jax.lax.scan here, but I don't think that would work for generation with early stopping on a "end token". For early stopping with a condition we need a jax.lax.while_loop (or a python loop with an inner compiled step).

Ah, I see. Thank you for pointing it.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update. Do you also intend to add the torch implementation? I think there's a built-in utility in torch that we could use. If not, you can add a NotImplementError in the torch backend.

score_labels = np.array([[-1.2], [-1.3], [-0.4]])

(decoded,), scores = knn.ctc_decode(
inputs, sequence_lengths=[3, 3, 1], strategy="greedy"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also check the different strategies (when available) and the behavior of the different arguments for the beam search strategy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just added a unit test for beam search when the backend is TF, thanks!

@MaanasArora
Copy link
Contributor Author

Thank you for reviewing! Yes, I do intend to add the torch implementation. The built-in implementation seems to be in torch audio rather than torch core, so I wondered if we should use it directly. If you would like, I can include both the JAX beam search implementation and the torch implementations in another PR.

@fchollet
Copy link
Member

Cool -- you can keep adding the functions to the current PR, that will make it easier to handle any potential API changes.

@fchollet
Copy link
Member

@MaanasArora are you still working on this PR?

@MaanasArora
Copy link
Contributor Author

Hi @fchollet, yes, sorry for the long delay. I have an implementation for CTC beam search in JAX but am struggling with debugging a discrepancy between TF's implementation and mine.

@fchollet
Copy link
Member

Any way we can help?

@MaanasArora
Copy link
Contributor Author

MaanasArora commented Apr 11, 2024

For sure, thank you. I have pushed the changes I have so far. Please note they currently only work for merge_repeated=True and do not yet handle the final removal of redundant tokens and zero padding. I will provide a description of my issue.

The tests for CTC beam search decode fail because of the way I am handling masked entries. In the way I am using _pad to replace mask_index on beam extensions, repetitions of the same character are not possible since there is no record at each time step of whether a path ends with mask_index. However, when we do use mask_index appropriately to extend sequences, at least once, we cannot use jnp.unique to accurately filter for unique paths, as paths with a final mask token and those without are treated as unique.

I have been struggling to incorporate jnp.unique somehow because of simplicity, but given it seems difficult, I am now planning to use jnp.lexsort (or lax.scan, if needs be) to find unique paths instead, so we can use custom functionality to merge the scores for paths that are identical except for a final mask_index.

Perhaps this will work. But I am also wondering whether I am approaching this wrong? Any help may speed up coding and would be appreciated. Thanks again.

@MaanasArora
Copy link
Contributor Author

MaanasArora commented Apr 13, 2024

Hello, I was able to fix the issue I mentioned. I used a boolean array to decouple masking from the path arrays. I also addressed the effect of the sequence_length and merge_repetitions arguments. But there is still a bug regarding the storage of beam entries; the output differs from the TF implementation when the beam width is small. I will be working on that. Thank you!

@MaanasArora
Copy link
Contributor Author

MaanasArora commented Apr 15, 2024

Hello, I've fixed the bugs I mentioned, refactored for clarity, added JAX to the unit test for beam search, and also improved some of the unit testing. I'm planning to try to improve the code structure further if possible. Thank you.

@MaanasArora
Copy link
Contributor Author

Just fixed another minor bug--it seems that jnp.argsort does not support the order parameter, without which sorting causes a discrepancy with the tensorflow implementation. I've flipped the inputs on the class axis and then flipped back to the right classes at the end as a workaround.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work!

mask_index=None,
):
inputs = jnp.array(inputs)
sequence_length = jnp.array(sequence_length)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we cast this to int32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, nice catch! Done.

mask_index=None,
):
inputs = jnp.array(inputs)
sequence_length = jnp.array(sequence_length)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

keras/ops/nn.py Outdated
A tuple containing:

- a list of decoded sequences.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove line break

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

@sampathweb
Copy link
Collaborator

Please note that we refactored code to move all sources to keras/src and public API is generated in keras/api folder. Please rebase your code with master.

Please run ./shell/api_gen.sh script when making changes that affect keras_export.

@MaanasArora
Copy link
Contributor Author

@sampathweb This is done; thank you for the information.

@fchollet
Copy link
Member

+16,586 −6,035

It looks like you will need to rebase your PR on master to avoid extraneous changes.

@MaanasArora
Copy link
Contributor Author

Yes, this is done @fchollet

@MaanasArora
Copy link
Contributor Author

Hi @fchollet, so although torchaudio does have a CTC decoding implementation, it is not a dependency of Keras. Also, I am not sure if the numerics would be consistent with TF and (our replicated implementation in) JAX. Is it better to create a custom solution?

@fchollet
Copy link
Member

Hi @fchollet, so although torchaudio does have a CTC decoding implementation, it is not a dependency of Keras. Also, I am not sure if the numerics would be consistent with TF and (our replicated implementation in) JAX. Is it better to create a custom solution?

It depends!

  1. If we want to use it, we can do an inline import of torchaudio in the backend function, except the ImportError, and display an error that asks users to install it. We this with torchvision in a few places.
  2. We could also take a look at the implementation and judge how difficult it would be to simply fork it.
  3. We could also simply convert the JAX implementation to torch. Would it be difficult?

I think a big factor here is whether we're getting the same numerics with the torchaudio version. If not, better to reimplement (based on the JAX code, and potentially using the torchaudio code as reference).

@MaanasArora
Copy link
Contributor Author

MaanasArora commented Apr 18, 2024

It seems torchaudio uses the Flashlight library under the hood. I don't seem to be able to get the same numerics, unfortunately. It should not be too difficult to convert the JAX code, though I'm not very familiar with torch's low-level functions yet, so it may take longer. In case you'd like to release this first or open this up to others, we can probably create another PR for that? Thank you!

@fchollet
Copy link
Member

In case you'd like to release this first or open this up to others, we can probably create another PR for that? Thank you!

Sounds good -- please add the op function in backend/torch/nn.py and backend/numpy/nn.py and raise NotImplementedError in the functions. Also format the code via sh shell/format.sh. I'll do a full review tomorrow.

@MaanasArora
Copy link
Contributor Author

Sounds good -- please add the op function in backend/torch/nn.py and backend/numpy/nn.py and raise
NotImplementedError in the functions. Also format the code via sh shell/format.sh. I'll do a full review tomorrow.

Both done! Thank you for reviewing!

PR Queue automation moved this from Assigned Reviewer to Approved by Reviewer Apr 20, 2024
Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great -- thank you for the contribution!

Will do a few minor docstring edits post-merge.

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Apr 20, 2024
@fchollet fchollet merged commit 2e31633 into keras-team:master Apr 20, 2024
6 checks passed
PR Queue automation moved this from Approved by Reviewer to Merged Apr 20, 2024
@google-ml-butler google-ml-butler bot removed ready to pull Ready to be merged into the codebase kokoro:force-run labels Apr 20, 2024
fchollet added a commit that referenced this pull request May 3, 2024
* Introduce float8 training (#19488)

* Add float8 training support

* Add tests for fp8 training

* Add `quantize_and_dequantize` test

* Fix bugs and add float8 correctness tests

* Cleanup

* Address comments and cleanup

* Add docstrings and some minor refactoring

* Add `QuantizedFloat8DTypePolicy`

* Add dtype policy setter

* Fix torch dynamo issue by using `self._dtype_policy`

* Improve test coverage

* Add LoRA to ConvND layers (#19516)

* Add LoRA to `BaseConv`

* Add tests

* Fix typo

* Fix tests

* Fix tests

* Add path to run keras on dm-tree when optree is not available.

* feat(losses): add Tversky loss implementation (#19511)

* feat(losses): add Tversky loss implementation

* adjusted documentation

* Update KLD docs

* Models and layers now return owned metrics recursively. (#19522)

- added `Layer.metrics` to return all metrics owned by the layer and its sub-layers recursively.
- `Layer.metrics_variables` now returns variables from all metrics recursively, not just the layer and its direct sub-layers.
- `Model.metrics` now returns all metrics recursively, not just the model level metrics.
- `Model.metrics_variables` now returns variables from all metrics recursively, not just the model level metrics.
- added test coverage to test metrics and variables 2 levels deep.

This is consistent with the Keras 2 behavior and how `Model/Layer.variables` and `Model/Layer.weights` work.

* Update IoU ignore_class handling

* Fix `RandomBrightness`, Enhance `IndexLookup` Initialization and Expand Test Coverage for `Preprocessing Layers` (#19513)

* Add tests for CategoryEncoding class in category_encoding_test.py

* fix

* Fix IndexLookup class initialization and add test cases

* Add test case for IndexLookupLayerTest without vocabulary

* Fix IndexLookup class initialization

* Add normalization test cases

* Add test cases for Hashing class

* Fix value range validation error in RandomBrightness class

* Refactor IndexLookup class initialization and add test cases

* Reffix ndexLookup class initialization and afix est cases

* Add test for spectral norm

* Add missing test decorator

* Fix torch test

* Fix code format

* Generate API (#19530)

* API Generator for Keras

* API Generator for Keras

* Generates API Gen via api_gen.sh

* Remove recursive import of _tf_keras

* Generate API Files via api_gen.sh

* Update APIs

* Added metrics from custom `train_step`/`test_step` are now returned. (#19529)

This works the same way as in Keras 2, whereby the metrics are returned directly from the logs if the set of keys doesn't match the model metrics.

* Use temp dir and abs path in `api_gen.py` (#19533)

* Use temp dir and abs path

* Use temp dir and abs path

* Update Readme

* Update API

* Fix gradient accumulation when using `overwrite_with_gradient` during float8 training (#19534)

* Fix gradient accumulation with `overwrite_with_gradient` in float8 training

* Add comments

* Fix annotation

* Update code path in ignore path (#19537)

* Add operations per run (#19538)

* Include input shapes in model visualization.

* Add pad_to_aspect_ratio feature in ops.image.resize

* Add pad_to_aspect_ratio feature in Resizing layer.

* Fix incorrect usage of `quantize` (#19541)

* Add logic to prevent double quantization

* Add detailed info for double quantization error

* Update error msg

* Add eigh op.

* Add keepdim in argmax/argmin.

* Fix small bug in model.save_weights (#19545)

* Update public APIs.

* eigh should work on JAX GPU

* Copy init to keras/__init__.py (#19551)

* Revert "Copy init to keras/__init__.py (#19551)" (#19552)

This reverts commit da9af61.

* sum-reduce inlined losses

* Remove the dependency on `tensorflow.experimental.numpy` and support negative indices for `take` and `take_along_axis` (#19556)

* Remove `tfnp`

* Update numpy api

* Improve test coverage

* Improve test coverage

* Fix `Tri` and `Eye` and increase test converage

* Update `round` test

* Fix `jnp.round`

* Fix `diag` bug for iou_metrics

* Add op.select.

* Add new API for select

* Make `ops.abs` and `ops.absolute` consistent between backends. (#19563)

- The TensorFlow implementation was missing `convert_to_tensor`
- The sparse annotation was unnecessarily applied twice
- Now `abs` calls `absolute` in all backends

Also fixed TensorFlow `ops.select`.

* Add pickle support for Keras model (#19555)

* Implement unit tests for pickling

* Reformat model_test

* Reformat model_test

* Rename depickle to unpickle

* Rename depickle to unpickle

* Reformat

* remove a comment

* Ellipsis Serialization and tests (#19564)

* Serialization and tests

* Serialization and tests

* Serialization and tests

* Make TF one_hot input dtype less strict.

* Fix einsum `_int8_call` (#19570)

* CTC Decoding for JAX and Tensorflow (#19366)

* Tensorflow OP for CTC decoding

* JAX op for CTC greedy decoding

* Update CTC decoding documentation

* Fix linting issues

* Fix trailing whitespace

* Simplify returns in tensorflow CTC wrapper

* Fix CTC decoding error messages

* Fix line too long

* Bug fixes to JAX CTC greedy decoder

* Force int typecast in TF CTC decoder

* Unit tests for CTC greedy decoding

* Add unit test for CTC beam search decoding

* Fix mask index set location in JAX CTC decoding

* CTC beam search decoding for JAX

* Fix unhandled token repetitions in ctc_beam_search_decode

* Fix merge_repeated bug in CTC beam search decode

* Fix beam storage and repetition bugs in JAX ctc_decode

* Remove trailing whitespace

* Fix ordering bug for ties in JAX CTC beam search

* Cast sequence lengths to integers in JAX ctc_decode

* Remove line break in docstring

* CTC beam search decoding for JAX

* Fix unhandled token repetitions in ctc_beam_search_decode

* Fix merge_repeated bug in CTC beam search decode

* Fix beam storage and repetition bugs in JAX ctc_decode

* Fix ordering bug for ties in JAX CTC beam search

* Generate public api directory

* Add not implemented errors for NumPy and Torch CTC decoding

* Remove unused redefinition of JAX ctc_beam_search_decode

* Docstring edits

* Expand nan_to_num args.

* Add vectorize op.

* list insert requires index (#19575)

* Add signature and exclude args to knp.vectorize.

* Fix the apis of `dtype_polices` (#19580)

* Fix api of `dtype_polices`

* Update docstring

* Increase test coverage

* Fix format

* Fix keys of `save_own_variables` and `load_own_variables` (#19581)

* Fix JAX CTC test.

* Fix loss_weights handling in single output case

* Fix JAX vectorize.

* Move _tf_keras directory to the root of the pip package.

* One time fix to _tf_keras API.

* Convert return type imdb.load_data to nparray (#19598)

Convert return type imdb.load_data to Numpy array. Currently X_train and X-test returned as list.

* Fix typo

* fix api_gen.py for legacy (#19590)

* fix api_gen.py for legacy

* merge api and legacy for _tf_keras

* Improve int8 for `Embedding` (#19595)

* pin torch < 2.3.0 (#19603)

* Clean up duplicated `inputs_quantizer` (#19604)

* Cleanup duplicated `inputs_quantizer` and add type check for `input_spec` and `supports_masking`

* Revert setter

* output format changes and errors in github (#19608)

* Provide write permission to action for cache management. (#19606)

* Pickle support for all saveables (#19592)

* Pickle support

* Add keras pickleable mixin

* Reformat

* Implement pickle all over

* reformat

* Reformat

* Keras saveable

* Keras saveable

* Keras saveable

* Keras saveable

* Keras saveable

* obj_type

* Update pickleable

* Saveable logic touchups

* Add slogdet op.

* Update APIs

* Remove unused import

* Refactor CTC APIs (#19611)

* Add `ctc_loss` and `ctc_decode` for numpy backend, improve imports and tests

* Support "beam_search" strategy for torch's `ctc_decode`

* Improve `ctc_loss`

* Cleanup

* Refactor `ctc_decode`

* Update docstring

* Update docstring

* Add `CTCDecode` operation and ensure dtype inference of `ctc_decode`

* Fix `name` of `losses.CTC`

* update the namex version requirements (#19617)

* Add `PSNR` API (#19616)

* PSNR

* Fix

* Docstring format

* Remove `PYTORCH_ENABLE_MPS_FALLBACK` flag requirement for mps (#19618)

* Remove `PYTORCH_ENABLE_MPS_FALLBACK` flag requirement for mps

* Formatting

* Implement custom layer insertion in clone_model. (#19610)

* Implement custom layer insertion in clone_model.

* Add recursive arg and tests.

* Add nested sequential cloning test

* Fix bidir lstm saving issue.

* Fix CI

* Fix cholesky tracing with jax

* made extract_patches dtype agnostic (#19621)

* Simplify Bidirectional implementation

* Add support for infinite `PyDataset`s. (#19624)

`PyDataset` now uses the `num_batches` property instead of `__len__` to support `None`, which is how one indicates the dataset is infinite. Note that infinite datasets are not shuffled.

Fixes #19528

Also added exception reporting when using multithreading / multiprocessing. Previously, the program would just hang with no error reported.

* Fix dataset shuffling issue.

* Update version string.

* Minor fix

* Restore version string resolution in pip_build.

* Speed up `DataAdapter` tests by testing only the current backend. (#19625)

There is no use case for using an iterator for a different backend than the current backend.

Also:
- limit the number of tests using multiprocessing, the threading tests give us good coverage.
- fixed the `test_exception_reported` test, which was not actually exercising the multiprocessing / multithreading cases.
- removed unused `init_pool` method.

* feat(ops): support np.argpartition (#19588)

* feat(ops): support np.argpartition

* updated documentation, type-casting, and tf implementation

* fixed tf implementation

* added torch cast to int32

* updated torch type and API generated files

* added torch output type cast

* test(trainers): add test_errors implementation for ArrayDataAdapter class (#19626)

* Fix torch GPU CI

* Fix argmax/argmin keepdims with defined axis in TF

* Misc fixes in TF backend ops.

* Fix `argpartition` cuda bug in torch (#19634)

* fix(ops): specify NonZero output dtype and add test coverage (#19635)

* Fix `ops.ctc_decode` (#19633)

* Fix greedy ctc decode

* Remove print

* Fix `tf.nn.ctc_beam_search_decoder`

* Change default `mask_index` to `0`

* Fix losses test

* Update

* Ensure the same rule applies for np arrays in autocasting (#19636)

* Ensure the same rule applies for np arrays in autocasting

* Trigger CI by adding docstring

* Update

* Update docstring

* Fix `istft` and add class `TestMathErrors` in `ops/math_test.py` (#19594)

* Fix and test math functions for jax backend

* run /workspaces/keras/shell/format.sh

* refix

* fix

* fix _get_complex_tensor_from_tuple

* fix

* refix

* Fix istft function to handle inputs with less than 2 dimensions

* fix

* Fix ValueError in istft function for inputs with less than 2 dimensions

* Return a tuple from `ops.shape` with the Torch backend. (#19640)

With Torch, `x.shape` returns a `torch.Size`, which is a subclass of `tuple` but can cause different behaviors. In particular `convert_to_tensor` does not work on `torch.Size`.

This fixes #18900

* support conv3d on cpu for TF (#19641)

* Enable cudnn rnns when dropout is set (#19645)

* Enable cudnn rnns when dropout is set

* Fix

* Fix plot_model for input dicts.

* Fix deprecation warning in torch

* Bump the github-actions group with 2 updates (#19653)

Bumps the github-actions group with 2 updates: [actions/upload-artifact](https://github.com/actions/upload-artifact) and [github/codeql-action](https://github.com/github/codeql-action).


Updates `actions/upload-artifact` from 4.3.1 to 4.3.3
- [Release notes](https://github.com/actions/upload-artifact/releases)
- [Commits](actions/upload-artifact@5d5d22a...6546280)

Updates `github/codeql-action` from 3.24.9 to 3.25.3
- [Release notes](https://github.com/github/codeql-action/releases)
- [Changelog](https://github.com/github/codeql-action/blob/main/CHANGELOG.md)
- [Commits](github/codeql-action@1b1aada...d39d31e)

---
updated-dependencies:
- dependency-name: actions/upload-artifact
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: github-actions
- dependency-name: github/codeql-action
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Bump the python group with 2 updates (#19654)

Bumps the python group with 2 updates: torch and torchvision.


Updates `torch` from 2.2.1+cu121 to 2.3.0+cu121

Updates `torchvision` from 0.17.1+cu121 to 0.18.0+cu121

---
updated-dependencies:
- dependency-name: torch
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: python
- dependency-name: torchvision
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: python
...

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Revert "Bump the python group with 2 updates (#19654)" (#19655)

This reverts commit 09133f4.

---------

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: james77777778 <[email protected]>
Co-authored-by: Francois Chollet <[email protected]>
Co-authored-by: Luca Pizzini <[email protected]>
Co-authored-by: hertschuh <[email protected]>
Co-authored-by: Faisal Alsrheed <[email protected]>
Co-authored-by: Ramesh Sampath <[email protected]>
Co-authored-by: Sachin Prasad <[email protected]>
Co-authored-by: Uwe Schmidt <[email protected]>
Co-authored-by: Luke Wood <[email protected]>
Co-authored-by: Maanas Arora <[email protected]>
Co-authored-by: AlexanderLavelle <[email protected]>
Co-authored-by: Surya <[email protected]>
Co-authored-by: Shivam Mishra <[email protected]>
Co-authored-by: Haifeng Jin <[email protected]>
Co-authored-by: IMvision12 <[email protected]>
Co-authored-by: Gabriel Rasskin <[email protected]>
Co-authored-by: Vachan V Y <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
PR Queue
Merged
Development

Successfully merging this pull request may close these issues.

None yet

7 participants