This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def run_round(state, warmup=False): | |
# 1) Agent takes action given state tracker's representation of dialogue (state) | |
agent_action_index, agent_action = dqn_agent.get_action(state, use_rule=warmup) | |
# 2) Update state tracker with the agent's action | |
round_num = state_tracker.update_state_agent(agent_action) | |
# 3) User takes action given agent action | |
user_action, reward, done, success = user.step(agent_action, round_num) | |
if not done: | |
# 4) Infuse error into semantic frame level of user action | |
emc.infuse_error(user_action) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def episode_reset(): | |
# First reset the state tracker | |
state_tracker.reset() | |
# Then pick an init user action | |
user_action = user.reset() | |
# Infuse with error | |
emc.infuse_error(user_action) | |
# And update state tracker | |
state_tracker.update_state_user(user_action) | |
# Finally, reset agent |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def warmup_run(): | |
total_step = 0 | |
while total_step != WARMUP_MEM and not dqn_agent.is_memory_full(): | |
# Reset episode | |
episode_reset() | |
done = False | |
# Get initial state from state tracker | |
state = state_tracker.get_state() | |
while not done: | |
next_state, _, done, _ = run_round(state, warmup=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train_run(): | |
episode = 0 | |
period_success_total = 0 | |
success_rate_best = 0.0 | |
# Almost exact same loop as warm-up ----- | |
while episode < NUM_EP_TRAIN: | |
episode_reset() | |
episode += 1 | |
done = False | |
state = state_tracker.get_state() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Possible inform and request slots for the agent | |
agent_inform_slots = ['moviename', 'theater', 'starttime', 'date', 'genre', 'state', 'city', 'zip', 'critic_rating', | |
'mpaa_rating', 'distanceconstraints', 'video_format', 'theater_chain', 'price', 'actor', | |
'description', 'other', 'numberofkids'] | |
agent_request_slots = ['moviename', 'theater', 'starttime', 'date', 'numberofpeople', 'genre', 'state', 'city', 'zip', | |
'critic_rating', 'mpaa_rating', 'distanceconstraints', 'video_format', 'theater_chain', 'price', | |
'actor', 'description', 'other', 'numberofkids'] | |
# Possible actions for agent | |
agent_actions = [ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _build_model(self): | |
model = Sequential() | |
model.add(Dense(self.hidden_size, input_dim=self.state_size, activation='relu')) | |
model.add(Dense(self.num_actions, activation='linear')) | |
model.compile(loss='mse', optimizer=Adam(lr=self.lr)) | |
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_action(self, state, use_rule=False): | |
# self.eps is initialized to the starting epsilon and does NOT get annealed | |
if self.eps > random.random(): | |
index = random.randint(0, self.num_actions - 1) | |
# self._map_index_to_action(index) takes an index and maps the action from all possible agent actions | |
action = self._map_index_to_action(index) | |
return index, action | |
else: | |
if use_rule: | |
return self._rule_action() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _rule_action(self): | |
# self.rule_current_slot_index points to current slot | |
# rule_requests defined in dialogue_config.py | |
if self.rule_current_slot_index < len(rule_requests): | |
slot = rule_requests[self.rule_current_slot_index] | |
self.rule_current_slot_index += 1 | |
rule_response = {'intent': 'request', 'inform_slots': {}, 'request_slots': {slot: 'UNK'}} | |
# self.rule_phase used to indicate if we are at second to last round or last round | |
elif self.rule_phase == 'not done': | |
rule_response = {'intent': 'match_found', 'inform_slots': {}, 'request_slots': {}} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def reset(self): | |
self.rule_current_slot_index = 0 | |
self.rule_phase = 'not done' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _dqn_action(self, state): | |
# self.beh_model is our keras behavior model | |
index = np.argmax(self.beh_model.predict(state.reshape(1, self.state_size), target=target).flatten()) | |
action = self._map_index_to_action(index) | |
return index, action |
OlderNewer