Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
""" | |
This module demonstrates a Streamlit application for masking Personally Identifiable | |
Information (PII) in Hebrew text using the GolemPII-v1 model. | |
""" | |
import time | |
from typing import List, Dict, Tuple | |
import streamlit as st | |
import torch | |
from transformers import AutoTokenizer, AutoModelForTokenClassification | |
# Constants for model name and entity colors | |
MODEL_NAME = "CordwainerSmith/GolemPII-v1" | |
ENTITY_COLORS = { | |
"PHONE_NUM": "#FF9999", | |
"ID_NUM": "#99FF99", | |
"CC_NUM": "#9999FF", | |
"BANK_ACCOUNT_NUM": "#FFFF99", | |
"FIRST_NAME": "#FF99FF", | |
"LAST_NAME": "#99FFFF", | |
"CITY": "#FFB366", | |
"STREET": "#B366FF", | |
"POSTAL_CODE": "#66FFB3", | |
"EMAIL": "#66B3FF", | |
"DATE": "#FFB3B3", | |
"CC_PROVIDER": "#B3FFB3", | |
} | |
# Example sentences for demonstration | |
EXAMPLE_SENTENCES = [ | |
"砖诐 诪诇讗: 转诇诪讛 讗专讬讗诇讬 诪住驻专 转注讜讚转 讝讛讜转: 61453324-8 转讗专讬讱 诇讬讚讛: 15/09/1983 讻转讜讘转: 讗专诇讜讝讜专讜讘 22 驻转讞 转拽讜讜讛 诪讬拽讜讚 2731711 讗讬诪讬讬诇: [email protected] 讟诇驻讜谉: 054-8884771 讘驻讙讬砖讛 讝讜 谞讚讜谞讜 驻转专讜谞讜转 讟讻谞讜诇讜讙讬讬诐 讞讚砖谞讬讬诐 诇砖讬驻讜专 转讛诇讬讻讬 注讘讜讚讛. 讛诪砖转转祝 讬转讘拽砖 诇讛爪讬讙 诪爪讙转 讘谞讜砖讗 讘驻讙讬砖讛 讛讘讗讛 讗砖专 砖讬诇诐 讘 5326-1003-5299-5478 诪住讟专拽讗专讚 注诐 讛讜专讗转 拽讘注 诇 11-77-352300", | |
] | |
# Model details for display in the sidebar | |
MODEL_DETAILS = { | |
"name": "GolemPII-v1: Hebrew PII Detection Model", | |
"description": """ | |
The <a href="https://huggingface.co/CordwainerSmith/GolemPII-v1" target="_blank">GolemPII model</a> | |
was specifically designed to identify and categorize various types of personally | |
identifiable information (PII) present in Hebrew text. Its core intended usage | |
revolves around enhancing privacy protection and facilitating the process of data | |
anonymization. This makes it a good candidate for applications and systems that | |
handle sensitive data, such as legal documents, medical records, or any text data | |
containing PII, where the automatic redaction or removal of such information is | |
essential for ensuring compliance with data privacy regulations and safeguarding | |
individuals' personal information. The model can be deployed on-premise with a | |
relatively small hardware footprint, making it suitable for organizations with | |
limited computing resources or those prioritizing local data processing. | |
The model was trained on the <a href="https://huggingface.co/datasets/CordwainerSmith/GolemGuard" | |
target="_blank">GolemGuard</a> dataset, a Hebrew language dataset comprising over | |
115,000 examples of PII entities and containing both real and synthetically | |
generated text examples. This data represents various document types and | |
communication formats commonly found in Israeli professional and administrative | |
contexts. GolemGuard covers a wide range of document types and encompasses a | |
diverse array of PII entities, making it ideal for training and evaluating PII | |
detection models. | |
""", | |
"base_model": "xlm-roberta-base", | |
"training_data": "Custom Hebrew PII dataset", | |
"detected_pii_entities": [ | |
"FIRST_NAME", | |
"LAST_NAME", | |
"STREET", | |
"CITY", | |
"PHONE_NUM", | |
"EMAIL", | |
"ID_NUM", | |
"BANK_ACCOUNT_NUM", | |
"CC_NUM", | |
"CC_PROVIDER", | |
"DATE", | |
"POSTAL_CODE", | |
], | |
} | |
class PIIMaskingModel: | |
""" | |
A class for masking PII in Hebrew text using the GolemPII-v1 model. | |
""" | |
def __init__(self, model_name: str): | |
""" | |
Initializes the PIIMaskingModel with the specified model name. | |
Args: | |
model_name: The name of the pre-trained model to use. | |
""" | |
self.model_name = model_name | |
hf_token = st.secrets["hf_token"] | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) | |
self.model = AutoModelForTokenClassification.from_pretrained( | |
model_name, token=hf_token | |
) | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model.to(self.device) | |
self.model.eval() | |
def process_text( | |
self, text: str | |
) -> Tuple[str, float, str, List[str], List[str], List[Dict]]: | |
""" | |
Processes the input text and returns the masked text, processing time, | |
colored text, tokens, predicted labels, and privacy masks. | |
Args: | |
text: The input text to process. | |
Returns: | |
A tuple containing: | |
- masked_text: The text with PII masked. | |
- processing_time: The time taken to process the text. | |
- colored_text: The text with PII highlighted with colors. | |
- tokens: The tokens of the input text. | |
- predicted_labels: The predicted labels for each token. | |
- privacy_masks: A list of dictionaries containing information about | |
the masked PII entities. | |
""" | |
start_time = time.time() | |
tokenized_inputs = self.tokenizer( | |
text, | |
truncation=True, | |
padding=False, | |
return_tensors="pt", | |
return_offsets_mapping=True, | |
add_special_tokens=True, | |
) | |
input_ids = tokenized_inputs.input_ids.to(self.device) | |
attention_mask = tokenized_inputs.attention_mask.to(self.device) | |
offset_mapping = tokenized_inputs["offset_mapping"][0].tolist() | |
# Handle special tokens | |
offset_mapping[0] = None # <s> token | |
offset_mapping[-1] = None # </s> token | |
with torch.no_grad(): | |
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) | |
predictions = outputs.logits.argmax(dim=-1).cpu().numpy() | |
predicted_labels = [ | |
self.model.config.id2label[label_id] for label_id in predictions[0] | |
] | |
tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0]) | |
masked_text, colored_text, privacy_masks = self.mask_pii_in_sentence( | |
tokens, predicted_labels, text, offset_mapping | |
) | |
processing_time = time.time() - start_time | |
return ( | |
masked_text, | |
processing_time, | |
colored_text, | |
tokens, | |
predicted_labels, | |
privacy_masks, | |
) | |
def _find_entity_span( | |
self, | |
i: int, | |
labels: List[str], | |
tokens: List[str], | |
offset_mapping: List[Tuple[int, int]], | |
) -> Tuple[int, str, int]: | |
""" | |
Finds the span of an entity starting at the given index. | |
Args: | |
i: The starting index of the entity. | |
labels: The list of labels for each token. | |
tokens: The list of tokens. | |
offset_mapping: The offset mapping for each token. | |
Returns: | |
A tuple containing: | |
- The index of the next token after the entity. | |
- The type of the entity. | |
- The end character offset of the entity. | |
""" | |
current_entity = labels[i][2:] if labels[i].startswith("B-") else labels[i][2:] | |
j = i + 1 | |
last_valid_end = offset_mapping[i][1] if offset_mapping[i] else None | |
while j < len(tokens): | |
if offset_mapping[j] is None: | |
j += 1 | |
continue | |
next_label = labels[j] | |
if next_label.startswith("B-") and tokens[j].startswith(" "): | |
break | |
if next_label.startswith("I-") and next_label[2:] != current_entity: | |
break | |
if next_label.startswith("I-") and next_label[2:] == current_entity: | |
last_valid_end = offset_mapping[j][1] | |
j += 1 | |
elif next_label.startswith("B-") and not tokens[j].startswith(" "): | |
last_valid_end = offset_mapping[j][1] | |
j += 1 | |
else: | |
break | |
return j, current_entity, last_valid_end | |
def mask_pii_in_sentence( | |
self, | |
tokens: List[str], | |
labels: List[str], | |
original_text: str, | |
offset_mapping: List[Tuple[int, int]], | |
) -> Tuple[str, str, List[Dict]]: | |
""" | |
Masks the PII entities in a sentence. | |
Args: | |
tokens: The list of tokens in the sentence. | |
labels: The list of labels for each token. | |
original_text: The original text of the sentence. | |
offset_mapping: The offset mapping for each token. | |
Returns: | |
A tuple containing: | |
- The masked text. | |
- The colored text. | |
- A list of dictionaries containing information about the masked | |
PII entities. | |
""" | |
privacy_masks = [] | |
current_pos = 0 | |
masked_text_parts = [] | |
colored_text_parts = [] | |
i = 0 | |
while i < len(tokens): | |
if offset_mapping[i] is None: | |
i += 1 | |
continue | |
current_label = labels[i] | |
if current_label.startswith(("B-", "I-")): | |
start_char = offset_mapping[i][0] | |
next_pos, entity_type, last_valid_end = self._find_entity_span( | |
i, labels, tokens, offset_mapping | |
) | |
if current_pos < start_char: | |
text_before = original_text[current_pos:start_char] | |
masked_text_parts.append(text_before) | |
colored_text_parts.append(text_before) | |
entity_value = original_text[start_char:last_valid_end] | |
mask = self._get_mask_for_entity(entity_type) | |
privacy_masks.append( | |
{ | |
"label": entity_type, | |
"start": start_char, | |
"end": last_valid_end, | |
"value": entity_value, | |
"label_index": len(privacy_masks) + 1, | |
} | |
) | |
masked_text_parts.append(mask) | |
color = ENTITY_COLORS.get(entity_type, "#CCCCCC") | |
colored_text_parts.append( | |
f'<span style="background-color: {color}; color: black; padding: 2px; border-radius: 3px;">{mask}</span>' | |
) | |
current_pos = last_valid_end | |
i = next_pos | |
else: | |
if offset_mapping[i] is not None: | |
start_char = offset_mapping[i][0] | |
end_char = offset_mapping[i][1] | |
if current_pos < end_char: | |
text_chunk = original_text[current_pos:end_char] | |
masked_text_parts.append(text_chunk) | |
colored_text_parts.append(text_chunk) | |
current_pos = end_char | |
i += 1 | |
if current_pos < len(original_text): | |
remaining_text = original_text[current_pos:] | |
masked_text_parts.append(remaining_text) | |
colored_text_parts.append(remaining_text) | |
return ("".join(masked_text_parts), "".join(colored_text_parts), privacy_masks) | |
def _get_mask_for_entity(self, entity_type: str) -> str: | |
""" | |
Returns the mask for a given entity type. | |
Args: | |
entity_type: The type of the entity. | |
Returns: | |
The mask for the entity type. | |
""" | |
return { | |
"PHONE_NUM": "[讟诇驻讜谉]", | |
"ID_NUM": "[转.讝]", | |
"CC_NUM": "[讻专讟讬住 讗砖专讗讬]", | |
"BANK_ACCOUNT_NUM": "[讞砖讘讜谉 讘谞拽]", | |
"FIRST_NAME": "[砖诐 驻专讟讬]", | |
"LAST_NAME": "[砖诐 诪砖驻讞讛]", | |
"CITY": "[注讬专]", | |
"STREET": "[专讞讜讘]", | |
"POSTAL_CODE": "[诪讬拽讜讚]", | |
"EMAIL": "[讗讬诪讬讬诇]", | |
"DATE": "[转讗专讬讱]", | |
"CC_PROVIDER": "[住驻拽 讻专讟讬住 讗砖专讗讬]", | |
"BANK": "[讘谞拽]", | |
}.get(entity_type, f"[{entity_type}]") | |
def main(): | |
""" | |
The main function for the Streamlit application. | |
""" | |
st.set_page_config(layout="wide") | |
st.title("馃椏 GolemPII: Hebrew PII Masking Application 馃椏") | |
st.markdown( | |
""" | |
<style> | |
.rtl { direction: rtl; text-align: right; } | |
.entity-legend { padding: 5px; margin: 2px; border-radius: 3px; display: inline-block; } | |
.masked-text { | |
direction: rtl; | |
text-align: right; | |
line-height: 2; | |
padding: 10px; | |
background-color: #f6f8fa; | |
border-radius: 5px; | |
color: black; | |
white-space: pre-wrap; | |
} | |
.main h3 { | |
margin-bottom: 10px; | |
} | |
textarea { | |
direction: rtl !important; | |
text-align: right !important; | |
} | |
.stTextArea label { | |
direction: ltr !important; | |
text-align: left !important; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Sidebar with model details | |
st.sidebar.markdown( | |
f""" | |
<div> | |
<h2>{MODEL_DETAILS['name']}</h2> | |
<p>{MODEL_DETAILS['description']}</p> | |
<h3>Supported PII Entities</h3> | |
<ul> | |
{" ".join([f'<li><span style="background-color: {ENTITY_COLORS.get(entity, "#CCCCCC")}; color: black; padding: 3px 5px; border-radius: 3px; margin-right: 5px;">{entity}</span></li>' for entity in MODEL_DETAILS['detected_pii_entities']])} | |
</ul> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
text_input = st.text_area( | |
"Enter text to mask (separate multiple texts with commas):", | |
value="\n".join(EXAMPLE_SENTENCES), | |
height=200, | |
) | |
show_json = st.checkbox("Show JSON Output", value=True) | |
if st.button("Process Text"): | |
texts = [text.strip() for text in text_input.split(",") if text.strip()] | |
model = PIIMaskingModel(MODEL_NAME) | |
for text in texts: | |
st.markdown( | |
'<h3 style="text-align: center;">Original Text</h3>', | |
unsafe_allow_html=True, | |
) | |
st.markdown(f'<div class="rtl">{text}</div>', unsafe_allow_html=True) | |
( | |
masked_text, | |
processing_time, | |
colored_text, | |
tokens, | |
predicted_labels, | |
privacy_masks, | |
) = model.process_text(text) | |
st.markdown( | |
'<h3 style="text-align: center;">Masked Text</h3>', | |
unsafe_allow_html=True, | |
) | |
st.markdown( | |
f'<div class="masked-text">{colored_text}</div>', unsafe_allow_html=True | |
) | |
st.markdown(f"Processing Time: {processing_time:.3f} seconds") | |
if show_json: | |
st.json( | |
{ | |
"original": text, | |
"masked": masked_text, | |
"processing_time": processing_time, | |
"tokens": tokens, | |
"token_classes": predicted_labels, | |
"privacy_mask": privacy_masks, | |
"span_labels": [ | |
[m["start"], m["end"], m["label"]] for m in privacy_masks | |
], | |
} | |
) | |
st.markdown("---") | |
if __name__ == "__main__": | |
main() | |