KonradSzafer commited on
Commit
bfdf8df
·
1 Parent(s): c6dce39

channel id added to config

Browse files
data/hugging_face_videos_dataset.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+
5
+ import torch
6
+ import scrapetube
7
+ from pytube import YouTube
8
+ from faster_whisper import WhisperModel
9
+ from tqdm import tqdm
10
+
11
+
12
+ # Available models:
13
+ # tiny.en, tiny, base.en, base, small.en, small, medium.en, medium
14
+ # large-v1, large-v2, large-v3, large
15
+ MODEL_NAME = "large-v3"
16
+ AUDIO_SAVE_PATH = 'datasets/huggingface_audio/'
17
+ TRANSCRIPTS_SAVE_PATH = 'datasets/huggingface_audio_transcribed/'
18
+
19
+ if torch.cuda.is_available():
20
+ # requires: conda install -c anaconda cudnn
21
+ print(f"Using {MODEL_NAME} on GPU and float16")
22
+ model = WhisperModel(MODEL_NAME, device="cuda", compute_type="float16", device_index=[5])
23
+ else:
24
+ print(f"Using {MODEL_NAME} on CPU and int8")
25
+ model = WhisperModel(MODEL_NAME, device="cpu", compute_type="int8")
26
+
27
+
28
+ def replace_unallowed_chars(filename: str) -> str:
29
+ unallowed_chars = [' ', '/', '\\', ':', '*', '?', '"', '<', '>', '|']
30
+ for char in unallowed_chars:
31
+ filename = filename.replace(char, '_')
32
+ return filename
33
+
34
+
35
+ def get_videos_urls(channel_url: str) -> list[str]:
36
+ videos = scrapetube.get_channel(channel_url=channel_url)
37
+ return [
38
+ f"https://www.youtube.com/watch?v={video['videoId']}"
39
+ for video in videos
40
+ ]
41
+
42
+
43
+ def get_audio_from_video(video_url: str, save_path: str) -> tuple[str, int, str, int]:
44
+ yt = YouTube(video_url)
45
+ if check_if_file_exists(yt.title, save_path):
46
+ print(f'Audio already exists for: {yt.title}')
47
+ return (video_url, yt.title.replace(" ", "_")+".mp3", yt.title, yt.length)
48
+ else:
49
+ print(f'Downloading audio for: {yt.title}')
50
+ video = yt.streams.filter(only_audio=True).first()
51
+ out_file = video.download(output_path=save_path)
52
+ base, ext = os.path.splitext(out_file)
53
+ new_filename = save_path + replace_unallowed_chars(yt.title) + '.mp3'
54
+ print(f'Saving audio to: {new_filename}')
55
+ os.rename(out_file, new_filename)
56
+ print(f'Video length: {yt.length} seconds')
57
+ return (video_url, new_filename, yt.title, yt.length)
58
+
59
+
60
+ def check_if_file_exists(filename: str, save_path: str) -> bool:
61
+ title = filename.replace(' ', '_')
62
+ return any([
63
+ title in filename_
64
+ for filename_ in os.listdir(save_path)
65
+ ])
66
+
67
+
68
+ def transcript_from_audio(audio_path: str) -> dict[str, list[str]]:
69
+ segments, info = model.transcribe(audio_path, beam_size=10)
70
+ return list(segments)
71
+
72
+
73
+ def process_text(text: str) -> str:
74
+ text = text.strip()
75
+ text = re.sub('\s+', ' ', text)
76
+ return text
77
+
78
+
79
+ def merge_transcripts_segements(
80
+ segments: list[str],
81
+ file_title: str,
82
+ num_segments_to_merge: int = 5,
83
+ ) -> dict[str, list[str]]:
84
+
85
+ merged_segments = {}
86
+ temp_text = ''
87
+ start_time = None
88
+ end_time = None
89
+
90
+ for i, segment in enumerate(segments):
91
+ if i % num_segments_to_merge == 0:
92
+ start_time = segment.start
93
+ end_time = segment.end
94
+ temp_text += segment.text + ' '
95
+
96
+ if (i + 1) % num_segments_to_merge == 0 or i == len(segments) - 1:
97
+ key = f'{start_time:.2f}_{end_time:.2f}'
98
+ merged_segments[key] = process_text(temp_text)
99
+ temp_text = ''
100
+
101
+ return merged_segments
102
+
103
+
104
+ def main():
105
+ if not os.path.exists(AUDIO_SAVE_PATH):
106
+ os.makedirs(AUDIO_SAVE_PATH)
107
+ if not os.path.exists(TRANSCRIPTS_SAVE_PATH):
108
+ os.makedirs(TRANSCRIPTS_SAVE_PATH)
109
+
110
+ print('Getting videos urls')
111
+ videos_urls = get_videos_urls('https://www.youtube.com/@HuggingFace')
112
+
113
+ print('Downloading audio files')
114
+ audio_data = []
115
+ for video_url in tqdm(videos_urls):
116
+ try:
117
+ audio_data.append(
118
+ get_audio_from_video(video_url, save_path=AUDIO_SAVE_PATH)
119
+ )
120
+ except Exception as e:
121
+ print(f'Error downloading video: {video_url}')
122
+ print(e)
123
+
124
+ print('Transcribing audio files')
125
+ for video_url, filename, title, audio_length in tqdm(audio_data):
126
+ if check_if_file_exists(title, TRANSCRIPTS_SAVE_PATH):
127
+ print(f'Transcript already exists for: {title}')
128
+ continue
129
+ try:
130
+ print(f'Transcribing: {title}')
131
+ start_time = time.time()
132
+ segments = transcript_from_audio(filename)
133
+ print(f'Transcription took: {time.time() - start_time:.1f} seconds')
134
+ merged_segments = merge_transcripts_segements(
135
+ segments,
136
+ title,
137
+ num_segments_to_merge=10
138
+ )
139
+ # save transcripts to separate files
140
+ title = replace_unallowed_chars(title)
141
+ for segment, text in merged_segments.items():
142
+ with open(f'{TRANSCRIPTS_SAVE_PATH}{title}_{segment}.txt', 'w') as f:
143
+ video_url_with_time = f'{video_url}&t={float(segment.split("_")[0]):.0f}'
144
+ f.write(f'source: {video_url_with_time}\n\n' + text)
145
+ except Exception as e:
146
+ print(f'Error transcribing: {title}')
147
+ print(e)
148
+
149
+
150
+ if __name__ == '__main__':
151
+ main()
data/indexer.ipynb CHANGED
@@ -7,16 +7,18 @@
7
  "outputs": [],
