Commit a3930002 authored by Yves Reker's avatar Yves Reker

fixed bug causing crash because of wrong layer in a session

parent e161e0cc
......@@ -436,38 +436,49 @@ class ServerSession(QObject, ParserListener, SessionCommon):
res = self.__ensureDirectory()
if len(res) > 0:
return res
toSave = {"SessionState": self.state, "Iteration": self.iteration, "MaxIter": self.max_iter}
toSave["UID"] = self.uid
toSave["SID"] = self.sid
toSave["ProjectID"] = self.pid
if self.last_solverstate:
toSave["LastSnapshot"] = self.last_solverstate
if self.pretrainedWeights:
toSave["PretrainedWeights"] = self.pretrainedWeights
if self.state_dictionary:
serializedDict = copy.deepcopy(self.state_dictionary)
if "network" in serializedDict:
if includeProtoTxt:
netDict = copy.deepcopy(self.state_dictionary["network"])
net = saver.saveNet(netdict=netDict)
with open(os.path.join(self.directory, Paths.FILE_NAME_NET_ORIGINAL), 'w') as f:
f.write(net)
if "layers" in serializedDict["network"]:
layers = serializedDict["network"]["layers"]
for id in layers:
del layers[id]["type"]
toSave["NetworkState"] = serializedDict
filename = os.path.join(self.directory, Paths.FILE_NAME_SESSION_JSON)
# clear the file. sometimes json.dump does not exitclean and causes valuerrors on load
open(filename, 'w').close()
with open(filename, "w") as f:
json.dump(toSave, f, sort_keys=True, indent=4)
return []
availableTypes = info.CaffeMetaInformation().availableLayerTypes()
unknownLayerTypes = []
for layerID in self.state_dictionary["network"]["layers"]:
type = self.state_dictionary["network"]["layers"][layerID]["parameters"]["type"]
if not type in availableTypes and not type in unknownLayerTypes:
unknownLayerTypes.append(type)
if len(unknownLayerTypes) > 0:
return False
else:
toSave = {"SessionState": self.state, "Iteration": self.iteration, "MaxIter": self.max_iter}
toSave["UID"] = self.uid
toSave["SID"] = self.sid
toSave["ProjectID"] = self.pid
if self.last_solverstate:
toSave["LastSnapshot"] = self.last_solverstate
if self.pretrainedWeights:
toSave["PretrainedWeights"] = self.pretrainedWeights
if self.state_dictionary:
serializedDict = copy.deepcopy(self.state_dictionary)
if "network" in serializedDict:
if includeProtoTxt:
netDict = copy.deepcopy(self.state_dictionary["network"])
net = saver.saveNet(netdict=netDict)
with open(os.path.join(self.directory, Paths.FILE_NAME_NET_ORIGINAL), 'w') as f:
f.write(net)
if "layers" in serializedDict["network"]:
layers = serializedDict["network"]["layers"]
for id in layers:
del layers[id]["type"]
toSave["NetworkState"] = serializedDict
filename = os.path.join(self.directory, Paths.FILE_NAME_SESSION_JSON)
# clear the file. sometimes json.dump does not exitclean and causes valuerrors on load
open(filename, 'w').close()
with open(filename, "w") as f:
json.dump(toSave, f, sort_keys=True, indent=4)
return True
def prepairInternalPrototxt(self):
error = []
......
......@@ -65,7 +65,7 @@ class ServerSessionManager(QObject):
except UnknownLayerTypeException as e:
msg["status"] = False
msg["error"] = [e._msg]
logging.error("Could not create session. Unknonw layer")
logging.error("Could not create session. Unknown layer")
transaction.send(msg)
return
dirp = self._createDirName(msg["sid"])
......
......@@ -76,10 +76,17 @@ class ServerTransaction(Transaction):
path = msg["pid"]
if "pid" in msg:
sessionsUIDs = self.parent.sessionManager.findSessionIDsByProjectId(msg["pid"])
success = True
for uid in sessionsUIDs:
session = self.parent.sessionManager.findSessionBySessionUid(uid)
session.save()
os.execl(sys.executable, sys.executable, *sys.argv)
if not session.save():
success = False
if success is False:
msg["error"] = ["Could not save a session because it contains layer of a other caffe-version."]
msg["status"] = False
self.send(msg)
else:
os.execl(sys.executable, sys.executable, *sys.argv)
def getDir(self):
"""list all files and subdirectories for a given dir with regex."""
......
......@@ -107,9 +107,14 @@ class CaffeVersionWidget(QWidget):
msgBox = QMessageBox(QMessageBox.Warning, "Warning", "Please restart Barista host for changes to apply, otherwise Barista may be unstable!")
msgBox.addButton("Ok", QMessageBox.NoRole)
msgBox.addButton("Restart now", QMessageBox.YesRole)
if msgBox.exec_() == 1:
msg = {"key": Protocol.RESTART, "pid": self.versionManager.project.getProjectId()}
sendMsgToHost(self.host.host, self.host.port, msg)
ret = sendMsgToHost(self.host.host, self.host.port, msg)
if ret and not ret["status"]:
msgBox = QMessageBox(QMessageBox.Warning, "Warning", ret["error"][0])
msgBox.addButton("Ok", QMessageBox.NoRole)
msgBox.exec_()
self.versionManager.updateList()
def restartWarning(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment