deepchopper / app.py
yangliz5's picture
fix: Update text for detecting Adapter regions
9e875d0
raw
history blame
8.73 kB
import multiprocessing
from functools import partial
from pathlib import Path
import gradio as gr
import lightning
import torch
from datasets import Dataset
from torch.utils.data import DataLoader
import deepchopper
from deepchopper.deepchopper import default, encode_qual, remove_intervals_and_keep_left, smooth_label_region
from deepchopper.models.llm import (
tokenize_and_align_labels_and_quals,
)
from deepchopper.utils import (
summary_predict,
)
def parse_fq_record(text: str):
"""Parse a single FASTQ record into a dictionary."""
lines = text.strip().split("\n")
for i in range(0, len(lines), 4):
content = lines[i : i + 4]
record_id, seq, _, qual = content
assert len(seq) == len(qual) # noqa: S101
yield {
"id": record_id,
"seq": seq,
"qual": encode_qual(qual, default.KMER_SIZE),
"target": [0, 0],
}
def load_dataset(text: str, tokenizer):
"""Load dataset from text."""
dataset = Dataset.from_generator(parse_fq_record, gen_kwargs={"text": text}).with_format("torch")
tokenized_dataset = dataset.map(
partial(
tokenize_and_align_labels_and_quals,
tokenizer=tokenizer,
max_length=tokenizer.max_len_single_sentence,
),
num_proc=multiprocessing.cpu_count(), # type: ignore
).remove_columns(["id", "seq", "qual", "target"])
return dataset, tokenized_dataset
def predict(
text: str,
smooth_window_size: int = 21,
min_interval_size: int = 13,
approved_interval_number: int = 20,
max_process_intervals: int = 8, # default is 4
batch_size: int = 1,
num_workers: int = 1,
):
tokenizer = deepchopper.models.llm.load_tokenizer_from_hyena_model(model_name="hyenadna-small-32k-seqlen")
dataset, tokenized_dataset = load_dataset(text, tokenizer)
dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, num_workers=num_workers, persistent_workers=True)
model = deepchopper.DeepChopper.from_pretrained("yangliz5/deepchopper")
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
trainer = lightning.pytorch.trainer.Trainer(
accelerator=accelerator,
deterministic=False,
logger=False,
)
predicts = trainer.predict(model=model, dataloaders=dataloader, return_predictions=True)
assert len(predicts) == 1 # noqa: S101
smooth_interval_json: list[dict[str, int]] = []
highlighted_text: list[tuple[str, str | None]] = []
for idx, preds in enumerate(predicts):
true_prediction, _true_label = summary_predict(predictions=preds[0], labels=preds[1])
_id = dataset[idx]["id"]
seq = dataset[idx]["seq"]
smooth_predict_targets = smooth_label_region(
true_prediction[0], smooth_window_size, min_interval_size, approved_interval_number
)
if not smooth_predict_targets or len(smooth_predict_targets) > max_process_intervals:
continue
# zip two consecutive elements
_selected_seqs, selected_intervals = remove_intervals_and_keep_left(seq, smooth_predict_targets)
total_intervals = sorted(selected_intervals + smooth_predict_targets)
smooth_interval_json.extend({"start": i[0], "end": i[1]} for i in smooth_predict_targets)
highlighted_text.extend(
(seq[interval[0] : interval[1]], "ada" if interval in smooth_predict_targets else None)
for interval in total_intervals
)
return smooth_interval_json, highlighted_text
def process_input(text: str | None, file: str | None):
"""Process the input and return the prediction."""
if not text and not file:
gr.Warning("Both text and file are empty")
if file:
MAX_LINES = 4
file_content = []
with Path(file).open() as f:
for idx, line in enumerate(f):
if idx >= MAX_LINES:
break
file_content.append(line)
text = "".join(file_content)
return predict(text=text)
return predict(text=text)
def create_gradio_app():
"""Create a Gradio app for DeepChopper."""
example = (
"@1065:1135|393d635c-64f0-41ed-8531-12174d8efb28+f6a60069-1fcf-4049-8e7c-37523b4e273f\n"
"GCAGCTATGAATGCAAGGCCACAAGGTGGATGGAAGAGTTGTGGAACCAAAGAGCTGTCTTCCAGAGAAGATTTCGAGATAAGTCGCCCATCAGTGAACAAGATATTGTTGGTGGCATTTGATGAGAACGTTCCAAGATTATTGACAGATTAGTGAAAAGTAAGATTGAAATCATGACTGACCGTAAGTGGCAAGAAAGGGCTTTTGCCTTTGTAACCTTTGACGACCATGACTCCGTGGATAAGATTGTCATTCAGAATACCATACTGTGAATGGCCACATCTTTATTGTGAAGTTAGAAAAGCCCTGTCAAAGCAAGAGATGAATCAGTGCTTCTCCAGCCAAAGAGGTCGAAGTGGTTCTGGAAACTTTGGTGGTGGTCGTGGAGGTGGTTTCGGTGGGAATGACAACTCGGTCGTGGAGGAAACTTCAGTGGTCGTGGTGGCTTTGGTGGCAGCCGTGGTGGTGGTGGATATGGTGGCAGTGGGGATGGCTATAATGGATTTGGTAATGATGGAAGCAATTTGGAGGTGGTGGAAGCTACAATGATTTTGGGAATTACAACAATCAGTCTTCAAATTTTGGACCCCTAGGAGGAAATTTTGGTAGAAGCTCTGGCCCCATGGCGGTGGAGGCCAAATACTTTTGCAAACCACGAAACCAAGGTGGCTATGGCGGTCCAGCAGCAGCAGTAGCTATGGCAGTGGCAGAAGATTTTAATTAGGAAACAAAGCTTAGCAGGAGAGGAGAGCCAGAGAAGTGACAGGGAAGTACAGGTTACAACAGATTTGTGAACTCAGCCCAAGCACAGTGGTGGCAGGGCCTAGCTGCTACAAAGAAGACATGTTTTAGACAAATACTCATGTGTATGGGCAAAACTTGAGGACTGTATTTGTGACTAACTGTATAACAGGTTATTTTAGTTTCTGTTTGTGGAAAGTGTAAAGCATTCCAACAAAGGTTTTTAATGTAGATTTTTTTTTTTGCACCCCATGCTGTTGATTTGCTAAATGTAACAGTCTGATCGTGACGCTGAATAAATGTCTTTTTTAAAAAAAAAAAAAAGCTCCCTCCCATCCCCTGCTGCTAACTGATCCCATTATATCTAACCTGCCCCCCCATATCACCTGCTCCCGAGCTACCTAAGAACAGCTAAAAGAGCACACCCGCATGTAGCAAAATAGTGGGAAGATTATAGGTAGAGGCGACAAACCTACCGAGCCTGGTGATAGCTGGTTGTCCTAGATAGAATCTTAGTTCAACTTTAAATTTGCCCACAGAACCCTCTAAATCCCCTTGTAAATTTAACTGTTAGTCCAAAGAGGAACAGCTCTTTGGACACTAGGAAAAAACCTTGTAGAGAGTAAAAAATCAACACCCA\n"
"+\n"
".0==?SSSSSSSSSSSH2216<868;SSSSSSSSSQQSRSIIHEDDESSSSSSJIKMGEKISSJJICCBDQ?;;8:;,**(&$'+501)\"#$()+%&&0<5+*/('%'))))'''$##\"\"\"\"%&--$\"\"\"('%)1L3*'')'#\"#&+*$&\"\"#*(&'''+,,<;9<BHGF//.LKORQSK<###%*-89<FSSSSE=BAFHFDB???3313NN?>=ANOSJDCADHGMOQSSD=7>BRRSPIEEEOQSSQ4->LIC7EE045///03IIJQSSSNGE6('.5??@A@=,,EGRSPKJ<==<556GFLLQRANSSSSSSSSG...*%%%(***(%'3@LOOSSSSM...7BCMMSSSSSSSSSSSSSSSDFIPSSSGGGGPOQLIHIL4103HMSILLNOSSSSSSSSSS22CBCGSHHHHSSSSSSSSD??@<<<:DDDSSSSSSSSSSA@6688OSSSSSROJJKLSNNNMSSSSQPOOSOOQSSSSSRRHIHISSRSSSSSSSSSSSJFF=??@SSQRK:424<444FFG///1S@@@ASNNNNPN:4JMDDLPSSSSSSBA?B?@@+'&'BD**8EDEFQPIMLE$$&',79CSJJPSGA+***DN;3-('&(;>6(()/-,,)%')1FRNNJ-:=>GC;&;CHNFFDCEEKJLFA22/27A.....HSQLHL))8<=?JSSSFGSKIHDDCCEFDAA@CFJKLNL>:9/1>>?OSLK@+HPSA;>>>K;;;;SSSSOQLPPMORSSSSSQSSSSSSS=:9**?D889SSRFFEDKJJJEEDKSSSNNOSSS.---,&*++SSSSQRSSSSQPGED<<89<@GJ999:SSKBBBAJHK=SSSJJKNMGHKKHQA<<>OPKFEAACDHJKMORB/)'((6**)15DA99;JSQSSS2())+J))EGMQOMMKJF>?<<AA620..D..,/112SOIIJSQFNEEEOMF?066=>@4,3;B>87FSSSSSSSSSSSSSSS<<::5658@AHMMSSRECC448/=<<>SSCB:5546;<??KF==;;FFEDFHKKJG):C>=>BJHINJFDPPPPPPPPPPPPPP%'*%$%+-%'(-22&&%('''&&&#\"\"%&'+0,,0;:1&\"\"%'(+++8'**(\"$$#&$'**//.3497$\"3CFHLOSSSSR:887:;;FSSRPRSSS4433$#$%&$$-056>@:;>=@?AHEFEC;*EKMSSRSRRDB>=AFRSSSSBSOOPSMDAABHH976951-9DHPQO/---?@ELSSQSRJHKKBKKLSSLINSOSSQSRIMSSSSSS>?MKIINSSGSSSSSSSQQMK544MJKKNKHGGLFFGBDB?EHIKGD?@DHPPIIF555)&(+,ADSSSSRQSSSQSS=9/0JJMSQSOSSO/97=B@=:>"
)
custom_css = """
.header { text-align: center; margin-bottom: 30px; }
.footer { text-align: center; margin-top: 30px; font-size: 0.8em; color: #666; }
"""
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
gr.HTML(
"""
<div class="header">
<h1>🧬 DeepChopper: DNA Sequence Analysis</h1>
<p>Analyze DNA sequences and detect Adapter sequences</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
text_input = gr.Textbox(
label="Input DNA Sequence", placeholder="Paste your DNA sequence here...", lines=10
)
file_input = gr.File(label="Or upload a FASTQ file")
submit_btn = gr.Button("Analyze", variant="primary")
with gr.Column(scale=1):
json_output = gr.JSON(label="Detected Adapter Regions")
highlighted_text = gr.HighlightedText(label="Highlighted Sequence")
submit_btn.click(fn=process_input, inputs=[text_input, file_input], outputs=[json_output, highlighted_text])
gr.Examples(
examples=[[example]],
inputs=[text_input],
)
gr.HTML(
"""
<div class="footer">
<p>DeepChopper - Powered by AI for DNA sequence analysis</p>
</div>
"""
)
return demo
def main():
"""Launch the Gradio app."""
app = create_gradio_app()
app.launch()
if __name__ == "__main__":
main()