-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmanage_models.exs
117 lines (98 loc) · 3.51 KB
/
manage_models.exs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
defmodule Comparison.Models do
@moduledoc """
Manages loading the modules when benchmarking models.
It is inspired by the `App.Models` module in the Phoenix app.
"""
require Logger
@doc """
Verifies and downloads the model according.
You can optionally force it to re-download the model by passing `force_download?`
"""
def verify_and_download_model(model, force_download? \\ false) do
case force_download? do
true ->
# Delete any cached pre-existing model
File.rm_rf!(model.cache_path)
# Download model
download_model(model)
false ->
# Check if the model cache directory exists or if it's not empty.
# If so, we download the model.
model_location = Path.join(model.cache_path, "huggingface")
if not File.exists?(model_location) or File.ls!(model_location) == [] do
download_model(model)
end
end
end
@doc """
Serving function that serves the `Bumblebee` models used throughout the app.
This function is meant to be called and served by `Nx`,
like `Nx.Serving.run(serving, "The capital of [MASK] is Paris.")`
This assumes the models that are being used exist locally.
"""
def serving(model) do
model = load_offline_model_params(model)
Bumblebee.Vision.image_to_text(
model.model_info,
model.featurizer,
model.tokenizer,
model.generation_config,
compile: [batch_size: 10],
defn_options: [compiler: EXLA],
preallocate_params: true
)
end
# Loads the model from the cache folder.
# It will load the model and the respective the featurizer, tokenizer and generation config if needed,
# and return a map with all of these at the end.
defp load_offline_model_params(model) do
Logger.info("ℹ️ Loading #{model.name}...")
# Loading model
loading_settings = {:hf, model.name, cache_dir: model.cache_path, offline: true}
{:ok, model_info} = Bumblebee.load_model(loading_settings)
info = %{model_info: model_info}
# Load featurizer, tokenizer and generation config if needed
info =
if(model.load_featurizer) do
{:ok, featurizer} = Bumblebee.load_featurizer(loading_settings)
Map.put(info, :featurizer, featurizer)
else
info
end
info =
if(model.load_tokenizer) do
{:ok, tokenizer} = Bumblebee.load_tokenizer(loading_settings)
Map.put(info, :tokenizer, tokenizer)
else
info
end
info =
if(model.load_generation_config) do
{:ok, generation_config} =
Bumblebee.load_generation_config(loading_settings)
Map.put(info, :generation_config, generation_config)
else
info
end
# Return a map with the model and respective parameters.
info
end
# Downloads the models according to a given %ModelInfo struct.
# It will load the model and the respective the featurizer, tokenizer and generation config if needed.
defp download_model(model) do
Logger.info("ℹ️ Downloading #{model.name}...")
# Download model
downloading_settings = {:hf, model.name, cache_dir: model.cache_path}
Bumblebee.load_model(downloading_settings)
# Download featurizer, tokenizer and generation config if needed
if(model.load_featurizer) do
Bumblebee.load_featurizer(downloading_settings)
end
if(model.load_tokenizer) do
Bumblebee.load_tokenizer(downloading_settings)
end
if(model.load_generation_config) do
Bumblebee.load_generation_config(downloading_settings)
end
end
end