Skip to content

Commit

Permalink
add TFJS compatibility mode
Browse files Browse the repository at this point in the history
fix #23
  • Loading branch information
patlevin committed Oct 5, 2020
1 parent d05d468 commit 66a1614
Show file tree
Hide file tree
Showing 8 changed files with 233 additions and 20 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ for quick and easy model conversion.
| `-h`, `--help` | Show help message and exit |
| `--output_format` | Use `tf_frozen_model` (the default) to save a Tensorflow frozen model. `tf_saved_model` exports to a Tensorflow _SavedModel_ instead. |
| `--saved_model_tags` | Specifies the tags of the MetaGraphDef to save, in comma separated string format. Defaults to "serve". Applicable only if `--output_format` is `tf_saved_model` |
| `-c`, `--compat_mode` | Keep the input types compatible with TensorflowJS <=2.4.x |
| `-v`, `--version` | Shows the version of the converter and its dependencies. |
| `-s`, `--silent` | Suppresses any output besides error messages. |

Expand Down
28 changes: 24 additions & 4 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ __ https://www.tensorflow.org/api_docs/python/tf/Graph
.. code:: python
load_graph_model_and_signature(
model_dir: str
model_dir: str,
compat_mode: bool = False
) -> Tuple[tf.Graph, Optional[SignatureDef]]
Loads a tensorflowjs graph model from a directory and returns a TF v1
Expand All @@ -118,6 +119,10 @@ that contains the inputs and outputs of the model.
specified directly. Weight files must be located in the
same directory as the model file.

**compat_mode**
Set this argument to ``True`` to ensure that the resulting graph is
compatible with TensorflowJS if possible.

..
**Returns:**
Expand Down Expand Up @@ -282,7 +287,8 @@ input tensors as arguments and returns a list of model outputs as tensors.
graph_model_to_frozen_graph(
model_dir: str,
export_path: str
export_path: str,
compat_mode: bool = False
) -> str
Converts a tensorflowjs graph model to a tensorflow frozen graph.
Expand All @@ -303,6 +309,10 @@ The resulting graph is written to a **binary** protobuf message.
The file name usually ends in `.pb` and the directory
must exist.

**compat_mode**
Set this argument to ``True`` to ensure that the resulting graph is
compatible with TensorflowJS if possible.

..
**Returns:**
Expand Down Expand Up @@ -333,7 +343,8 @@ written.
export_dir: str,
tags: Union[str, List[str]] = None,
signature_def_map: dict = None,
signature_key_map: RenameMap = None
signature_key_map: RenameMap = None,
compat_mode: bool = False
) -> str
Converts a tensorflowjs graph model to a tensorflow `SavedModel`__
Expand Down Expand Up @@ -380,6 +391,10 @@ __ https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_
Optional mapping of tensor names to custom input or output names, see
`RenameMap`_.

**compat_mode**
Set this argument to ``True`` to ensure that the resulting graph is
compatible with TensorflowJS if possible.

..
**Returns:**
Expand Down Expand Up @@ -439,7 +454,8 @@ multi-head model):
model_list: List[Tuple[str, List[str]]],
export_dir: str,
signatures: dict = None,
signature_keys: Dict[str, RenameMap] = None
signature_keys: Dict[str, RenameMap] = None,
compat_mode: bool = False
) -> str
This function merges several tensorflowjs graph models into a single
Expand Down Expand Up @@ -478,6 +494,10 @@ __ https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_
``model_list``) to `RenameMap`_ instances for assigning new names to model
inputs and outputs.

**compat_mode**
Set this argument to ``True`` to ensure that the resulting graph is
compatible with TensorflowJS if possible.

