Commit f84ae630 authored by Florian Hurck's avatar Florian Hurck
Browse files

move reward values to different class

so that all configuration values are in one place
parent dd39c31a
......@@ -501,8 +501,19 @@ public class ModelHandler {
.l2(0) //l2 regularization
.updater(new Sgd(0.001)) //Gradient Descent that only applies the learning rate. There are other fancy options, see the folder: org.nd4j.linalg.learning.config
.build();
int rewardFulfilled, rewardNotFulfilled;
if(maximize) {
rewardFulfilled = 100;
rewardNotFulfilled = -100;
} else {
rewardFulfilled = -100;
rewardNotFulfilled = 100;
}
int rewardValidAction = 1;
int rewardInvalidAction = -1;
SchedulerDQNLearner DQNLearner = new SchedulerDQNLearner(model, logger, child, time, simulationHandler.getProphetic(), maximize, simulationHandler.getPrintRunResults());
SchedulerDQNLearner DQNLearner = new SchedulerDQNLearner(model, logger, child, time, simulationHandler.getProphetic(), rewardFulfilled, rewardNotFulfilled, rewardValidAction, rewardInvalidAction, simulationHandler.getPrintRunResults());
//DiscreteDense does not mean the state space is discrete but the action space is
QLearningDiscreteDense<Observation> dql = new QLearningDiscreteDense<Observation>(DQNLearner, netConf, qLearningConf);
......
......@@ -36,7 +36,6 @@ public class SchedulerDQNLearner implements MDP<Observation, Integer, DiscreteSp
private SimpleNode root;
private Double propertyTime;
private boolean prophetic;
private boolean maximize;
private DynamicSimulator simulator;
private PropertyChecker checker;
private ArrayList<String> related;
......@@ -48,6 +47,8 @@ public class SchedulerDQNLearner implements MDP<Observation, Integer, DiscreteSp
private boolean printRunResults;
private int rewardFulfilled;
private int rewardNotFulfilled;
private int rewardValidAction;
private int rewardInvalidAction;
private DiscreteSpace actionSpace;
private ObservationSpace<Observation> observationSpace;
private boolean crash;
......@@ -58,13 +59,16 @@ public class SchedulerDQNLearner implements MDP<Observation, Integer, DiscreteSp
private int totalValidActionCounter;
public SchedulerDQNLearner(HPnGModel model, Logger logger, SimpleNode root, Double propertyTime, boolean prophetic, boolean maximize, boolean printRunResults) throws InvalidPropertyException, InvalidDistributionParameterException {
public SchedulerDQNLearner(HPnGModel model, Logger logger, SimpleNode root, Double propertyTime, boolean prophetic, int rewardFulfilled, int rewardNotFulfilled, int rewardValidAction, int rewardInvalidAction, boolean printRunResults) throws InvalidPropertyException, InvalidDistributionParameterException {
this.model = model;
this.logger = logger;
this.root = root;
this.propertyTime = propertyTime;
this.prophetic = prophetic;
this.maximize = maximize;
this.rewardFulfilled = rewardFulfilled;
this.rewardNotFulfilled = rewardNotFulfilled;
this.rewardValidAction = rewardValidAction;
this.rewardInvalidAction = rewardInvalidAction;
this.printRunResults = printRunResults;
this.maxTime = PropertyChecker.getMaxTimeForSimulation(root, propertyTime);
epocheCounter = 0;
......@@ -84,15 +88,6 @@ public class SchedulerDQNLearner implements MDP<Observation, Integer, DiscreteSp
generator = new SampleGenerator();
generator.initializeRandomStream();
if(maximize) {
rewardFulfilled = 1;
rewardNotFulfilled = -1;
}
else {
rewardFulfilled = -1;
rewardNotFulfilled = 1;
}
//calculate how many actions in total we can take => count transitions
int actionSpaceSize = model.getTransitions().size();
actionSpace = new DiscreteSpace(actionSpaceSize);
......@@ -180,11 +175,11 @@ public class SchedulerDQNLearner implements MDP<Observation, Integer, DiscreteSp
if (printRunResults) {
if (simulator.event.getEventType().equals(SimulationEvent.SimulationEventType.general_transition)) {
System.out.println(simulator.event.getOccurenceTime() + " seconds: General transition " + transition.getId() + " is fired for the " + ((GeneralTransition) transition).getFirings() + ". time");
System.out.println(simulator.event.getOccurenceTime() + "s: General transition " + transition.getId() + " is fired for the " + ((GeneralTransition) transition).getFirings() + ". time");
} else if (simulator.event.getEventType().equals(SimulationEvent.SimulationEventType.immediate_transition)) {
System.out.println(simulator.event.getOccurenceTime() + " seconds: Immediate transition " + transition.getId() + " is fired");
System.out.println(simulator.event.getOccurenceTime() + "s: Immediate transition " + transition.getId() + " is fired");
} else if (simulator.event.getEventType().equals(SimulationEvent.SimulationEventType.deterministic_transition)) {
System.out.println(simulator.event.getOccurenceTime() + " seconds: Deterministic transition " + transition.getId() + " is fired");
System.out.println(simulator.event.getOccurenceTime() + "s: Deterministic transition " + transition.getId() + " is fired");
}
}
}
......@@ -305,8 +300,7 @@ public class SchedulerDQNLearner implements MDP<Observation, Integer, DiscreteSp
// would it be equal to do?: transition.getEnabled() == true
if(simulator.event.getRelatedObjects().contains(transition)) {
executeTransition(transition);
//a reward bigger 0 (but smaller than rewardFulfilled/rewardNotFulfilled) might be better to reward the NN for choosing a legal action
reward = 0;
reward = rewardValidAction;
totalValidActionCounter++;
try {
......@@ -341,7 +335,7 @@ public class SchedulerDQNLearner implements MDP<Observation, Integer, DiscreteSp
} else {
//cant execute action, conflict remains unsolved, give big bad reward and try again
reward = -10;
reward = rewardInvalidAction;
invalidActionCounter++;
totalInvalidActionCounter++;
}
......@@ -361,7 +355,7 @@ public class SchedulerDQNLearner implements MDP<Observation, Integer, DiscreteSp
//debug purpose, please delete later
System.out.println("DQNLearner.newInstance()");
try {
return new SchedulerDQNLearner(model, logger, root, propertyTime, prophetic, maximize, printRunResults);
return new SchedulerDQNLearner(model, logger, root, propertyTime, prophetic, rewardFulfilled, rewardNotFulfilled, rewardValidAction, rewardInvalidAction, printRunResults);
} catch (InvalidPropertyException | InvalidDistributionParameterException e) {
System.out.println("An impossible internal error occured while using Deep Q Learning.");
logger.log(Level.SEVERE, "This error should be impossible because an object with exactly the same parameters was previously created without crashing", e);
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment