diff --git a/ide/static/css/dash_style.css b/ide/static/css/dash_style.css
new file mode 100644
index 000000000..fe720deb5
--- /dev/null
+++ b/ide/static/css/dash_style.css
@@ -0,0 +1,21 @@
+.overlay {
+ opacity: 0;
+ z-index: -2;
+ height: 290px;
+ width: 240px;
+ background: rgb(34,47,62,0.9);
+ border-radius: 20px;
+ position: relative;
+ top: -310px;
+ transition: all .4s ease;
+}
+
+.card {
+ transition: all .4s ease;
+}
+
+.card:hover + .overlay, .overlay:hover {
+ opacity: 1;
+ z-index: 1;
+ transition: all .4s ease;
+}
diff --git a/ide/static/css/login_style.css b/ide/static/css/login_style.css
index b4228b7f1..bf3c37ea9 100644
--- a/ide/static/css/login_style.css
+++ b/ide/static/css/login_style.css
@@ -15,6 +15,36 @@
cursor: pointer;
}
+#sidebar-logout-button {
+ background: rgb(205, 207, 210);
+ color: rgb(69, 80, 97);
+ text-align: center;
+ border-radius: 5px;
+ width: 110px;
+ margin: 0.2em;
+ transition: 0.2s;
+ position: relative;
+}
+
+#sidebar-logout-button:hover {
+ cursor: pointer;
+}
+
+#sidebar-dash-button {
+ background: rgb(205, 207, 210);
+ color: rgb(69, 80, 97);
+ text-align: center;
+ border-radius: 5px;
+ width: 130px;
+ margin: 0.2em;
+ transition: 0.2s;
+ position: relative;
+}
+
+#sidebar-dash-button:hover {
+ cursor: pointer;
+}
+
#sidebar-login-button span {
position: absolute;
left: 9px;
@@ -67,6 +97,13 @@
}
}
+.long-buttons {
+ display: flex;
+ flex: 1;
+ flex-direction: row;
+ margin-top: -0.5em;
+}
+
.login-panel {
position: relative;
width: 350px;
@@ -265,4 +302,4 @@
position: absolute;
top: 0px;
right: 10px;
-}
\ No newline at end of file
+}
diff --git a/ide/static/css/searchbar_style.css b/ide/static/css/searchbar_style.css
index d7e3c537a..02e86a05d 100644
--- a/ide/static/css/searchbar_style.css
+++ b/ide/static/css/searchbar_style.css
@@ -1,3 +1,7 @@
+body {
+ background: #F3F5F7;
+}
+
.insert-layer-title {
position: relative;
}
diff --git a/ide/static/js/card.js b/ide/static/js/card.js
new file mode 100644
index 000000000..c57152847
--- /dev/null
+++ b/ide/static/js/card.js
@@ -0,0 +1,79 @@
+import React from 'react';
+import '../css/dash_style.css';
+
+class Card extends React.Component {
+ render() {
+ return(
+
+
+
+
+
+
{this.props.ModelName}
+
+ );
+ }
+}
+
+Card.propTypes = {
+ ModelName: React.PropTypes.string,
+ ModelID: React.PropTypes.number,
+ ModelFunction: React.PropTypes.func
+};
+
+
+export default Card;
diff --git a/ide/static/js/content.js b/ide/static/js/content.js
index 9a0c92d4c..ba151f864 100644
--- a/ide/static/js/content.js
+++ b/ide/static/js/content.js
@@ -85,6 +85,7 @@ class Content extends React.Component {
this.openModal = this.openModal.bind(this);
this.closeModal = this.closeModal.bind(this);
this.saveDb = this.saveDb.bind(this);
+ this.saveModel = this.saveModel.bind(this);
this.loadDb = this.loadDb.bind(this);
this.infoModal = this.infoModal.bind(this);
this.faqModal = this.faqModal.bind(this);
@@ -979,6 +980,34 @@ class Content extends React.Component {
layer.info.phase = 0;
this.setState({ net });
}
+ saveModel(){
+ let modelData = this.state.net;
+ this.setState({ load: true });
+ $.ajax({
+ url: '/saveModel',
+ dataType: 'json',
+ type: 'POST',
+ data: {
+ net: JSON.stringify(modelData),
+ net_name: this.state.net_name,
+ user_id: this.getUserId(),
+ nextLayerId: this.state.nextLayerId
+ },
+ success : function (response) {
+ if (response.result == 'success') {
+ this.modalContent = "Successfully Saved!";
+ this.openModal();
+ }
+ else if (response.result == 'error') {
+ this.addError(response.error);
+ }
+ this.setState({ load: false });
+ }.bind(this),
+ error() {
+ this.setState({ load: false });
+ }
+ });
+ }
saveDb(){
let netData = this.state.net;
this.setState({ load: true });
@@ -1055,7 +1084,7 @@ class Content extends React.Component {
// Note: this needs to be improved when handling conflict resolution to avoid
// inconsistent states of model
let nextLayerId = this.state.nextLayerId;
-
+ let is_shared = false;
this.setState({ load: true });
this.dismissAllErrors();
@@ -1072,6 +1101,13 @@ class Content extends React.Component {
// while loading a model ensure paramete intialisation
// for UI show/hide is not executed, it leads to inconsistent
// data which cannot be used further
+ if (response.public_sharing == false) {
+ is_shared = false;
+ }
+ else {
+ is_shared = true;
+ }
+ console.log(response);
nextLayerId = response.next_layer_id;
this.initialiseImportedNet(response.net,response.net_name);
if (Object.keys(response.net).length){
@@ -1083,8 +1119,10 @@ class Content extends React.Component {
}
this.setState({
load: false,
- isShared: true,
+ isShared: is_shared,
nextLayerId: parseInt(nextLayerId)
+ }, function() {
+ console.log("Shared value: " + this.state.isShared);
});
}.bind(this),
error() {
@@ -1092,6 +1130,7 @@ class Content extends React.Component {
}
});
}
+
infoModal() {
this.modalHeader = "About"
this.modalContent = `Fabrik is an online collaborative platform to build and visualize deep\
@@ -1113,7 +1152,7 @@ class Content extends React.Component {
here .
Q: What do the Train/Test buttons mean?
- A: They are two different modes of your model:
+ A: They are two different modes of your model:
Train and Test - respectively for training your model with data and testing how and if it works.
Q: What does the import fuction do?
A: It allows you to import your previously created models in Caffe (.protoxt files),
@@ -1127,7 +1166,7 @@ class Content extends React.Component {
A: Please see the instructions listed
here
-
+
If you have anymore questions, please visit Fabrik's Github page available
here for more information.
);
@@ -1282,6 +1321,7 @@ class Content extends React.Component {
this.addNewLayer(layer);
}
}
+
render() {
let loader = null;
if (this.state.load) {
@@ -1299,9 +1339,11 @@ class Content extends React.Component {
;
+ }
+ else {
+ var data_array = JSON.parse(localStorage.getItem("obj"));
+ var len = Object.keys(data_array).length/2;
+ var elements=[];
+ for (var i = 1; i < len+1; i++) {
+ elements.push( )
+ }
+ }
+ return (
+
+
+
+ ×
+ { this.modalHeader }
+ { this.modalContent }
+
+
+ );
+ }
+ else {
+ window.open("#","_self");
+ return null;
+ }
+ }
+}
+
+export default Dashboard;
diff --git a/ide/static/js/dashbutton.js b/ide/static/js/dashbutton.js
new file mode 100644
index 000000000..3a45a0a41
--- /dev/null
+++ b/ide/static/js/dashbutton.js
@@ -0,0 +1,22 @@
+import React from 'react';
+
+class DashButton extends React.Component {
+ constructor(props) {
+ super(props);
+ this.openDash = this.openDash.bind(this);
+ }
+ openDash(){
+ window.location.href = "/#/dashboard";
+ }
+ render(){
+ return(
+
+
+
+ );
+ }
+}
+
+export default DashButton;
diff --git a/ide/static/js/index.js b/ide/static/js/index.js
index d387a3637..1448af664 100644
--- a/ide/static/js/index.js
+++ b/ide/static/js/index.js
@@ -1,12 +1,13 @@
-
import React from 'react';
import { render } from 'react-dom';
import { Router, Route, hashHistory } from 'react-router';
import App from './app.js';
+import Dashboard from './dashboard.js';
import '../css/style.css';
render(
+
, document.getElementById('app')
);
diff --git a/ide/static/js/login.js b/ide/static/js/login.js
index a31214afa..7165f4dae 100644
--- a/ide/static/js/login.js
+++ b/ide/static/js/login.js
@@ -1,4 +1,5 @@
import React from 'react';
+import DashButton from './dashbutton';
class Login extends React.Component {
constructor(props) {
@@ -23,6 +24,7 @@ class Login extends React.Component {
contentType: false,
success: function (response) {
if (response) {
+ localStorage.removeItem("userID");
this.setState({ loginState: false });
this.props.setUserId(null);
this.props.setUserName(null);
@@ -67,6 +69,7 @@ class Login extends React.Component {
if (response.result) {
this.setState({ loginState: response.result });
this.props.setUserId(response.user_id);
+ localStorage.setItem("userID",response.user_id);
this.props.setUserName(response.username);
if (showNotification) {
@@ -181,12 +184,15 @@ class Login extends React.Component {
if(this.state.loginState) {
return (
-
done
diff --git a/ide/static/js/topBar.js b/ide/static/js/topBar.js
index b3a02e437..b95712e3f 100644
--- a/ide/static/js/topBar.js
+++ b/ide/static/js/topBar.js
@@ -4,20 +4,14 @@ import ReactTooltip from 'react-tooltip';
class TopBar extends React.Component {
constructor(props) {
super(props);
- this.checkURL = this.checkURL.bind(this);
+ this.state = {isShared: false};
}
- checkURL() {
- let url = window.location.href;
- let urlParams = url.indexOf("load");
-
- if(urlParams != -1) {
- return true;
- }
- return false;
+ componentWillReceiveProps(newProps){
+ this.setState({isShared: newProps.isShared});
}
render() {
let content = null;
- if (this.checkURL()) {
+ if (this.state.isShared == true) {
content = (
{content}
+
+
+
+ this.props.saveModel()} data-tip="Save Model">
+
+
+
+
+
@@ -116,11 +120,13 @@ TopBar.propTypes = {
exportNet: React.PropTypes.func,
importNet: React.PropTypes.func,
saveDb: React.PropTypes.func,
+ saveModel: React.PropTypes.func,
loadDb: React.PropTypes.func,
zooModal: React.PropTypes.func,
textboxModal: React.PropTypes.func,
urlModal: React.PropTypes.func,
- updateHistoryModal: React.PropTypes.func
+ updateHistoryModal: React.PropTypes.func,
+ isShared: React.PropTypes.bool
};
export default TopBar;
diff --git a/ide/urls.py b/ide/urls.py
index 7798fdc97..140087b11 100644
--- a/ide/urls.py
+++ b/ide/urls.py
@@ -3,7 +3,9 @@
from django.conf.urls.static import static
from django.conf import settings
from views import index, calculate_parameter, fetch_layer_shape
-from views import load_from_db, save_to_db, fetch_model_history
+from views import load_from_db, load_model_from_db, \
+ delete_model_from_db, save_to_db, save_model_to_db, \
+ fetch_model_history
urlpatterns = [
url(r'^$', index),
@@ -14,9 +16,15 @@
url(r'^keras/', include('keras_app.urls')),
url(r'^tensorflow/', include('tensorflow_app.urls')),
url(r'^save$', save_to_db, name='saveDB'),
+ url(r'^saveModel$', save_model_to_db, name='saveModel'),
url(r'^load*', load_from_db, name='loadDB'),
+ url(r'^deleteModel$', delete_model_from_db, name='deleteModel'),
+ url(r'^getModel$', load_model_from_db, name='getModelData'),
url(r'^model_history', fetch_model_history, name='model-history'),
- url(r'^model_parameter/', calculate_parameter, name='calculate-parameter'),
- url(r'^layer_parameter/', fetch_layer_shape, name='fetch-layer-shape')
-] + static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT) + \
- static(settings.STATIC_URL, document_root=settings.STATIC_ROOT)
+ url(r'^model_parameter/', calculate_parameter,
+ name='calculate-parameter'),
+ url(r'^layer_parameter/', fetch_layer_shape,
+ name='fetch-layer-shape'),
+ ] + static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT) \
+ + static(settings.STATIC_URL,
+ document_root=settings.STATIC_ROOT)
diff --git a/ide/views.py b/ide/views.py
index 9cd03ea87..4cc94b8b4 100644
--- a/ide/views.py
+++ b/ide/views.py
@@ -8,7 +8,8 @@
from django.http import JsonResponse
from django.views.decorators.csrf import csrf_exempt
from django.contrib.auth.models import User
-from utils.shapes import get_shapes, get_layer_shape, handle_concat_layer
+from utils.shapes import get_shapes, get_layer_shape, \
+ handle_concat_layer
def index(request):
@@ -24,38 +25,65 @@ def fetch_layer_shape(request):
net[layerId]['shape'] = {}
net[layerId]['shape']['input'] = None
net[layerId]['shape']['output'] = None
- dataLayers = ['ImageData', 'Data', 'HDF5Data', 'Input', 'WindowData', 'MemoryData', 'DummyData']
+ dataLayers = [
+ 'ImageData',
+ 'Data',
+ 'HDF5Data',
+ 'Input',
+ 'WindowData',
+ 'MemoryData',
+ 'DummyData',
+ ]
# Obtain input shape of new layer
- if (net[layerId]['info']['type'] == "Concat"):
+
+ if net[layerId]['info']['type'] == 'Concat':
for parentLayerId in net[layerId]['connection']['input']:
+
# Check if parent layer have shapes
- if (net[parentLayerId]['shape']['output']):
- net[layerId]['shape']['input'] = handle_concat_layer(net[layerId], net[parentLayerId])
- elif (not (net[layerId]['info']['type'] in dataLayers)):
- if (len(net[layerId]['connection']['input']) > 0):
+
+ if net[parentLayerId]['shape']['output']:
+ net[layerId]['shape']['input'] = \
+ handle_concat_layer(net[layerId],
+ net[parentLayerId])
+ elif not net[layerId]['info']['type'] in dataLayers:
+ if len(net[layerId]['connection']['input']) > 0:
parentLayerId = net[layerId]['connection']['input'][0]
+
# Check if parent layer have shapes
- if (net[parentLayerId]['shape']['output']):
- net[layerId]['shape']['input'] = net[parentLayerId]['shape']['output'][:]
+
+ if net[parentLayerId]['shape']['output']:
+ net[layerId]['shape']['input'] = \
+ (net[parentLayerId]['shape']['output'])[:]
# Obtain output shape of new layer
- if (net[layerId]['info']['type'] in dataLayers):
+
+ if net[layerId]['info']['type'] in dataLayers:
+
# handling Data Layers separately
- if ('dim' in net[layerId]['params'] and len(net[layerId]['params']['dim'])):
+
+ if 'dim' in net[layerId]['params'] \
+ and len(net[layerId]['params']['dim']):
+
# layers with empty dim parameter can't be passed
- net[layerId]['shape']['input'], net[layerId]['shape']['output'] =\
- get_layer_shape(net[layerId])
- elif ('dim' not in net[layerId]['params']):
+
+ (net[layerId]['shape']['input'],
+ net[layerId]['shape']['output']) = \
+ get_layer_shape(net[layerId])
+ elif 'dim' not in net[layerId]['params']:
+
# shape calculation for layers with no dim param
- net[layerId]['shape']['input'], net[layerId]['shape']['output'] =\
- get_layer_shape(net[layerId])
+
+ (net[layerId]['shape']['input'],
+ net[layerId]['shape']['output']) = \
+ get_layer_shape(net[layerId])
else:
- if (net[layerId]['shape']['input']):
- net[layerId]['shape']['output'] = get_layer_shape(net[layerId])
+ if net[layerId]['shape']['input']:
+ net[layerId]['shape']['output'] = \
+ get_layer_shape(net[layerId])
except BaseException:
- return JsonResponse({
- 'result': 'error', 'error': str(sys.exc_info()[1])})
+ return JsonResponse({'result': 'error',
+ 'error': str(sys.exc_info()[1])})
return JsonResponse({'result': 'success', 'net': net})
@@ -64,57 +92,118 @@ def calculate_parameter(request):
if request.method == 'POST':
net = yaml.safe_load(request.POST.get('net'))
try:
+
# While calling get_shapes we need to remove the flag
# added in frontend to show the parameter on pane
+
netObj = copy.deepcopy(net)
for layerId in netObj:
for param in netObj[layerId]['params']:
- netObj[layerId]['params'][param] = netObj[layerId]['params'][param][0]
+ netObj[layerId]['params'][param] = \
+ netObj[layerId]['params'][param][0]
+
# use get_shapes method to obtain shapes of each layer
+
netObj = get_shapes(netObj)
for layerId in net:
net[layerId]['shape'] = {}
net[layerId]['shape']['input'] = netObj[layerId]['shape']['input']
- net[layerId]['shape']['output'] = netObj[layerId]['shape']['output']
+ net[layerId]['shape']['output'] = \
+ netObj[layerId]['shape']['output']
except BaseException:
- return JsonResponse({
- 'result': 'error', 'error': str(sys.exc_info()[1])})
+ return JsonResponse({'result': 'error',
+ 'error': str(sys.exc_info()[1])})
return JsonResponse({'result': 'success', 'net': net})
@csrf_exempt
-def save_to_db(request):
+def delete_model_from_db(request):
+ if request.method == 'POST':
+ if 'userID' in request.POST:
+ userID = request.POST.get('userID')
+ model_id = request.POST.get('modelid')
+ model = Network.objects.get(id=model_id)
+ if model.author_id == int(userID):
+ model.delete()
+ return JsonResponse({'result': 'success',
+ 'data': 'Model successfully deleted!'})
+ else:
+ return JsonResponse({'result': 'error',
+ 'error': "This model doesn't belong to you!"})
+
+
+@csrf_exempt
+def save(request, public_sharing):
if request.method == 'POST':
net = request.POST.get('net')
net_name = request.POST.get('net_name')
user_id = request.POST.get('user_id')
next_layer_id = request.POST.get('nextLayerId')
- public_sharing = True
user = None
+ if public_sharing is True:
+ tag = "ModelShared"
+ else:
+ tag = "ModelNotShared"
if net_name == '':
net_name = 'Net'
- try:
- # making model sharing public by default for now
- # TODO: Prvilege on Sharing
- if user_id:
- user_id = int(user_id)
- user = User.objects.get(id=user_id)
-
- # create a new model on share event
- model = Network(name=net_name, public_sharing=public_sharing, author=user)
- model.save()
- # create first version of model
- model_version = NetworkVersion(network=model, network_def=net)
- model_version.save()
- # create initial update for nextLayerId
- model_update = NetworkUpdates(network_version=model_version,
- updated_data=json.dumps({'nextLayerId': next_layer_id}),
- tag='ModelShared')
- model_update.save()
-
- return JsonResponse({'result': 'success', 'id': model.id})
- except:
- return JsonResponse({'result': 'error', 'error': str(sys.exc_info()[1])})
+
+ if Network.objects.filter(name=net_name).exists():
+ # Update the exising json field
+ try:
+ if user_id:
+ user_id = int(user_id)
+ user = User.objects.get(id=user_id)
+ # load the model with the net name
+ model = Network.objects.get(name=net_name)
+ model_id = model.id
+ # update the model with network id same as model id
+ existing_model = \
+ NetworkVersion.objects.get(network_id=model_id)
+ existing_model.network_def = net
+ existing_model.save()
+ return JsonResponse({'result': 'success',
+ 'id': model.id})
+ except:
+ return JsonResponse({'result': 'error',
+ 'error': str(sys.exc_info()[1])})
+ else:
+ try:
+ if user_id:
+ user_id = int(user_id)
+ user = User.objects.get(id=user_id)
+ # create a new model on save event
+ model = Network(name=net_name,
+ public_sharing=public_sharing,
+ author=user)
+ model.save()
+ # create first version of model
+ model_version = NetworkVersion(network=model,
+ network_def=net)
+ model_version.save()
+ # create initial update for nextLayerId
+ model_update = \
+ NetworkUpdates(network_version=model_version,
+ updated_data=json.dumps(
+ {'nextLayerId': next_layer_id}),
+ tag=tag)
+ model_update.save()
+ return JsonResponse({'result': 'success',
+ 'id': model.id})
+ except:
+ return JsonResponse({'result': 'error',
+ 'error': str(sys.exc_info()[1])})
+
+
+@csrf_exempt
+def save_model_to_db(request):
+ response = save(request, False)
+ return response
+
+
+@csrf_exempt
+def save_to_db(request):
+ response = save(request, True)
+ return response
def create_network_version(network_def, updates_batch):
@@ -129,7 +218,9 @@ def create_network_version(network_def, updates_batch):
next_layer_id = updated_data['nextLayerId']
if tag == 'UpdateParam':
+
# Update Param UI event handling
+
param = updated_data['param']
layer_id = updated_data['layerId']
value = updated_data['value']
@@ -138,9 +229,10 @@ def create_network_version(network_def, updates_batch):
network_def[layer_id]['props'][param] = value
else:
network_def[layer_id]['params'][param][0] = value
-
elif tag == 'DeleteLayer':
+
# Delete layer UI event handling
+
layer_id = updated_data['layerId']
input_layer_ids = network_def[layer_id]['connection']['input']
output_layer_ids = network_def[layer_id]['connection']['output']
@@ -152,48 +244,53 @@ def create_network_version(network_def, updates_batch):
network_def[output_layer_id]['connection']['input'].remove(layer_id)
del network_def[layer_id]
-
elif tag == 'AddLayer':
+
# Add layer UI event handling
+
prev_layer_id = updated_data['prevLayerId']
new_layer_id = updated_data['layerId']
if isinstance(prev_layer_id, list):
for layer_id in prev_layer_id:
- network_def[layer_id]['connection']['output'].append(new_layer_id)
+ network_def[layer_id]['connection']['output'
+ ].append(new_layer_id)
else:
- network_def[prev_layer_id]['connection']['output'].append(new_layer_id)
+ network_def[prev_layer_id]['connection']['output'
+ ].append(new_layer_id)
network_def[new_layer_id] = updated_data['layer']
-
elif tag == 'AddComment':
+
layer_id = updated_data['layerId']
comment = updated_data['comment']
- if ('comments' not in network_def[layer_id]):
+ if 'comments' not in network_def[layer_id]:
network_def[layer_id]['comments'] = []
network_def[layer_id]['comments'].append(comment)
- return {
- 'network': network_def,
- 'next_layer_id': next_layer_id
- }
+ return {'network': network_def, 'next_layer_id': next_layer_id}
def get_network_version(netObj):
- network_version = NetworkVersion.objects.filter(network=netObj).order_by('-created_on')[0]
- updates_batch = NetworkUpdates.objects.filter(network_version=network_version).order_by('created_on')
+ network_version = \
+ NetworkVersion.objects.filter(network=netObj).order_by('-created_on'
+ )[0]
+ updates_batch = NetworkUpdates.objects.filter(
+ network_version=network_version).order_by('created_on')
- return create_network_version(network_version.network_def, updates_batch)
+ return create_network_version(network_version.network_def,
+ updates_batch)
def get_checkpoint_version(netObj, checkpoint_id):
network_update = NetworkUpdates.objects.get(id=checkpoint_id)
network_version = network_update.network_version
- updates_batch = NetworkUpdates.objects.filter(network_version=network_version)\
- .filter(created_on__lte=network_update.created_on)\
- .order_by('created_on')
- return create_network_version(network_version.network_def, updates_batch)
+ updates_batch = NetworkUpdates.objects.filter(
+ network_version=network_version).filter(
+ created_on__lte=network_update.created_on).order_by('created_on')
+ return create_network_version(network_version.network_def,
+ updates_batch)
@csrf_exempt
@@ -201,56 +298,82 @@ def load_from_db(request):
if request.method == 'POST':
if 'proto_id' in request.POST:
try:
- model = Network.objects.get(id=int(request.POST['proto_id']))
+ model = \
+ Network.objects.get(id=int(request.POST['proto_id'
+ ]))
version_id = None
data = {}
- if 'version_id' in request.POST and request.POST['version_id'] != '':
+ if 'version_id' in request.POST \
+ and request.POST['version_id'] != '':
+
# added for loading any previous version of model
+
version_id = int(request.POST['version_id'])
data = get_checkpoint_version(model, version_id)
else:
+
# fetch the required version of model
+
data = get_network_version(model)
net = data['network']
next_layer_id = data['next_layer_id']
-
- # authorizing the user for access to model
- if not model.public_sharing:
- return JsonResponse({'result': 'error',
- 'error': 'Permission denied for access to model'})
except Exception:
+
return JsonResponse({'result': 'error',
'error': 'No network file found'})
- return JsonResponse({'result': 'success', 'net': net, 'net_name': model.name,
- 'next_layer_id': next_layer_id})
+ return JsonResponse({
+ 'result': 'success',
+ 'net': net,
+ 'net_name': model.name,
+ 'next_layer_id': next_layer_id,
+ 'public_sharing': model.public_sharing,
+ })
if request.method == 'GET':
return index(request)
+@csrf_exempt
+def load_model_from_db(request):
+ if request.method == 'POST':
+ if 'userID' in request.POST:
+ userID = request.POST.get('userID')
+ if Network.objects.filter(author=userID).exists():
+ data = {}
+ models = Network.objects.filter(author=userID)
+ i = 1
+ for mod in models:
+ data_index1 = 'Model%d_Name' % i
+ data_index2 = 'Model%d_ID' % i
+ data[data_index1] = mod.name
+ data[data_index2] = mod.id
+ i += 1
+ return JsonResponse({'result': 'success', 'data': data})
+ else:
+ return JsonResponse({'result': 'error',
+ 'error': 'No models found'})
+
+
@csrf_exempt
def fetch_model_history(request):
if request.method == 'POST':
try:
network_id = int(request.POST['net_id'])
network = Network.objects.get(id=network_id)
- network_versions = NetworkVersion.objects.filter(network=network).order_by('created_on')
+ network_versions = NetworkVersion.objects.filter(
+ network=network).order_by('created_on')
modelHistory = {}
for version in network_versions:
- network_updates = NetworkUpdates.objects.filter(network_version=version)\
- .order_by('created_on')
+ network_updates = NetworkUpdates.objects.filter(
+ network_version=version).order_by('created_on')
for update in network_updates:
modelHistory[update.id] = update.tag
- return JsonResponse({
- 'result': 'success',
- 'data': modelHistory
- })
+ return JsonResponse({'result': 'success',
+ 'data': modelHistory})
except Exception:
- return JsonResponse({
- 'result': 'error',
- 'error': 'Unable to load model history'
- })
+ return JsonResponse({'result': 'error',
+ 'error': 'Unable to load model history'})
diff --git a/settings/test.py b/settings/test.py
index 540f7a344..64921989c 100644
--- a/settings/test.py
+++ b/settings/test.py
@@ -1,4 +1,4 @@
-from .common import * # noqa: ignore=F405
+from .common import * # flake8: noqa
# Database
# https://docs.djangoproject.com/en/1.9/ref/settings/#databases
@@ -8,10 +8,10 @@
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.postgresql_psycopg2',
- 'NAME': 'fabrik',
- 'USER': 'admin',
- 'PASSWORD': 'fabrik',
- 'HOST': 'localhost',
+ 'NAME': 'fabrik', # Change this to 'postgres' if you're using docker
+ 'USER': 'admin', # Change this to 'postgres' if you're using docker
+ 'PASSWORD': 'fabrik', # Change this to 'postgres' if you're using docker
+ 'HOST': 'localhost', # Change this to 'db' if you're using docker
'PORT': 5432,
}
}
diff --git a/tests/unit/caffe_app/test_db.py b/tests/unit/caffe_app/test_db.py
index ca0fa4a88..8dea0a934 100644
--- a/tests/unit/caffe_app/test_db.py
+++ b/tests/unit/caffe_app/test_db.py
@@ -51,3 +51,83 @@ def test_load_nofile(self):
response = json.loads(response.content)
self.assertEqual(response['result'], 'error')
self.assertEqual(response['error'], 'No network file found')
+
+
+class SaveModelToDBTest(unittest.TestCase):
+
+ def setUp(self):
+ self.client = Client()
+
+ def test_save_json1(self):
+ tests = open(os.path.join(settings.BASE_DIR, 'tests', 'unit', 'ide',
+ 'caffe_export_test.json'), 'r')
+ net = json.load(tests)['net']
+ response = self.client.post(
+ reverse('saveModel'),
+ {'net': net, 'net_name': 'netname'})
+ response = json.loads(response.content)
+ self.assertEqual(response['result'], 'success')
+
+ def test_load1(self):
+ u_3 = User(id=3, username='user_3')
+ u_3.save()
+ u_4 = User(id=4, username='user_4')
+ u_4.save()
+ model = Network(name='net')
+ model.save()
+ model_version = NetworkVersion(network=model, network_def={})
+ model_version.save()
+
+ response = self.client.post(
+ reverse('saveModel'),
+ {'net': '{"net": "testnet"}', 'net_name': 'name'})
+ response = json.loads(response.content)
+ self.assertEqual(response['result'], 'success')
+ self.assertTrue('id' in response)
+ proto_id = response['id']
+ response = self.client.post(reverse('loadDB'), {'proto_id': proto_id})
+ response = json.loads(response.content)
+ self.assertEqual(response['result'], 'success')
+ self.assertEqual(response['net_name'], 'name')
+
+ def test_load_nofile1(self):
+ response = self.client.post(reverse('loadDB'),
+ {'proto_id': 'inexistent'})
+ response = json.loads(response.content)
+ self.assertEqual(response['result'], 'error')
+ self.assertEqual(response['error'], 'No network file found')
+
+
+class LoadModelFromDB(unittest.TestCase):
+
+ def setUp(self):
+ self.client = Client()
+
+ def test_load_model(self):
+ u_5 = User(id=5, username='user_5')
+ u_5.save()
+ model = Network(id=9, name='test_net', author_id='5')
+ model.save()
+ response = self.client.post(
+ reverse('getModelData'), {'userID': '5'})
+ response = json.loads(response.content)
+ self.assertEqual(response['result'], 'success')
+ self.assertEqual(response['data']['Model1_Name'], 'test_net')
+ self.assertEqual(response['data']['Model1_ID'], 9)
+
+
+class DeleteModelFromDB(unittest.TestCase):
+
+ def setUp(self):
+ self.client = Client()
+
+ def test_delete_model(self):
+ u_6 = User(id=6, username='user_6')
+ u_6.save()
+ model = Network(id=10, name='test_net2', author_id='6')
+ model.save()
+ response = self.client.post(
+ reverse('deleteModel'), {'userID': '6', 'modelid': '10'})
+ response = json.loads(response.content)
+ self.assertEqual(response['result'], 'success')
+ self.assertEqual(Network.objects.filter(id=10).exists(), False)