..
**Returns:**
Expand Down
3 changes: 3 additions & 0 deletions docs/converter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ aren't very useful when called from within another script, though
Multiple tags can be given as a
comma-separated list.
-------------------------- ----------------------------------------------
--compat_mode, -c Keep the input types compatible with
TensorflowJS <=v2.4.x
-------------------------- ----------------------------------------------
--version, -v Prints the library version and the versions of
dependencies (TF, TFJS). *Useful only in CLI*
-------------------------- ----------------------------------------------
Expand Down
55 changes: 42 additions & 13 deletions tfjs_graph_converter/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from google.protobuf.json_format import ParseDict

import tfjs_graph_converter.common as common
from tfjs_graph_converter.compat import convert_int64_to_int32
from tfjs_graph_converter.convert_prelu import replace_prelu, split_fused_prelu
from tfjs_graph_converter.convert_fused_depthwise import split_fused_depthwise
from tfjs_graph_converter.graph_rewrite_util import validate_supported_ops
Expand Down Expand Up @@ -172,6 +173,20 @@ def _remove_channel_from_key(mapfield):
return signature_def


def _set_signature_dtypes(graph: tf.Graph, signature_def: util.SignatureDef
) -> Tuple[tf.Graph, util.SignatureDef]:
"""Set the dtype of each input and output to match the graph and return
both
"""
for key, value in signature_def.inputs.items():
node = graph.get_tensor_by_name(value.name)
value.dtype = node.dtype.as_datatype_enum
for key, value in signature_def.outputs.items():
node = graph.get_tensor_by_name(value.name)
value.dtype = node.dtype.as_datatype_enum
return graph, signature_def


def _create_graph(graph_def: GraphDef,
weight_dict: Dict[str, Tensor],
modifiers: Dict[str, Callable]) -> tf.Graph:
Expand Down Expand Up @@ -220,14 +235,16 @@ def _replace_unsupported_operations(


def _convert_graph_model_to_graph(model_json: Dict[str, Any],
base_path: str
base_path: str,
compat_mode: bool = False
) -> Tuple[tf.Graph, util.SignatureDef]:
"""
Convert TFJS JSON model to TF Graph
Args:
model_json: JSON dict from TFJS model file
base_path: Path to the model file (where to find the model weights)
compat_mode: True, if only TFJS datatypes should be used
Returns:
Tuple of TF Graph for inference or saving and TF signature definition
Expand All @@ -249,14 +266,16 @@ def _convert_graph_model_to_graph(model_json: Dict[str, Any],
name, data = common.TFJS_NAME_KEY, common.TFJS_DATA_KEY
weight_dict = dict((weight[name], weight[data]) for weight in weight_list)
graph_def, weight_modifiers = _replace_unsupported_operations(graph_def)

if compat_mode:
graph_def = convert_int64_to_int32(graph_def)
graph = _create_graph(graph_def, weight_dict, weight_modifiers)
signature_def = _extract_signature_def(model_json) or util.infer_signature(
graph)
return (graph, signature_def)
return _set_signature_dtypes(graph, signature_def)


def load_graph_model_and_signature(model_dir: str
def load_graph_model_and_signature(model_dir: str,
compat_mode: bool = False
) -> Tuple[tf.Graph, util.SignatureDef]:
"""
Load a TFJS Graph Model from a directory
Expand All @@ -265,6 +284,7 @@ def load_graph_model_and_signature(model_dir: str
model_dir: Directory that contains the tfjs model.json and weights;
alternatively name and path of the model.json if the name
differs from the default ("model.json")
compat_mode: If True, only TFJS-compatible datatypes are used
Returns:
Tupel of TF frozen graph for inference or saving and TF signature def
Expand All @@ -273,22 +293,23 @@ def load_graph_model_and_signature(model_dir: str
model_file_path = os.path.join(model_path, model_name)
with open(model_file_path, "r") as model_file:
model_json = json.load(model_file)
return _convert_graph_model_to_graph(model_json, model_path)
return _convert_graph_model_to_graph(model_json, model_path, compat_mode)


def load_graph_model(model_dir: str) -> tf.Graph:
def load_graph_model(model_dir: str, compat_mode: bool = False) -> tf.Graph:
"""
Load a TFJS Graph Model from a directory
Args:
model_dir: Directory that contains the tfjs model.json and weights;
alternatively name and path of the model.json if the name
differs from the default ("model.json")
compat_mode: If True, only TFJS-compatible datatypes are used
Returns:
TF frozen graph for inference or saving
"""
graph, _ = load_graph_model_and_signature(model_dir)
graph, _ = load_graph_model_and_signature(model_dir, compat_mode)
return graph


Expand Down Expand Up @@ -338,29 +359,32 @@ def _imports_graph_def():
tf.nest.map_structure(import_graph.as_graph_element, outputs))


def graph_model_to_frozen_graph(model_dir: str, export_path: str) -> str:
def graph_model_to_frozen_graph(model_dir: str, export_path: str,
compat_mode: bool = False) -> str:
"""
Convert a TFJS graph model to a frozen TF graph
Args:
model_dir: Directory that contains the TFJS JSON model and weights
export_path: Path to the frozen graph (e.g. './output.pb')
compat_mode: If True, only TFJS-compatible datatypes are used
Returns:
The path to the output proto-file.
"""
export_dir = os.path.dirname(export_path)
model_name = os.path.basename(export_path)

graph = load_graph_model(model_dir)
graph = load_graph_model(model_dir, compat_mode)
return tf.io.write_graph(graph, export_dir, model_name, as_text=False)


def graph_model_to_saved_model(model_dir: str,
export_dir: str,
tags: List[str] = None,
signature_def_map: dict = None,
signature_key_map: RenameMap = None) -> str:
signature_key_map: RenameMap = None,
compat_mode: bool = False) -> str:
"""
Convert a TFJS graph model to a SavedModel
Expand All @@ -387,11 +411,13 @@ def graph_model_to_saved_model(model_dir: str,
keys. The default signature uses tensor names for
signature keys. This argument allows to map tensor
names to different keys.
compat_mode: If True, only TFJS-compatible datatypes are used
Returns:
The path to which the model was written.
"""
graph, signature_def = load_graph_model_and_signature(model_dir)
graph, signature_def = load_graph_model_and_signature(model_dir,
compat_mode)
builder = tf.compat.v1.saved_model.Builder(export_dir)
signature_map = _get_signature_map(graph, signature_def, signature_def_map)
tags = _get_tags(tags)
Expand All @@ -408,7 +434,8 @@ def graph_model_to_saved_model(model_dir: str,
def graph_models_to_saved_model(model_list: List[Tuple[str, List[str]]],
export_dir: str,
signatures: dict = None,
signature_keys: Dict[str, RenameMap] = None
signature_keys: Dict[str, RenameMap] = None,
compat_mode: bool = False
) -> str:
"""
Read multiple TFJS graph models and saves them in a single SavedModel
Expand Down Expand Up @@ -438,6 +465,7 @@ def graph_models_to_saved_model(model_list: List[Tuple[str, List[str]]],
`model_list` tuples) to per-model signature key mappings. This
allows a remapping of signature inputs and outputs to different
keys (the tensor names stay unaffected).
compat_mode: If True, only TFJS-compatible datatypes are used
Returns:
The path to which the model was written.
Expand All @@ -454,7 +482,8 @@ def _apply_key_map(model, signature_map):

model_dir, tags = model_list[0]
tags = _get_tags(tags)
graph, signature_def = load_graph_model_and_signature(model_dir)
graph, signature_def = load_graph_model_and_signature(model_dir,
compat_mode)
signature = signatures[model_dir] if model_dir in signatures else None
signature_map = _get_signature_map(graph, signature_def, signature)
_apply_key_map(model_dir, signature_map)
Expand Down
1 change: 1 addition & 0 deletions tfjs_graph_converter/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@
CLI_SIGNATURE_KEY = 'signature_key'
CLI_METHOD_NAME = 'method_name'
CLI_RENAME = 'rename'
CLI_COMPATIBLE = 'compat_mode'
Loading

0 comments on commit 66a1614

Please sign in to comment.