Skip to content

Instantly share code, notes, and snippets.

@SupreethRao99
Created April 8, 2022 03:48
Show Gist options
  • Select an option

  • Save SupreethRao99/e8f4ccedaff09caa8180d76caa29b584 to your computer and use it in GitHub Desktop.

Select an option

Save SupreethRao99/e8f4ccedaff09caa8180d76caa29b584 to your computer and use it in GitHub Desktop.
Inference code for RNN-Transducer
  1. To convert to tflite, clone the TensorflowASR repository and build its dependencies
git clone https://github.com/TensorSpeech/TensorFlowASR.git
cd TensorFlowASR
pip3 install -e ".[tf2.8]"
  1. Download the model weights and config file from here (https://drive.google.com/drive/folders/1rYpiYF0F9JIsAKN2DCFFtEdfNzVbBLHe?usp=sharing)
  2. Add the downloaded config file to the same directory as the conversion file. ie to ./TensorFlowASR/examples/rnn_transducer/inference/
  3. Change the code in gen_tflite_model.py to the code below
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import fire
from tensorflow_asr.utils import env_util

logger = env_util.setup_environment()
import tensorflow as tf

from tensorflow_asr.configs.config import Config
from tensorflow_asr.models.transducer.rnn_transducer import RnnTransducer
from tensorflow_asr.helpers import exec_helpers, featurizer_helpers

DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")

tf.compat.v1.enable_control_flow_v2()


def main(
    config: str = DEFAULT_YAML,
    subwords: bool = True,
    sentence_piece: bool = False,
    output: str = "./rnnt-tflite.tflite",
):

    tf.keras.backend.clear_session()
    tf.compat.v1.enable_control_flow_v2()

    config = Config(config)
    speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers(
        config=config,
        subwords=subwords,
        sentence_piece=sentence_piece,
    )

    rnn_transducer = RnnTransducer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
    rnn_transducer.make(speech_featurizer.shape)
    rnn_transducer.load_weights("<path-to-downloaded-weights (.h5) file>", by_name=True)
    rnn_transducer.summary(line_length=100)
    rnn_transducer.add_featurizers(speech_featurizer, text_featurizer)
    exec_helpers.convert_tflite(model=rnn_transducer, output=output)


if __name__ == "__main__":
    fire.Fire(main)
  1. To run the above code, run python3 gen_tflite_model.py in the terminal.
  2. Error that i'm getting
tensorflow.lite.python.convert_phase.ConverterError: input resource[0] expected type resource != float, the type of rnn_transducer_greedy_while_rnn_transducer_decoder_rnn_transducer_prediction_embedding_embedding_lookup_13010_0[0]
        In {{node rnn_transducer_greedy/while/rnn_transducer_decoder/rnn_transducer_prediction_embedding/embedding_lookup}}
        Failed to functionalize Control Flow V1 ops. Consider using Control Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/tf/compat/v1/enable_control_flow_v2.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment