-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathsimpleobsws.py
332 lines (303 loc) · 13.7 KB
/
simpleobsws.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
import logging
wsLogger = logging.getLogger('websockets')
wsLogger.setLevel(logging.INFO)
log = logging.getLogger(__name__)
import asyncio
import websockets
import base64
import hashlib
import json
import msgpack
import uuid
import time
import inspect
import enum
from dataclasses import dataclass, field
from inspect import signature
RPC_VERSION = 1
class RequestBatchExecutionType(enum.Enum):
SerialRealtime = 0
SerialFrame = 1
Parallel = 2
@dataclass
class IdentificationParameters:
ignoreNonFatalRequestChecks: bool = None
eventSubscriptions: int = None
@dataclass
class Request:
requestType: str
requestData: dict = None
inputVariables: dict = None # Request batch only
outputVariables: dict = None # Request batch only
@dataclass
class RequestStatus:
result: bool = False
code: int = 0
comment: str = None
@dataclass
class RequestResponse:
requestType: str = ''
requestStatus: RequestStatus = field(default_factory=RequestStatus)
responseData: dict = None
def has_data(self):
return self.responseData != None
def ok(self):
return self.requestStatus.result
@dataclass
class _ResponseWaiter:
event: asyncio.Event = field(default_factory=asyncio.Event)
response_data: dict = None
class MessageTimeout(Exception):
pass
class EventRegistrationError(Exception):
pass
class NotIdentifiedError(Exception):
pass
async def _wait_for_cond(cond, func):
async with cond:
await cond.wait_for(func)
class WebSocketClient:
def __init__(self,
url: str = "ws://localhost:4444",
password: str = '',
identification_parameters: IdentificationParameters = IdentificationParameters()
):
self.url = url
self.password = password
self.identification_parameters = identification_parameters
self.http_headers = {}
self.ws = None
self.ws_open = False
self.waiters = {}
self.identified = False
self.recv_task = None
self.hello_message = None
self.event_callbacks = []
self.cond = asyncio.Condition()
# Todo: remove bool return, raise error if already open
async def connect(self):
if self.ws and self.ws_open:
log.debug('WebSocket session is already open. Returning early.')
return False
self.answers = {}
self.recv_task = None
self.identified = False
self.hello_message = None
self.ws = await websockets.connect(self.url, subprotocols = ['obswebsocket.msgpack'], additional_headers = self.http_headers, max_size=2**24)
self.ws_open = True
self.recv_task = asyncio.create_task(self._ws_recv_task())
return True
async def wait_until_identified(self, timeout: int = 10):
if not self.ws_open:
log.debug('WebSocket session is not open. Returning early.')
return False
try:
await asyncio.wait_for(_wait_for_cond(self.cond, self.is_identified), timeout=timeout)
return True
except asyncio.TimeoutError:
return False
# Todo: remove bool return, raise error if already closed
async def disconnect(self):
if self.recv_task == None:
log.debug('WebSocket session is not open. Returning early.')
return False
self.recv_task.cancel()
await self.ws.close()
self.ws = None
self.ws_open = False
self.answers = {}
self.identified = False
self.recv_task = None
self.hello_message = None
return True
async def call(self, request: Request, timeout: int = 15):
if not self.identified:
raise NotIdentifiedError('Calls to requests cannot be made without being identified with obs-websocket.')
request_id = str(uuid.uuid1())
request_payload = {
'op': 6,
'd': {
'requestType': request.requestType,
'requestId': request_id
}
}
if request.requestData != None:
request_payload['d']['requestData'] = request.requestData
log.debug('Sending Request message:\n{}'.format(json.dumps(request_payload, indent=2)))
waiter = _ResponseWaiter()
try:
self.waiters[request_id] = waiter
await self.ws.send(msgpack.packb(request_payload))
await asyncio.wait_for(waiter.event.wait(), timeout=timeout)
except asyncio.TimeoutError:
raise MessageTimeout('The request with type {} timed out after {} seconds.'.format(request.requestType, timeout))
finally:
del self.waiters[request_id]
return self._build_request_response(waiter.response_data)
async def emit(self, request: Request):
if not self.identified:
raise NotIdentifiedError('Emits to requests cannot be made without being identified with obs-websocket.')
request_id = str(uuid.uuid1())
request_payload = {
'op': 6,
'd': {
'requestType': request.requestType,
'requestId': 'emit_{}'.format(request_id)
}
}
if request.requestData != None:
request_payload['d']['requestData'] = request.requestData
log.debug('Sending Request message:\n{}'.format(json.dumps(request_payload, indent=2)))
await self.ws.send(msgpack.packb(request_payload))
async def call_batch(self, requests: list, timeout: int = 15, halt_on_failure: bool = None, execution_type: RequestBatchExecutionType = None, variables: dict = None):
if not self.identified:
raise NotIdentifiedError('Calls to requests cannot be made without being identified with obs-websocket.')
request_batch_id = str(uuid.uuid1())
request_batch_payload = {
'op': 8,
'd': {
'requestId': request_batch_id,
'requests': []
}
}
if halt_on_failure != None:
request_batch_payload['d']['haltOnFailure'] = halt_on_failure
if execution_type:
request_batch_payload['d']['executionType'] = execution_type.value
if variables:
request_batch_payload['d']['variables'] = variables
for request in requests:
request_payload = {
'requestType': request.requestType
}
if request.inputVariables:
request_payload['inputVariables'] = request.inputVariables
if request.outputVariables:
request_payload['outputVariables'] = request.outputVariables
if request.requestData:
request_payload['requestData'] = request.requestData
request_batch_payload['d']['requests'].append(request_payload)
log.debug('Sending Request batch message:\n{}'.format(json.dumps(request_batch_payload, indent=2)))
waiter = _ResponseWaiter()
try:
self.waiters[request_batch_id] = waiter
await self.ws.send(msgpack.packb(request_batch_payload))
await asyncio.wait_for(waiter.event.wait(), timeout=timeout)
except asyncio.TimeoutError:
raise MessageTimeout('The request batch timed out after {} seconds.'.format(timeout))
finally:
del self.waiters[request_batch_id]
ret = []
for result in waiter.response_data['results']:
ret.append(self._build_request_response(result))
return ret
async def emit_batch(self, requests: list, halt_on_failure: bool = None, execution_type: RequestBatchExecutionType = None, variables: dict = None):
if not self.identified:
raise NotIdentifiedError('Emits to requests cannot be made without being identified with obs-websocket.')
request_batch_id = str(uuid.uuid1())
request_batch_payload = {
'op': 8,
'd': {
'requestId': 'emit_{}'.format(request_batch_id),
'requests': []
}
}
if halt_on_failure != None:
request_batch_payload['d']['haltOnFailure'] = halt_on_failure
if execution_type:
request_batch_payload['d']['executionType'] = execution_type.value
if variables:
request_batch_payload['d']['variables'] = variables
for request in requests:
request_payload = {
'requestType': request.requestType
}
if request.requestData:
request_payload['requestData'] = request.requestData
request_batch_payload['d']['requests'].append(request_payload)
log.debug('Sending Request batch message:\n{}'.format(json.dumps(request_batch_payload, indent=2)))
await self.ws.send(msgpack.packb(request_batch_payload))
def register_event_callback(self, callback, event: str = None):
if not inspect.iscoroutinefunction(callback):
raise EventRegistrationError('Registered functions must be async')
else:
self.event_callbacks.append((callback, event))
def deregister_event_callback(self, callback, event: str = None):
for c, t in self.event_callbacks.copy():
if (c == callback) and (event == None or t == event):
self.event_callbacks.remove((c, t))
def is_identified(self):
return self.identified
def _get_hello_data(self):
return self.hello_message
def _build_request_response(self, response: dict):
ret = RequestResponse(response['requestType'], responseData = response.get('responseData'))
ret.requestStatus.result = response['requestStatus']['result']
ret.requestStatus.code = response['requestStatus']['code']
ret.requestStatus.comment = response['requestStatus'].get('comment')
return ret
async def _send_identify(self, password, identification_parameters):
if self.hello_message == None:
return
identify_message = {'op': 1, 'd': {}}
identify_message['d']['rpcVersion'] = RPC_VERSION
if 'authentication' in self.hello_message:
secret = base64.b64encode(hashlib.sha256((self.password + self.hello_message['authentication']['salt']).encode('utf-8')).digest())
authentication_string = base64.b64encode(hashlib.sha256(secret + (self.hello_message['authentication']['challenge'].encode('utf-8'))).digest()).decode('utf-8')
identify_message['d']['authentication'] = authentication_string
if self.identification_parameters.ignoreNonFatalRequestChecks != None:
identify_message['d']['ignoreNonFatalRequestChecks'] = self.identification_parameters.ignoreNonFatalRequestChecks
if self.identification_parameters.eventSubscriptions != None:
identify_message['d']['eventSubscriptions'] = self.identification_parameters.eventSubscriptions
log.debug('Sending Identify message:\n{}'.format(json.dumps(identify_message, indent=2)))
await self.ws.send(msgpack.packb(identify_message))
async def _ws_recv_task(self):
while self.ws_open:
message = ''
try:
message = await self.ws.recv()
if not message or type(message) != bytes:
continue
incoming_payload = msgpack.unpackb(message)
log.debug('Received message:\n{}'.format(json.dumps(incoming_payload, indent=2)))
op_code = incoming_payload['op']
data_payload = incoming_payload['d']
if op_code == 7 or op_code == 9: # RequestResponse or RequestBatchResponse
paylod_request_id = data_payload['requestId']
if paylod_request_id.startswith('emit_'):
continue
try:
waiter = self.waiters[paylod_request_id]
waiter.response_data = data_payload
waiter.event.set()
except KeyError:
log.warning('Discarding request response {} because there is no waiter for it.'.format(paylod_request_id))
elif op_code == 5: # Event
for callback, trigger in self.event_callbacks:
if trigger == None:
params = len(signature(callback).parameters)
if params == 1:
asyncio.create_task(callback(data_payload))
elif params == 2:
asyncio.create_task(callback(data_payload['eventType'], data_payload.get('eventData')))
elif params == 3:
asyncio.create_task(callback(data_payload['eventType'], data_payload.get('eventIntent'), data_payload.get('eventData')))
elif trigger == data_payload['eventType']:
asyncio.create_task(callback(data_payload.get('eventData')))
elif op_code == 0: # Hello
self.hello_message = data_payload
await self._send_identify(self.password, self.identification_parameters)
elif op_code == 2: # Identified
self.identified = True
async with self.cond:
self.cond.notify_all()
else:
log.warning('Unknown OpCode: {}'.format(op_code))
except (websockets.exceptions.ConnectionClosed, websockets.exceptions.ConnectionClosedError, websockets.exceptions.ConnectionClosedOK):
log.debug('The WebSocket connection was closed. Code: {} | Reason: {}'.format(self.ws.close_code, self.ws.close_reason))
self.ws_open = False
break
except (ValueError, msgpack.UnpackException):
continue
self.ws_open = False
self.identified = False