Lazyhope commited on
Commit
2ab1f7c
·
1 Parent(s): ea248b5

Add license to extracted information; add unicode decode error and other minor updates

Browse files
Files changed (1) hide show
  1. pipeline.py +32 -29
pipeline.py CHANGED
@@ -1,7 +1,6 @@
1
  import ast
2
  import tarfile
3
  from ast import AsyncFunctionDef, ClassDef, FunctionDef, Module
4
- from io import BytesIO
5
 
6
  import numpy as np
7
  import requests
@@ -42,55 +41,59 @@ def extract_code_and_docs(text: str):
42
  return code_set, docs_set
43
 
44
 
45
- def get_topics(repo_name, headers=None):
46
  api_url = f"https://api.github.com/repos/{repo_name}"
47
- print(f"[+] Getting topics for {repo_name}")
48
  try:
49
  response = requests.get(api_url, headers=headers)
50
  response.raise_for_status()
51
- except requests.exceptions.HTTPError as e:
52
- print(f"[-] Failed to get topics for {repo_name}: {e}")
53
- return []
54
-
55
- metadata = response.json()
56
- topics = metadata.get("topics", [])
57
- if topics:
58
- print(f"[+] Topics found for {repo_name}: {topics}")
59
 
60
- return topics
 
 
 
61
 
62
 
63
  def download_and_extract(repos, headers=None):
64
  extracted_info = {}
65
- for repo_name in repos:
 
 
66
  extracted_info[repo_name] = {
67
  "funcs": set(),
68
  "docs": set(),
69
- "topics": get_topics(repo_name, headers=headers),
 
70
  }
71
 
 
72
  download_url = f"https://api.github.com/repos/{repo_name}/tarball"
73
- print(f"[+] Extracting functions and docstrings from {repo_name}")
74
  try:
75
  response = requests.get(download_url, headers=headers, stream=True)
76
  response.raise_for_status()
77
  except requests.exceptions.HTTPError as e:
78
- print(f"[-] Failed to download {repo_name}: {e}")
79
  continue
80
 
81
- repo_bytes = BytesIO(response.raw.read())
82
- print(f"[+] Extracting {repo_name} info")
83
- with tarfile.open(fileobj=repo_bytes) as tar:
84
- for member in tar.getmembers():
85
- if member.isfile() and member.name.endswith(".py"):
 
 
86
  file_content = tar.extractfile(member).read().decode("utf-8")
87
- try:
88
- code_set, docs_set = extract_code_and_docs(file_content)
89
- except SyntaxError as e:
90
- print(f"[-] SyntaxError in {member.name}: {e}, skipping")
91
- continue
92
  extracted_info[repo_name]["funcs"].update(code_set)
93
  extracted_info[repo_name]["docs"].update(docs_set)
 
 
 
 
 
 
94
 
95
  return extracted_info
96
 
@@ -133,7 +136,7 @@ class RepoEmbeddingPipeline(Pipeline):
133
 
134
  def preprocess(self, inputs):
135
  if isinstance(inputs, str):
136
- inputs = (inputs,)
137
 
138
  if self.st_messager:
139
  self.st_messager.info("[*] Downloading and extracting repos...")
@@ -175,7 +178,7 @@ class RepoEmbeddingPipeline(Pipeline):
175
  with tqdm(total=num_texts) as pbar:
176
  for repo_name, repo_info in extracted_infos.items():
177
  pbar.set_description(f"Processing {repo_name}")
178
- entry = {"topics": repo_info.get("topics")}
179
 
180
  message = f"[*] Generating embeddings for {repo_name}"
181
  tqdm.write(message)
@@ -187,7 +190,7 @@ class RepoEmbeddingPipeline(Pipeline):
187
  code_embeddings.append(
188
  [func, self.encode(func, max_length).squeeze().tolist()]
189
  )
190
-
191
  pbar.update(1)
192
  if st_progress:
193
  st_progress.progress(pbar.n / pbar.total)
 
1
  import ast
2
  import tarfile
3
  from ast import AsyncFunctionDef, ClassDef, FunctionDef, Module
 
4
 
5
  import numpy as np
6
  import requests
 
41
  return code_set, docs_set
42
 
43
 
44
+ def get_metadata(repo_name, headers=None):
45
  api_url = f"https://api.github.com/repos/{repo_name}"
46
+ tqdm.write(f"[+] Getting metadata for {repo_name}")
47
  try:
48
  response = requests.get(api_url, headers=headers)
49
  response.raise_for_status()
 
 
 
 
 
 
 
 
50
 
51
+ return response.json()
52
+ except requests.exceptions.HTTPError as e:
53
+ tqdm.write(f"[-] Failed to retrieve metadata from {repo_name}: {e}")
54
+ return {}
55
 
56
 
57
  def download_and_extract(repos, headers=None):
58
  extracted_info = {}
59
+ for repo_name in tqdm(repos, disable=len(repos) <= 1):
60
+ # Get metadata
61
+ metadata = get_metadata(repo_name, headers=headers)
62
  extracted_info[repo_name] = {
63
  "funcs": set(),
64
  "docs": set(),
65
+ "topics": metadata.get("topics", []),
66
+ "license": metadata.get("license", {}).get("spdx_id", None),
67
  }
68
 
69
+ # Download repo tarball bytes
70
  download_url = f"https://api.github.com/repos/{repo_name}/tarball"
71
+ tqdm.write(f"[+] Downloading {repo_name}")
72
  try:
73
  response = requests.get(download_url, headers=headers, stream=True)
74
  response.raise_for_status()
75
  except requests.exceptions.HTTPError as e:
76
+ tqdm.write(f"[-] Failed to download {repo_name}: {e}")
77
  continue
78
 
79
+ # Extract python files and parse them
80
+ tqdm.write(f"[+] Extracting {repo_name} info")
81
+ with tarfile.open(fileobj=response.raw, mode="r|gz") as tar:
82
+ for member in tar:
83
+ if (member.name.endswith(".py") and member.isfile()) is False:
84
+ continue
85
+ try:
86
  file_content = tar.extractfile(member).read().decode("utf-8")
87
+ code_set, docs_set = extract_code_and_docs(file_content)
88
+
 
 
 
89
  extracted_info[repo_name]["funcs"].update(code_set)
90
  extracted_info[repo_name]["docs"].update(docs_set)
91
+ except UnicodeDecodeError as e:
92
+ tqdm.write(
93
+ f"[-] UnicodeDecodeError in {member.name}, skipping: \n{e}"
94
+ )
95
+ except SyntaxError as e:
96
+ tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
97
 
98
  return extracted_info
99
 
 
136
 
137
  def preprocess(self, inputs):
138
  if isinstance(inputs, str):
139
+ inputs = [inputs]
140
 
141
  if self.st_messager:
142
  self.st_messager.info("[*] Downloading and extracting repos...")
 
178
  with tqdm(total=num_texts) as pbar:
179
  for repo_name, repo_info in extracted_infos.items():
180
  pbar.set_description(f"Processing {repo_name}")
181
+ entry = {"topics": repo_info["topics"], "license": repo_info["license"]}
182
 
183
  message = f"[*] Generating embeddings for {repo_name}"
184
  tqdm.write(message)
 
190
  code_embeddings.append(
191
  [func, self.encode(func, max_length).squeeze().tolist()]
192
  )
193
+
194
  pbar.update(1)
195
  if st_progress:
196
  st_progress.progress(pbar.n / pbar.total)