Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Extract Entities from NER labels
def split_tag(tag):
    return tuple(tag.split("-", 1)) if tag != "O" else (tag, None) 
    
def extract_entities(tags):
    curr_entity = []
    entities = []
    for i,tag in enumerate(tags + ["O"]):
        # Add dummy tag in end to ensure the last entity is added to entities
        boundary, label = split_tag(tag)
        if curr_entity:
            # Exit entity
            if boundary in {"B", "O"} or label != curr_entity[-1][1]:
                start = i - len(curr_entity)
                end = i
                entity_label = curr_entity[-1][1]
                entities.append((entity_label, start, end))
                curr_entity = []
            elif boundary == "I":
                curr_entity.append((boundary, label))
        if boundary == "B":
            # Enter or inside entity
            assert not curr_entity, f"Entity should be empty. Found: {curr_entity}"
            curr_entity.append((boundary, label))
    return entities
In [29]: extract_entities(["I-LOC", "I-LOC", "O", "B-LOC"])                                                                                                            
Out[29]: [('LOC', 3, 4)]

In [30]: extract_entities(["B-LOC", "I-LOC", "O", "B-LOC"])                                                                                                            
Out[30]: [('LOC', 0, 2), ('LOC', 3, 4)]

In [31]: extract_entities(["B-LOC", "B-LOC", "O", "B-LOC"])                                                                                                            
Out[31]: [('LOC', 0, 1), ('LOC', 1, 2), ('LOC', 3, 4)]

In [32]: extract_entities(["B-LOC", "B-PER", "O", "B-LOC"])                                                                                                            
Out[32]: [('LOC', 0, 1), ('PER', 1, 2), ('LOC', 3, 4)]

In [34]: extract_entities(["B-LOC"])                                                                                                                                   
Out[34]: [('LOC', 0, 1)]

In [35]: extract_entities(["I-LOC"])                                                                                                                                   
Out[35]: []

In [36]: extract_entities(["O"])                                                                                                                                       
Out[36]: []

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment