Skip to content

Commit

Permalink
fix signature bugs (#10)
Browse files Browse the repository at this point in the history
* fix signature bugs

* Update setup.py
  • Loading branch information
adiazulay authored Nov 3, 2020
1 parent 7960f35 commit a9d5159
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 9 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from setuptools import setup, find_packages
setup(
name="lobe",
version="0.2.0",
version="0.2.1",
packages=find_packages("src"),
package_dir={"": "src"},
install_requires=[
Expand Down
2 changes: 1 addition & 1 deletion src/lobe/Signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def load(model_path: str) -> Signature:
class Signature:
def __init__(self, signature_path: str):
signature_path = pathlib.Path(signature_path)
self.__model_path = signature_path.parent
self.__model_path = str(signature_path.parent)

with open(signature_path, "r") as f:
self.__signature = json.load(f)
Expand Down
12 changes: 5 additions & 7 deletions src/lobe/backends/_backend_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@

class ImageClassificationModel():
__input_key_image = 'Image'
__input_key_batch_size = "batch_size"
__output_key_labels = 'Labels_idx_000'
__output_key_confidences = 'Labels_idx_001'
__output_key_confidences = 'Confidences'
__output_key_prediction = 'Prediction'

def __init__(self, signature):
self.__model_path = signature.model_path
self.__tf_predict_fn = None
self.__labels = signature.classes

def __load(self):
self.__tf_predict_fn = predictor.from_saved_model(self.__model_path)
Expand All @@ -34,11 +33,10 @@ def predict(self, image: Image.Image) -> PredictionResult:
np_image = np_image[np.newaxis, ...]

predictions = self.__tf_predict_fn({
self.__input_key_image: np_image,
self.__input_key_batch_size: 1 })
self.__input_key_image: np_image
})

labels = [label.decode('utf-8') for label in predictions[self.__output_key_labels][0].tolist()]
confidences = predictions[self.__output_key_confidences][0]
top_prediction = predictions[self.__output_key_prediction][0].decode('utf-8')

return PredictionResult(labels=labels, confidences=confidences, prediction=top_prediction)
return PredictionResult(labels=self.__labels, confidences=confidences, prediction=top_prediction)

0 comments on commit a9d5159

Please sign in to comment.