mhdzumair commited on
Commit
9b14766
·
1 Parent(s): d87305b

Add support for response proxy headers & flag to use request proxy via use_request_proxy

Browse files
mediaflow_proxy/handlers.py CHANGED
@@ -16,6 +16,7 @@ from .utils.http_utils import (
16
  download_file_with_retry,
17
  request_with_retry,
18
  EnhancedStreamingResponse,
 
19
  )
20
  from .utils.m3u8_processor import M3U8Processor
21
  from .utils.mpd_utils import pad_base64
@@ -24,7 +25,12 @@ logger = logging.getLogger(__name__)
24
 
25
 
26
  async def handle_hls_stream_proxy(
27
- request: Request, destination: str, headers: dict, key_url: HttpUrl = None, verify_ssl: bool = True
 
 
 
 
 
28
  ):
29
  """
30
  Handles the HLS stream proxy request, fetching and processing the m3u8 playlist or streaming the content.
@@ -32,9 +38,10 @@ async def handle_hls_stream_proxy(
32
  Args:
33
  request (Request): The incoming HTTP request.
34
  destination (str): The destination URL to fetch the content from.
35
- headers (dict): The headers to include in the request.
36
  key_url (str, optional): The HLS Key URL to replace the original key URL. Defaults to None.
37
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
 
38
 
39
  Returns:
40
  Response: The HTTP response with the processed m3u8 playlist or streamed content.
@@ -43,19 +50,19 @@ async def handle_hls_stream_proxy(
43
  follow_redirects=True,
44
  timeout=httpx.Timeout(30.0),
45
  limits=httpx.Limits(max_keepalive_connections=10, max_connections=20),
46
- proxy=settings.proxy_url,
47
  verify=verify_ssl,
48
  )
49
  streamer = Streamer(client)
50
  try:
51
  if destination.endswith((".m3u", ".m3u8")):
52
- return await fetch_and_process_m3u8(streamer, destination, headers, request, key_url)
53
 
54
- response = await streamer.head(destination, headers)
55
  if "mpegurl" in response.headers.get("content-type", "").lower():
56
- return await fetch_and_process_m3u8(streamer, destination, headers, request, key_url)
57
 
58
- headers.update({"range": headers.get("range", "bytes=0-")})
59
  # clean up the headers to only include the necessary headers and remove acl headers
60
  response_headers = {k: v for k, v in response.headers.multi_items() if k in SUPPORTED_RESPONSE_HEADERS}
61
 
@@ -65,9 +72,10 @@ async def handle_hls_stream_proxy(
65
  else:
66
  transfer_encoding = "chunked"
67
  response_headers["transfer-encoding"] = transfer_encoding
 
68
 
69
  return EnhancedStreamingResponse(
70
- streamer.stream_content(destination, headers),
71
  status_code=response.status_code,
72
  headers=response_headers,
73
  background=BackgroundTask(streamer.close),
@@ -86,31 +94,45 @@ async def handle_hls_stream_proxy(
86
  return Response(status_code=502, content=f"Internal server error: {e}")
87
 
88
 
89
- async def proxy_stream(method: str, video_url: str, headers: dict, verify_ssl: bool = True):
 
 
 
 
 
 
90
  """
91
  Proxies the stream request to the given video URL.
92
 
93
  Args:
94
  method (str): The HTTP method (e.g., GET, HEAD).
95
  video_url (str): The URL of the video to stream.
96
- headers (dict): The headers to include in the request.
97
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
 
98
 
99
  Returns:
100
  Response: The HTTP response with the streamed content.
101
  """
102
- return await handle_stream_request(method, video_url, headers, verify_ssl)
103
 
104
 
105
- async def handle_stream_request(method: str, video_url: str, headers: dict, verify_ssl: bool = True):
 
 
 
 
 
 
106
  """
107
  Handles the stream request, fetching the content from the video URL and streaming it.
108
 
109
  Args:
110
  method (str): The HTTP method (e.g., GET, HEAD).
111
  video_url (str): The URL of the video to stream.
112
- headers (dict): The headers to include in the request.
113
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
 
114
 
115
  Returns:
116
  Response: The HTTP response with the streamed content.
@@ -119,12 +141,12 @@ async def handle_stream_request(method: str, video_url: str, headers: dict, veri
119
  follow_redirects=True,
120
  timeout=httpx.Timeout(30.0),
121
  limits=httpx.Limits(max_keepalive_connections=10, max_connections=20),
122
- proxy=settings.proxy_url,
123
  verify=verify_ssl,
124
  )
125
  streamer = Streamer(client)
126
  try:
127
- response = await streamer.head(video_url, headers)
128
  # clean up the headers to only include the necessary headers and remove acl headers
129
  response_headers = {k: v for k, v in response.headers.multi_items() if k in SUPPORTED_RESPONSE_HEADERS}
130
  if transfer_encoding := response_headers.get("transfer-encoding"):
