Skip to content

Commit

Permalink
Merge pull request #291 from madgik/NaiveBayesGalaxyBugFix
Browse files Browse the repository at this point in the history
add error checking for NB Training
  • Loading branch information
ThanKarab authored Nov 12, 2020
2 parents 9b2b936 + 266670f commit f3edf2b
Show file tree
Hide file tree
Showing 9 changed files with 283 additions and 95 deletions.
2 changes: 2 additions & 0 deletions Exareme-Docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ RUN apt update
RUN apt install -y r-base
RUN Rscript -e 'install.packages("randomForest", repos="https://cloud.r-project.org")'

RUN Rscript -e 'install.packages("caret")'
RUN Rscript -e 'install.packages("e1071")'
RUN pip install rpy2==2.8.6

# Add Madis Server
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ private static void validateAlgorithmParameterValueType(
String algorithmName,
String value,
ParameterProperties parameterProperties
) throws AlgorithmException {
) throws AlgorithmException, UserException {
if (parameterProperties.getValueType().equals(ParameterProperties.ParameterValueType.json)) {
try {
new JSONObject(value);
Expand All @@ -285,19 +285,19 @@ private static void validateAlgorithmParameterValueType(
try {
Double.parseDouble(curValue);
} catch (NumberFormatException nfe) {
throw new AlgorithmException(algorithmName,
throw new UserException(
"The value of the parameter '" + parameterProperties.getName() + "' should be a real number.");
}
} else if (parameterProperties.getValueType().equals(ParameterProperties.ParameterValueType.integer)) {
try {
Integer.parseInt(curValue);
} catch (NumberFormatException e) {
throw new AlgorithmException(algorithmName,
throw new UserException(
"The value of the parameter '" + parameterProperties.getName() + "' should be an integer.");
}
} else if (parameterProperties.getValueType().equals(ParameterProperties.ParameterValueType.string)) {
if (curValue.equals("")) {
throw new AlgorithmException(algorithmName,
throw new UserException(
"The value of the parameter '" + parameterProperties.getName()
+ "' contains an empty string.");
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ def holdoutvalidation_inputerrorchecking2(train_size, test_size):

holdoutvalidation_inputerrorchecking2.registered = True

def naive_bayes_training_inputerrorchecking(colname,noLevels):
if (noLevels < 2):
raise functions.OperatorError("ExaremeError", colname + ": should contain more than two distinct values")
else:
return "OK"

naive_bayes_training_inputerrorchecking.registered = True


# def maxnumberofiterations_errorhandling(maxnumberofiterations,no): # For most of the iterative algorithms
# if maxnumberofiterations< no:
# raise functions.OperatorError("ExaremeError", "The algorithm could not complete in the max number of iterations given. Please increase the iterations_max_number and try again.")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import setpath
import functions
import json
registered=True

'''
Highcharts.chart('container',
{ "chart": {"type": "heatmap","marginTop": 40,"marginBottom": 80,"plotBorderWidth": 1},
"title": {"text": " confusion matrix "},
"xAxis": {"title": { "text": " actual values "},"categories": [ "AD","CN","Other"]},
"yAxis": {"title": { "text": " predicted values "},"categories": [ "AD", "CN", "Other"]},
"colorAxis": {"min": 0,"minColor": "#FFFFFF","maxColor": "#6699ff"},
"legend": {"align": "right","layout": "vertical","margin": 0,"verticalAlign": "top","y": 25,"symbolHeight": 280},
"series": [{ "borderWidth": 1, "data": [ [ 0, 0, 46],
[ 0, 1, 39],
[ 0, 2, 0],
[ 1, 0, 20],
[ 1, 1,76],
[ 1, 2, 0],
[2, 0, 26],
[ 2, 1,33],
[2, 2,0]],
"dataLabels": {"enabled": true,"color": "#000000" }}]}
);
'''
class highchartheatmap(functions.vtable.vtbase.VT):
def VTiter(self, *parsedArgs,**envars):
largs, dictargs = self.full_parse(parsedArgs)

if 'query' not in dictargs:
raise functions.OperatorError(__name__.rsplit('.')[-1],"No query argument ")
query = dictargs['query']
if 'title' not in dictargs:
raise functions.OperatorError(__name__.rsplit('.')[-1],"No title argument ")
if 'xtitle' not in dictargs:
raise functions.OperatorError(__name__.rsplit('.')[-1],"No xtitle argument ")
if 'ytitle' not in dictargs:
raise functions.OperatorError(__name__.rsplit('.')[-1],"No ytitle argument ")

cur = envars['db'].cursor()
c=cur.execute(query)
schema = cur.getdescriptionsafe()

mydata = []
xcategories = []
ycategories = []

for myrow in c:
if str(myrow[0]) not in xcategories:
xcategories.append(str(myrow[0]))
if str(myrow[1]) not in ycategories:
ycategories.append(str(myrow[1]))
mydata.append([xcategories.index(str(myrow[0])), ycategories.index(str(myrow[1])), float(myrow[2])])

myresult = {
"type" : "application/vnd.highcharts+json",
"data" :{ "chart": {"type": "heatmap","marginTop": 40,"marginBottom": 80,"plotBorderWidth": 1},
"title": {"text": str(dictargs['title'])},
"xAxis": {"title": { "text":str(dictargs['xtitle'])},"categories": xcategories},
"yAxis": {"title": { "text":str(dictargs['ytitle'])},"categories": ycategories},
"colorAxis": {"min": 0,"minColor": "#FFFFFF","maxColor": "#6699ff"},
"legend": {"align": "right","layout": "vertical","margin": 0,"verticalAlign": "top","y": 25,"symbolHeight": 280},
"series": [{ "borderWidth": 1, "data": mydata,
"dataLabels": {"enabled": True,"color": "#000000" }}]
}
}
myjsonresult = json.dumps(myresult)
yield [('highchartresult',)]
yield (myjsonresult,)


def Source():
return functions.vtable.vtbase.VTGenerator(highchartheatmap)


if not ('.' in __name__):
"""
This is needed to be able to test the function, put it at the end of every
new function you create
"""
import sys
import setpath
from functions import *
testfunction()
if __name__ == "__main__":
reload(sys)
sys.setdefaultencoding('utf-8')
import doctest
doctest.tesdoctest.tes
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""
"""
import setpath
import functions
import json
import sys
from rpy2.robjects import StrVector
from rpy2.robjects.packages import importr
from rpy2.rinterface import RRuntimeError

import warnings
warnings.filterwarnings("ignore")

caret = importr('caret')
e = importr('e1071')
base = importr('base')

### Classic stream iterator
registered=True

class rconfusionmatrixtable(functions.vtable.vtbase.VT): #predictedclass,actualclass,val
def VTiter(self, *parsedArgs, **envars):
largs, dictargs = self.full_parse(parsedArgs)

if 'query' not in dictargs:
raise functions.OperatorError(__name__.rsplit('.')[-1], "No query argument")
query = dictargs['query']

cur = envars['db'].cursor()
c = cur.execute(query)

predictedclasses =[]
actualclasses = []
classnames = []
for myrow in c:
for i in xrange(myrow[2]):
predictedclasses.append(myrow[0])
actualclasses.append(myrow[1])
if myrow[0] not in classnames:
classnames.append(myrow[0])

numberofclassnames = len(classnames)

print "Predicted vector:", predictedclasses
print "Actual vector:", actualclasses

#print (classnames)
predictedData = base.factor(base.c(StrVector(predictedclasses)), base.c(StrVector(classnames)))
truthData = base.factor(base.c(StrVector(actualclasses)), base.c(StrVector(classnames)))
Rresult = caret.confusionMatrix(predictedData,truthData)
print 'Rresult[1]', Rresult[1]
print 'Rresult[2]', Rresult[2]
print 'Rresult[3]', Rresult[3]

#####################################################
dataOverall = []
if numberofclassnames == 2:
dataOverall.append(["Positive Class",Rresult[0][0]])
else:
dataOverall.append(["Positive Class",None])

#Rresult[1] -->Table (I have already computed this)
#Rresult[2] -->overall statistics
dataOverall.append(["Accuracy",(Rresult[2][0])])
dataOverall.append(["Kappa",(Rresult[2][1])])
dataOverall.append(["Accuracy Lower",(Rresult[2][2])])
dataOverall.append(["Accuracy Upper",(Rresult[2][3])])
dataOverall.append(["Accuracy Null",(Rresult[2][4])])
dataOverall.append(["Accuracy P Value",(Rresult[2][5])])
dataOverall.append(["Mcnemar P Value",(Rresult[2][6])])

ResultOverall = { "data": {
"profile": "tabular-data-resource",
"data": dataOverall,
"name": "Overall Statistic Results",
"schema": {
"fields": [
{
"type": "text",
"name": "Statistic Name"
},
{
"type": "real",
"name": "Value"
}
]
}
},
"type": "application/vnd.dataresource+json"
}
print "ResultOverall", ResultOverall
#####################################################

FieldClassNames = [
{ "type": "text",
"name": "Statistic Name" }]
for i in range(len(classnames)):
FieldClassNames.append(
{
"type": "real",
"name": classnames[i] + " class"
})

DataClassNames = [["Sensitivity"],["Specificity"],["Pos Pred Value"],["Neg Pred Value"],["Precision"],["Recall"],
["F1"],["Prevalence"],["Detection Rate"],["Detection Prevalence"],["Balanced Accuracy"]]

#Rresult[3] -->byClass statistics

i = 0
for k in range(len(DataClassNames)):
for l in range(len(classnames)):
if str(Rresult[3][i])!='nan' and str(Rresult[3][i])!='NA':
DataClassNames[k].append(Rresult[3][i])
else:
DataClassNames[k].append(None)
i = i + 1

ResultClassNames = {
"data": {
"profile": "tabular-data-resource",
"data": DataClassNames,
"name": "Statistic Results per Class",
"schema": {"fields": FieldClassNames}
},
"type": "application/vnd.dataresource+json"}

print "resultClassNames", ResultClassNames

yield (['statscolname'],['statsval'],)

a = json.dumps(ResultOverall)
#a = a.replace(' ','')
yield ("ResultOverall" , a)

b = json.dumps(ResultClassNames)
#b = b.replace(' ','')
yield ("ResultClassNames",b)


def Source():
return functions.vtable.vtbase.VTGenerator(rconfusionmatrixtable)

if not ('.' in __name__):
"""
This is needed to be able to test the function, put it at the end of every
new function you create
"""
import sys
import setpath
from functions import *
testfunction()
if __name__ == "__main__":
reload(sys)
sys.setdefaultencoding('utf-8')
import doctest
doctest.testmod()
Loading

0 comments on commit f3edf2b

Please sign in to comment.