-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathuser_value_functions.py
35 lines (25 loc) · 1.03 KB
/
user_value_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np
class UserValueFunction:
def __init__(self, weights):
if np.round(np.sum(weights), 7) == 1.0:
self.weights = np.array(weights)
else:
raise ValueError(f'weights do not add up to 1.0: {weights} (sum={np.sum(weights)})')
def _function(self, objectives):
raise NotImplementedError
def calculate(self, objectives):
if len(objectives) == len(self.weights):
return self._function(objectives)
else:
raise ValueError(f'number of objectives ({len(objectives)}) is different than the '
f'number of weights ({len(self.weights)})')
class LinearUserValueFunction(UserValueFunction):
def __str__(self):
return 'Linear'
def _function(self, objectives):
return np.sum(self.weights * objectives)
class ChebycheffUserValueFunction(UserValueFunction):
def __str__(self):
return 'Chebycheff'
def _function(self, objectives):
return np.max(self.weights * objectives)