File size: 3,699 Bytes
d1ed09d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import pickle
import requests
import umap
from numba.typed import List
import torch
from sentence_transformers import SentenceTransformer
import time
from pathlib import Path

def check_resources(files_dict, basemap_path, mapper_params_path):
    """
    Check if all required resources are present.
    
    Args:
        files_dict (dict): Dictionary mapping filenames to their download URLs
        basemap_path (str): Path to the basemap pickle file
        mapper_params_path (str): Path to the UMAP mapper parameters pickle file
        
    Returns:
        bool: True if all resources are present, False otherwise
    """
    all_files_present = True
    
    # Check downloaded files
    for filename in files_dict.keys():
        if not Path(filename).exists():
            print(f"Missing file: {filename}")
            all_files_present = False
    
    # Check basemap
    if not Path(basemap_path).exists():
        print(f"Missing basemap file: {basemap_path}")
        all_files_present = False
        
    # Check mapper params
    if not Path(mapper_params_path).exists():
        print(f"Missing mapper params file: {mapper_params_path}")
        all_files_present = False
    
    return all_files_present

def download_required_files(files_dict):
    """
    Download required files from URLs only if they don't exist.
    
    Args:
        files_dict (dict): Dictionary mapping filenames to their download URLs
    """
    print(f"Checking required files: {time.strftime('%Y-%m-%d %H:%M:%S')}")
    
    files_to_download = {
        filename: url 
        for filename, url in files_dict.items() 
        if not Path(filename).exists()
    }
    
    if not files_to_download:
        print("All files already present, skipping downloads")
        return
        
    print(f"Downloading missing files: {list(files_to_download.keys())}")
    for filename, url in files_to_download.items():
        print(f"Downloading {filename}...")
        response = requests.get(url)
        with open(filename, "wb") as f:
            f.write(response.content)

def setup_basemap_data(basemap_path):
    """
    Load and setup the base map data.
    
    Args:
        basemap_path (str): Path to the basemap pickle file
    """
    print(f"Getting basemap data: {time.strftime('%Y-%m-%d %H:%M:%S')}")
    basedata_df = pickle.load(open(basemap_path, 'rb'))
    return basedata_df

def setup_mapper(mapper_params_path):
    """
    Setup and configure the UMAP mapper.
    
    Args:
        mapper_params_path (str): Path to the UMAP mapper parameters pickle file
    """
    print(f"Getting Mapper: {time.strftime('%Y-%m-%d %H:%M:%S')}")
    
    params_new = pickle.load(open(mapper_params_path, 'rb'))
    print("setting up mapper...")
    mapper = umap.UMAP()
    
    umap_params = {k: v for k, v in params_new.get('umap_params', {}).items() 
                  if k != 'target_backend'}
    mapper.set_params(**umap_params)
    
    for attr, value in params_new.get('umap_attributes', {}).items():
        if attr != 'embedding_':
            setattr(mapper, attr, value)
    
    if 'embedding_' in params_new.get('umap_attributes', {}):
        mapper.embedding_ = List(params_new['umap_attributes']['embedding_'])
    
    return mapper

def setup_embedding_model(model_name):
    """
    Setup the SentenceTransformer model.
    
    Args:
        model_name (str): Name or path of the SentenceTransformer model
    """
    print(f"Setting up language model: {time.strftime('%Y-%m-%d %H:%M:%S')}")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    model = SentenceTransformer(model_name)
    return model