Skip to content

Instantly share code, notes, and snippets.

@lectricas
Created June 13, 2021 12:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lectricas/a8b45cd6e1805347716caac92730d268 to your computer and use it in GitHub Desktop.
Save lectricas/a8b45cd6e1805347716caac92730d268 to your computer and use it in GitHub Desktop.
#include <cstdio>
#include <cctype>
#include <string>
#include <iostream>
#include <vector>
#include <assert.h>
#include <algorithm>
#include <unordered_map>
#include <map>
using namespace std;
struct Node;
struct Node {
int terminal = 0;
Node(char letter) : letter(letter) {}
char letter;
unordered_map<char, Node *> map;
vector<int> words;
bool operator==(const Node &n) const {
return this->letter == n.letter;
}
};
struct Trie {
Node *head = new Node('#');
void add_string_internal(string &text, int index_of_string) {
Node *currentNode = head;
for (char symbol : text) {
if (symbol <= 'Z') {
if (currentNode->map.count(symbol) == 0) {
Node *next = new Node(symbol);
currentNode->map[symbol] = next;
}
currentNode = currentNode->map[symbol];
}
}
currentNode->terminal++;
currentNode->words.push_back(index_of_string);
}
Node *find_any(string &text) {
Node *current = head;
for (char &symbol : text) {
if (current->map.count(symbol) != 0) {
current = current->map[symbol];
} else {
return nullptr;
}
}
return current;
}
void visitNode(Node *root, vector<int> &words) {
if (root == nullptr) {
return;
}
if (root->terminal > 0) {
for (int index: root->words) {
words.push_back(index);
}
}
for (auto n: root->map) {
visitNode(n.second, words);
}
}
vector<int> findAll(string &text) {
vector<int> worlds;
Node *common = find_any(text);
visitNode(common, worlds);
return worlds;
}
};
int main() {
Trie tr;
int n;
cin >> n;
vector<string> strings;
cin.ignore();
for (int i = 0; i < n; i++) {
string line;
getline(cin, line);
strings.push_back(line);
}
sort(strings.begin(), strings.end());
for (int i = 0; i < n; i++) {
tr.add_string_internal(strings[i], i);
}
int k = 0;
cin >> k;
cin.ignore();
for (int i = 0; i < k; i++) {
string line;
getline(cin, line);
vector<int> answers = tr.findAll(line);
sort(answers.begin(), answers.end());
for (int j = 0; j < answers.size(); j++) {
cout << strings[answers[j]] << "\n";
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment