From e58745cf556f705868c1ad1eb60dee0654352164 Mon Sep 17 00:00:00 2001 From: Stephen von Takach Date: Fri, 26 May 2023 01:09:29 +1000 Subject: [PATCH] feat(extract_labels): add support for finding labels from bytes --- shard.yml | 2 +- src/tensorflow_lite/client.cr | 13 ++++++++++--- src/tensorflow_lite/utilities/extract_labels.cr | 12 +++++++----- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/shard.yml b/shard.yml index a6d946b..d9a5c8c 100644 --- a/shard.yml +++ b/shard.yml @@ -1,5 +1,5 @@ name: tensorflow_lite -version: 1.6.2 +version: 1.6.3 development_dependencies: ameba: diff --git a/src/tensorflow_lite/client.cr b/src/tensorflow_lite/client.cr index 80ceff3..43359f8 100644 --- a/src/tensorflow_lite/client.cr +++ b/src/tensorflow_lite/client.cr @@ -8,7 +8,7 @@ class TensorflowLite::Client include Indexable(Tensor) # Configures the tensorflow interpreter with the options provided - def initialize(model : URI| Bytes | Path | Model | String, delegate : Delegate? = nil, threads : Int? = nil, labels : URI | Hash(Int32, String)? = nil, &on_error : String -> Nil) + def initialize(model : URI | Bytes | Path | Model | String, delegate : Delegate? = nil, threads : Int? = nil, labels : URI | Hash(Int32, String)? = nil, &on_error : String -> Nil) @labels_fetched = !!@labels @model = case model in String, Path @@ -16,13 +16,15 @@ class TensorflowLite::Client @model_path = path Model.new(path) in Bytes + @model_bytes = model Model.new(model) in Model model in URI HTTP::Client.get(model) do |response| - raise "model download failed with #{response.status} (#{response.status_code}) while fetching #{model}" unless response.success? - Model.new response.body_io.getb_to_end + raise "model download failed with #{response.status} (#{response.status_code}) while fetching #{model}" unless response.success? + @model_bytes = model_bytes = response.body_io.getb_to_end + Model.new model_bytes end end @@ -102,6 +104,7 @@ class TensorflowLite::Client getter labels_fetched : Bool @labels : Hash(Int32, String)? + @model_bytes : Bytes? = nil # attempt to extract any labels in the model def labels @@ -110,6 +113,10 @@ class TensorflowLite::Client elsif path = @model_path @labels_fetched = true @labels = Utilities::ExtractLabels.from(path) + elsif bytes = @model_bytes + @labels_fetched = true + @model_bytes = nil + @labels = Utilities::ExtractLabels.from(bytes) end end end diff --git a/src/tensorflow_lite/utilities/extract_labels.cr b/src/tensorflow_lite/utilities/extract_labels.cr index a2971fc..7b33f7e 100644 --- a/src/tensorflow_lite/utilities/extract_labels.cr +++ b/src/tensorflow_lite/utilities/extract_labels.cr @@ -8,13 +8,15 @@ module TensorflowLite::Utilities::ExtractLabels MAGIC_ZIP = Bytes[0x50, 0x4b, 0x03, 0x04] # extracts the label names from tensorflow lite model at the path specified - def self.from(input : Path, metadata_file : String = ".txt") : Hash(Int32, String)? + def self.from(input : Path | Bytes, metadata_file : String = ".txt") : Hash(Int32, String)? # TODO:: we should update this to search the file more optimally # and work more memory effciently - file = File.new input - bytes = Bytes.new file.size - file.read_fully bytes - file.close + bytes = case input + in Path + File.open(input, &.getb_to_end) + in Bytes + input + end io = IO::Memory.new(bytes) found = 0