-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinferOperators.m
155 lines (138 loc) · 4.48 KB
/
inferOperators.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
function [operators] = inferOperators(X, U, Vr, params, rhs)
% Infers linear, quadratic, bilinear, input, and constant matrix operators
% for state data in X projected onto basis in Vr and optional input data U.
%
% INPUTS
% X N-by-K full state data matrix
% U K-by-m input data matrix
% Vr N-by-r basis in which to learn reduced model
% params struct with operator inference parameters - see PARAMS below
% rhs N-b-K optional user-specified RHS for least-squares solve to be used
% e.g., if user has Xdot data in the continuous time setting or
% if data in X is non-uniformly spaced in time
%
% PARAMETERS - below are the possible fields of the input struct 'params'
% modelform string indicating which terms to learn: e.g. 'LI' corresponds
% to linear model wit input; dxdt = A*x + B*u(t)
% Options: 'L'inear, 'Q'uadratic, 'B'ilinear, 'I'nput, 'C'onstant
% Order of inputs does not matter, e.g. 'LQI' vs 'QIL'
% modeltime 'continuous' e.g., if model form is dxdt = A*x OR
% 'discrete' e.g., if model form is x(k+1) = A*x(k)
% dt timestep used to calculate state time deriv for
% continuous-time models (use 0 if not needed)
% lambda L2 penalty weighting
% scale if true, scale data matrix to within [-1,1] before LS solve
% ddt_order passed to ddt.m to determine scheme used to calculate
% derivative, default is first order forward difference
%
% OUTPUT
% operators struct with inferred operators A, H, N, B, C. Terms that
% are not part of the model are returned as empty matrices.
%
% AUTHOR (CODE)
% Elizabeth Qian ([email protected]) 11 July 2019
%
% PUBLICATIONS
% Peherstorfer, B. and Willcox, K., "Data-driven operator inference for
% non-intrusive projection-based model reduction." Computer Methods in
% Applied Mechanics and Engineering, 306:196-215, 2016.
%
% Qian, E., Kramer, B., Marques, A. and Willcox, K., "Transform & Learn:
% A data-driven approach to nonlinear model reduction." In AIAA Aviation
% 2019 Forum, June 17-21, Dallas, TX.
if ~isfield(params,'modeltime') & nargin < 5
error('Discrete vs continuous not specified and no RHS provided for LS solve.')
end
if ~isfield(params,'dt') & nargin <5 & strcmp(params.modelform,'continuous')
error('No dXdt data provided and no timestep provided in params with which to calculate dXdt')
end
if ~isfield(params,'ddt_order')
params.ddt_order = 1;
end
if ~isfield(params,'lambda')
params.lambda = 0;
end
if ~isfield(params,'scale')
params.scale = false;
end
m = size(U,2);
% if no right-hand side for LS problem is provided, calculate the rhs from
% state data based on specified model time
if nargin < 5
switch params.modeltime
case 'discrete'
rhs = X(:,2:end)'*Vr;
ind = 1:(size(X,2)-1);
case 'continuous'
[Xdot,ind] = ddt(X,params.dt,params.ddt_order);
rhs = Xdot'*Vr;
end
else
rhs = rhs'*Vr;
ind = 1:size(X,2);
end
% get least-squares data matrix based on desired model form
[D,l,c,s,mr] = getDataMatrix(X,Vr,U,ind,params.modelform);
% scale data before LS solve if desired
if params.scale
scl = max(abs(D),[],1);
else
scl = ones(1,size(D,2));
end
Dscl = D./scl;
% Solve LS problem and pull operators from result
temp = tikhonov(rhs,Dscl,params.lambda)';
temp = temp./scl; % un-scale solution
operators.A = temp(:,1:l);
operators.F = temp(:,l+1:l+s);
operators.H = F2H(operators.F);
operators.N = temp(:,l+s+1:l+s+mr);
operators.B = temp(:,l+s+mr+1:l+s+mr+m);
operators.C = temp(:,l+s+mr+m+c);
end
%% builds data matrix based on desired form of learned model
function [D,l,c,s,mr] = getDataMatrix(X,Vr,U,ind,modelform)
K = length(ind);
r = size(Vr,2);
Xhat = Vr'*X(:,ind);
% if rhs contains B*u(t) input term
if contains(modelform,'I')
U0 = U(ind,:);
else
U0 = [];
end
% if rhs contains quadratic H*kron(x,x) term
if contains(modelform,'Q')
Xsq = get_x_sq(Xhat');
s = r*(r+1)/2;
else
Xsq = [];
s = 0;
end
% if rhs contains constant term
if contains(modelform,'C')
Y = ones(K,1);
c = 1;
else
Y = [];
c = 0;
end
% if rhs contains linear A*x term
if ~contains(modelform,'L')
Xhat = [];
l = 0;
else
l = r;
end
% if rhs contains bilinear N*x*u term
XU = [];
if contains(modelform,'B')
for i = 1:m
XU = [XU, Xhat'.*U0(:,i)];
end
mr = m*r;
else
mr = 0;
end
D = [Xhat', Xsq, XU, U0, Y];
end