Skip to content

Commit

Permalink
feat(extract_labels): add support for finding labels from bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
stakach committed May 25, 2023
1 parent 14f64a0 commit e58745c
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
2 changes: 1 addition & 1 deletion shard.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: tensorflow_lite
version: 1.6.2
version: 1.6.3

development_dependencies:
ameba:
Expand Down
13 changes: 10 additions & 3 deletions src/tensorflow_lite/client.cr
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,23 @@ class TensorflowLite::Client
include Indexable(Tensor)

# Configures the tensorflow interpreter with the options provided
def initialize(model : URI| Bytes | Path | Model | String, delegate : Delegate? = nil, threads : Int? = nil, labels : URI | Hash(Int32, String)? = nil, &on_error : String -> Nil)
def initialize(model : URI | Bytes | Path | Model | String, delegate : Delegate? = nil, threads : Int? = nil, labels : URI | Hash(Int32, String)? = nil, &on_error : String -> Nil)
@labels_fetched = !!@labels
@model = case model
in String, Path
path = Path.new(model)
@model_path = path
Model.new(path)
in Bytes
@model_bytes = model
Model.new(model)
in Model
model
in URI
HTTP::Client.get(model) do |response|
raise "model download failed with #{response.status} (#{response.status_code}) while fetching #{model}" unless response.success?
Model.new response.body_io.getb_to_end
raise "model download failed with #{response.status} (#{response.status_code}) while fetching #{model}" unless response.success?
@model_bytes = model_bytes = response.body_io.getb_to_end
Model.new model_bytes
end
end

Expand Down Expand Up @@ -102,6 +104,7 @@ class TensorflowLite::Client

getter labels_fetched : Bool
@labels : Hash(Int32, String)?
@model_bytes : Bytes? = nil

# attempt to extract any labels in the model
def labels
Expand All @@ -110,6 +113,10 @@ class TensorflowLite::Client
elsif path = @model_path
@labels_fetched = true
@labels = Utilities::ExtractLabels.from(path)
elsif bytes = @model_bytes
@labels_fetched = true
@model_bytes = nil
@labels = Utilities::ExtractLabels.from(bytes)
end
end
end
12 changes: 7 additions & 5 deletions src/tensorflow_lite/utilities/extract_labels.cr
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ module TensorflowLite::Utilities::ExtractLabels
MAGIC_ZIP = Bytes[0x50, 0x4b, 0x03, 0x04]

# extracts the label names from tensorflow lite model at the path specified
def self.from(input : Path, metadata_file : String = ".txt") : Hash(Int32, String)?
def self.from(input : Path | Bytes, metadata_file : String = ".txt") : Hash(Int32, String)?
# TODO:: we should update this to search the file more optimally
# and work more memory effciently
file = File.new input
bytes = Bytes.new file.size
file.read_fully bytes
file.close
bytes = case input
in Path
File.open(input, &.getb_to_end)
in Bytes
input
end

io = IO::Memory.new(bytes)
found = 0
Expand Down

0 comments on commit e58745c

Please sign in to comment.