Can I run this model effectively in google colab?

#35
by k0rruptt - opened

Hi, I have been playing around with this model in google colab since it's easy for me to set up the environment and I am familiar with it. However, I am getting poor results with respect to the retrieval. I am not sure if it's because my configuration is bad, my dataset is too large, something else, or a collection of the above.

I saw this discussion thread which led me to believe that maybe I cannot use the model on google's environment offerings: https://huggingface.co/dunzhang/stella_en_1.5B_v5/discussions/23

Below is the code I use to embed my data to pinecone:


!pip install transformers
!pip install torch
!pip install scikit-learn
!pip install beautifulsoup4
!pip install pinecone-client
!pip install bs4 lxml

import os
import re
import sqlite3
from bs4 import BeautifulSoup
import torch
from transformers import AutoModel, AutoTokenizer
import pinecone
from pinecone import Pinecone, ServerlessSpec, PineconeApiException
from getpass import getpass
import time

# Constants
DB_NAME = 'data.sqlite'
TABLE_NAME = 'tablename'
HTML_COLUMN_NAME = 'html'
HASH_COLUMN_NAME = 'hash'
INDEX_NAME = 'test'
CHUNK_SIZE = 64
CHUNK_OVERLAP = 0
MODEL_DIR = 'dunzhang/stella_en_1.5B_v5'

# Initialize model and tokenizer globally
print("Initializing model and tokenizer...")
MODEL = AutoModel.from_pretrained(MODEL_DIR).cuda().eval()
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_DIR)
print("Model and tokenizer initialized.")

def clean_html(conn, row_id):
    with conn:
        c = conn.cursor()
        c.execute(f"SELECT {HTML_COLUMN_NAME}, {HASH_COLUMN_NAME} FROM {TABLE_NAME} WHERE rowid = ?", (row_id,))

        row = c.fetchone()
        soup = BeautifulSoup(row[0], 'lxml')

        # Remove unwanted tags
        for tag in soup.find_all(['nav', 'aside', 'footer', 'header']):
            tag.decompose()

        # Convert the soup to text
        full_text = soup.get_text(separator='\n', strip=True)

        # Split the text into lines
        lines = full_text.split('\n')

        # Find the index of the line that starts the references section
        reference_start_index = -1
        for i, line in enumerate(lines):
            # Look for patterns like "References", "Reference List", "Works Cited", etc.
            # with the first letter of each word capitalized
            if re.match(r'^(References|Reference List|Works Cited|Bibliography)(\s*:)?$', line.strip()):
                reference_start_index = i
                break

        # If we found the references section, remove it and everything after
        if reference_start_index != -1:
            cleaned_text = '\n'.join(lines[:reference_start_index])
        else:
            cleaned_text = full_text

        # Remove any remaining citation numbers (assuming they're in square brackets)
        cleaned_text = re.sub(r'\[\d+\]', '', cleaned_text)

        return cleaned_text.strip(), row[1]

def embed_text(ids, texts):
    vectors = []
    new_ids = []
    metadata = []
    try:
        for id, text in zip(ids, texts):
            input_data = TOKENIZER(text, truncation=False, padding=False, return_tensors='pt')
            input_ids = input_data['input_ids'][0]

            chunks = []
            start_idx = 0
            while start_idx < len(input_ids):
                end_idx = min(start_idx + CHUNK_SIZE, len(input_ids))
                chunk = input_ids[start_idx:end_idx]
                chunks.append(chunk)
                start_idx = end_idx - CHUNK_OVERLAP

            for i, chunk in enumerate(chunks):
                with torch.no_grad():
                    input_data = {'input_ids': chunk.unsqueeze(0).cuda()}
                    attention_mask = torch.ones_like(input_data['input_ids'])
                    last_hidden_state = MODEL(**input_data)[0]
                    last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
                    vector = (last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]).cpu().numpy()
                    vectors.append(vector.flatten().tolist())

                    chunk_id = f'{id}_chunk_{i}'
                    new_ids.append(chunk_id)

                    # Create window metadata
                    window = []
                    if i > 0:
                        window.append(TOKENIZER.decode(chunks[i-1]))
                    window.append(TOKENIZER.decode(chunk))
                    if i < len(chunks) - 1:
                        window.append(TOKENIZER.decode(chunks[i+1]))
                    metadata.append({
                        "window": " ".join(window),
                        "parent_hash": id
                    })

        return new_ids, vectors, metadata
    except Exception as e:
        print(f"Error in embedding text: {str(e)}")
        return [], [], []

