Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Clément Walter committed Mar 19, 2020
1 parent 1a49f2d commit bf063db
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions keras_fsl/losses/yolo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

def yolo_loss(anchors, threshold):
"""
Args:
anchors (pandas.DataFrame): dataframe of the anchors with width and height columns.
threshold:
Expand All @@ -14,9 +13,19 @@ def yolo_loss(anchors, threshold):
def _yolo_loss(y_true, y_pred):
"""
y_true and y_pred are (batch_size, number of boxes, 4 (+ 1) + number of classes (+ anchor_id for y_pred)).
The number of boxes is determined by the network architecture as in single-shot detection one can only predict
The number of boxes is determined by the network architecture: as in single-shot detection one can only predict
grid_width x grid_height boxes per anchor.
In all Yolo implementation, it is assumed that the input shape is squared. For simplicity we do the same here
@TODO: rewrite loss without assuming a squared input_shape
"""
batch_size = tf.shape(y_pred)
anchor_feature_maps = [tf.reshape(y_pred[y_pred[..., -1] == i], (batch_size, -1, tf.shape(y_pred)[-1])) for i in anchors.index]
anchor_feature_maps = [tf.reshape(map, (batch_size, tf.math.sqrt(tf.shape(map)[1]), -1, )) for map in anchor_feature_maps]
y_pred = [
tf.
]
y_true = tf.reshape()

loss_coordinates = tf.Variable(0.0)
loss_box = tf.Variable(0.0)
loss_objectness = tf.Variable(0.0)
Expand Down

0 comments on commit bf063db

Please sign in to comment.