Skip to content

Commit

Permalink
Move sampler code to samplers.lua
Browse files Browse the repository at this point in the history
  • Loading branch information
achalddave committed May 6, 2017
1 parent 9b67d64 commit dda7cf9
Show file tree
Hide file tree
Showing 8 changed files with 564 additions and 552 deletions.
534 changes: 2 additions & 532 deletions data_loader.lua

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions experimental/samplers.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ require 'classic.torch'

local data_loader = require 'data_loader'
local log = require 'util/log'
local samplers = require 'samplers'

local Sampler = data_loader.Sampler
local BalancedSampler = data_loader.BalancedSampler
local Sampler = samplers.Sampler
local BalancedSampler = samplers.BalancedSampler

-- TODO(achald): Implement greedy balancing sampler.
local GreedyBalancingSampler, GreedyBalancingSamplerSuper =
Expand Down Expand Up @@ -232,5 +233,5 @@ function MarkSeenBalancingSampler:_advance_label_index(label)
end
end

data_loader.GreedyBalancingSampler = GreedyBalancingSampler
data_loader.MarkSeenBalancingSampler = MarkSeenBalancingSampler
samplers.GreedyBalancingSampler = GreedyBalancingSampler
samplers.MarkSeenBalancingSampler = MarkSeenBalancingSampler
11 changes: 6 additions & 5 deletions main.lua
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ local data_loader = require 'data_loader'
local data_source = require 'data_source'
local experiment_saver = require 'util/experiment_saver'
local log = require 'util/log'
local samplers = require 'samplers'
local trainer = require 'trainer'
require 'last_step_criterion'
require 'layers/init'
Expand Down Expand Up @@ -329,9 +330,9 @@ end
log.info('Loaded model')

local sampling_strategies = {
permuted = data_loader.PermutedSampler,
balanced = data_loader.BalancedSampler,
sequential = data_loader.SequentialSampler
permuted = samplers.PermutedSampler,
balanced = samplers.BalancedSampler,
sequential = samplers.SequentialSampler
}

local train_source = data_source[config.data_source_class](
Expand Down Expand Up @@ -362,14 +363,14 @@ end

local val_sampler
if config.sampling_strategy:lower() == 'sequential' then
val_sampler = data_loader.SequentialSampler(
val_sampler = samplers.SequentialSampler(
val_source,
config.sequence_length,
config.step_size,
config.use_boundary_frames,
config.sampling_strategy_options)
else
val_sampler = data_loader.PermutedSampler(
val_sampler = samplers.PermutedSampler(
val_source,
config.sequence_length,
config.step_size,
Expand Down
Loading

0 comments on commit dda7cf9

Please sign in to comment.