Skip to content

Instantly share code, notes, and snippets.

@Belrestro
Created May 13, 2019 09:27
Show Gist options
  • Save Belrestro/a6b8e5b8d59924ad2dac97d0b4ca5260 to your computer and use it in GitHub Desktop.
Save Belrestro/a6b8e5b8d59924ad2dac97d0b4ca5260 to your computer and use it in GitHub Desktop.
train rasa core, and rasa nlu models programmatically rasa.core 0.14.3 nlu 0.15
from rasa_core.agent import Agent
from rasa_core.policies import KerasPolicy
from rasa_nlu.training_data import load_data
from rasa_nlu import config
from rasa_nlu.model import Trainer
import datetime
def _archive_name (type, environment):
t = datetime.datetime.now()
timestring = t.strftime('%Y%m%d-%H%M%S')
return '%s__model_%s' % (prefix, timestring)
def train_core (params):
domain = params.get('domain') if 'domain' in params else ''
stories = params.get('stories') if 'stories' in params else ''
environment = params.get('environment') if 'environment' in params else None
model_name = _archive_name('core', environment)
additional_arguments = {
"epochs": 100,
"batch_size": 20,
"validation_split": 0.1,
"augmentation_factor": 50,
"debug_plots": True,
"max_history": 5
}
agent = Agent(domain_path,
policies=[KerasPolicy(**additional_arguments)])
training_data = agent.load_data(md_stories_file_path if stories_in_json else stories_path)
agent.train(training_data)
# persist
agent.persist(model_dir)
def train_nlu (params):
intents_file = params.get('intents') if 'intents' in params else ''
config_file = params.get('config') if 'config' in params else {}
environment = params.get('environment') if 'environment' in params else None
model_name = _archive_name('nlu', environment)
model_dir = '%s/%s' % (BASE_DIR, model_name)
nlu_config = config.load(intents_file)
data = load_data(intents_file)
trainer = Trainer(nlu_config)
trainer.train(data)
trainer.persist(BASE_DIR, project_name= '', fixed_model_name = model_name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment