Skip to content

Commit

Permalink
Update um.py
Browse files Browse the repository at this point in the history
  • Loading branch information
monabraeunig authored Jan 28, 2025
1 parent e4f7f56 commit cad5de6
Showing 1 changed file with 55 additions and 12 deletions.
67 changes: 55 additions & 12 deletions umbridge/um.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor

import traceback
import json

class Model(object):

def __init__(self, name):
Expand Down Expand Up @@ -52,12 +55,23 @@ def __init__(self, url, name):
self.__supports_apply_jacobian = response["support"].get("ApplyJacobian", False)
self.__supports_apply_hessian = response["support"].get("ApplyHessian", False)

if not isinstance(self.get_input_sizes(), list):
print("Input Size is not a list!")
raise Exception("Input Size is not a list!")
if not all(isinstance(x, int) for x in self.get_input_sizes()):
print("Input sizes must be of type int")
raise Exception("Input sizes must be of type int")

def get_input_sizes(self, config={}):
input = {}
input["name"] = self.name
input["config"] = config
response = requests.post(f"{self.url}/InputSizes", json=input).json()
return response["inputSizes"]
try:
return response["inputSizes"]
except Exception as e:
print(response["errorMessage"])
raise Exception(f"Invalid input size.")

def get_output_sizes(self, config={}):
input = {}
Expand Down Expand Up @@ -195,9 +209,13 @@ async def evaluate(request):
config = {}
if "config" in req_json:
config = req_json["config"]

input_sizes = model.get_input_sizes(config)
output_sizes = model.get_output_sizes(config)
try:
output_sizes = model.get_output_sizes(config)
except Exception as e:
print(traceback.format_exc())
return error_response("InvalidOutputSizes", str(traceback.format_exc()), 500)

# Check if parameter dimensions match model input sizes
if len(parameters) != len(input_sizes):
Expand All @@ -212,7 +230,7 @@ async def evaluate(request):
except Exception as e:
print(traceback.format_exc())
return error_response("InvalidEvaluation", str(traceback.format_exc()), 500)

# Check if output is a list of lists
if not isinstance(output, list):
return error_response("InvalidOutput", "Model output is not a list of lists!", 500)
Expand Down Expand Up @@ -248,7 +266,11 @@ async def gradient(request):
config = req_json["config"]

input_sizes = model.get_input_sizes(config)
output_sizes = model.get_output_sizes(config)
try:
output_sizes = model.get_output_sizes(config)
except Exception as e:
print(traceback.format_exc())
return error_response("InvalidOutputSizes", str(traceback.format_exc()), 500)

# Check if parameter dimensions match model input sizes
if len(parameters) != len(input_sizes):
Expand All @@ -272,7 +294,8 @@ async def gradient(request):
except Exception as e:
print(traceback.format_exc())
return error_response("InvalidGradient", str(traceback.format_exc()), 500)



# Check if output is a list
if not isinstance(output, list):
return error_response("InvalidOutput", "Model output is not a list!", 500)
Expand Down Expand Up @@ -303,7 +326,11 @@ async def applyjacobian(request):
config = req_json["config"]

input_sizes = model.get_input_sizes(config)
output_sizes = model.get_output_sizes(config)
try:
output_sizes = model.get_output_sizes(config)
except Exception as e:
print(traceback.format_exc())
return error_response("InvalidOutputSizes", str(traceback.format_exc()), 500)

# Check if parameter dimensions match model input sizes
if len(parameters) != len(input_sizes):
Expand Down Expand Up @@ -360,7 +387,11 @@ async def applyhessian(request):
config = req_json["config"]

input_sizes = model.get_input_sizes(config)
output_sizes = model.get_output_sizes(config)
try:
output_sizes = model.get_output_sizes(config)
except Exception as e:
print(traceback.format_exc())
return error_response("InvalidOutputSizes", str(traceback.format_exc()), 500)

# Check if parameter dimensions match model input sizes
if len(parameters) != len(input_sizes):
Expand All @@ -381,10 +412,10 @@ async def applyhessian(request):
try:
output_future = model_executor.submit(model.apply_hessian, out_wrt, in_wrt1, in_wrt2, parameters, sens, vec, config)
output = await asyncio.wrap_future(output_future)
except Exception as e:
except Exception as e:
print(traceback.format_exc())
return error_response("InvalidHessian", str(traceback.format_exc()), 500)

# Check if output is a list
if not isinstance(output, list):
return error_response("InvalidOutput", "Model output is not a list!", 500)
Expand All @@ -406,7 +437,13 @@ async def get_input_sizes(request):
model = get_model_from_name(model_name)
if model is None:
return model_not_found_response(req_json["name"])
return web.Response(text=f"{{\"inputSizes\": {model.get_input_sizes(config)} }}")
try:
return web.Response(text=f"{{\"inputSizes\": {model.get_input_sizes(config)} }}")
except Exception as e:
print(traceback.format_exc())
tb = traceback.format_exc()
error_message = {"error": "An exception occurred","traceback": tb}
return web.Response(text=f"{{\"errorMessage\": {json.dumps(error_message)} }}", content_type='application/json')

@routes.post('/OutputSizes')
async def get_output_sizes(request):
Expand All @@ -419,7 +456,13 @@ async def get_output_sizes(request):
model = get_model_from_name(model_name)
if model is None:
return model_not_found_response(req_json["name"])
return web.Response(text=f"{{\"outputSizes\": {model.get_output_sizes(config)} }}")
try:
return web.Response(text=f"{{\"outputSizes\": {model.get_output_sizes(config)} }}")
except Exception as e:
print(traceback.format_exc())
tb = traceback.format_exc()
error_message = {"error": "An exception occurred","traceback": tb}
return web.Response(text=f"{{\"errorMessage\": {json.dumps(error_message)} }}", content_type='application/json')

@routes.post('/ModelInfo')
async def modelinfo(request):
Expand Down

0 comments on commit cad5de6

Please sign in to comment.