Skip to content

Instantly share code, notes, and snippets.

@JEM-Mosig
Created January 21, 2021 13:16
Show Gist options
  • Save JEM-Mosig/c6e15b81ee70561cb72e361aff310d7e to your computer and use it in GitHub Desktop.
Save JEM-Mosig/c6e15b81ee70561cb72e361aff310d7e to your computer and use it in GitHub Desktop.
Rasa diagnostic data access

You can access diagnostic data for DIET like this (please define the YOUR_RASA_MODEL_DIRECTORY and YOUR_RASA_MODEL_NAME constants):

from rasa.cli.utils import get_validated_path
from rasa.model import get_model, get_model_subdirectories
from rasa.nlu.model import Interpreter
from rasa.shared.nlu.training_data.message import Message
from rasa.shared.nlu.constants import TEXT
from rasa.shared.constants import DIAGNOSTIC_DATA
import pathlib


def load_interpreter(model_dir, model):
    path_str = str(pathlib.Path(model_dir) / model)
    model = get_validated_path(path_str, "model")
    model_path = get_model(model)
    _, nlu_model = get_model_subdirectories(model_path)
    return Interpreter.load(nlu_model)


if __name__ == "__main__":
    interpreter = load_interpreter(YOUR_RASA_MODEL_DIRECTORY, f"{YOUR_RASA_MODEL_NAME}.tar.gz")
    data = interpreter.default_output_attributes()
    data[TEXT] = "hello world"
    message = Message(data=data)
    for e in interpreter.pipeline:
        e.process(message)
    nlu_diagnostic_data = message.as_dict()[DIAGNOSTIC_DATA]

    for component_name, diagnostic_data in nlu_diagnostic_data.items():
        print(f"attention_weights for {component_name}:")
        attention_weights = diagnostic_data["attention_weights"]
        print(attention_weights)

        print(f"\ntext_transformed for {component_name}:")
        text_transformed = diagnostic_data["text_transformed"]
        print(text_transformed)

You can access diagnostic data for TED like this:

from rasa.core.policies.ted_policy import TEDPolicy
from rasa.shared.core.constants import ACTION_LISTEN_NAME
from rasa.shared.core.domain import Domain
from rasa.shared.core.events import ActionExecuted, UserUttered
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.shared.nlu.interpreter import RegexInterpreter


UTTER_GREET_ACTION = "utter_greet"
GREET_INTENT_NAME = "greet"
DOMAIN_YAML = f"""
intents:
- {GREET_INTENT_NAME}
actions:
- {UTTER_GREET_ACTION}
"""


if __name__ == "__main__":
    domain = Domain.from_yaml(DOMAIN_YAML)
    policy = TEDPolicy()
    GREET_RULE = DialogueStateTracker.from_events(
        "greet rule",
        evts=[
            UserUttered(intent={"name": GREET_INTENT_NAME}),
            ActionExecuted(UTTER_GREET_ACTION),
            ActionExecuted(ACTION_LISTEN_NAME),
            UserUttered(intent={"name": GREET_INTENT_NAME}),
            ActionExecuted(ACTION_LISTEN_NAME),
        ],
    )
    policy.train([GREET_RULE], domain, RegexInterpreter())
    prediction = policy.predict_action_probabilities(
        GREET_RULE, domain, RegexInterpreter()
    )

    print(f"{prediction.diagnostic_data.get('attention_weights')}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment