diff --git a/keras_fsl/losses/yolo_loss.py b/keras_fsl/losses/yolo_loss.py index ecda644..aee469f 100644 --- a/keras_fsl/losses/yolo_loss.py +++ b/keras_fsl/losses/yolo_loss.py @@ -3,7 +3,6 @@ def yolo_loss(anchors, threshold): """ - Args: anchors (pandas.DataFrame): dataframe of the anchors with width and height columns. threshold: @@ -14,9 +13,19 @@ def yolo_loss(anchors, threshold): def _yolo_loss(y_true, y_pred): """ y_true and y_pred are (batch_size, number of boxes, 4 (+ 1) + number of classes (+ anchor_id for y_pred)). - The number of boxes is determined by the network architecture as in single-shot detection one can only predict + The number of boxes is determined by the network architecture: as in single-shot detection one can only predict grid_width x grid_height boxes per anchor. + In all Yolo implementation, it is assumed that the input shape is squared. For simplicity we do the same here + @TODO: rewrite loss without assuming a squared input_shape """ + batch_size = tf.shape(y_pred) + anchor_feature_maps = [tf.reshape(y_pred[y_pred[..., -1] == i], (batch_size, -1, tf.shape(y_pred)[-1])) for i in anchors.index] + anchor_feature_maps = [tf.reshape(map, (batch_size, tf.math.sqrt(tf.shape(map)[1]), -1, )) for map in anchor_feature_maps] + y_pred = [ + tf. + ] + y_true = tf.reshape() + loss_coordinates = tf.Variable(0.0) loss_box = tf.Variable(0.0) loss_objectness = tf.Variable(0.0)