@@ -133,13 +155,14 @@ async def handle_stream_request(method: str, video_url: str, headers: dict, veri
133
  else:
134
  transfer_encoding = "chunked"
135
  response_headers["transfer-encoding"] = transfer_encoding
 
136
 
137
  if method == "HEAD":
138
  await streamer.close()
139
  return Response(headers=response_headers, status_code=response.status_code)
140
  else:
141
  return EnhancedStreamingResponse(
142
- streamer.stream_content(video_url, headers),
143
  headers=response_headers,
144
  status_code=response.status_code,
145
  background=BackgroundTask(streamer.close),
@@ -159,7 +182,7 @@ async def handle_stream_request(method: str, video_url: str, headers: dict, veri
159
 
160
 
161
  async def fetch_and_process_m3u8(
162
- streamer: Streamer, url: str, headers: dict, request: Request, key_url: HttpUrl = None
163
  ):
164
  """
165
  Fetches and processes the m3u8 playlist, converting it to an HLS playlist.
@@ -167,7 +190,7 @@ async def fetch_and_process_m3u8(
167
  Args:
168
  streamer (Streamer): The HTTP client to use for streaming.
169
  url (str): The URL of the m3u8 playlist.
170
- headers (dict): The headers to include in the request.
171
  request (Request): The incoming HTTP request.
172
  key_url (HttpUrl, optional): The HLS Key URL to replace the original key URL. Defaults to None.
173
 
@@ -175,16 +198,15 @@ async def fetch_and_process_m3u8(
175
  Response: The HTTP response with the processed m3u8 playlist.
176
  """
177
  try:
178
- content = await streamer.get_text(url, headers)
179
  processor = M3U8Processor(request, key_url)
180
  processed_content = await processor.process_m3u8(content, str(streamer.response.url))
 
 
181
  return Response(
182
  content=processed_content,
183
  media_type="application/vnd.apple.mpegurl",
184
- headers={
185
- "Content-Disposition": "inline",
186
- "Accept-Ranges": "none",
187
- },
188
  )
189
  except httpx.HTTPStatusError as e:
190
  logger.error(f"HTTP error while fetching m3u8: {e}")
@@ -229,7 +251,13 @@ async def handle_drm_key_data(key_id, key, drm_info):
229
 
230
 
231
  async def get_manifest(
232
- request: Request, mpd_url: str, headers: dict, key_id: str = None, key: str = None, verify_ssl: bool = True
 
 
 
 
 
 
233
  ):
234
  """
235
  Retrieves and processes the MPD manifest, converting it to an HLS manifest.
@@ -237,17 +265,22 @@ async def get_manifest(
237
  Args:
238
  request (Request): The incoming HTTP request.
239
  mpd_url (str): The URL of the MPD manifest.
240
- headers (dict): The headers to include in the request.
241
  key_id (str, optional): The DRM key ID. Defaults to None.
242
  key (str, optional): The DRM key. Defaults to None.
243
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
 
244
 
245
  Returns:
246
  Response: The HTTP response with the HLS manifest.
247
  """
248
  try:
249
  mpd_dict = await get_cached_mpd(
250
- mpd_url, headers=headers, parse_drm=not key_id and not key, verify_ssl=verify_ssl
 
 
 
 
251
  )
252
  except DownloadError as e:
253
  raise HTTPException(status_code=e.status_code, detail=f"Failed to download MPD: {e.message}")
@@ -255,7 +288,7 @@ async def get_manifest(
255
 
256
  if drm_info and not drm_info.get("isDrmProtected"):
257
  # For non-DRM protected MPD, we still create an HLS manifest
258
- return await process_manifest(request, mpd_dict, None, None)
259
 
260
  key_id, key = await handle_drm_key_data(key_id, key, drm_info)
261
 
@@ -265,17 +298,18 @@ async def get_manifest(
265
  if key and len(key) != 32:
266
  key = base64.urlsafe_b64decode(pad_base64(key)).hex()
267
 
268
- return await process_manifest(request, mpd_dict, key_id, key)
269
 
270
 
271
  async def get_playlist(
272
  request: Request,
273
  mpd_url: str,
274
  profile_id: str,
275
- headers: dict,
276
  key_id: str = None,
277
  key: str = None,
278
  verify_ssl: bool = True,
 
279
  ):
280
  """
281
  Retrieves and processes the MPD manifest, converting it to an HLS playlist for a specific profile.
@@ -284,32 +318,35 @@ async def get_playlist(
284
  request (Request): The incoming HTTP request.
285
  mpd_url (str): The URL of the MPD manifest.
286
  profile_id (str): The profile ID to generate the playlist for.
287
- headers (dict): The headers to include in the request.
288
  key_id (str, optional): The DRM key ID. Defaults to None.
289
  key (str, optional): The DRM key. Defaults to None.
290
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
 
291
 
292
  Returns:
293
  Response: The HTTP response with the HLS playlist.
294
  """
295
  mpd_dict = await get_cached_mpd(
296
  mpd_url,
297
- headers=headers,
298
  parse_drm=not key_id and not key,
299
  parse_segment_profile_id=profile_id,
300
  verify_ssl=verify_ssl,
 
301
  )
302
- return await process_playlist(request, mpd_dict, profile_id)
303
 
304
 
305
  async def get_segment(
306
  init_url: str,
307
  segment_url: str,
308
  mimetype: str,
309
- headers: dict,
310
  key_id: str = None,
311
  key: str = None,
312
  verify_ssl: bool = True,
 
313
  ):
314
  """
315
  Retrieves and processes a media segment, decrypting it if necessary.
@@ -318,28 +355,36 @@ async def get_segment(
318
  init_url (str): The URL of the initialization segment.
319
  segment_url (str): The URL of the media segment.
320
  mimetype (str): The MIME type of the segment.
321
- headers (dict): The headers to include in the request.
322
  key_id (str, optional): The DRM key ID. Defaults to None.
323
  key (str, optional): The DRM key. Defaults to None.
324
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
 
325
 
326
  Returns:
327
  Response: The HTTP response with the processed segment.
328
  """
329
  try:
330
- init_content = await get_cached_init_segment(init_url, headers, verify_ssl)
331
- segment_content = await download_file_with_retry(segment_url, headers, verify_ssl=verify_ssl)
 
 
332
  except DownloadError as e:
333
  raise HTTPException(status_code=e.status_code, detail=f"Failed to download segment: {e.message}")
334
- return await process_segment(init_content, segment_content, mimetype, key_id, key)
335
 
336
 
337
- async def get_public_ip():
338
  """
339
  Retrieves the public IP address of the MediaFlow proxy.
340
 
 
 
 
341
  Returns:
342
  Response: The HTTP response with the public IP address.
343
  """
344
- ip_address_data = await request_with_retry("GET", "https://api.ipify.org?format=json", {})
 
 
345
  return ip_address_data.json()
 
16
  download_file_with_retry,
17
  request_with_retry,
18
  EnhancedStreamingResponse,
19
+ ProxyRequestHeaders,
20
  )
21
  from .utils.m3u8_processor import M3U8Processor
22
  from .utils.mpd_utils import pad_base64
 
25
 
26
 
27
  async def handle_hls_stream_proxy(
28
+ request: Request,
29
+ destination: str,
30
+ proxy_headers: ProxyRequestHeaders,
31
+ key_url: HttpUrl = None,
32
+ verify_ssl: bool = True,
33
+ use_request_proxy: bool = True,
34
  ):
35
  """
36
  Handles the HLS stream proxy request, fetching and processing the m3u8 playlist or streaming the content.
 
38
  Args:
39
  request (Request): The incoming HTTP request.
40
  destination (str): The destination URL to fetch the content from.
41
+ proxy_headers (ProxyRequestHeaders): The headers to include in the request.
42
  key_url (str, optional): The HLS Key URL to replace the original key URL. Defaults to None.
43
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
44
+ use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True.
45
 
46
  Returns:
47
  Response: The HTTP response with the processed m3u8 playlist or streamed content.
 
50
  follow_redirects=True,
51
  timeout=httpx.Timeout(30.0),
52
  limits=httpx.Limits(max_keepalive_connections=10, max_connections=20),
53
+ proxy=settings.proxy_url if use_request_proxy else None,
54
  verify=verify_ssl,
55
  )
56
  streamer = Streamer(client)
57
  try:
58
  if destination.endswith((".m3u", ".m3u8")):
59
+ return await fetch_and_process_m3u8(streamer, destination, proxy_headers, request, key_url)
60
 
61
+ response = await streamer.head(destination, proxy_headers.request)
62
  if "mpegurl" in response.headers.get("content-type", "").lower():
63
+ return await fetch_and_process_m3u8(streamer, destination, proxy_headers, request, key_url)
64
 
65
+ proxy_headers.request.update({"range": proxy_headers.request.get("range", "bytes=0-")})
66
  # clean up the headers to only include the necessary headers and remove acl headers
67
  response_headers = {k: v for k, v in response.headers.multi_items() if k in SUPPORTED_RESPONSE_HEADERS}
68
 
 
72
  else:
73
  transfer_encoding = "chunked"
74
  response_headers["transfer-encoding"] = transfer_encoding
75
+ response_headers.update(proxy_headers.response)
76
 
77
  return EnhancedStreamingResponse(
78
+ streamer.stream_content(destination, proxy_headers.request),
79
  status_code=response.status_code,
80
  headers=response_headers,
81
  background=BackgroundTask(streamer.close),
 
94
  return Response(status_code=502, content=f"Internal server error: {e}")
95
 
96
 
97
+ async def proxy_stream(
98
+ method: str,
99
+ video_url: str,
100
+ proxy_headers: ProxyRequestHeaders,
101
+ verify_ssl: bool = True,
102
+ use_request_proxy: bool = True,
103
+ ):
104
  """
105
  Proxies the stream request to the given video URL.
106
 
107
  Args:
108
  method (str): The HTTP method (e.g., GET, HEAD).
109
  video_url (str): The URL of the video to stream.
110
+ proxy_headers (ProxyRequestHeaders): The headers to include in the request.
111
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
112
+ use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True.
113
 
114
  Returns:
115
  Response: The HTTP response with the streamed content.
116
  """
117
+ return await handle_stream_request(method, video_url, proxy_headers, verify_ssl, use_request_proxy)
118
 
119
 
120
+ async def handle_stream_request(
121
+ method: str,
122
+ video_url: str,
123
+ proxy_headers: ProxyRequestHeaders,
124
+ verify_ssl: bool = True,
125
+ use_request_proxy: bool = True,
126
+ ):
127
  """
128
  Handles the stream request, fetching the content from the video URL and streaming it.
129
 
130
  Args:
131
  method (str): The HTTP method (e.g., GET, HEAD).
132
  video_url (str): The URL of the video to stream.
133
+ proxy_headers (ProxyRequestHeaders): The headers to include in the request.
134
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
135
+ use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True.
136
 
137
  Returns:
138
  Response: The HTTP response with the streamed content.
 
141
  follow_redirects=True,
142
  timeout=httpx.Timeout(30.0),
143
  limits=httpx.Limits(max_keepalive_connections=10, max_connections=20),
144
+ proxy=settings.proxy_url if use_request_proxy else None,
145
  verify=verify_ssl,
146
  )
147
  streamer = Streamer(client)
148
  try:
149
+ response = await streamer.head(video_url, proxy_headers.request)
150
  # clean up the headers to only include the necessary headers and remove acl headers
151
  response_headers = {k: v for k, v in response.headers.multi_items() if k in SUPPORTED_RESPONSE_HEADERS}
152
  if transfer_encoding := response_headers.get("transfer-encoding"):
 
155
  else:
156
  transfer_encoding = "chunked"
157
  response_headers["transfer-encoding"] = transfer_encoding
158
+ response_headers.update(proxy_headers.response)
159
 
160
  if method == "HEAD":
161
  await streamer.close()
162
  return Response(headers=response_headers, status_code=response.status_code)
163
  else:
164
  return EnhancedStreamingResponse(
165
+ streamer.stream_content(video_url, proxy_headers.request),
166
  headers=response_headers,
167
  status_code=response.status_code,
168
  background=BackgroundTask(streamer.close),
 
182
 
183
 
184
  async def fetch_and_process_m3u8(
185
+ streamer: Streamer, url: str, proxy_headers: ProxyRequestHeaders, request: Request, key_url: HttpUrl = None
186
  ):
187
  """
188
  Fetches and processes the m3u8 playlist, converting it to an HLS playlist.
 
190
  Args:
191
  streamer (Streamer): The HTTP client to use for streaming.
192
  url (str): The URL of the m3u8 playlist.
193
+ proxy_headers (ProxyRequestHeaders): The headers to include in the request.
194
  request (Request): The incoming HTTP request.
195
  key_url (HttpUrl, optional): The HLS Key URL to replace the original key URL. Defaults to None.
196
 
 
198
  Response: The HTTP response with the processed m3u8 playlist.
199
  """
200
  try:
201
+ content = await streamer.get_text(url, proxy_headers.request)
202
  processor = M3U8Processor(request, key_url)
203
  processed_content = await processor.process_m3u8(content, str(streamer.response.url))
204
+ response_headers = {"Content-Disposition": "inline", "Accept-Ranges": "none"}
205
+ response_headers.update(proxy_headers.response)
206
  return Response(
207
  content=processed_content,
208
  media_type="application/vnd.apple.mpegurl",
209
+ headers=response_headers,
 
 
 
210
  )
211
  except httpx.HTTPStatusError as e:
212
  logger.error(f"HTTP error while fetching m3u8: {e}")
 
251
 
252
 
253
  async def get_manifest(
254
+ request: Request,
255
+ mpd_url: str,
256
+ proxy_headers: ProxyRequestHeaders,
257
+ key_id: str = None,
258
+ key: str = None,
259
+ verify_ssl: bool = True,
260
+ use_request_proxy: bool = True,
261
  ):
262
  """
263
  Retrieves and processes the MPD manifest, converting it to an HLS manifest.
 
265
  Args:
266
  request (Request): The incoming HTTP request.
267
  mpd_url (str): The URL of the MPD manifest.
268
+ proxy_headers (ProxyRequestHeaders): The headers to include in the request.
269
  key_id (str, optional): The DRM key ID. Defaults to None.
270
  key (str, optional): The DRM key. Defaults to None.
271
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
272
+ use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True.
273
 
274
  Returns:
275
  Response: The HTTP response with the HLS manifest.
276
  """
277
  try:
278
  mpd_dict = await get_cached_mpd(
279
+ mpd_url,
280
+ headers=proxy_headers.request,
281
+ parse_drm=not key_id and not key,
282
+ verify_ssl=verify_ssl,
283
+ use_request_proxy=use_request_proxy,
284
  )
285
  except DownloadError as e:
286
  raise HTTPException(status_code=e.status_code, detail=f"Failed to download MPD: {e.message}")
 
288
 
289
  if drm_info and not drm_info.get("isDrmProtected"):
290
  # For non-DRM protected MPD, we still create an HLS manifest
291
+ return await process_manifest(request, mpd_dict, proxy_headers, None, None)
292
 
293
  key_id, key = await handle_drm_key_data(key_id, key, drm_info)
294
 
 
298
  if key and len(key) != 32:
299
  key = base64.urlsafe_b64decode(pad_base64(key)).hex()
300
 
301
+ return await process_manifest(request, mpd_dict, proxy_headers, key_id, key)
302
 
303
 
304
  async def get_playlist(
305
  request: Request,
306
  mpd_url: str,
307
  profile_id: str,
308
+ proxy_headers: ProxyRequestHeaders,
309
  key_id: str = None,
310
  key: str = None,
311
  verify_ssl: bool = True,
312
+ use_request_proxy: bool = True,
313
  ):
314
  """
315
  Retrieves and processes the MPD manifest, converting it to an HLS playlist for a specific profile.
 
318
  request (Request): The incoming HTTP request.
319
  mpd_url (str): The URL of the MPD manifest.
320
  profile_id (str): The profile ID to generate the playlist for.
321
+ proxy_headers (ProxyRequestHeaders): The headers to include in the request.
322
  key_id (str, optional): The DRM key ID. Defaults to None.
323
  key (str, optional): The DRM key. Defaults to None.
324
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
325
+ use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True.
326
 
327
  Returns:
328
  Response: The HTTP response with the HLS playlist.
329
  """
330
  mpd_dict = await get_cached_mpd(
331
  mpd_url,
332
+ headers=proxy_headers.request,
333
  parse_drm=not key_id and not key,
334
  parse_segment_profile_id=profile_id,
335
  verify_ssl=verify_ssl,
336
+ use_request_proxy=use_request_proxy,
337
  )
338
+ return await process_playlist(request, mpd_dict, profile_id, proxy_headers)
339
 
340
 
341
  async def get_segment(
342
  init_url: str,
343
  segment_url: str,
344
  mimetype: str,
345
+ proxy_headers: ProxyRequestHeaders,
346
  key_id: str = None,
347
  key: str = None,
348
  verify_ssl: bool = True,
349
+ use_request_proxy: bool = True,
350
  ):
351
  """
352
  Retrieves and processes a media segment, decrypting it if necessary.
 
355
  init_url (str): The URL of the initialization segment.
356
  segment_url (str): The URL of the media segment.
357
  mimetype (str): The MIME type of the segment.
358
+ proxy_headers (ProxyRequestHeaders): The headers to include in the request.
359
  key_id (str, optional): The DRM key ID. Defaults to None.
360
  key (str, optional): The DRM key. Defaults to None.
361
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
362
+ use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True.
363
 
364
  Returns:
365
  Response: The HTTP response with the processed segment.
366
  """
367
  try:
368
+ init_content = await get_cached_init_segment(init_url, proxy_headers.request, verify_ssl, use_request_proxy)
369
+ segment_content = await download_file_with_retry(
370
+ segment_url, proxy_headers.request, verify_ssl=verify_ssl, use_request_proxy=use_request_proxy
371
+ )
372
  except DownloadError as e:
373
  raise HTTPException(status_code=e.status_code, detail=f"Failed to download segment: {e.message}")
374
+ return await process_segment(init_content, segment_content, mimetype, proxy_headers, key_id, key)
375
 
376
 
377
+ async def get_public_ip(use_request_proxy: bool = True):
378
  """
379
  Retrieves the public IP address of the MediaFlow proxy.
380
 
381
+ Args:
382
+ use_request_proxy (bool, optional): Whether to use the proxy configuration from the user's MediaFlow config. Defaults to True.
383
+
384
  Returns:
385
  Response: The HTTP response with the public IP address.
386
  """
387
+ ip_address_data = await request_with_retry(
388
+ "GET", "https://api.ipify.org?format=json", {}, use_request_proxy=use_request_proxy
389
+ )
390
  return ip_address_data.json()
mediaflow_proxy/main.py CHANGED
@@ -3,6 +3,7 @@ from importlib import resources
3
 
4
  from fastapi import FastAPI, Depends, Security, HTTPException
5
  from fastapi.security import APIKeyQuery, APIKeyHeader
 
6
  from starlette.responses import RedirectResponse
7
  from starlette.staticfiles import StaticFiles
8
 
@@ -13,6 +14,13 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(level
13
  app = FastAPI()
14
  api_password_query = APIKeyQuery(name="api_password", auto_error=False)
15
  api_password_header = APIKeyHeader(name="api_password", auto_error=False)
 
 
 
 
 
 
 
16
 
17
 
18
  async def verify_api_key(api_key: str = Security(api_password_query), api_key_alt: str = Security(api_password_header)):
 
3
 
4
  from fastapi import FastAPI, Depends, Security, HTTPException
5
  from fastapi.security import APIKeyQuery, APIKeyHeader
6
+ from starlette.middleware.cors import CORSMiddleware
7
  from starlette.responses import RedirectResponse
8
  from starlette.staticfiles import StaticFiles
9
 
 
14
  app = FastAPI()
15
  api_password_query = APIKeyQuery(name="api_password", auto_error=False)
16
  api_password_header = APIKeyHeader(name="api_password", auto_error=False)
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=["*"],
20
+ allow_credentials=True,
21
+ allow_methods=["*"],
22
+ allow_headers=["*"],
23
+ )
24
 
25
 
26
  async def verify_api_key(api_key: str = Security(api_password_query), api_key_alt: str = Security(api_password_header)):
mediaflow_proxy/mpd_processor.py CHANGED
@@ -7,18 +7,21 @@ from fastapi import Request, Response, HTTPException
7
 
8
  from mediaflow_proxy.configs import settings
9
  from mediaflow_proxy.drm.decrypter import decrypt_segment
10
- from mediaflow_proxy.utils.http_utils import encode_mediaflow_proxy_url, get_original_scheme
11
 
12
  logger = logging.getLogger(__name__)
13
 
14
 
15
- async def process_manifest(request: Request, mpd_dict: dict, key_id: str = None, key: str = None) -> Response:
 
 
16
  """
17
  Processes the MPD manifest and converts it to an HLS manifest.
18
 
19
  Args:
20
  request (Request): The incoming HTTP request.
21
  mpd_dict (dict): The MPD manifest data.
 
22
  key_id (str, optional): The DRM key ID. Defaults to None.
23
  key (str, optional): The DRM key. Defaults to None.
24
 
@@ -26,10 +29,12 @@ async def process_manifest(request: Request, mpd_dict: dict, key_id: str = None,
26
  Response: The HLS manifest as an HTTP response.
27
  """
28
  hls_content = build_hls(mpd_dict, request, key_id, key)
29
- return Response(content=hls_content, media_type="application/vnd.apple.mpegurl")
30
 
31
 
32
- async def process_playlist(request: Request, mpd_dict: dict, profile_id: str) -> Response:
 
 
33
  """
34
  Processes the MPD manifest and converts it to an HLS playlist for a specific profile.
35
 
@@ -37,6 +42,7 @@ async def process_playlist(request: Request, mpd_dict: dict, profile_id: str) ->
37
  request (Request): The incoming HTTP request.
38
  mpd_dict (dict): The MPD manifest data.
39
  profile_id (str): The profile ID to generate the playlist for.
 
40
 
41
  Returns:
42
  Response: The HLS playlist as an HTTP response.
@@ -49,13 +55,14 @@ async def process_playlist(request: Request, mpd_dict: dict, profile_id: str) ->
49
  raise HTTPException(status_code=404, detail="Profile not found")
50
 
51
  hls_content = build_hls_playlist(mpd_dict, matching_profiles, request)
52
- return Response(content=hls_content, media_type="application/vnd.apple.mpegurl")
53
 
54
 
55
  async def process_segment(
56
  init_content: bytes,
57
  segment_content: bytes,
58
  mimetype: str,
 
59
  key_id: str = None,
60
  key: str = None,
61
  ) -> Response:
@@ -66,6 +73,7 @@ async def process_segment(
66
  init_content (bytes): The initialization segment content.
67
  segment_content (bytes): The media segment content.
68
  mimetype (str): The MIME type of the segment.
 
69
  key_id (str, optional): The DRM key ID. Defaults to None.
70
  key (str, optional): The DRM key. Defaults to None.
71
 
@@ -81,7 +89,7 @@ async def process_segment(
81
  # For non-DRM protected content, we just concatenate init and segment content
82
  decrypted_content = init_content + segment_content
83
 
84
- return Response(content=decrypted_content, media_type=mimetype)
85
 
86
 
87
  def build_hls(mpd_dict: dict, request: Request, key_id: str = None, key: str = None) -> str:
 
7
 
8
  from mediaflow_proxy.configs import settings
9
  from mediaflow_proxy.drm.decrypter import decrypt_segment
10
+ from mediaflow_proxy.utils.http_utils import encode_mediaflow_proxy_url, get_original_scheme, ProxyRequestHeaders
11
 
12
  logger = logging.getLogger(__name__)
13
 
14
 
15
+ async def process_manifest(
16
+ request: Request, mpd_dict: dict, proxy_headers: ProxyRequestHeaders, key_id: str = None, key: str = None
17
+ ) -> Response:
18
  """
19
  Processes the MPD manifest and converts it to an HLS manifest.
20
 
21
  Args:
22
  request (Request): The incoming HTTP request.
23
  mpd_dict (dict): The MPD manifest data.
24
+ proxy_headers (ProxyRequestHeaders): The headers to include in the request.
25
  key_id (str, optional): The DRM key ID. Defaults to None.
26
  key (str, optional): The DRM key. Defaults to None.
27
 
 
29
  Response: The HLS manifest as an HTTP response.
30
  """
31
  hls_content = build_hls(mpd_dict, request, key_id, key)
32
+ return Response(content=hls_content, media_type="application/vnd.apple.mpegurl", headers=proxy_headers.response)
33
 
34
 
35
+ async def process_playlist(
36
+ request: Request, mpd_dict: dict, profile_id: str, proxy_headers: ProxyRequestHeaders
37
+ ) -> Response:
38
  """
39
  Processes the MPD manifest and converts it to an HLS playlist for a specific profile.
40
 
 
42
  request (Request): The incoming HTTP request.
43
  mpd_dict (dict): The MPD manifest data.
44
  profile_id (str): The profile ID to generate the playlist for.
45
+ proxy_headers (ProxyRequestHeaders): The headers to include in the request.
46
 
47
  Returns:
48
  Response: The HLS playlist as an HTTP response.
 
55
  raise HTTPException(status_code=404, detail="Profile not found")
56
 
57
  hls_content = build_hls_playlist(mpd_dict, matching_profiles, request)
58
+ return Response(content=hls_content, media_type="application/vnd.apple.mpegurl", headers=proxy_headers.response)
59
 
60
 
61
  async def process_segment(
62
  init_content: bytes,
63
  segment_content: bytes,
64
  mimetype: str,
65
+ proxy_headers: ProxyRequestHeaders,
66
  key_id: str = None,
67
  key: str = None,
68
  ) -> Response:
 
73
  init_content (bytes): The initialization segment content.
74
  segment_content (bytes): The media segment content.
75
  mimetype (str): The MIME type of the segment.
76
+ proxy_headers (ProxyRequestHeaders): The headers to include in the request.
77
  key_id (str, optional): The DRM key ID. Defaults to None.
78
  key (str, optional): The DRM key. Defaults to None.
79
 
 
89
  # For non-DRM protected content, we just concatenate init and segment content
90
  decrypted_content = init_content + segment_content
91
 
92
+ return Response(content=decrypted_content, media_type=mimetype, headers=proxy_headers.response)
93
 
94
 
95
  def build_hls(mpd_dict: dict, request: Request, key_id: str = None, key: str = None) -> str:
mediaflow_proxy/routes.py CHANGED
@@ -2,7 +2,7 @@ from fastapi import Request, Depends, APIRouter
2
  from pydantic import HttpUrl
3
 
4
  from .handlers import handle_hls_stream_proxy, proxy_stream, get_manifest, get_playlist, get_segment, get_public_ip
5
- from .utils.http_utils import get_proxy_headers
6
 
7
  proxy_router = APIRouter()
8
 
@@ -12,9 +12,10 @@ proxy_router = APIRouter()
12
  async def hls_stream_proxy(
13
  request: Request,
14
  d: HttpUrl,
15
- headers: dict = Depends(get_proxy_headers),
16
  key_url: HttpUrl | None = None,
17
  verify_ssl: bool = False,
 
18
  ):
19
  """
20
  Proxify HLS stream requests, fetching and processing the m3u8 playlist or streaming the content.
@@ -23,20 +24,25 @@ async def hls_stream_proxy(
23
  request (Request): The incoming HTTP request.
24
  d (HttpUrl): The destination URL to fetch the content from.
25
  key_url (HttpUrl, optional): The HLS Key URL to replace the original key URL. Defaults to None. (Useful for bypassing some sneaky protection)
26
- headers (dict): The headers to include in the request.
27
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
 
28
 
29
  Returns:
30
  Response: The HTTP response with the processed m3u8 playlist or streamed content.
31
  """
32
  destination = str(d)
33
- return await handle_hls_stream_proxy(request, destination, headers, key_url, verify_ssl)
34
 
35
 
36
  @proxy_router.head("/stream")
37
  @proxy_router.get("/stream")
38
  async def proxy_stream_endpoint(
39
- request: Request, d: HttpUrl, headers: dict = Depends(get_proxy_headers), verify_ssl: bool = False
 
 
 
 
40
  ):
41
  """
42
  Proxies stream requests to the given video URL.
@@ -44,24 +50,26 @@ async def proxy_stream_endpoint(
44
  Args:
45
  request (Request): The incoming HTTP request.
46
  d (HttpUrl): The URL of the video to stream.
47
- headers (dict): The headers to include in the request.
48
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
 
49
 
50
  Returns:
51
  Response: The HTTP response with the streamed content.
52
  """
53
- headers.update({"range": headers.get("range", "bytes=0-")})
54
- return await proxy_stream(request.method, str(d), headers, verify_ssl)
55
 
56
 
57
  @proxy_router.get("/mpd/manifest")
58
  async def manifest_endpoint(
59
  request: Request,
60
  d: HttpUrl,
61
- headers: dict = Depends(get_proxy_headers),
62
  key_id: str = None,
63
  key: str = None,
64
  verify_ssl: bool = False,
 
65
  ):
66
  """
67
  Retrieves and processes the MPD manifest, converting it to an HLS manifest.
@@ -69,15 +77,16 @@ async def manifest_endpoint(
69
  Args:
70
  request (Request): The incoming HTTP request.
71
  d (HttpUrl): The URL of the MPD manifest.
72
- headers (dict): The headers to include in the request.
73
  key_id (str, optional): The DRM key ID. Defaults to None.
74
  key (str, optional): The DRM key. Defaults to None.
75
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
 
76
 
77
  Returns:
78
  Response: The HTTP response with the HLS manifest.
79
  """
80
- return await get_manifest(request, str(d), headers, key_id, key, verify_ssl)
81
 
82
 
83
  @proxy_router.get("/mpd/playlist")
@@ -85,10 +94,11 @@ async def playlist_endpoint(
85
  request: Request,
86
  d: HttpUrl,
87
  profile_id: str,
88
- headers: dict = Depends(get_proxy_headers),
89
  key_id: str = None,
90
  key: str = None,
91
  verify_ssl: bool = False,
 
92
  ):
93
  """
94
  Retrieves and processes the MPD manifest, converting it to an HLS playlist for a specific profile.
@@ -97,15 +107,16 @@ async def playlist_endpoint(
97
  request (Request): The incoming HTTP request.
98
  d (HttpUrl): The URL of the MPD manifest.
99
  profile_id (str): The profile ID to generate the playlist for.
100
- headers (dict): The headers to include in the request.
101
  key_id (str, optional): The DRM key ID. Defaults to None.
102
  key (str, optional): The DRM key. Defaults to None.
103
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
 
104
 
105
  Returns:
106
  Response: The HTTP response with the HLS playlist.
107
  """
108
- return await get_playlist(request, str(d), profile_id, headers, key_id, key, verify_ssl)
109
 
110
 
111
  @proxy_router.get("/mpd/segment")
@@ -113,10 +124,11 @@ async def segment_endpoint(
113
  init_url: HttpUrl,
114
  segment_url: HttpUrl,
115
  mime_type: str,
116
- headers: dict = Depends(get_proxy_headers),
117
  key_id: str = None,
118
  key: str = None,
119
  verify_ssl: bool = False,
 
120
  ):
121
  """
122
  Retrieves and processes a media segment, decrypting it if necessary.
@@ -125,23 +137,28 @@ async def segment_endpoint(
125
  init_url (HttpUrl): The URL of the initialization segment.
126
  segment_url (HttpUrl): The URL of the media segment.
127
  mime_type (str): The MIME type of the segment.
128
- headers (dict): The headers to include in the request.
129
  key_id (str, optional): The DRM key ID. Defaults to None.
130
  key (str, optional): The DRM key. Defaults to None.
131
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
 
132
 
133
  Returns:
134
  Response: The HTTP response with the processed segment.
135
  """
136
- return await get_segment(str(init_url), str(segment_url), mime_type, headers, key_id, key, verify_ssl)
 
 
137
 
138
 
139
  @proxy_router.get("/ip")
140
- async def get_mediaflow_proxy_public_ip():
 
 
141
  """
142
  Retrieves the public IP address of the MediaFlow proxy server.
143
 
144
  Returns:
145
  Response: The HTTP response with the public IP address in the form of a JSON object. {"ip": "xxx.xxx.xxx.xxx"}
146
  """
147
- return await get_public_ip()
 
2
  from pydantic import HttpUrl
3
 
4
  from .handlers import handle_hls_stream_proxy, proxy_stream, get_manifest, get_playlist, get_segment, get_public_ip
5
+ from .utils.http_utils import get_proxy_headers, ProxyRequestHeaders
6
 
7
  proxy_router = APIRouter()
8
 
 
12
  async def hls_stream_proxy(
13
  request: Request,
14
  d: HttpUrl,
15
+ proxy_headers: ProxyRequestHeaders = Depends(get_proxy_headers),
16
  key_url: HttpUrl | None = None,
17
  verify_ssl: bool = False,
18
+ use_request_proxy: bool = True,
19
  ):
20
  """
21
  Proxify HLS stream requests, fetching and processing the m3u8 playlist or streaming the content.
 
24
  request (Request): The incoming HTTP request.
25
  d (HttpUrl): The destination URL to fetch the content from.
26
  key_url (HttpUrl, optional): The HLS Key URL to replace the original key URL. Defaults to None. (Useful for bypassing some sneaky protection)
27
+ proxy_headers (ProxyRequestHeaders): The headers to include in the request.
28
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
29
+ use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True.
30
 
31
  Returns:
32
  Response: The HTTP response with the processed m3u8 playlist or streamed content.
33
  """
34
  destination = str(d)
35
+ return await handle_hls_stream_proxy(request, destination, proxy_headers, key_url, verify_ssl, use_request_proxy)
36
 
37
 
38
  @proxy_router.head("/stream")
39
  @proxy_router.get("/stream")
40
  async def proxy_stream_endpoint(
41
+ request: Request,
42
+ d: HttpUrl,
43
+ proxy_headers: ProxyRequestHeaders = Depends(get_proxy_headers),
44
+ verify_ssl: bool = False,
45
+ use_request_proxy: bool = True,
46
  ):
47
  """
48
  Proxies stream requests to the given video URL.
 
50
  Args:
51
  request (Request): The incoming HTTP request.
52
  d (HttpUrl): The URL of the video to stream.
53
+ proxy_headers (ProxyRequestHeaders): The headers to include in the request.
54
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
55
+ use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True.
56
 
57
  Returns:
58
  Response: The HTTP response with the streamed content.
59
  """
60
+ proxy_headers.request.update({"range": proxy_headers.request.get("range", "bytes=0-")})
61
+ return await proxy_stream(request.method, str(d), proxy_headers, verify_ssl, use_request_proxy)
62
 
63
 
64
  @proxy_router.get("/mpd/manifest")
65
  async def manifest_endpoint(
66
  request: Request,
67
  d: HttpUrl,
68
+ proxy_headers: ProxyRequestHeaders = Depends(get_proxy_headers),
69
  key_id: str = None,
70
  key: str = None,
71
  verify_ssl: bool = False,
72
+ use_request_proxy: bool = True,
73
  ):
74
  """
75
  Retrieves and processes the MPD manifest, converting it to an HLS manifest.
 
77
  Args:
78
  request (Request): The incoming HTTP request.
79
  d (HttpUrl): The URL of the MPD manifest.
80
+ proxy_headers (ProxyRequestHeaders): The headers to include in the request.
81
  key_id (str, optional): The DRM key ID. Defaults to None.
82
  key (str, optional): The DRM key. Defaults to None.
83
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
84
+ use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True.
85
 
86
  Returns:
87
  Response: The HTTP response with the HLS manifest.
88
  """
89
+ return await get_manifest(request, str(d), proxy_headers, key_id, key, verify_ssl, use_request_proxy)
90
 
91
 
92
  @proxy_router.get("/mpd/playlist")
 
94
  request: Request,
95
  d: HttpUrl,
96
  profile_id: str,
97
+ proxy_headers: ProxyRequestHeaders = Depends(get_proxy_headers),
98
  key_id: str = None,
99
  key: str = None,
100
  verify_ssl: bool = False,
101
+ use_request_proxy: bool = True,
102
  ):
103
  """
104
  Retrieves and processes the MPD manifest, converting it to an HLS playlist for a specific profile.
 
107
  request (Request): The incoming HTTP request.
108
  d (HttpUrl): The URL of the MPD manifest.
109
  profile_id (str): The profile ID to generate the playlist for.
110
+ proxy_headers (ProxyRequestHeaders): The headers to include in the request.
111
  key_id (str, optional): The DRM key ID. Defaults to None.
112
  key (str, optional): The DRM key. Defaults to None.
113
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
114
+ use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True.
115
 
116
  Returns:
117
  Response: The HTTP response with the HLS playlist.
118
  """
119
+ return await get_playlist(request, str(d), profile_id, proxy_headers, key_id, key, verify_ssl, use_request_proxy)
120
 
121
 
122
  @proxy_router.get("/mpd/segment")
 
124
  init_url: HttpUrl,
125
  segment_url: HttpUrl,
126
  mime_type: str,
127
+ proxy_headers: ProxyRequestHeaders = Depends(get_proxy_headers),
128
  key_id: str = None,
129
  key: str = None,
130
  verify_ssl: bool = False,
131
+ use_request_proxy: bool = True,
132
  ):
133
  """
134
  Retrieves and processes a media segment, decrypting it if necessary.
 
137
  init_url (HttpUrl): The URL of the initialization segment.
138
  segment_url (HttpUrl): The URL of the media segment.
139
  mime_type (str): The MIME type of the segment.
140
+ proxy_headers (ProxyRequestHeaders): The headers to include in the request.
141
  key_id (str, optional): The DRM key ID. Defaults to None.
142
  key (str, optional): The DRM key. Defaults to None.
143
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
144
+ use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True.
145
 
146
  Returns:
147
  Response: The HTTP response with the processed segment.
148
  """
149
+ return await get_segment(
150
+ str(init_url), str(segment_url), mime_type, proxy_headers, key_id, key, verify_ssl, use_request_proxy
151
+ )
152
 
153
 
154
  @proxy_router.get("/ip")
155
+ async def get_mediaflow_proxy_public_ip(
156
+ use_request_proxy: bool = True,
157
+ ):
158
  """
159
  Retrieves the public IP address of the MediaFlow proxy server.
160
 
161
  Returns:
162
  Response: The HTTP response with the public IP address in the form of a JSON object. {"ip": "xxx.xxx.xxx.xxx"}
163
  """
164
+ return await get_public_ip(use_request_proxy)
mediaflow_proxy/utils/cache_utils.py CHANGED
@@ -14,7 +14,12 @@ init_segment_cache = TTLCache(maxsize=100, ttl=3600) # 1 hour default TTL
14
 
15
 
16
  async def get_cached_mpd(
17
- mpd_url: str, headers: dict, parse_drm: bool, parse_segment_profile_id: str | None = None, verify_ssl: bool = True
 
 
 
 
 
18
  ) -> dict:
19
  """
20
  Retrieves and caches the MPD manifest, parsing it if not already cached.
@@ -25,6 +30,7 @@ async def get_cached_mpd(
25
  parse_drm (bool): Whether to parse DRM information.
26
  parse_segment_profile_id (str, optional): The profile ID to parse segments for. Defaults to None.
27
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
 
28
 
29
  Returns:
30
  dict: The parsed MPD manifest data.
@@ -34,7 +40,9 @@ async def get_cached_mpd(
34
  logger.info(f"Using cached MPD for {mpd_url}")
35
  return parse_mpd_dict(mpd_cache[mpd_url]["mpd"], mpd_url, parse_drm, parse_segment_profile_id)
36
 
37
- mpd_dict = parse_mpd(await download_file_with_retry(mpd_url, headers, verify_ssl=verify_ssl))
 
 
38
  parsed_mpd_dict = parse_mpd_dict(mpd_dict, mpd_url, parse_drm, parse_segment_profile_id)
39
  current_time = datetime.datetime.now(datetime.UTC)
40
  expiration_time = current_time + datetime.timedelta(seconds=parsed_mpd_dict.get("minimumUpdatePeriod", 300))
@@ -42,7 +50,9 @@ async def get_cached_mpd(
42
  return parsed_mpd_dict
43
 
44
 
45
- async def get_cached_init_segment(init_url: str, headers: dict, verify_ssl: bool = True) -> bytes:
 
 
46
  """
47
  Retrieves and caches the initialization segment.
48
 
@@ -50,11 +60,14 @@ async def get_cached_init_segment(init_url: str, headers: dict, verify_ssl: bool
50
  init_url (str): The URL of the initialization segment.
51
  headers (dict): The headers to include in the request.
52
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
 
53
 
54
  Returns:
55
  bytes: The initialization segment content.
56
  """
57
  if init_url not in init_segment_cache:
58
- init_content = await download_file_with_retry(init_url, headers, verify_ssl=verify_ssl)
 
 
59
  init_segment_cache[init_url] = init_content
60
  return init_segment_cache[init_url]
 
14
 
15
 
16
  async def get_cached_mpd(
17
+ mpd_url: str,
18
+ headers: dict,
19
+ parse_drm: bool,
20
+ parse_segment_profile_id: str | None = None,
21
+ verify_ssl: bool = True,
22
+ use_request_proxy: bool = True,
23
  ) -> dict:
24
  """
25
  Retrieves and caches the MPD manifest, parsing it if not already cached.
 
30
  parse_drm (bool): Whether to parse DRM information.
31
  parse_segment_profile_id (str, optional): The profile ID to parse segments for. Defaults to None.
32
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
33
+ use_request_proxy (bool, optional): Whether to use the proxy configuration from the user's MediaFlow config. Defaults to True.
34
 
35
  Returns:
36
  dict: The parsed MPD manifest data.
 
40
  logger.info(f"Using cached MPD for {mpd_url}")
41
  return parse_mpd_dict(mpd_cache[mpd_url]["mpd"], mpd_url, parse_drm, parse_segment_profile_id)
42
 
43
+ mpd_dict = parse_mpd(
44
+ await download_file_with_retry(mpd_url, headers, verify_ssl=verify_ssl, use_request_proxy=use_request_proxy)
45
+ )
46
  parsed_mpd_dict = parse_mpd_dict(mpd_dict, mpd_url, parse_drm, parse_segment_profile_id)
47
  current_time = datetime.datetime.now(datetime.UTC)
48
  expiration_time = current_time + datetime.timedelta(seconds=parsed_mpd_dict.get("minimumUpdatePeriod", 300))
 
50
  return parsed_mpd_dict
51
 
52
 
53
+ async def get_cached_init_segment(
54
+ init_url: str, headers: dict, verify_ssl: bool = True, use_request_proxy: bool = True
55
+ ) -> bytes:
56
  """
57
  Retrieves and caches the initialization segment.
58
 
 
60
  init_url (str): The URL of the initialization segment.
61
  headers (dict): The headers to include in the request.
62
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
63
+ use_request_proxy (bool, optional): Whether to use the proxy configuration from the user's MediaFlow config. Defaults to True.
64
 
65
  Returns:
66
  bytes: The initialization segment content.
67
  """
68
  if init_url not in init_segment_cache:
69
+ init_content = await download_file_with_retry(
70
+ init_url, headers, verify_ssl=verify_ssl, use_request_proxy=use_request_proxy
71
+ )
72
  init_segment_cache[init_url] = init_content
73
  return init_segment_cache[init_url]
mediaflow_proxy/utils/http_utils.py CHANGED
@@ -1,5 +1,6 @@
1
  import logging
2
  import typing
 
3
  from functools import partial
4
  from urllib import parse
5
 
@@ -137,7 +138,13 @@ class Streamer:
137
  await self.client.aclose()
138
 
139
 
140
- async def download_file_with_retry(url: str, headers: dict, timeout: float = 10.0, verify_ssl: bool = True):
 
 
 
 
 
 
141
  """
142
  Downloads a file with retry logic.
143
 
@@ -146,6 +153,7 @@ async def download_file_with_retry(url: str, headers: dict, timeout: float = 10.
146
  headers (dict): The headers to include in the request.
147
  timeout (float, optional): The request timeout. Defaults to 10.0.
148
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
 
149
 
150
  Returns:
151
  bytes: The downloaded file content.
@@ -154,7 +162,10 @@ async def download_file_with_retry(url: str, headers: dict, timeout: float = 10.
154
  DownloadError: If the download fails after retries.
155
  """
156
  async with httpx.AsyncClient(
157
- follow_redirects=True, timeout=timeout, proxy=settings.proxy_url, verify=verify_ssl
 
 
 
158
  ) as client:
159
  try:
160
  response = await fetch_with_retry(client, "GET", url, headers)
@@ -166,7 +177,9 @@ async def download_file_with_retry(url: str, headers: dict, timeout: float = 10.
166
  raise DownloadError(502, f"Failed to download file: {e.last_attempt.result()}")
167
 
168
 
169
- async def request_with_retry(method: str, url: str, headers: dict, timeout: float = 10.0, **kwargs):
 
 
170
  """
171
  Sends an HTTP request with retry logic.
172
 
@@ -175,6 +188,7 @@ async def request_with_retry(method: str, url: str, headers: dict, timeout: floa
175
  url (str): The URL to send the request to.
176
  headers (dict): The headers to include in the request.
177
  timeout (float, optional): The request timeout. Defaults to 10.0.
 
178
  **kwargs: Additional arguments to pass to the request.
179
 
180
  Returns:
@@ -183,7 +197,9 @@ async def request_with_retry(method: str, url: str, headers: dict, timeout: floa
183
  Raises:
184
  DownloadError: If the request fails after retries.
185
  """
186
- async with httpx.AsyncClient(follow_redirects=True, timeout=timeout, proxy=settings.proxy_url) as client:
 
 
187
  try:
188
  response = await fetch_with_retry(client, method, url, headers, **kwargs)
189
  return response
@@ -198,6 +214,7 @@ def encode_mediaflow_proxy_url(
198
  destination_url: str | None = None,
199
  query_params: dict | None = None,
200
  request_headers: dict | None = None,
 
201
  ) -> str:
202
  """
203
  Encodes a MediaFlow proxy URL with query parameters and headers.
@@ -208,6 +225,7 @@ def encode_mediaflow_proxy_url(
208
  destination_url (str, optional): The destination URL to include in the query parameters. Defaults to None.
209
  query_params (dict, optional): Additional query parameters to include. Defaults to None.
210
  request_headers (dict, optional): Headers to include as query parameters. Defaults to None.
 
211
 
212
  Returns:
213
  str: The encoded MediaFlow proxy URL.
@@ -221,6 +239,10 @@ def encode_mediaflow_proxy_url(
221
  query_params.update(
222
  {key if key.startswith("h_") else f"h_{key}": value for key, value in request_headers.items()}
223
  )
 
 
 
 
224
  # Encode the query parameters
225
  encoded_params = parse.urlencode(query_params, quote_via=parse.quote)
226
 
@@ -263,7 +285,13 @@ def get_original_scheme(request: Request) -> str:
263
  return "http"
264
 
265
 
266
- def get_proxy_headers(request: Request) -> dict:
 
 
 
 
 
 
267
  """
268
  Extracts proxy headers from the request query parameters.
269
 
@@ -271,11 +299,12 @@ def get_proxy_headers(request: Request) -> dict:
271
  request (Request): The incoming HTTP request.
272
 
273
  Returns:
274
- dict: A dictionary of proxy headers.
275
  """
276
  request_headers = {k: v for k, v in request.headers.items() if k in SUPPORTED_REQUEST_HEADERS}
277
  request_headers.update({k[2:].lower(): v for k, v in request.query_params.items() if k.startswith("h_")})
278
- return request_headers
 
279
 
280
 
281
  class EnhancedStreamingResponse(Response):
 
1
  import logging
2
  import typing
3
+ from dataclasses import dataclass
4
  from functools import partial
5
  from urllib import parse
6
 
 
138
  await self.client.aclose()
139
 
140
 
141
+ async def download_file_with_retry(
142
+ url: str,
143
+ headers: dict,
144
+ timeout: float = 10.0,
145
+ verify_ssl: bool = True,
146
+ use_request_proxy: bool = True,
147
+ ):
148
  """
149
  Downloads a file with retry logic.
150
 
 
153
  headers (dict): The headers to include in the request.
154
  timeout (float, optional): The request timeout. Defaults to 10.0.
155
  verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
156
+ use_request_proxy (bool, optional): Whether to use the proxy configuration from the user's MediaFlow config. Defaults to True.
157
 
158
  Returns:
159
  bytes: The downloaded file content.
 
162
  DownloadError: If the download fails after retries.
163
  """
164
  async with httpx.AsyncClient(
165
+ follow_redirects=True,
166
+ timeout=timeout,
167
+ proxy=settings.proxy_url if use_request_proxy else None,
168
+ verify=verify_ssl,
169
  ) as client:
170
  try:
171
  response = await fetch_with_retry(client, "GET", url, headers)
 
177
  raise DownloadError(502, f"Failed to download file: {e.last_attempt.result()}")
178
 
179
 
180
+ async def request_with_retry(
181
+ method: str, url: str, headers: dict, timeout: float = 10.0, use_request_proxy: bool = True, **kwargs
182
+ ):
183
  """
184
  Sends an HTTP request with retry logic.
185
 
 
188
  url (str): The URL to send the request to.
189
  headers (dict): The headers to include in the request.
190
  timeout (float, optional): The request timeout. Defaults to 10.0.
191
+ use_request_proxy (bool, optional): Whether to use the proxy configuration from the user's MediaFlow config. Defaults to True.
192
  **kwargs: Additional arguments to pass to the request.
193
 
194
  Returns:
 
197
  Raises:
198
  DownloadError: If the request fails after retries.
199
  """
200
+ async with httpx.AsyncClient(
201
+ follow_redirects=True, timeout=timeout, proxy=settings.proxy_url if use_request_proxy else None
202
+ ) as client:
203
  try:
204
  response = await fetch_with_retry(client, method, url, headers, **kwargs)
205
  return response
 
214
  destination_url: str | None = None,
215
  query_params: dict | None = None,
216
  request_headers: dict | None = None,
217
+ response_headers: dict | None = None,
218
  ) -> str:
219
  """
220
  Encodes a MediaFlow proxy URL with query parameters and headers.
 
225
  destination_url (str, optional): The destination URL to include in the query parameters. Defaults to None.
226
  query_params (dict, optional): Additional query parameters to include. Defaults to None.
227
  request_headers (dict, optional): Headers to include as query parameters. Defaults to None.
228
+ response_headers (dict, optional): Headers to include as query parameters. Defaults to None.
229
 
230
  Returns:
231
  str: The encoded MediaFlow proxy URL.
 
239
  query_params.update(
240
  {key if key.startswith("h_") else f"h_{key}": value for key, value in request_headers.items()}
241
  )
242
+ if response_headers:
243
+ query_params.update(
244
+ {key if key.startswith("r_") else f"r_{key}": value for key, value in response_headers.items()}
245
+ )
246
  # Encode the query parameters
247
  encoded_params = parse.urlencode(query_params, quote_via=parse.quote)
248
 
 
285
  return "http"
286
 
287
 
288
+ @dataclass
289
+ class ProxyRequestHeaders:
290
+ request: dict
291
+ response: dict
292
+
293
+
294
+ def get_proxy_headers(request: Request) -> ProxyRequestHeaders:
295
  """
296
  Extracts proxy headers from the request query parameters.
297
 
 
299
  request (Request): The incoming HTTP request.
300
 
301
  Returns:
302
+ ProxyRequest: A named tuple containing the request headers and response headers.
303
  """
304
  request_headers = {k: v for k, v in request.headers.items() if k in SUPPORTED_REQUEST_HEADERS}
305
  request_headers.update({k[2:].lower(): v for k, v in request.query_params.items() if k.startswith("h_")})
306
+ response_headers = {k[2:].lower(): v for k, v in request.query_params.items() if k.startswith("r_")}
307
+ return ProxyRequestHeaders(request_headers, response_headers)
308
 
309
 
310
  class EnhancedStreamingResponse(Response):