Skip to content

Commit

Permalink
Add DifferentialBinarization model
Browse files Browse the repository at this point in the history
  • Loading branch information
gowthamkpr committed Sep 13, 2024
1 parent 753047d commit 3d06308
Show file tree
Hide file tree
Showing 4 changed files with 454 additions and 0 deletions.
13 changes: 13 additions & 0 deletions keras_nlp/src/models/diffbin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
243 changes: 243 additions & 0 deletions keras_nlp/src/models/diffbin/diffbin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import keras
from keras import layers

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.diffbin.losses import DBLoss
from keras_nlp.src.models.task import Task


@keras_nlp_export("keras_nlp.models.DifferentialBinarization")
class DifferentialBinarization(Task):
"""
A Keras model implementing the Differential Binarization
architecture for scene text detection, described in
[Real-time Scene Text Detection with Differentiable Binarization](
https://arxiv.org/abs/1911.08947).
Args:
backbone: A `keras_nlp.models.ResNetBackbone` instance.
fpn_channels: int. The number of channels to output by the Feature
Pyramid Network. Defaults to 256.
fpn_kernel_list: list of ints. The number of filters for probability
and threshold maps. Defaults to [3, 2, 2].
Examples:
```python
input_data = np.ones(shape=(8, 224, 224, 3))
backbone = keras_nlp.models.ResNetBackbone.from_preset("resnet50_vd")
detector = keras_nlp.models.DifferentialBinarization(
backbone=backbone
)
detector(input_data)
```
"""

def __init__(
self,
backbone,
fpn_channels=256,
head_kernel_list=[3, 2, 2],
step_function_k=50.0,
preprocessor=None, # adding this dummy arg for saved model test
**kwargs,
):
backbone = backbone

inputs = backbone.input
x = backbone.pyramid_outputs
x = diffbin_fpn_model(x, out_channels=fpn_channels)
probability_maps = diffbin_head(
x,
in_channels=fpn_channels,
kernel_list=head_kernel_list,
name="head_prob",
)
threshold_maps = diffbin_head(
x,
in_channels=fpn_channels,
kernel_list=head_kernel_list,
name="head_thresh",
)
binary_maps = step_function(
probability_maps, threshold_maps, k=step_function_k
)
outputs = layers.Concatenate(axis=-1)(
[probability_maps, threshold_maps, binary_maps]
)

super().__init__(inputs=inputs, outputs=outputs, **kwargs)

self.backbone = backbone
self.fpn_channels = fpn_channels
self.head_kernel_list = head_kernel_list
self.step_function_k = step_function_k

def compile(
self,
optimizer="auto",
loss="auto",
**kwargs,
):
"""Configures the `DifferentialBinarization` task for training.
`DifferentialBinarization` extends the default compilation signature of
`keras.Model.compile` with defaults for `optimizer` and `loss`. To
override these defaults, pass any value to these arguments during
compilation.
Args:
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
instance. Defaults to `"auto"`, which uses the default optimizer
for `DifferentialBinarization`. See `keras.Model.compile` and
`keras.optimizers` for more info on possible `optimizer` values.
loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
Defaults to `"auto"`, in which case the default loss
computation of `DifferentialBinarization` will be applied. See
`keras.Model.compile` and `keras.losses` for more info on
possible `loss` values.
**kwargs: See `keras.Model.compile` for a full list of arguments
supported by the compile method.
"""
if optimizer == "auto":
optimizer = keras.optimizers.SGD(
learning_rate=0.007, weight_decay=0.0001, momentum=0.9
)
if loss == "auto":
loss = DBLoss()
super().compile(
optimizer=optimizer,
loss=loss,
**kwargs,
)

def get_config(self):
# Backbone serialized in `super`
config = super().get_config()
config.update(
{
"fpn_channels": self.fpn_channels,
"head_kernel_list": self.head_kernel_list,
"step_function_k": self.step_function_k,
}
)
return config


def diffbin_fpn_model(inputs, out_channels):
in2 = layers.Conv2D(
out_channels, kernel_size=1, use_bias=False, name="neck_in2"
)(inputs["P2"])
in3 = layers.Conv2D(
out_channels, kernel_size=1, use_bias=False, name="neck_in3"
)(inputs["P3"])
in4 = layers.Conv2D(
out_channels, kernel_size=1, use_bias=False, name="neck_in4"
)(inputs["P4"])
in5 = layers.Conv2D(
out_channels, kernel_size=1, use_bias=False, name="neck_in5"
)(inputs["P5"])
out4 = layers.Add(name="add1")([layers.UpSampling2D()(in5), in4])
out3 = layers.Add(name="add2")([layers.UpSampling2D()(out4), in3])
out2 = layers.Add(name="add3")([layers.UpSampling2D()(out3), in2])
p5 = layers.Conv2D(
out_channels // 4,
kernel_size=3,
padding="same",
use_bias=False,
name="neck_p5",
)(in5)
p4 = layers.Conv2D(
out_channels // 4,
kernel_size=3,
padding="same",
use_bias=False,
name="neck_p4",
)(out4)
p3 = layers.Conv2D(
out_channels // 4,
kernel_size=3,
padding="same",
use_bias=False,
name="neck_p3",
)(out3)
p2 = layers.Conv2D(
out_channels // 4,
kernel_size=3,
padding="same",
use_bias=False,
name="neck_p2",
)(out2)
p5 = layers.UpSampling2D((8, 8))(p5)
p4 = layers.UpSampling2D((4, 4))(p4)
p3 = layers.UpSampling2D((2, 2))(p3)

fused = layers.Concatenate(axis=-1)([p5, p4, p3, p2])
return fused


def step_function(x, y, k):
return 1.0 / (1.0 + keras.ops.exp(-k * (x - y)))


def diffbin_head(inputs, in_channels, kernel_list, name):
x = layers.Conv2D(
in_channels // 4,
kernel_size=kernel_list[0],
padding="same",
use_bias=False,
name=f"{name}_conv0_weights",
)(inputs)
x = layers.BatchNormalization(
beta_initializer=keras.initializers.Constant(1e-4),
gamma_initializer=keras.initializers.Constant(1.0),
name=f"{name}_conv0_bn",
)(x)
x = layers.ReLU(name=f"{name}_conv0_relu")(x)
x = layers.Conv2DTranspose(
in_channels // 4,
kernel_size=kernel_list[1],
strides=2,
padding="valid",
bias_initializer=keras.initializers.RandomUniform(
minval=-1.0 / math.sqrt(in_channels // 4 * 1.0),
maxval=1.0 / math.sqrt(in_channels // 4 * 1.0),
),
name=f"{name}_conv1_weights",
)(x)
x = layers.BatchNormalization(
beta_initializer=keras.initializers.Constant(1e-4),
gamma_initializer=keras.initializers.Constant(1.0),
name=f"{name}_conv1_bn",
)(x)
x = layers.ReLU(name=f"{name}_conv1_relu")(x)
x = layers.Conv2DTranspose(
1,
kernel_size=kernel_list[2],
strides=2,
padding="valid",
activation="sigmoid",
bias_initializer=keras.initializers.RandomUniform(
minval=-1.0 / math.sqrt(in_channels // 4 * 1.0),
maxval=1.0 / math.sqrt(in_channels // 4 * 1.0),
),
name=f"{name}_conv2_weights",
)(x)
return x
59 changes: 59 additions & 0 deletions keras_nlp/src/models/diffbin/diffbin_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from keras import ops

from keras_nlp.src.models.diffbin.diffbin import DifferentialBinarization
from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone
from keras_nlp.src.tests.test_case import TestCase


class DifferentialBinarizationTest(TestCase):
def setUp(self):
self.images = ops.ones((2, 224, 224, 3))
self.labels = ops.zeros((2, 224, 224, 4))
self.backbone = ResNetBackbone(
input_conv_filters=[64],
input_conv_kernel_sizes=[7],
stackwise_num_filters=[64, 128, 256, 512],
stackwise_num_blocks=[3, 4, 6, 3],
stackwise_num_strides=[1, 2, 2, 2],
block_type="bottleneck_block",
image_shape=(224, 224, 3),
include_rescaling=False,
)
self.init_kwargs = {
"backbone": self.backbone,
}
self.train_data = (self.images, self.labels)

def test_basics(self):
pytest.skip(
reason="TODO: enable after preprocessor flow is figured out"
)
self.run_task_test(
cls=DifferentialBinarization,
init_kwargs=self.init_kwargs,
train_data=self.train_data,
expected_output_shape=(2, 224, 224, 3),
)

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=DifferentialBinarization,
init_kwargs=self.init_kwargs,
input_data=self.images,
)
Loading

0 comments on commit 3d06308

Please sign in to comment.