Skip to content

Commit

Permalink
Support RNN models
Browse files Browse the repository at this point in the history
  • Loading branch information
guoqingbao committed May 20, 2024
1 parent cb93d1b commit 8e7aa85
Show file tree
Hide file tree
Showing 7 changed files with 811 additions and 93 deletions.
572 changes: 572 additions & 0 deletions examples/compilation_test_torch.ipynb

Large diffs are not rendered by default.

65 changes: 64 additions & 1 deletion examples/keras_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,4 +786,67 @@ def ResNet18(classes, input_shape, weight_decay=1e-4):
x = Flatten()(x)
x = Dense(classes, activation='softmax')(x)
model = Model(input, x, name='ResNet18')
return model
return model

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Layer, RNN
class LSTMBlockCell(Layer):
def __init__(self, units, **kwargs):
self.units = units
self.state_size = [units, units]
super(LSTMBlockCell, self).__init__(**kwargs)

def build(self, input_shape):
self.w_xi = self.add_weight(name='w_xi',
shape=(input_shape[-1], self.units), initializer='uniform')
self.w_xf = self.add_weight(name='w_xf',
shape=(input_shape[-1], self.units), initializer='uniform')
self.w_xo = self.add_weight(name='w_xo',
shape=(input_shape[-1], self.units), initializer='uniform')
self.w_xc = self.add_weight(name='w_xc',
shape=(input_shape[-1], self.units), initializer='uniform')
self.w_hi = self.add_weight(name='w_hi',
shape=(self.units, self.units), initializer='uniform')
self.w_hf = self.add_weight(name='w_hf',
shape=(self.units, self.units), initializer='uniform')
self.w_ho = self.add_weight(name='w_ho',
shape=(self.units, self.units), initializer='uniform')
self.w_hc = self.add_weight(name='w_hc',
shape=(self.units, self.units), initializer='uniform')
self.b_i = self.add_weight(name='b_i',
shape=(1, self.units), initializer='zeros')
self.b_f = self.add_weight(name='b_f',
shape=(1, self.units), initializer='zeros')
self.b_o = self.add_weight(name='b_o',
shape=(1, self.units), initializer='zeros')
self.b_c = self.add_weight(name='b_c',
shape=(1, self.units), initializer='zeros')

self.built = True

def call(self, x, states):
h, c = states
i = K.sigmoid(K.dot(x, self.w_xi) + K.dot(h, self.w_hi) + self.b_i)
f = K.sigmoid(K.dot(x, self.w_xf) + K.dot(h, self.w_hf) + self.b_f)
o = K.sigmoid(K.dot(x, self.w_xo) + K.dot(h, self.w_ho) + self.b_o)

c_in = K.tanh(K.dot(x, self.w_xc) + K.dot(h, self.w_hc) + self.b_c)
c_n = f * c + i * c_in
h_n = o * K.tanh(c_n)

return h_n, [h_n, c_n]

def KerasLSTM(input_shape, seq_size, hidden_size):
h0 = np.zeros((seq_size, hidden_size), dtype=np.float32)
c0 = np.zeros((seq_size, hidden_size), dtype=np.float32)
state = (K.constant(h0), K.constant(c0))

input = Input(shape=input_shape)
x = input
for i in range(16):
x, state = LSTMBlockCell(hidden_size)(x, state)

model = Model(input, x, name='LSTM'+str(seq_size))
return model
46 changes: 43 additions & 3 deletions examples/torch_def.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import torch
import torch.nn as nn
from torch.onnx import TrainingMode
import math

# Sample pytorch model definition
#input 32x32
class SimpleCNN1(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -113,4 +112,45 @@ def forward(self, x):
x = self.flat(self.batch_norm(self.pool(self.conv(x))))

x = self.linear(x)
return self.relu(x)
return self.relu(x)

class SimpleLSTM(nn.Module):
def __init__(self, input_size, hidden_size, seq_size):
super().__init__()
self.input_sz = input_size
self.hidden_size = hidden_size
self.seq_sz = seq_size
self.W = nn.Parameter(torch.Tensor(hidden_size, hidden_size * 4))
self.U = nn.Parameter(torch.Tensor(hidden_size, hidden_size * 4))
self.bias = nn.Parameter(torch.Tensor(hidden_size * 4))
self.init_weights()

def init_weights(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)

def forward(self, x,
h_t, c_t):
"""Assumes x is of shape (batch, sequence, feature)"""
bs, seq_sz, _ = x.size()
hidden_seq = []

HS = self.hidden_size
for t in range(self.seq_sz):
x_t = x[:, t, :]
# batch the computations into a single matrix multiplication
gates = x_t @ self.W + h_t @ self.U + self.bias
i_t, f_t, g_t, o_t = (
torch.sigmoid(gates[:, :HS]), # input
torch.sigmoid(gates[:, HS:HS*2]), # forget
torch.tanh(gates[:, HS*2:HS*3]),
torch.sigmoid(gates[:, HS*3:]), # output
)
c_t = f_t * c_t + i_t * g_t
h_t = o_t * torch.tanh(c_t)
hidden_seq.append(h_t.unsqueeze(0))
hidden_seq = torch.cat(hidden_seq, dim=0)
# reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
hidden_seq = hidden_seq.transpose(0, 1).contiguous()
return hidden_seq, (h_t, c_t)
9 changes: 5 additions & 4 deletions python/ufront/keras/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@
from ..utils import list_product, onnx_to_ufront_dtype, numpy_to_ufront_dtype, ufront_to_numpy_dtype

class BaseModel(object):
def __init__(self, inputs, onnx_model, batch_size, transformer, pass_weights):
def __init__(self, inputs, onnx_model, batch_size, simplify, transformer, pass_weights):
self.ufront_model = Model()
self.transformer=transformer
self._onnx_model = onnx_model
self.pass_weights=pass_weights
self._loss = None
self._metrics = []
self._label_type = DataType.Float
self._my_onnx_model = ONNXModelKeras(self._onnx_model, self.ufront_model, self.transformer, self.pass_weights)
self._my_onnx_model = ONNXModelKeras(self._onnx_model, self.ufront_model, simplify, self.transformer, self.pass_weights)
self._num_samples = 0
self._input_dataloaders = []
self._input_dataloaders_dim = []
Expand Down Expand Up @@ -293,7 +293,7 @@ def _train(self, epochs, callbacks, eval=False):

class UFrontKeras(tf_keras_Model):
def __init__(self, base_model, inputs,
batch_size, verbose=False, transformer=False, pass_weights=False):
batch_size, verbose=False, simplify=False, transformer=False, pass_weights=False):
super(UFrontKeras, self).__init__(inputs=base_model.inputs, outputs=base_model.output, name=base_model.name)
if (isinstance(inputs, list) == False):
assert 0, "Inputs must be in list format, e.g., [input_tensor1, input_tensor2]"
Expand All @@ -305,7 +305,8 @@ def __init__(self, base_model, inputs,
# onnx_model = onnx.load_model("resnet18.onnx")
# self._base_model = BaseModel(inputs=input_dict, onnx_model=onnx_model[0], batch_size=batch_size, transformer=transformer, pass_weights=pass_weights)
self.onnx_model = tf2onnx.convert.from_keras(self, opset=18 if transformer else 17)
self._base_model = BaseModel(inputs=input_dict, onnx_model=self.onnx_model[0], batch_size=batch_size, transformer=transformer, pass_weights=pass_weights)
# onnx.save_model(self.onnx_model[0], f="Keras_LSTM.onnx")
self._base_model = BaseModel(inputs=input_dict, onnx_model=self.onnx_model[0], batch_size=batch_size, simplify=simplify, transformer=transformer, pass_weights=pass_weights)


def umodel(self):
Expand Down
57 changes: 36 additions & 21 deletions python/ufront/onnx/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ def __init__(self, onnx_model, umodel=None, simplify = False, pass_weights=False
self.umodel.weight_type = WeightType.EXTERNAL if pass_weights else WeightType.INTERNAL

if type(onnx_model) == str:
model = onnx.load(onnx_model)
elif simplify:
onnx_model = onnx.load(onnx_model)

if simplify:
try:
# simply onnx models, for example, merge sub operators in onnx for chunk, remove redundant operators
onnx_model_, check = onnxsim.simplify(onnx_model)
Expand All @@ -83,19 +84,15 @@ def __init__(self, onnx_model, umodel=None, simplify = False, pass_weights=False
except:
print("Some of the ONNX models requires onnxsim library!")

model = onnx_model
else:
model = onnx_model

self.inputs = {}
self.operators = []
for input in model.graph.input:
for input in onnx_model.graph.input:
tensor = ONNXTensor(input.name, input.type.tensor_type.shape.dim, 1)
self.inputs[input.name] = tensor
self.outputs = {}
for output in model.graph.output:
for output in onnx_model.graph.output:
self.outputs[output.name] = output
self.model = model
self.model = onnx_model
self.symbol_table = {}

def handleAdd(self, node, node_to_output):
Expand Down Expand Up @@ -225,6 +222,11 @@ def handleExpand(self, node, node_to_output):
operator = self.addTensor(input_tensor, False, node.input[0])
self.operators.append(operator)
input_tensor = operator.get_output(0)
elif type(input_tensor) != Tensor:
input_tensor = np.array([input_tensor]).astype(np.float32)
operator = self.addTensor(input_tensor, False, node.input[0])
self.operators.append(operator)
input_tensor = operator.get_output(0)

return self.umodel.expand(
input=input_tensor, sizes=output_shape, name=node.name,
Expand All @@ -233,15 +235,18 @@ def handleExpand(self, node, node_to_output):
def handleSplit(self, node, node_to_output):
input = node_to_output[node.input[0]]
attribute = {x.name: x for x in node.attribute}
if 'axis' in attribute:
axis = attribute['axis'].i
else:
axis = 0

if "split" in attribute:
split = list(attribute['split'].ints)
elif len(node.input) > 1:
split = node_to_output[node.input[1]]
if 'axis' in attribute:
axis = attribute['axis'].i
else:
axis = 0

if type(input) == list:
return input
return self.umodel.split(input=input, sizes=list(split), axis=axis, name=node.name)

def handleAveragePool(self, node, node_to_output):
Expand Down Expand Up @@ -546,7 +551,7 @@ def handleUnsqueeze(self, node, node_to_output):
input=input, shape=list(shape), name=node.name,
)

def unpack_rawdata(self, raw_data, data_type, shape, name):
def unpack_rawdata(self, raw_data, data_type, shape, name=None):
if len(shape) > 0:
length = list_product(shape)
else:
Expand All @@ -571,7 +576,7 @@ def unpack_rawdata(self, raw_data, data_type, shape, name):
output = output[0] #scalar
elif len(shape) > 1:
output = output.reshape(shape) # ndarray
if name =="class_token" or name =="encoder.pos_embedding":
if name != None and (name =="class_token" or name =="encoder.pos_embedding"):
weight_op = self.umodel.parameter(np_tensor=output, dtype=numpy_to_ufront_dtype(output.dtype), requires_grad=True, name=name)
output = weight_op.get_output(0)
return output
Expand All @@ -582,10 +587,11 @@ def handleConstant(self, node, node_to_output):
data_type = onnx_to_ufront_dtype(tensor.data_type)
raw_data = tensor.raw_data

if len(tensor.dims) > 1: #TODO set raw_data array to constant tensor
if len(tensor.dims) > 1:
np_tensor = self.unpack_rawdata(raw_data, data_type, tensor.dims)
output = self.umodel.create_tensor(tensor.dims, DataType.Float, True, "constant_tensor" + str(ONNXModel.const_tensor_idx))
output.set_ndarray(np_tensor)
output = self.addTensor(np_tensor, False, "constant_tensor" + str(ONNXModel.const_tensor_idx))
# output = self.umodel.create_tensor(tensor.dims, DataType.Float, True, "constant_tensor" + str(ONNXModel.const_tensor_idx))
# output.ndarray = np_tensor
ONNXModel.const_tensor_idx += 1
else:
output = self.unpack_rawdata(raw_data, data_type, tensor.dims)
Expand Down Expand Up @@ -741,6 +747,8 @@ def handleSlice(self, node, node_to_output):

elif type(input) != list and type(input) != tuple:
output_shape = input.shape
if ends > output_shape[axis]: #exceed the dim size
ends = output_shape[axis]
output_shape[axis] = ends - starts
if axis == -1:
axis = len(input.shape) - 1
Expand Down Expand Up @@ -1000,7 +1008,14 @@ def apply(self, input_tensors):
for name in outputs.keys():
if name in node_to_output.keys():
tensor_outputs.append(node_to_output[name])
return tensor_outputs if len(tensor_outputs) > 0 else node_to_output[next(reversed(node_to_output))]

tensor_outputs = tensor_outputs if len(tensor_outputs) > 0 else node_to_output[next(reversed(node_to_output))]
output_names = []
for tensor in tensor_outputs:
output_names.append(tensor.name)
if len(output_names) > 1:
self.umodel.output(outputs=output_names)
return tensor_outputs

def get_output_operator(self):
if len(self.operators) > 0:
Expand Down Expand Up @@ -1115,8 +1130,8 @@ def _fusion_layer(self, name):


class ONNXModelKeras(ONNXModel):
def __init__(self, onnx_model, umodel=None, transformer=False, pass_weights=False):
super(ONNXModelKeras, self).__init__(onnx_model=onnx_model, umodel=umodel, transformer=transformer, pass_weights=pass_weights)
def __init__(self, onnx_model, umodel=None, simplify=False, transformer=False, pass_weights=False):
super(ONNXModelKeras, self).__init__(onnx_model=onnx_model, umodel=umodel, simplify=simplify, transformer=transformer, pass_weights=pass_weights)
self.transformer=transformer
for node in onnx_model.graph.node:
if node.name.find("u_front_keras/") != -1:
Expand Down
Loading

0 comments on commit 8e7aa85

Please sign in to comment.