8
  "source": [
9
  "import math\n",
10
- "import numpy as np\n",
11
  "from pathlib import Path\n",
 
 
 
12
  "from tqdm import tqdm\n",
13
- "from typing import List, Any\n",
14
  "from langchain.chains import RetrievalQA\n",
15
  "from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings\n",
16
  "from langchain.document_loaders import TextLoader\n",
17
  "from langchain.indexes import VectorstoreIndexCreator\n",
18
  "from langchain.text_splitter import CharacterTextSplitter\n",
19
- "from langchain.vectorstores import FAISS"
 
20
  ]
21
  },
22
  {
@@ -25,16 +27,32 @@
25
  "metadata": {},
26
  "outputs": [],
27
  "source": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  "docs = []\n",
29
  "metadata = []\n",
30
- "for p in Path(\"./datasets/huggingface_docs/\").iterdir():\n",
31
- " if not p.is_dir():\n",
32
- " with open(p) as f:\n",
33
- " # the first line is the source of the text\n",
34
- " source = f.readline().strip().replace('source: ', '')\n",
35
- " docs.append(f.read())\n",
36
- " metadata.append({\"source\": source})\n",
37
- " # break\n",
38
  "\n",
39
  "print(f'number of documents: {len(docs)}')"
40
  ]
@@ -88,7 +106,7 @@
88
  " if self.max_length < 0:\n",
89
  " print('max_length is not specified, using model default max_seq_length')\n",
90
  "\n",
91
- " def embed_documents(self, texts: List[str]) -> List[List[float]]:\n",
92
  " all_embeddings = []\n",
93
  " for text in tqdm(texts, desc=\"Embedding documents\"):\n",
94
  " if len(text) > self.max_length and self.max_length > -1:\n",
@@ -109,7 +127,8 @@
109
  " return all_embeddings\n",
110
  "\n",
111
  "\n",
112
- "# max length fed to the model, if longer than max then chunks + averaging\n",
 
113
  "max_length = 512\n",
114
  "embedding_model = AverageInstructEmbeddings( \n",
115
  " model_name=model_name,\n",
@@ -143,8 +162,8 @@
143
  "metadata": {},
144
  "outputs": [],
145
  "source": [
146
- "index_name = f'index-{model_name}-{chunk_size}-m{max_length}-notebooks'\n",
147
- "index_name"
148
  ]
149
  },
150
  {
@@ -189,8 +208,6 @@
189
  "metadata": {},
190
  "outputs": [],
191
  "source": [
192
- "from huggingface_hub import HfApi\n",
193
- "\n",
194
  "api = HfApi()\n",
195
  "api.create_repo(\n",
196
  " repo_id=f'KonradSzafer/{index_name}',\n",
@@ -204,13 +221,6 @@
204
  " repo_type='dataset',\n",
205
  ")"
206
  ]
207
- },
208
- {
209
- "cell_type": "code",
210
- "execution_count": null,
211
- "metadata": {},
212
- "outputs": [],
213
- "source": []
214
  }
