Spaces:
Sleeping
Sleeping
Commit
·
4fa4501
1
Parent(s):
24d7d26
fix
Browse files
app.py
CHANGED
@@ -144,24 +144,20 @@ if selected_model == 'Cas9':
|
|
144 |
|
145 |
if 'exons' not in st.session_state:
|
146 |
st.session_state['exons'] = []
|
147 |
-
if 'cds' not in st.session_state:
|
148 |
-
st.session_state['cds'] = []
|
149 |
|
150 |
# Process predictions
|
151 |
if predict_button and gene_symbol:
|
152 |
with st.spinner('Predicting... Please wait'):
|
153 |
-
predictions, gene_sequence, exons
|
154 |
sorted_predictions = sorted(predictions, key=lambda x: x[-1], reverse=True)[:10]
|
155 |
st.session_state['on_target_results'] = sorted_predictions
|
156 |
st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
|
157 |
st.session_state['exons'] = exons # Store exon data
|
158 |
-
st.session_state['cds'] = cds # Store CDS data
|
159 |
|
160 |
# Notify the user once the process is completed successfully.
|
161 |
st.success('Prediction completed!')
|
162 |
st.session_state['prediction_made'] = True
|
163 |
|
164 |
-
|
165 |
if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
|
166 |
ensembl_id = gene_annotations.get(gene_symbol, 'Unknown') # Get Ensembl ID or default to 'Unknown'
|
167 |
col1, col2, col3 = st.columns(3)
|
@@ -177,7 +173,7 @@ if selected_model == 'Cas9':
|
|
177 |
# Include "Target" in the DataFrame's columns
|
178 |
try:
|
179 |
df = pd.DataFrame(st.session_state['on_target_results'],
|
180 |
-
columns=["Chr", "Start Pos", "End Pos", "Strand", "Transcript", "Target", "gRNA", "Prediction"])
|
181 |
st.dataframe(df)
|
182 |
except ValueError as e:
|
183 |
st.error(f"DataFrame creation error: {e}")
|
@@ -189,7 +185,6 @@ if selected_model == 'Cas9':
|
|
189 |
|
190 |
EXON_BASE = 0 # Base position for exons and CDS on the Y axis
|
191 |
EXON_HEIGHT = 0.02 # How 'tall' the exon markers should appear
|
192 |
-
CDS_HEIGHT = 0.04 # How 'tall' the CDS markers should appear
|
193 |
|
194 |
# Plot Exons as small markers on the X-axis
|
195 |
for exon in st.session_state['exons']:
|
@@ -203,18 +198,6 @@ if selected_model == 'Cas9':
|
|
203 |
name='Exon'
|
204 |
))
|
205 |
|
206 |
-
# Plot CDS in a similar manner
|
207 |
-
for cds in st.session_state['cds']:
|
208 |
-
cds_start, cds_end = cds['start'], cds['end']
|
209 |
-
fig.add_trace(go.Bar(
|
210 |
-
x=[(cds_start + cds_end) / 2],
|
211 |
-
y=[CDS_HEIGHT],
|
212 |
-
width=[cds_end - cds_start],
|
213 |
-
base=[EXON_BASE],
|
214 |
-
marker_color='rgba(0, 0, 255, 1)',
|
215 |
-
name='CDS'
|
216 |
-
))
|
217 |
-
|
218 |
VERTICAL_GAP = 0.2 # Gap between different ranks
|
219 |
|
220 |
# Define max and min Y values based on strand and rank
|
@@ -254,38 +237,38 @@ if selected_model == 'Cas9':
|
|
254 |
# Display the plot
|
255 |
st.plotly_chart(fig)
|
256 |
|
257 |
-
if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']:
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
|
290 |
elif target_selection == 'off-target':
|
291 |
ENTRY_METHODS = dict(
|
|
|
144 |
|
145 |
if 'exons' not in st.session_state:
|
146 |
st.session_state['exons'] = []
|
|
|
|
|
147 |
|
148 |
# Process predictions
|
149 |
if predict_button and gene_symbol:
|
150 |
with st.spinner('Predicting... Please wait'):
|
151 |
+
predictions, gene_sequence, exons = cas9on.process_gene(gene_symbol, cas9on_path)
|
152 |
sorted_predictions = sorted(predictions, key=lambda x: x[-1], reverse=True)[:10]
|
153 |
st.session_state['on_target_results'] = sorted_predictions
|
154 |
st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
|
155 |
st.session_state['exons'] = exons # Store exon data
|
|
|
156 |
|
157 |
# Notify the user once the process is completed successfully.
|
158 |
st.success('Prediction completed!')
|
159 |
st.session_state['prediction_made'] = True
|
160 |
|
|
|
161 |
if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
|
162 |
ensembl_id = gene_annotations.get(gene_symbol, 'Unknown') # Get Ensembl ID or default to 'Unknown'
|
163 |
col1, col2, col3 = st.columns(3)
|
|
|
173 |
# Include "Target" in the DataFrame's columns
|
174 |
try:
|
175 |
df = pd.DataFrame(st.session_state['on_target_results'],
|
176 |
+
columns=["Chr", "Start Pos", "End Pos", "Strand", "Transcript", "Exon", "Target", "gRNA", "Prediction"])
|
177 |
st.dataframe(df)
|
178 |
except ValueError as e:
|
179 |
st.error(f"DataFrame creation error: {e}")
|
|
|
185 |
|
186 |
EXON_BASE = 0 # Base position for exons and CDS on the Y axis
|
187 |
EXON_HEIGHT = 0.02 # How 'tall' the exon markers should appear
|
|
|
188 |
|
189 |
# Plot Exons as small markers on the X-axis
|
190 |
for exon in st.session_state['exons']:
|
|
|
198 |
name='Exon'
|
199 |
))
|
200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
VERTICAL_GAP = 0.2 # Gap between different ranks
|
202 |
|
203 |
# Define max and min Y values based on strand and rank
|
|
|
237 |
# Display the plot
|
238 |
st.plotly_chart(fig)
|
239 |
|
240 |
+
# if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']:
|
241 |
+
# gene_symbol = st.session_state['current_gene_symbol']
|
242 |
+
# gene_sequence = st.session_state['gene_sequence']
|
243 |
+
#
|
244 |
+
# # Define file paths
|
245 |
+
# genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
|
246 |
+
# bed_file_path = f"{gene_symbol}_crispr_targets.bed"
|
247 |
+
# csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
|
248 |
+
#
|
249 |
+
# # Generate files
|
250 |
+
# cas9on.generate_genbank_file_from_df(df, gene_sequence, gene_symbol, genbank_file_path)
|
251 |
+
# cas9on.create_bed_file_from_df(df, bed_file_path)
|
252 |
+
# cas9on.create_csv_from_df(df, csv_file_path)
|
253 |
+
#
|
254 |
+
# # Prepare an in-memory buffer for the ZIP file
|
255 |
+
# zip_buffer = io.BytesIO()
|
256 |
+
# with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
|
257 |
+
# # For each file, add it to the ZIP file
|
258 |
+
# zip_file.write(genbank_file_path, arcname=genbank_file_path.split('/')[-1])
|
259 |
+
# zip_file.write(bed_file_path, arcname=bed_file_path.split('/')[-1])
|
260 |
+
# zip_file.write(csv_file_path, arcname=csv_file_path.split('/')[-1])
|
261 |
+
#
|
262 |
+
# # Important: move the cursor to the beginning of the BytesIO buffer before reading it
|
263 |
+
# zip_buffer.seek(0)
|
264 |
+
#
|
265 |
+
# # Display the download button for the ZIP file
|
266 |
+
# st.download_button(
|
267 |
+
# label="Download genbank,.bed,csv files as ZIP",
|
268 |
+
# data=zip_buffer.getvalue(),
|
269 |
+
# file_name=f"{gene_symbol}_files.zip",
|
270 |
+
# mime="application/zip"
|
271 |
+
# )
|
272 |
|
273 |
elif target_selection == 'off-target':
|
274 |
ENTRY_METHODS = dict(
|
cas9on.py
CHANGED
@@ -39,167 +39,159 @@ class DCModelOntar:
|
|
39 |
yp = self.model.predict(x)
|
40 |
return yp.ravel()
|
41 |
|
42 |
-
# Function to predict on-target efficiency and format output
|
43 |
-
def format_prediction_output(targets, model_path):
|
44 |
-
dcModel = DCModelOntar(model_path)
|
45 |
-
formatted_data = []
|
46 |
|
47 |
-
for target in targets:
|
48 |
-
# Encode the gRNA sequence
|
49 |
-
encoded_seq = get_seqcode(target[0]).reshape(-1,4,1,23)
|
50 |
-
|
51 |
-
# Predict on-target efficiency using the model
|
52 |
-
prediction = dcModel.ontar_predict(encoded_seq)
|
53 |
-
|
54 |
-
# Format output
|
55 |
-
sgRNA = target[1]
|
56 |
-
chr = target[2]
|
57 |
-
start = target[3]
|
58 |
-
end = target[4]
|
59 |
-
strand = target[5]
|
60 |
-
transcript_id = target[6]
|
61 |
-
formatted_data.append([chr, start, end, strand, transcript_id, target[0], sgRNA, prediction[0]])
|
62 |
-
|
63 |
-
return formatted_data
|
64 |
|
65 |
def fetch_ensembl_transcripts(gene_symbol):
|
66 |
-
|
67 |
-
|
68 |
-
response = requests.get(url, headers=headers)
|
69 |
if response.status_code == 200:
|
70 |
gene_data = response.json()
|
71 |
-
|
|
|
|
|
|
|
|
|
72 |
else:
|
73 |
print(f"Error fetching gene data from Ensembl: {response.text}")
|
74 |
return None
|
75 |
|
76 |
def fetch_ensembl_sequence(transcript_id):
|
77 |
-
|
78 |
-
|
79 |
-
response = requests.get(url, headers=headers)
|
80 |
if response.status_code == 200:
|
81 |
sequence_data = response.json()
|
82 |
-
|
|
|
|
|
|
|
|
|
83 |
else:
|
84 |
-
print(f"Error fetching sequence data from Ensembl
|
85 |
return None
|
86 |
|
87 |
-
def
|
88 |
-
headers = {"Content-Type": "application/json"}
|
89 |
-
url = f"https://rest.ensembl.org/overlap/id/{transcript_id}?feature=exon"
|
90 |
-
response = requests.get(url, headers=headers)
|
91 |
-
if response.status_code == 200:
|
92 |
-
return response.json()
|
93 |
-
else:
|
94 |
-
print(f"Error fetching exon data from Ensembl for transcript {transcript_id}: {response.text}")
|
95 |
-
return None
|
96 |
-
|
97 |
-
def fetch_ensembl_cds(transcript_id):
|
98 |
-
headers = {"Content-Type": "application/json"}
|
99 |
-
url = f"https://rest.ensembl.org/overlap/id/{transcript_id}?feature=cds"
|
100 |
-
response = requests.get(url, headers=headers)
|
101 |
-
if response.status_code == 200:
|
102 |
-
return response.json()
|
103 |
-
else:
|
104 |
-
print(f"Error fetching CDS data from Ensembl for transcript {transcript_id}: {response.text}")
|
105 |
-
return None
|
106 |
-
|
107 |
-
def find_crispr_targets(sequence, chr, start, strand, transcript_id, pam="NGG", target_length=20):
|
108 |
targets = []
|
109 |
len_sequence = len(sequence)
|
110 |
complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
|
|
|
111 |
|
112 |
if strand == -1:
|
113 |
-
sequence = ''.join([complement[base] for base in
|
114 |
for i in range(len_sequence - len(pam) + 1):
|
115 |
if sequence[i + 1:i + 3] == pam[1:]:
|
116 |
if i >= target_length:
|
117 |
target_seq = sequence[i - target_length:i + 3]
|
118 |
tar_start = start + i - target_length
|
119 |
tar_end = start + i + 3
|
120 |
-
|
121 |
-
targets.append([target_seq,
|
122 |
|
123 |
return targets
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
def process_gene(gene_symbol, model_path):
|
127 |
transcripts = fetch_ensembl_transcripts(gene_symbol)
|
128 |
-
|
129 |
-
|
130 |
if transcripts:
|
131 |
-
|
132 |
-
|
133 |
-
transcript_id =
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
|
|
|
|
|
|
|
|
|
39 |
yp = self.model.predict(x)
|
40 |
return yp.ravel()
|
41 |
|
|
|
|
|
|
|
|
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
def fetch_ensembl_transcripts(gene_symbol):
|
45 |
+
url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
|
46 |
+
response = requests.get(url)
|
|
|
47 |
if response.status_code == 200:
|
48 |
gene_data = response.json()
|
49 |
+
if 'Transcript' in gene_data:
|
50 |
+
return gene_data['Transcript']
|
51 |
+
else:
|
52 |
+
print("No transcripts found for gene:", gene_symbol)
|
53 |
+
return None
|
54 |
else:
|
55 |
print(f"Error fetching gene data from Ensembl: {response.text}")
|
56 |
return None
|
57 |
|
58 |
def fetch_ensembl_sequence(transcript_id):
|
59 |
+
url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
|
60 |
+
response = requests.get(url)
|
|
|
61 |
if response.status_code == 200:
|
62 |
sequence_data = response.json()
|
63 |
+
if 'seq' in sequence_data:
|
64 |
+
return sequence_data['seq']
|
65 |
+
else:
|
66 |
+
print("No sequence found for transcript:", transcript_id)
|
67 |
+
return None
|
68 |
else:
|
69 |
+
print(f"Error fetching sequence data from Ensembl: {response.text}")
|
70 |
return None
|
71 |
|
72 |
+
def find_crispr_targets(sequence, chr, start, strand, transcript_id, exon_id, pam="NGG", target_length=20):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
targets = []
|
74 |
len_sequence = len(sequence)
|
75 |
complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
|
76 |
+
dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
|
77 |
|
78 |
if strand == -1:
|
79 |
+
sequence = ''.join([complement[base] for base in sequence])
|
80 |
for i in range(len_sequence - len(pam) + 1):
|
81 |
if sequence[i + 1:i + 3] == pam[1:]:
|
82 |
if i >= target_length:
|
83 |
target_seq = sequence[i - target_length:i + 3]
|
84 |
tar_start = start + i - target_length
|
85 |
tar_end = start + i + 3
|
86 |
+
gRNA = ''.join([dnatorna[base] for base in sequence[i - target_length:i]])
|
87 |
+
targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id])
|
88 |
|
89 |
return targets
|
90 |
|
91 |
+
# Function to predict on-target efficiency and format output
|
92 |
+
def format_prediction_output(targets, model_path):
|
93 |
+
dcModel = DCModelOntar(model_path)
|
94 |
+
formatted_data = []
|
95 |
+
|
96 |
+
for target in targets:
|
97 |
+
# Encode the gRNA sequence
|
98 |
+
encoded_seq = get_seqcode(target[0]).reshape(-1,4,1,23)
|
99 |
+
|
100 |
+
# Predict on-target efficiency using the model
|
101 |
+
prediction = dcModel.ontar_predict(encoded_seq)
|
102 |
+
|
103 |
+
# Format output
|
104 |
+
gRNA = target[1]
|
105 |
+
chr = target[2]
|
106 |
+
start = target[3]
|
107 |
+
end = target[4]
|
108 |
+
strand = target[5]
|
109 |
+
transcript_id = target[6]
|
110 |
+
exon_id = target[7]
|
111 |
+
formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, prediction[0]])
|
112 |
+
|
113 |
+
return formatted_data
|
114 |
|
115 |
def process_gene(gene_symbol, model_path):
|
116 |
transcripts = fetch_ensembl_transcripts(gene_symbol)
|
117 |
+
results = []
|
|
|
118 |
if transcripts:
|
119 |
+
for i in range(len(transcripts)):
|
120 |
+
Exons = transcripts[i]['Exon']
|
121 |
+
transcript_id = transcripts[i]['id']
|
122 |
+
for j in range(len(Exons)):
|
123 |
+
exon_id = Exons[j]['id']
|
124 |
+
gene_sequence = fetch_ensembl_sequence(exon_id)
|
125 |
+
if gene_sequence:
|
126 |
+
start = Exons[j]['start']
|
127 |
+
strand = Exons[j]['strand']
|
128 |
+
chr = Exons[j]['seq_region_name']
|
129 |
+
targets = find_crispr_targets(gene_sequence, chr, start, strand, transcript_id, exon_id)
|
130 |
+
if not targets:
|
131 |
+
print("No gRNA sites found in the gene sequence.")
|
132 |
+
else:
|
133 |
+
# Predict on-target efficiency for each gRNA site
|
134 |
+
formatted_data = format_prediction_output(targets,model_path)
|
135 |
+
results.append(formatted_data)
|
136 |
+
# for data in formatted_data:
|
137 |
+
# print(f"Chr: {data[0]}, Start: {data[1]}, End: {data[2]}, Strand: {data[3]}, gRNA: {data[4]}, pred_Score: {data[5]}")
|
138 |
+
else:
|
139 |
+
print("Failed to retrieve gene sequence.")
|
140 |
+
else:
|
141 |
+
print("Failed to retrieve transcripts.")
|
142 |
+
return results, gene_sequence, Exons
|
143 |
+
|
144 |
+
|
145 |
+
# def create_genbank_features(formatted_data):
|
146 |
+
# features = []
|
147 |
+
# for data in formatted_data:
|
148 |
+
# # Strand conversion to Biopython's convention
|
149 |
+
# strand = 1 if data[3] == '+' else -1
|
150 |
+
# location = FeatureLocation(start=int(data[1]), end=int(data[2]), strand=strand)
|
151 |
+
# feature = SeqFeature(location=location, type="misc_feature", qualifiers={
|
152 |
+
# 'label': data[5], # Use gRNA as the label
|
153 |
+
# 'target': data[4], # Include the target sequence
|
154 |
+
# 'note': f"Prediction: {data[6]}" # Include the prediction score
|
155 |
+
# })
|
156 |
+
# features.append(feature)
|
157 |
+
# return features
|
158 |
+
#
|
159 |
+
# def generate_genbank_file_from_df(df, gene_sequence, gene_symbol, output_path):
|
160 |
+
# features = []
|
161 |
+
# for index, row in df.iterrows():
|
162 |
+
# # Use 'Transcript ID' if it exists, otherwise use a default value like 'Unknown'
|
163 |
+
# transcript_id = row.get("Transcript ID", "Unknown")
|
164 |
+
#
|
165 |
+
# # Make sure to use the correct column names for Start Pos, End Pos, and Strand
|
166 |
+
# location = FeatureLocation(start=int(row["Start Pos"]),
|
167 |
+
# end=int(row["End Pos"]),
|
168 |
+
# strand=1 if row["Strand"] == '+' else -1)
|
169 |
+
# feature = SeqFeature(location=location, type="gene", qualifiers={
|
170 |
+
# 'locus_tag': transcript_id, # Now using the variable that holds the safe value
|
171 |
+
# 'note': f"gRNA: {row['gRNA']}, Prediction: {row['Prediction']}"
|
172 |
+
# })
|
173 |
+
# features.append(feature)
|
174 |
+
#
|
175 |
+
# # The rest of the function remains unchanged
|
176 |
+
# record = SeqRecord(Seq(gene_sequence), id=gene_symbol, name=gene_symbol,
|
177 |
+
# description=f'CRISPR Cas9 predicted targets for {gene_symbol}', features=features)
|
178 |
+
# record.annotations["molecule_type"] = "DNA"
|
179 |
+
# SeqIO.write(record, output_path, "genbank")
|
180 |
+
#
|
181 |
+
#
|
182 |
+
# def create_bed_file_from_df(df, output_path):
|
183 |
+
# with open(output_path, 'w') as bed_file:
|
184 |
+
# for index, row in df.iterrows():
|
185 |
+
# # Adjust field names based on your actual formatted data
|
186 |
+
# chrom = row["Chr"]
|
187 |
+
# start = int(row["Start Pos"])
|
188 |
+
# end = int(row["End Pos"])
|
189 |
+
# strand = '+' if row["Strand"] == '+' else '-' # Ensure strand is correctly interpreted
|
190 |
+
# gRNA = row["gRNA"]
|
191 |
+
# score = str(row["Prediction"]) # Ensure score is converted to string if not already
|
192 |
+
# transcript_id = row["Transcript"] # Extract transcript ID
|
193 |
+
# bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\t{transcript_id}\n") # Include transcript ID in BED output
|
194 |
+
#
|
195 |
+
#
|
196 |
+
# def create_csv_from_df(df, output_path):
|
197 |
+
# df.to_csv(output_path, index=False)
|