From 6a4b5a27b11191dec8af41dbaa594394b2ea9810 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Walter?= Date: Fri, 21 Feb 2020 16:07:06 +0100 Subject: [PATCH 1/5] Style fix Darknet --- keras_fsl/models/encoders/darknet.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/keras_fsl/models/encoders/darknet.py b/keras_fsl/models/encoders/darknet.py index 3924e32..3aa3fa0 100644 --- a/keras_fsl/models/encoders/darknet.py +++ b/keras_fsl/models/encoders/darknet.py @@ -10,12 +10,13 @@ def conv_2d(*args, **kwargs): return Conv2D(*args, **kwargs, kernel_regularizer=l2(5e-4), padding="valid" if kwargs.get("strides") == (2, 2) else "same") +@wraps(Conv2D) def conv_block(*args, **kwargs): - layer = Sequential() - layer.add(conv_2d(*args, **kwargs, use_bias=False)) - layer.add(BatchNormalization()) - layer.add(LeakyReLU(alpha=0.1)) - return layer + return Sequential([ + conv_2d(*args, **kwargs, use_bias=False), + BatchNormalization(), + LeakyReLU(alpha=0.1), + ]) def residual_block(input_shape, num_filters, num_blocks): From 43dbb5661a8c44988b24fafffa0db990c2e0f896 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Walter?= Date: Fri, 21 Feb 2020 16:07:23 +0100 Subject: [PATCH 2/5] Add FeaturePyramidNet model builder --- keras_fsl/models/feature_pyramid_net.py | 116 ++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 keras_fsl/models/feature_pyramid_net.py diff --git a/keras_fsl/models/feature_pyramid_net.py b/keras_fsl/models/feature_pyramid_net.py new file mode 100644 index 0000000..380e378 --- /dev/null +++ b/keras_fsl/models/feature_pyramid_net.py @@ -0,0 +1,116 @@ +from functools import wraps + +import pandas as pd +from tensorflow.keras import Model, Sequential +from tensorflow.keras.layers import Conv2D, BatchNormalization, UpSampling2D, ReLU, Concatenate + +from keras_fsl.models import branch_models + + +ANCHORS = pd.DataFrame([ + [0, 116, 90], + [0, 156, 198], + [0, 373, 326], + [1, 30, 61], + [1, 62, 45], + [1, 59, 119], + [2, 10, 13], + [2, 16, 30], + [2, 33, 23], +], columns=['scale', 'width', 'height']) + + +@wraps(Conv2D) +def conv_block(*args, **kwargs): + return Sequential([ + Conv2D(*args, **kwargs, use_bias=False), + BatchNormalization(), + ReLU(), + ]) + + +def bottleneck(filters, *args, **kwargs): + return Sequential([ + conv_block(filters // 4, (1, 1), padding='same'), + conv_block(filters, (3, 3), padding='same'), + ], *args, **kwargs) + + +def up_sampling_block(filters, *args, **kwargs): + return Sequential([ + conv_block(filters, (1, 1), padding='same'), + UpSampling2D(2), + ], *args, **kwargs) + + +def FeaturePyramidNet( + backbone='MobileNet', + *args, + feature_maps=3, + objectness=True, + anchors=None, + classes=None, + weights=None, + **kwargs, +): + """ + Multi scale feature extractor following the [Feature Pyramid Network for Object Detection](https://arxiv.org/pdf/1612.03144.pdf) + framework. + + It analyses the given backbone architecture so as to extract the features maps at relevant positions (last position before downsampling) + Then it builds a model with as many feature maps (outputs) as requested, starting from the deepest. + + Args: + backbone (Union[str, dict, tensorflow.keras.Model]): parameters of the feature extractor + feature_maps (int): number of feature maps to extract from the backbone. + objectness (bool): whether to add a score for object presence probability or not (similar to add a background class, see Yolo for + instance). + anchors (pandas.DataFrame): containing scale, width and height columns. Scale column will be used to select the corresponding + feature map: 0 for the smallest resolution, 1 for the next one, etc. + classes (pandas.Series): provide classes to build a single-shot detector from the anchors and the feature maps. + weights (Union[str, pathlib.Path]): path to the weights file to load with tensorflow.keras.load_weights + """ + if not isinstance(backbone, Model): + if isinstance(backbone, str): + backbone = {'name': backbone, 'init': {'include_top': False, 'input_shape': (416, 416, 3)}} + backbone_name = backbone['name'] + backbone = getattr(branch_models, backbone_name)(**backbone.get('init', {})) + + output_shapes = ( + pd.DataFrame([ + layer.input_shape[0] + if isinstance(layer.input_shape, list) + else layer.output_shape + for layer in backbone.layers + ], columns=['batch_size', 'width', 'height', 'channels']) + .loc[lambda df: df.width.iloc[0] % df.width == 0] + .drop_duplicates(['width', 'height'], keep='last') + .sort_index(ascending=False) + ) + + outputs = [] + for output_shape in output_shapes.iloc[:feature_maps].itertuples(): + input_ = backbone.layers[output_shape.Index].output + if outputs: + pyramid_input = up_sampling_block(output_shape.channels, name=f'up_sampling_{output_shape.channels}')(outputs[-1]) + input_ = Concatenate()([input_, pyramid_input]) + outputs += [bottleneck(output_shape.channels, name=f'bottleneck_{output_shape.channels}')(input_)] + + if classes is not None: + if anchors is None: + anchors = ANCHORS.copy() + anchors = anchors.assign(id=lambda df: 'scale_' + df.scale.astype(str) + '_' + df.width.astype(str) + 'x' + df.height.astype(str)) + outputs = [ + Concatenate(axis=3)( + [Conv2D(4, (1, 1), name=f'{anchor.id}_box')(outputs[anchor.scale])] + + ([Conv2D(1, (1, 1), name=f'{anchor.id}_objectness')(outputs[anchor.scale])] if objectness else []) + + [Conv2D(1, (1, 1), name=f'{anchor.id}_{label}')(outputs[anchor.scale]) for label in classes], + ) + for anchor in anchors.itertuples() + ] + + model = Model(backbone.input, outputs, *args, **kwargs) + if weights is not None: + model.load_weights(weights) + + return model From 8c8dc11b6cc98a3423be679fd028bef07b5df93c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Walter?= Date: Tue, 3 Mar 2020 09:06:02 +0100 Subject: [PATCH 3/5] WIP yolo loss --- keras_fsl/losses/yolo_loss.py | 61 +++++++++ keras_fsl/models/__init__.py | 1 + keras_fsl/models/activations/__init__.py | 7 + keras_fsl/models/activations/yolo_box.py | 27 ++++ .../models/activations/yolo_coordinates.py | 38 ++++++ keras_fsl/models/encoders/darknet.py | 6 +- keras_fsl/models/feature_pyramid_net.py | 125 +++++++++++------- keras_fsl/models/layers/__init__.py | 3 + keras_fsl/utils/training.py | 2 + 9 files changed, 217 insertions(+), 53 deletions(-) create mode 100644 keras_fsl/losses/yolo_loss.py create mode 100644 keras_fsl/models/activations/__init__.py create mode 100644 keras_fsl/models/activations/yolo_box.py create mode 100644 keras_fsl/models/activations/yolo_coordinates.py create mode 100644 keras_fsl/models/layers/__init__.py diff --git a/keras_fsl/losses/yolo_loss.py b/keras_fsl/losses/yolo_loss.py new file mode 100644 index 0000000..b311653 --- /dev/null +++ b/keras_fsl/losses/yolo_loss.py @@ -0,0 +1,61 @@ +import tensorflow as tf + + +def yolo_loss(anchors, threshold): + """ + + Args: + anchors (pandas.DataFrame): dataframe of the anchors with width and height columns. + threshold: + + """ + + def _yolo_loss(y_true, y_pred): + """ + y_true and y_pred are (batch_size, number of boxes, 4 (+ 1) + number of classes (+ anchor_id for y_pred)). + The number of boxes is determined by the network architecture as in single-shot detection one can only predict + grid_width x grid_height boxes per anchor. + """ + # 1. Find matching anchors: the anchor with the best IoU is chosen for predicting each true box + y_true_broadcast = tf.expand_dims(y_true, axis=2) + y_true_broadcast.shape + y_true_broadcast[..., 2:4].shape + + anchors_tensor = tf.broadcast_to(anchors[["height", "width"]].values, [1, 1, len(anchors), 2]) + anchors_tensor.shape + + height_width_min = tf.minimum(y_true_broadcast[..., 2:4], anchors_tensor) + height_width_max = tf.maximum(y_true_broadcast[..., 2:4], anchors_tensor) + height_width_min.shape + height_width_max.shape + intersection = tf.reduce_prod(height_width_min, axis=-1) + intersection.shape + true_box_area = tf.reduce_prod(y_true_broadcast[..., 2:4], axis=-1) + true_box_area.shape + anchor_boxes_area = tf.reduce_prod(anchors_tensor, axis=-1) + anchor_boxes_area.shape + union = true_box_area + anchor_boxes_area - intersection + union.shape + iou = intersection / union + iou.shape + best_anchor = tf.math.argmax(iou, axis=-1) + best_anchor.shape + best_anchor[0, 0] + + batch_size, boxes, _ = tf.shape(y_true) + # 2. Find grid cell: for each selected anchor, select the prediction coming from the cell which contains the true box center + for image in range(batch_size): + for box in range(boxes): + true_box_info = y_true[image, box] + selected_anchor = tf.cast(best_anchor[image, box], y_pred.dtype) + prediction_for_anchor = tf.boolean_mask(y_pred[image], y_pred[image, :, -1] == selected_anchor, axis=0) + prediction_for_anchor.shape + grid_size = prediction_for_anchor + y_pred[..., -1].shape == best_anchor + y_pred.shape + + # 3. For confidence loss: for each selected anchor, compute confidence loss for boxes with IoU < threshold + non_empty_boxes_mask = tf.cast(tf.math.reduce_prod(y_true[..., 2:4], axis=-1) > 0, tf.bool) + pass + + return _yolo_loss diff --git a/keras_fsl/models/__init__.py b/keras_fsl/models/__init__.py index bba21f2..65b244d 100644 --- a/keras_fsl/models/__init__.py +++ b/keras_fsl/models/__init__.py @@ -1,3 +1,4 @@ +from .feature_pyramid_net import FeaturePyramidNet from .siamese_nets import SiameseNets __all__ = ["SiameseNets"] diff --git a/keras_fsl/models/activations/__init__.py b/keras_fsl/models/activations/__init__.py new file mode 100644 index 0000000..c2199f8 --- /dev/null +++ b/keras_fsl/models/activations/__init__.py @@ -0,0 +1,7 @@ +from .yolo_box import YoloBox +from .yolo_coordinates import YoloCoordinates + +__all__ = [ + "YoloBox", + "YoloCoordinates", +] diff --git a/keras_fsl/models/activations/yolo_box.py b/keras_fsl/models/activations/yolo_box.py new file mode 100644 index 0000000..7976130 --- /dev/null +++ b/keras_fsl/models/activations/yolo_box.py @@ -0,0 +1,27 @@ +""" +Activation function for mapping feature into output coordinates as in Yolo V3 +""" +import tensorflow as tf +from tensorflow.keras.models import Sequential +from tensorflow.keras.layers import Activation, Lambda + + +def YoloBox(anchor): + """ + Activation function for the box dimension regression. Dimensions are relative to the image dimension, ie. between 0 + and 1 + + Args: + anchor (Union[pandas.Series, collections.namedtuple]): with key width and height. Note that given a tensor with shape + (batch_size, i, j, channels), i is related to height and j to width + """ + return Sequential( + [ + Activation("exponential"), + Lambda( + lambda input_, anchor_=anchor: ( + input_ * tf.convert_to_tensor([anchor_.height, anchor_.width], dtype=tf.float32) + ) + ), + ] + ) diff --git a/keras_fsl/models/activations/yolo_coordinates.py b/keras_fsl/models/activations/yolo_coordinates.py new file mode 100644 index 0000000..8325be9 --- /dev/null +++ b/keras_fsl/models/activations/yolo_coordinates.py @@ -0,0 +1,38 @@ +""" +Activation function for mapping feature into output coordinates as in Yolo V3 +""" +import tensorflow as tf +from tensorflow.keras.models import Sequential +from tensorflow.keras.layers import Activation, Lambda + + +@tf.function +def build_grid_coordinates(grid_shape): + """ + Build a grid coordinate tensor with shape (*grid_shape, 2) where grid[i, j, 0] = i and grid[i, j, 1] = j + Args: + grid_shape (Union[tuple, list, tensorflow.TensorShape]): to be passed to tf.range + + Returns: + (tensorflow.Tensor) + """ + height, width = tf.meshgrid(tf.range(0, grid_shape[0]), tf.range(0, grid_shape[1])) + width = tf.transpose(width) + height = tf.transpose(height) + return tf.stack([height, width], -1) + + +def YoloCoordinates(): + """ + Activation function for the box center coordinates regression. Coordinates are relative to the image dimension, ie. between 0 + and 1 + """ + return Sequential( + [ + Activation("sigmoid"), + Lambda( + lambda input_: input_ + tf.cast(tf.expand_dims(build_grid_coordinates(tf.shape(input_)[1:3]), 0), input_.dtype) + ), + Lambda(lambda input_: input_ / tf.cast(tf.shape(input_)[1:3], input_.dtype)), + ] + ) diff --git a/keras_fsl/models/encoders/darknet.py b/keras_fsl/models/encoders/darknet.py index 3aa3fa0..93401a6 100644 --- a/keras_fsl/models/encoders/darknet.py +++ b/keras_fsl/models/encoders/darknet.py @@ -12,11 +12,7 @@ def conv_2d(*args, **kwargs): @wraps(Conv2D) def conv_block(*args, **kwargs): - return Sequential([ - conv_2d(*args, **kwargs, use_bias=False), - BatchNormalization(), - LeakyReLU(alpha=0.1), - ]) + return Sequential([conv_2d(*args, **kwargs, use_bias=False), BatchNormalization(), LeakyReLU(alpha=0.1),]) def residual_block(input_shape, num_filters, num_blocks): diff --git a/keras_fsl/models/feature_pyramid_net.py b/keras_fsl/models/feature_pyramid_net.py index 380e378..10b0383 100644 --- a/keras_fsl/models/feature_pyramid_net.py +++ b/keras_fsl/models/feature_pyramid_net.py @@ -1,56 +1,57 @@ from functools import wraps import pandas as pd +import tensorflow as tf from tensorflow.keras import Model, Sequential -from tensorflow.keras.layers import Conv2D, BatchNormalization, UpSampling2D, ReLU, Concatenate - -from keras_fsl.models import branch_models - - -ANCHORS = pd.DataFrame([ - [0, 116, 90], - [0, 156, 198], - [0, 373, 326], - [1, 30, 61], - [1, 62, 45], - [1, 59, 119], - [2, 10, 13], - [2, 16, 30], - [2, 33, 23], -], columns=['scale', 'width', 'height']) +from tensorflow.keras.layers import Conv2D, BatchNormalization, UpSampling2D, ReLU, Concatenate, Reshape, Lambda + +from keras_fsl.models import branch_models, activations + +ANCHORS = pd.DataFrame( + [ + [0, 116 / 416, 90 / 416], + [0, 156 / 416, 198 / 416], + [0, 373 / 416, 326 / 416], + [1, 30 / 416, 61 / 416], + [1, 62 / 416, 45 / 416], + [1, 59 / 416, 119 / 416], + [2, 10 / 416, 13 / 416], + [2, 16 / 416, 30 / 416], + [2, 33 / 416, 23 / 416], + ], + columns=["scale", "width", "height"], +) @wraps(Conv2D) def conv_block(*args, **kwargs): - return Sequential([ - Conv2D(*args, **kwargs, use_bias=False), - BatchNormalization(), - ReLU(), - ]) + return Sequential([Conv2D(*args, **kwargs, use_bias=False), BatchNormalization(), ReLU(),]) def bottleneck(filters, *args, **kwargs): - return Sequential([ - conv_block(filters // 4, (1, 1), padding='same'), - conv_block(filters, (3, 3), padding='same'), - ], *args, **kwargs) + return Sequential( + [conv_block(filters // 4, (1, 1), padding="same"), conv_block(filters, (3, 3), padding="same"),], *args, **kwargs + ) def up_sampling_block(filters, *args, **kwargs): - return Sequential([ - conv_block(filters, (1, 1), padding='same'), - UpSampling2D(2), - ], *args, **kwargs) + return Sequential([conv_block(filters, (1, 1), padding="same"), UpSampling2D(2),], *args, **kwargs) + + +def regression_block(activation, *args, **kwargs): + return Sequential([Conv2D(2, (1, 1)), getattr(activations, activation)(*args),], **kwargs) def FeaturePyramidNet( - backbone='MobileNet', + backbone="MobileNet", *args, feature_maps=3, objectness=True, anchors=None, classes=None, weights=None, + coordinates_activation="YoloCoordinates", + box_activation="YoloBox", **kwargs, ): """ @@ -60,6 +61,12 @@ def FeaturePyramidNet( It analyses the given backbone architecture so as to extract the features maps at relevant positions (last position before downsampling) Then it builds a model with as many feature maps (outputs) as requested, starting from the deepest. + When classes is not None, it builds a single shot detector from the features based on a given list of anchors. In this case, all + dimensions are relative to the image dimension: coordinates and box dimensions will be float in [0, 1]. Hence anchors are defined with + floats for width and height. Anchor should also specify onto which feature map it is based: the current implementation counts backward + with 0 meaning the smallest resolution, 1 the following one, etc. The output shape of the model is then a list of boxes for each image, + ie (batch_size, number of boxes, {coordinates, (objectness,) labels, anchor_id}). + Args: backbone (Union[str, dict, tensorflow.keras.Model]): parameters of the feature extractor feature_maps (int): number of feature maps to extract from the backbone. @@ -69,22 +76,22 @@ def FeaturePyramidNet( feature map: 0 for the smallest resolution, 1 for the next one, etc. classes (pandas.Series): provide classes to build a single-shot detector from the anchors and the feature maps. weights (Union[str, pathlib.Path]): path to the weights file to load with tensorflow.keras.load_weights + coordinates_activation (str): activation function to be used for the center coordinates regression + box_activation (str): activation function to be used for the box height and width regression """ if not isinstance(backbone, Model): if isinstance(backbone, str): - backbone = {'name': backbone, 'init': {'include_top': False, 'input_shape': (416, 416, 3)}} - backbone_name = backbone['name'] - backbone = getattr(branch_models, backbone_name)(**backbone.get('init', {})) + backbone = {"name": backbone, "init": {"include_top": False, "input_shape": (416, 416, 3)}} + backbone_name = backbone["name"] + backbone = getattr(branch_models, backbone_name)(**backbone.get("init", {})) output_shapes = ( - pd.DataFrame([ - layer.input_shape[0] - if isinstance(layer.input_shape, list) - else layer.output_shape - for layer in backbone.layers - ], columns=['batch_size', 'width', 'height', 'channels']) + pd.DataFrame( + [layer.input_shape[0] if isinstance(layer.input_shape, list) else layer.output_shape for layer in backbone.layers], + columns=["batch_size", "height", "width", "channels"], + ) .loc[lambda df: df.width.iloc[0] % df.width == 0] - .drop_duplicates(['width', 'height'], keep='last') + .drop_duplicates(["width", "height"], keep="last") .sort_index(ascending=False) ) @@ -92,22 +99,44 @@ def FeaturePyramidNet( for output_shape in output_shapes.iloc[:feature_maps].itertuples(): input_ = backbone.layers[output_shape.Index].output if outputs: - pyramid_input = up_sampling_block(output_shape.channels, name=f'up_sampling_{output_shape.channels}')(outputs[-1]) + pyramid_input = up_sampling_block(output_shape.channels, name=f"up_sampling_{output_shape.channels}")(outputs[-1]) input_ = Concatenate()([input_, pyramid_input]) - outputs += [bottleneck(output_shape.channels, name=f'bottleneck_{output_shape.channels}')(input_)] + outputs += [bottleneck(output_shape.channels, name=f"bottleneck_{output_shape.channels}")(input_)] if classes is not None: if anchors is None: - anchors = ANCHORS.copy() - anchors = anchors.assign(id=lambda df: 'scale_' + df.scale.astype(str) + '_' + df.width.astype(str) + 'x' + df.height.astype(str)) + anchors = ANCHORS.copy().round(3) + anchors = anchors.assign( + id=lambda df: "scale_" + df.scale.astype(str) + "_" + df.width.astype(str) + "x" + df.height.astype(str) + ) outputs = [ - Concatenate(axis=3)( - [Conv2D(4, (1, 1), name=f'{anchor.id}_box')(outputs[anchor.scale])] + - ([Conv2D(1, (1, 1), name=f'{anchor.id}_objectness')(outputs[anchor.scale])] if objectness else []) + - [Conv2D(1, (1, 1), name=f'{anchor.id}_{label}')(outputs[anchor.scale]) for label in classes], + Reshape((-1, 4 + int(objectness) + len(classes)))( + Concatenate(axis=3, name=f"anchor_{anchor.id}_output")( + [regression_block(coordinates_activation, name=f"{anchor.id}_box_yx")(outputs[anchor.scale])] + + [regression_block(box_activation, anchor, name=f"{anchor.id}_box_hw")(outputs[anchor.scale])] + + ( + [Conv2D(1, (1, 1), name=f"{anchor.id}_objectness", activation="sigmoid")(outputs[anchor.scale])] + if objectness + else [] + ) + + [ + Conv2D(1, (1, 1), name=f"{anchor.id}_{label}", activation="sigmoid")(outputs[anchor.scale]) + for label in classes + ] + ) ) for anchor in anchors.itertuples() ] + outputs = Concatenate(axis=1)( + [ + Lambda( + lambda output: tf.concat( + [output, tf.expand_dims(tf.ones(tf.shape(output)[:2], dtype=output.dtype) * index, -1)], axis=-1 + ) + )(outputs[index]) + for index, anchor in anchors.iterrows() + ] + ) model = Model(backbone.input, outputs, *args, **kwargs) if weights is not None: diff --git a/keras_fsl/models/layers/__init__.py b/keras_fsl/models/layers/__init__.py new file mode 100644 index 0000000..423a131 --- /dev/null +++ b/keras_fsl/models/layers/__init__.py @@ -0,0 +1,3 @@ +from .classification import Classification +from .gram_matrix import GramMatrix +from .slicing import CenterSlicing2D diff --git a/keras_fsl/utils/training.py b/keras_fsl/utils/training.py index d012d33..496327f 100644 --- a/keras_fsl/utils/training.py +++ b/keras_fsl/utils/training.py @@ -2,6 +2,8 @@ from functools import reduce, wraps from unittest.mock import patch +import tensorflow as tf + def patch_len(fit_generator): """ From 928fe442deeddb042852df7228a1ef595cb3cd56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Walter?= Date: Thu, 5 Mar 2020 14:41:47 +0100 Subject: [PATCH 4/5] WIP --- keras_fsl/models/activations/yolo_box.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_fsl/models/activations/yolo_box.py b/keras_fsl/models/activations/yolo_box.py index 7976130..0f34b5d 100644 --- a/keras_fsl/models/activations/yolo_box.py +++ b/keras_fsl/models/activations/yolo_box.py @@ -1,5 +1,5 @@ """ -Activation function for mapping feature into output coordinates as in Yolo V3 +Activation function for mapping feature into output box dimensions as in Yolo V3 """ import tensorflow as tf from tensorflow.keras.models import Sequential From 5712a4803743bdc43d916d4d1e598211f24cd18d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Walter?= Date: Tue, 7 Jun 2022 09:16:21 +0200 Subject: [PATCH 5/5] WIP FPN --- keras_fsl/models/feature_pyramid_net.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/keras_fsl/models/feature_pyramid_net.py b/keras_fsl/models/feature_pyramid_net.py index 10b0383..bf0bc83 100644 --- a/keras_fsl/models/feature_pyramid_net.py +++ b/keras_fsl/models/feature_pyramid_net.py @@ -5,7 +5,7 @@ from tensorflow.keras import Model, Sequential from tensorflow.keras.layers import Conv2D, BatchNormalization, UpSampling2D, ReLU, Concatenate, Reshape, Lambda -from keras_fsl.models import branch_models, activations +from keras_fsl.models import encoders, activations ANCHORS = pd.DataFrame( [ @@ -25,21 +25,21 @@ @wraps(Conv2D) def conv_block(*args, **kwargs): - return Sequential([Conv2D(*args, **kwargs, use_bias=False), BatchNormalization(), ReLU(),]) + return Sequential([Conv2D(*args, **kwargs, use_bias=False), BatchNormalization(), ReLU()]) def bottleneck(filters, *args, **kwargs): return Sequential( - [conv_block(filters // 4, (1, 1), padding="same"), conv_block(filters, (3, 3), padding="same"),], *args, **kwargs + [conv_block(filters // 4, (1, 1), padding="same"), conv_block(filters, (3, 3), padding="same")], *args, **kwargs ) def up_sampling_block(filters, *args, **kwargs): - return Sequential([conv_block(filters, (1, 1), padding="same"), UpSampling2D(2),], *args, **kwargs) + return Sequential([conv_block(filters, (1, 1), padding="same"), UpSampling2D(2)], *args, **kwargs) def regression_block(activation, *args, **kwargs): - return Sequential([Conv2D(2, (1, 1)), getattr(activations, activation)(*args),], **kwargs) + return Sequential([Conv2D(2, (1, 1)), getattr(activations, activation)(*args)], **kwargs) def FeaturePyramidNet( @@ -83,7 +83,7 @@ def FeaturePyramidNet( if isinstance(backbone, str): backbone = {"name": backbone, "init": {"include_top": False, "input_shape": (416, 416, 3)}} backbone_name = backbone["name"] - backbone = getattr(branch_models, backbone_name)(**backbone.get("init", {})) + backbone = getattr(encoders, backbone_name)(**backbone.get("init", {})) output_shapes = ( pd.DataFrame(