def create_pinecone_index(pc, index_name, dimension):
    if index_name not in pc.list_indexes().names():
        pc.create_index(
            name=index_name,
            dimension=dimension,
            metric="cosine",
            spec=ServerlessSpec(
                cloud="aws",
                region="us-east-1"
            )
        )
        print(f"Index '{index_name}' created successfully.")
    else:
        print(f"Index '{index_name}' already exists.")

def upsert_to_pinecone(pc, index_name, ids, vectors, metadata, max_retries=3, delay=1):
    index = pc.Index(index_name)
    to_upsert = list(zip(ids, vectors, metadata))
    failed_entries = []

    for attempt in range(max_retries):
        try:
            index.upsert(vectors=to_upsert)
            print(f"Batch of {len(to_upsert)} vectors upserted successfully.")
            return failed_entries
        except PineconeApiException as e:
            if attempt < max_retries - 1:
                print(f"Upsert failed. Retrying in {delay} seconds... (Attempt {attempt + 1}/{max_retries})")
                time.sleep(delay)
            else:
                print(f"Failed to upsert after {max_retries} attempts. Skipping entries.")
                failed_entries.extend(to_upsert)

    return failed_entries

def process_documents(row_ids, pc, index_name, batch_size=1):
    failed_entries = []

    for i in range(0, len(row_ids), batch_size):
        batch_ids = row_ids[i:i+batch_size]
        texts = []
        ids = []

        with sqlite3.connect(DB_NAME) as conn:
            for row_id in batch_ids:
                cleaned_text, hash_id = clean_html(conn, row_id)
                texts.append(cleaned_text)
                ids.append(hash_id)

        doc_ids, doc_vectors, doc_metadata = embed_text(ids, texts)
        if doc_vectors:
            print(f"Vector dimension: {len(doc_vectors[0])}")
            failed_batch = upsert_to_pinecone(pc, index_name, doc_ids, doc_vectors, doc_metadata)
            failed_entries.extend(failed_batch)

        print(f"Processed sub-batch {i//batch_size + 1}")

    return failed_entries

def process_batch(start_row, end_row, pc, index_name):
    print(f"Processing batch: rows {start_row} to {end_row}")
    row_ids = range(start_row, end_row + 1)
    process_documents(row_ids, pc, index_name)

def process_all_batches(pc, index_name, batch_size=1):
    with sqlite3.connect(DB_NAME) as conn:
        c = conn.cursor()
        c.execute(f"SELECT COUNT(*) FROM {TABLE_NAME}")
        total_rows = c.fetchone()[0]

    start_row = 1
    end_row = total_rows

    while start_row <= end_row:
        batch_end_row = min(start_row + batch_size - 1, end_row)
        process_batch(start_row, batch_end_row, pc, index_name)
        start_row = batch_end_row + 1

def main():
    api_key = getpass("Enter your Pinecone API key: ")
    pc = Pinecone(api_key=api_key)
    index_name = INDEX_NAME
    dimension = 1536
    create_pinecone_index(pc, index_name, dimension)
    failed_entries = process_all_batches(pc, index_name)

    print("List of failed entries:")
    for entry in failed_entries:
        print(f"ID: {entry[0]}, Metadata: {entry[2]}")

if __name__ == "__main__":
    main()

and below is the code I used to query the pinecone index


!pip install transformers
!pip install torch
!pip install pinecone-client
!pip install scikit-learn

