Skip to content

Commit

Permalink
Update mAP computation code
Browse files Browse the repository at this point in the history
The main difference is that we were underestimating mAP earlier, by not
setting

    precision[j] = max(precision[j:])

Now, we use the code from
https://github.com/achalddave/average-precision
  • Loading branch information
achalddave committed Mar 30, 2017
1 parent 0d10896 commit e8d5d38
Showing 1 changed file with 70 additions and 22 deletions.
92 changes: 70 additions & 22 deletions evaluator.lua
Original file line number Diff line number Diff line change
@@ -1,12 +1,72 @@
local classic = require 'classic'
local cudnn = require 'cudnn'
local cunn = require 'cunn'
local cutorch = require 'cutorch'
local torch = require 'torch'

local image_util = require 'util/image_util'
local log = require 'util/log'
local END_OF_SEQUENCE = require('data_loader').END_OF_SEQUENCE

local function compute_average_precision(predictions, groundtruth)
--[[
Compute mean average prediction.
From
https://github.com/achalddave/average-precision/blob/e9edd7ef64f9d5f236cf2cf411627c234369eb72/lua/ap_torch.lua
TODO(achald): Add average-precision as a submodule so it stays updated.
Args:
predictions ((num_samples) Tensor)
groundtruth ((num_samples) Tensor): Contains 0/1 values.
Returns:
average_precision (num)
]]--
predictions = predictions:float()
groundtruth = groundtruth:byte()

--[[
Let P(k) be the precision at cut-off for item k. Then, we compute average
precision for each label as
\frac{ \sum_{k=1}^n (P(k) * is_positive(k)) }{ # of relevant documents }
where is_positive(k) is 1 if the groundtruth labeled item k as positive.
]]--
if not torch.any(groundtruth) then
return 0
end
local _, sorted_indices = torch.sort(predictions, 1, true --[[descending]])
local true_positives = 0
local average_precision = 0

local sorted_groundtruth = groundtruth:index(1, sorted_indices):float()

local true_positives = torch.cumsum(sorted_groundtruth)
local false_positives = torch.cumsum(1 - sorted_groundtruth)
local num_positives = true_positives[-1]

local precisions = torch.cdiv(
true_positives,
torch.cmax(true_positives + false_positives, 1e-16))
local recalls = true_positives / num_positives

-- Set precisions[i] = max(precisions[j] for j >= i)
-- This is because (for j > i), recall[j] >= recall[i], so we can
-- always use a lower threshold to get the higher recall and higher
-- precision at j.
for i = precisions:nElement()-1, 1, -1 do
precisions[i] = math.max(precisions[i], precisions[i+1])
end

-- Append end points of the precision recall curve.
local zero = torch.zeros(1):float()
local one = torch.ones(1):float()
precisions = torch.cat({zero, precisions, zero}, 1)
recalls = torch.cat({zero, recalls, one})

-- Find points where recall changes.
local changes = torch.ne(recalls[{{2, -1}}], recalls[{{1, -2}}])
local changes_plus_1 = torch.cat({torch.zeros(1):byte(), changes})
changes = torch.cat({changes, torch.zeros(1):byte()})

return torch.cmul((recalls[changes_plus_1] - recalls[changes]),
precisions[changes_plus_1]):sum()
end

-- TODO(achald): Move this to a separate util file.
function compute_mean_average_precision(predictions, groundtruth)
Expand Down Expand Up @@ -34,22 +94,10 @@ function compute_mean_average_precision(predictions, groundtruth)
where is_positive(k) is 1 if the groundtruth labeled item k as positive.
]]--
for label = 1, num_labels do
local label_groundtruth = groundtruth[{{}, label}]
if torch.any(label_groundtruth) then
if torch.any(groundtruth[{{}, label}]) then
label_has_sample[label] = 1
local label_predictions = predictions[{{}, label}]
local _, sorted_indices = torch.sort(
label_predictions, 1, true --[[descending]])
local true_positives = 0
local average_precision = 0

local sorted_groundtruth = label_groundtruth:index(
1, sorted_indices):float()
local true_positives = torch.cumsum(sorted_groundtruth)
local num_guesses = torch.range(1, label_predictions:nElement())
local precisions = torch.cdiv(true_positives, num_guesses)
precisions = precisions[torch.eq(sorted_groundtruth, 1)]
average_precisions[label] = precisions:mean()
average_precisions[label] = compute_average_precision(
predictions[{{}, label}], groundtruth[{{}, label}])
end
end
-- Return mean of average precisions for labels which had at least 1 sample
Expand Down

0 comments on commit e8d5d38

Please sign in to comment.