-
Notifications
You must be signed in to change notification settings - Fork 236
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support for real time collaboration and model export using celery(#390)
- Loading branch information
1 parent
cb9b8f9
commit 3e99340
Showing
67 changed files
with
2,697 additions
and
265 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
[flake8] | ||
max-line-length = 110 | ||
exclude = | ||
exclude = | ||
./tensorflow_app/caffe-tensorflow, | ||
./node_modules/*, | ||
*/migrations/*, | ||
docs/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,4 +25,6 @@ node_modules/ | |
|
||
ide/static/bundle/ | ||
|
||
.vscode/ | ||
|
||
celerybeat-schedule |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Register your models here. | ||
from django.contrib import admin # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# -*- coding: utf-8 -*- | ||
from __future__ import unicode_literals | ||
|
||
from django.apps import AppConfig | ||
|
||
|
||
class BackendapiConfig(AppConfig): | ||
name = 'backendAPI' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from django.conf.urls import url | ||
from views import check_login | ||
|
||
urlpatterns = [ | ||
url(r'^checkLogin$', check_login) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# -*- coding: utf-8 -*- | ||
from __future__ import unicode_literals | ||
from django.http import JsonResponse | ||
from django.contrib.auth.models import User | ||
|
||
|
||
def check_login(request): | ||
try: | ||
user = User.objects.get(username=request.user.username) | ||
user_id = user.id | ||
username = 'Anonymous' | ||
|
||
is_authenticated = user.is_authenticated() | ||
if (is_authenticated): | ||
username = user.username | ||
|
||
return JsonResponse({ | ||
'result': is_authenticated, | ||
'user_id': user_id, | ||
'username': username | ||
}) | ||
except Exception as e: | ||
return JsonResponse({ | ||
'result': False, | ||
'error': str(e) | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
# Register your models here. | ||
from django.contrib import admin | ||
from .models import ModelExport | ||
from .models import SharedWith, Network | ||
|
||
admin.site.register(ModelExport) | ||
admin.site.register(SharedWith) | ||
admin.site.register(Network) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
import json | ||
import yaml | ||
import urlparse | ||
from channels import Group | ||
from channels.auth import channel_session_user, channel_session_user_from_http | ||
from caffe_app.models import Network, NetworkVersion, NetworkUpdates | ||
from ide.views import get_network_version | ||
from ide.tasks import export_caffe_prototxt, export_keras_json | ||
|
||
|
||
def create_network_version(network, netObj): | ||
# creating a unique version of network to allow revert and view hitory | ||
network_version = NetworkVersion(network=netObj) | ||
network_version.network_def = network | ||
network_version.save() | ||
return network_version | ||
|
||
|
||
def create_network_update(network_version, updated_data, tag): | ||
network_update = NetworkUpdates(network_version=network_version, | ||
updated_data=updated_data, | ||
tag=tag) | ||
return network_update | ||
|
||
|
||
def fetch_network_version(netObj): | ||
network_version = NetworkVersion.objects.filter(network=netObj).order_by('-created_on')[0] | ||
updates_batch = NetworkUpdates.objects.filter(network_version=network_version) | ||
|
||
# Batching updates | ||
# Note - size of batch is 20 for now, optimization can be done | ||
if len(updates_batch) == 2: | ||
data = get_network_version(netObj) | ||
network_version = NetworkVersion(network=netObj, network_def=json.dumps(data['network'])) | ||
network_version.save() | ||
|
||
network_update = NetworkUpdates(network_version=network_version, | ||
updated_data=json.dumps({'nextLayerId': data['next_layer_id']}), | ||
tag='CheckpointCreated') | ||
network_update.save() | ||
return network_version | ||
|
||
|
||
@channel_session_user_from_http | ||
def ws_connect(message): | ||
print('connection being established...') | ||
message.reply_channel.send({ | ||
'accept': True | ||
}) | ||
# extracting id of network from url params | ||
params = urlparse.parse_qs(message.content['query_string']) | ||
networkId = params.get('id', ('Not Supplied',))[0] | ||
message.channel_session['networkId'] = networkId | ||
# adding socket to a group based on networkId to send updates of network | ||
Group('model-{0}'.format(networkId)).add(message.reply_channel) | ||
|
||
|
||
@channel_session_user | ||
def ws_disconnect(message): | ||
networkId = message.channel_session['networkId'] | ||
Group('model-{0}'.format(networkId)).discard(message.reply_channel) | ||
print('disconnected...') | ||
|
||
|
||
@channel_session_user | ||
def ws_receive(message): | ||
print('message received...') | ||
# param initialization | ||
data = yaml.safe_load(message['text']) | ||
action = data['action'] | ||
|
||
if ('randomId' in data): | ||
randomId = data['randomId'] | ||
|
||
if ('networkId' in message.channel_session): | ||
networkId = message.channel_session['networkId'] | ||
|
||
if (action == 'ExportNet'): | ||
# async export call | ||
framework = data['framework'] | ||
net = data['net'] | ||
net_name = data['net_name'] | ||
reply_channel = message.reply_channel.name | ||
|
||
if (framework == 'caffe'): | ||
export_caffe_prototxt.delay(net, net_name, reply_channel) | ||
elif (framework == 'keras'): | ||
export_keras_json.delay(net, net_name, False, reply_channel) | ||
elif (framework == 'tensorflow'): | ||
export_keras_json.delay(net, net_name, True, reply_channel) | ||
|
||
elif (action == 'UpdateHighlight'): | ||
add_highlight_to = data['addHighlightTo'] | ||
remove_highlight_from = data['removeHighlightFrom'] | ||
user_id = data['userId'] | ||
highlight_color = data['highlightColor'] | ||
username = data['username'] | ||
|
||
Group('model-{0}'.format(networkId)).send({ | ||
'text': json.dumps({ | ||
'addHighlightTo': add_highlight_to, | ||
'removeHighlightFrom': remove_highlight_from, | ||
'userId': user_id, | ||
'action': action, | ||
'randomId': randomId, | ||
'highlightColor': highlight_color, | ||
'username': username | ||
}) | ||
}) | ||
else: | ||
# save changes to database to maintain consistency | ||
# get the net object on which update is made | ||
netObj = Network.objects.get(id=int(networkId)) | ||
network_version = fetch_network_version(netObj) | ||
|
||
if (action == 'UpdateParam'): | ||
updated_data = {} | ||
updated_data['layerId'] = data['layerId'] | ||
updated_data['param'] = data['param'] | ||
updated_data['value'] = data['value'] | ||
updated_data['isProp'] = data['isProp'] | ||
updated_data['nextLayerId'] = data['nextLayerId'] | ||
|
||
network_update = create_network_update(network_version, json.dumps(updated_data), data['action']) | ||
network_update.save() | ||
# sending update made by one user over all the sessions of open network | ||
# Note - conflict resolution still pending | ||
Group('model-{0}'.format(networkId)).send({ | ||
'text': json.dumps({ | ||
'layerId': updated_data['layerId'], | ||
'param': updated_data['param'], | ||
'value': updated_data['value'], | ||
'isProp': updated_data['isProp'], | ||
'action': action, | ||
'version_id': 0, | ||
'randomId': randomId | ||
}) | ||
}) | ||
elif (data['action'] == 'DeleteLayer'): | ||
updated_data = {} | ||
updated_data['layerId'] = data['layerId'] | ||
updated_data['nextLayerId'] = data['nextLayerId'] | ||
|
||
network_update = create_network_update(network_version, json.dumps(updated_data), data['action']) | ||
network_update.save() | ||
|
||
# Note - conflict resolution still pending | ||
Group('model-{0}'.format(networkId)).send({ | ||
'text': json.dumps({ | ||
'layerId': updated_data['layerId'], | ||
'action': action, | ||
'version_id': 0, | ||
'randomId': randomId | ||
}) | ||
}) | ||
elif (action == 'AddLayer'): | ||
updated_data = {} | ||
updated_data['prevLayerId'] = data['prevLayerId'] | ||
updated_data['layer'] = data['layer'] | ||
updated_data['layerId'] = data['layerId'] | ||
updated_data['nextLayerId'] = data['nextLayerId'] | ||
|
||
network_update = create_network_update(network_version, json.dumps(updated_data), data['action']) | ||
network_update.save() | ||
# sending update made by one user over all the sessions of open network | ||
# Note - conflict resolution still pending | ||
Group('model-{0}'.format(networkId)).send({ | ||
'text': json.dumps({ | ||
'layer': updated_data['layer'], | ||
'prevLayerId': updated_data['prevLayerId'], | ||
'action': action, | ||
'version_id': 0, | ||
'randomId': randomId | ||
}) | ||
}) | ||
elif (action == 'AddComment'): | ||
updated_data = {} | ||
updated_data['layerId'] = data['layerId'] | ||
updated_data['comment'] = data['comment'] | ||
|
||
network_update = create_network_update(network_version, json.dumps(updated_data), data['action']) | ||
network_update.save() | ||
|
||
Group('model-{0}'.format(networkId)).send({ | ||
'text': json.dumps({ | ||
'layerId': updated_data['layerId'], | ||
'comment': updated_data['comment'], | ||
'action': action, | ||
'version_id': 0, | ||
'randomId': randomId | ||
}) | ||
}) |
Empty file.
Oops, something went wrong.