You can access diagnostic data for DIET like this (please define the YOUR_RASA_MODEL_DIRECTORY
and YOUR_RASA_MODEL_NAME
constants):
from rasa.cli.utils import get_validated_path
from rasa.model import get_model, get_model_subdirectories
from rasa.nlu.model import Interpreter
from rasa.shared.nlu.training_data.message import Message
from rasa.shared.nlu.constants import TEXT
from rasa.shared.constants import DIAGNOSTIC_DATA
import pathlib
def load_interpreter(model_dir, model):
path_str = str(pathlib.Path(model_dir) / model)
model = get_validated_path(path_str, "model")
model_path = get_model(model)
_, nlu_model = get_model_subdirectories(model_path)
return Interpreter.load(nlu_model)
if __name__ == "__main__":
interpreter = load_interpreter(YOUR_RASA_MODEL_DIRECTORY, f"{YOUR_RASA_MODEL_NAME}.tar.gz")
data = interpreter.default_output_attributes()
data[TEXT] = "hello world"
message = Message(data=data)
for e in interpreter.pipeline:
e.process(message)
nlu_diagnostic_data = message.as_dict()[DIAGNOSTIC_DATA]
for component_name, diagnostic_data in nlu_diagnostic_data.items():
print(f"attention_weights for {component_name}:")
attention_weights = diagnostic_data["attention_weights"]
print(attention_weights)
print(f"\ntext_transformed for {component_name}:")
text_transformed = diagnostic_data["text_transformed"]
print(text_transformed)
You can access diagnostic data for TED like this:
from rasa.core.policies.ted_policy import TEDPolicy
from rasa.shared.core.constants import ACTION_LISTEN_NAME
from rasa.shared.core.domain import Domain
from rasa.shared.core.events import ActionExecuted, UserUttered
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.shared.nlu.interpreter import RegexInterpreter
UTTER_GREET_ACTION = "utter_greet"
GREET_INTENT_NAME = "greet"
DOMAIN_YAML = f"""
intents:
- {GREET_INTENT_NAME}
actions:
- {UTTER_GREET_ACTION}
"""
if __name__ == "__main__":
domain = Domain.from_yaml(DOMAIN_YAML)
policy = TEDPolicy()
GREET_RULE = DialogueStateTracker.from_events(
"greet rule",
evts=[
UserUttered(intent={"name": GREET_INTENT_NAME}),
ActionExecuted(UTTER_GREET_ACTION),
ActionExecuted(ACTION_LISTEN_NAME),
UserUttered(intent={"name": GREET_INTENT_NAME}),
ActionExecuted(ACTION_LISTEN_NAME),
],
)
policy.train([GREET_RULE], domain, RegexInterpreter())
prediction = policy.predict_action_probabilities(
GREET_RULE, domain, RegexInterpreter()
)
print(f"{prediction.diagnostic_data.get('attention_weights')}")