import sqlite3
import torch
import numpy as np
from transformers import AutoModel, AutoTokenizer
from getpass import getpass
from pinecone import Pinecone
from sklearn.preprocessing import normalize

# Initialize SQLite connection
db_name = 'data.sqlite'  # Replace with your actual database name if different
conn = sqlite3.connect(db_name)
cursor = conn.cursor()

# Initialize Pinecone
api_key = getpass("Enter your Pinecone API key: ")
pc = Pinecone(api_key=api_key)

# Connect to the Pinecone index
index_name = "test"
index = pc.Index(index_name)

# Initialize model and tokenizer
MODEL_DIR = 'dunzhang/stella_en_1.5B_v5'
print("Initializing model and tokenizer...")
MODEL = AutoModel.from_pretrained(MODEL_DIR).cuda().eval()
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_DIR)
print("Model and tokenizer initialized.")

# Add the linear layer to the model
vector_dim = 1536
MODEL.vector_linear = torch.nn.Linear(in_features=MODEL.config.hidden_size, out_features=vector_dim).cuda()

def get_embedding(text, is_query=False):
    """
    Generate embedding for the given text using the initialized model.
    If is_query is True, prepend the query prompt.
    """
    query_prompt = "Instruct: Retrieve semantically similar text.\nQuery: "
    if is_query:
        text = query_prompt + text

    inputs = TOKENIZER(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to('cuda')
    with torch.no_grad():
        attention_mask = inputs["attention_mask"]
        last_hidden_state = MODEL(**inputs)[0]
        last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
        embeddings = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
        embeddings = normalize(MODEL.vector_linear(embeddings).cpu().numpy())
    return embeddings[0].tolist()  # Convert numpy array to list

def decode_vector(vector):
    """
    Decode a vector back into text using the model's tokenizer.
    """
    # Convert the vector to a tensor and reshape it
    tensor = torch.tensor(vector).unsqueeze(0).to('cuda')

    # Use the model to generate text from the vector
    with torch.no_grad():
        outputs = MODEL.generate(inputs_embeds=tensor, max_length=100)

    # Decode the generated tokens back to text
    decoded_text = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
    return decoded_text

def query_pinecone(query_vector, top_k=5):
    """
    Query the Pinecone index and return the most relevant results.
    """
    results = index.query(
        vector=query_vector,
        top_k=top_k,
        include_metadata=True
    )
    return results

def get_url_from_sqlite(hash_value):
    """
    Query the SQLite database to get the URL associated with a hash value.
    """
    cursor.execute("SELECT url FROM tablename WHERE hash = ?", (hash_value,))
    result = cursor.fetchone()
    return result[0] if result else None

def process_query(query_text, top_k=5):
    query_vector = get_embedding(query_text, is_query=True)
    results = query_pinecone(query_vector, top_k)

    processed_results = []
    for match in results['matches']:
        id_parts = match['id'].split('_')
        if len(id_parts) > 1:
            hash_value = id_parts[0]
            url = get_url_from_sqlite(hash_value)

            processed_results.append({
                'id': match['id'],
                'score': match['score'],
                'metadata': match['metadata'],
                'url': url,
                'text': match['metadata'].get('window', 'No text available')
            })

    return processed_results

# Main loop for multiple queries
while True:
    query_text = input("Enter your query text (or 'quit' to exit): ")
    if query_text.lower() == 'quit':
        break

    results = process_query(query_text)

    print(f"Query Text: {query_text}")
    print("---")

    for result in results:
        print(f"ID: {result['id']}")
        print(f"Score: {result['score']}")
        print(f"Metadata: {result['metadata']}")
        print(f"URL: {result['url']}")
        print(f"Text: {result['text']}")
        print("---")

I might have caught an error of mine. I was not aware (and I'm not entirely sure as I type this) that you use sentence transformers to encode text and transformers to decode a query. Will adjust my code given so and see if this changes things.

Edit: nope my assumption about sentence transformers and transformers was incorrect. so I'm back at square 1 and still need help with this topic.

Sign up or log in to comment