215
  ],
216
  "metadata": {
@@ -229,7 +239,7 @@
229
  "name": "python",
230
  "nbconvert_exporter": "python",
231
  "pygments_lexer": "ipython3",
232
- "version": "3.10.12"
233
  },
234
  "orig_nbformat": 4
235
  },
 
7
  "outputs": [],
8
  "source": [
9
  "import math\n",
 
10
  "from pathlib import Path\n",
11
+ "from typing import Any\n",
12
+ "\n",
13
+ "import numpy as np\n",
14
  "from tqdm import tqdm\n",
 
15
  "from langchain.chains import RetrievalQA\n",
16
  "from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings\n",
17
  "from langchain.document_loaders import TextLoader\n",
18
  "from langchain.indexes import VectorstoreIndexCreator\n",
19
  "from langchain.text_splitter import CharacterTextSplitter\n",
20
+ "from langchain.vectorstores import FAISS\n",
21
+ "from huggingface_hub import HfApi"
22
  ]
23
  },
24
  {
 
27
  "metadata": {},
28
  "outputs": [],
29
  "source": [
30
+ "def collect_docs(directory: str, docs: list[str], metadata: list[Any]):\n",
31
+ " for p in Path(directory).iterdir():\n",
32
+ " if not p.is_dir():\n",
33
+ " with open(p) as f:\n",
34
+ " # the first line is the source of the text\n",
35
+ " source = f.readline().strip().replace('source: ', '')\n",
36
+ " docs.append(f.read())\n",
37
+ " metadata.append({\"source\": source})\n",
38
+ " # break"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "DIRECTORIES = [\n",
48
+ " \"./datasets/huggingface_docs/\",\n",
49
+ " \"./datasets/huggingface_audio_transcribed/\"\n",
50
+ "]\n",
51
+ "\n",
52
  "docs = []\n",
53
  "metadata = []\n",
54
+ "for directory in DIRECTORIES:\n",
55
+ " collect_docs(directory, docs, metadata)\n",
 
 
 
 
 
 
56
  "\n",
57
  "print(f'number of documents: {len(docs)}')"
58
  ]
 
106
  " if self.max_length < 0:\n",
107
  " print('max_length is not specified, using model default max_seq_length')\n",
108
  "\n",
109
+ " def embed_documents(self, texts: list[str]) -> list[list[float]]:\n",
110
  " all_embeddings = []\n",
111
  " for text in tqdm(texts, desc=\"Embedding documents\"):\n",
112
  " if len(text) > self.max_length and self.max_length > -1:\n",
 
127
  " return all_embeddings\n",
128
  "\n",
129
  "\n",
130
+ "# max length fed to the mode\n",
131
+ "# if longer than CHUNK_SIZE in previous steps: then N chunks + averaging of embeddings\n",
132
  "max_length = 512\n",
133
  "embedding_model = AverageInstructEmbeddings( \n",
134
  " model_name=model_name,\n",
 
162
  "metadata": {},
163
  "outputs": [],
164
  "source": [
165
+ "index_name = f'index-{model_name}-{chunk_size}-m{max_length}-11_Jan_2024'\n",
166
+ "index_name = index_name.replace('/', '_')"
167
  ]
168
  },
169
  {
 
208
  "metadata": {},
209
  "outputs": [],
210
  "source": [
 
 
211
  "api = HfApi()\n",
212
  "api.create_repo(\n",
213
  " repo_id=f'KonradSzafer/{index_name}',\n",
 
221
  " repo_type='dataset',\n",
222
  ")"
223
  ]
 
 
 
 
 
 
 
224
  }
225
  ],
226
  "metadata": {
 
239
  "name": "python",
240
  "nbconvert_exporter": "python",
241
  "pygments_lexer": "ipython3",
242
+ "version": "3.11.5"
243
  },
244
  "orig_nbformat": 4
245
  },
data/requirements-audio.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ scrapetube>=2.5.1
2
+ pytube>=15.0.0
3
+ faster-whisper>=0.10.0
4
+ torch>=2.0.1
5
+ tqdm>=4.66.1
discord_bot/__main__.py CHANGED
@@ -16,6 +16,7 @@ qa_engine = QAEngine(
16
  )
17
  client = DiscordClient(
18
  qa_engine=qa_engine,
 
19
  num_last_messages=config.num_last_messages,
20
  use_names_in_context=config.use_names_in_context,
21
  enable_commands=config.enable_commands,
 
16
  )
17
  client = DiscordClient(
18
  qa_engine=qa_engine,
19
+ channel_ids=config.discotd_channel_ids,
20
  num_last_messages=config.num_last_messages,
21
  use_names_in_context=config.use_names_in_context,
22
  enable_commands=config.enable_commands,
discord_bot/client/client.py CHANGED
@@ -31,6 +31,7 @@ class DiscordClient(discord.Client):
31
  def __init__(
32
  self,
33
  qa_engine: QAEngine,
 
34
  num_last_messages: int = 5,
35
  use_names_in_context: bool = True,
36
  enable_commands: bool = True,
@@ -45,6 +46,7 @@ class DiscordClient(discord.Client):
45
  'The number of last messages in context should be at least 1'
46
 
47
  self.qa_engine: QAEngine = qa_engine
 
48
  self.num_last_messages: int = num_last_messages
49
  self.use_names_in_context: bool = use_names_in_context
50
  self.enable_commands: bool = enable_commands
@@ -98,38 +100,34 @@ class DiscordClient(discord.Client):
98
 
99
 
100
  async def on_message(self, message):
101
- if message.channel.id == 1162396480825462935:
102
- """
103
- Callback function to be called when a message is received.
104
-
105
- Args:
106
- message (discord.Message): The received message.
107
- """
108
- if message.author == self.user:
109
- return
110
-
111
- """
112
- if self.enable_commands and message.content.startswith('!'):
113
- if message.content == '!clear':
114
- await message.channel.purge()
115
- return
116
- """
117
-
118
-
119
- last_messages = await self.get_last_messages(message)
120
- context = '\n'.join(last_messages)
121
-
122
- logger.info('Received message: {0.content}'.format(message))
123
- response = self.qa_engine.get_response(
124
- question=message.content,
125
- messages_context=context
 
 
 
126
  )
127
- logger.info('Sending response: {0}'.format(response))
128
- try:
129
- await self.send_message(
130
- message,
131
- response.get_answer(),
132
- response.get_sources_as_text()
133
- )
134
- except Exception as e:
135
- logger.error('Failed to send response: {0}'.format(e))
 
31
  def __init__(
32
  self,
33
  qa_engine: QAEngine,
34
+ channel_ids: list[int] = [],
35
  num_last_messages: int = 5,
36
  use_names_in_context: bool = True,
37
  enable_commands: bool = True,
 
46
  'The number of last messages in context should be at least 1'
47
 
48
  self.qa_engine: QAEngine = qa_engine
49
+ self.channel_ids: list[int] = channel_ids
50
  self.num_last_messages: int = num_last_messages
51
  self.use_names_in_context: bool = use_names_in_context
52
  self.enable_commands: bool = enable_commands
 
100
 
101
 
102
  async def on_message(self, message):
103
+
104
+ if self.channel_ids and message.channel.id not in self.channel_ids:
105
+ return
106
+
107
+ if message.author == self.user:
108
+ return
109
+
110
+ """
111
+ if self.enable_commands and message.content.startswith('!'):
112
+ if message.content == '!clear':
113
+ await message.channel.purge()
114
+ return
115
+ """
116
+
117
+ last_messages = await self.get_last_messages(message)
118
+ context = '\n'.join(last_messages)
119
+
120
+ logger.info('Received message: {0.content}'.format(message))
121
+ response = self.qa_engine.get_response(
122
+ question=message.content,
123
+ messages_context=context
124
+ )
125
+ logger.info('Sending response: {0}'.format(response))
126
+ try:
127
+ await self.send_message(
128
+ message,
129
+ response.get_answer(),
130
+ response.get_sources_as_text()
131
  )
132
+ except Exception as e:
133
+ logger.error('Failed to send response: {0}'.format(e))
 
 
 
 
 
 
 
qa_engine/config.py CHANGED
@@ -36,6 +36,7 @@ class Config:
36
 
37
  # Discord bot config - optional
38
  discord_token: str = get_env('DISCORD_TOKEN', '-', warn=False)
 
39
  num_last_messages: int = int(get_env('NUM_LAST_MESSAGES', 2, warn=False))
40
  use_names_in_context: bool = eval(get_env('USE_NAMES_IN_CONTEXT', 'False', warn=False))
41
  enable_commands: bool = eval(get_env('ENABLE_COMMANDS', 'True', warn=False))
 
36
 
37
  # Discord bot config - optional
38
  discord_token: str = get_env('DISCORD_TOKEN', '-', warn=False)
39
+ discotd_channel_ids: list[int] = eval(get_env('DISCORD_CHANNEL_IDS', [], warn=False))
40
  num_last_messages: int = int(get_env('NUM_LAST_MESSAGES', 2, warn=False))
41
  use_names_in_context: bool = eval(get_env('USE_NAMES_IN_CONTEXT', 'False', warn=False))
42
  enable_commands: bool = eval(get_env('ENABLE_COMMANDS', 'True', warn=False))