|
import os |
|
from literalai import AsyncLiteralClient |
|
from datetime import datetime, timedelta, timezone |
|
from modules.config.constants import COOLDOWN_TIME, TOKENS_LEFT, REGEN_TIME |
|
from typing_extensions import TypedDict |
|
import tiktoken |
|
from typing import Any, Generic, List, Literal, Optional, TypeVar, Union |
|
|
|
Field = TypeVar("Field") |
|
Operators = TypeVar("Operators") |
|
Value = TypeVar("Value") |
|
|
|
BOOLEAN_OPERATORS = Literal["is", "nis"] |
|
STRING_OPERATORS = Literal["eq", "neq", "ilike", "nilike"] |
|
NUMBER_OPERATORS = Literal["eq", "neq", "gt", "gte", "lt", "lte"] |
|
STRING_LIST_OPERATORS = Literal["in", "nin"] |
|
DATETIME_OPERATORS = Literal["gte", "lte", "gt", "lt"] |
|
|
|
OPERATORS = Union[ |
|
BOOLEAN_OPERATORS, |
|
STRING_OPERATORS, |
|
NUMBER_OPERATORS, |
|
STRING_LIST_OPERATORS, |
|
DATETIME_OPERATORS, |
|
] |
|
|
|
|
|
class Filter(Generic[Field], TypedDict, total=False): |
|
field: Field |
|
operator: OPERATORS |
|
value: Any |
|
path: Optional[str] |
|
|
|
|
|
class OrderBy(Generic[Field], TypedDict): |
|
column: Field |
|
direction: Literal["ASC", "DESC"] |
|
|
|
|
|
threads_filterable_fields = Literal[ |
|
"id", |
|
"createdAt", |
|
"name", |
|
"stepType", |
|
"stepName", |
|
"stepOutput", |
|
"metadata", |
|
"tokenCount", |
|
"tags", |
|
"participantId", |
|
"participantIdentifiers", |
|
"scoreValue", |
|
"duration", |
|
] |
|
threads_orderable_fields = Literal["createdAt", "tokenCount"] |
|
threads_filters = List[Filter[threads_filterable_fields]] |
|
threads_order_by = OrderBy[threads_orderable_fields] |
|
|
|
steps_filterable_fields = Literal[ |
|
"id", |
|
"name", |
|
"input", |
|
"output", |
|
"participantIdentifier", |
|
"startTime", |
|
"endTime", |
|
"metadata", |
|
"parentId", |
|
"threadId", |
|
"error", |
|
"tags", |
|
] |
|
steps_orderable_fields = Literal["createdAt"] |
|
steps_filters = List[Filter[steps_filterable_fields]] |
|
steps_order_by = OrderBy[steps_orderable_fields] |
|
|
|
users_filterable_fields = Literal[ |
|
"id", |
|
"createdAt", |
|
"identifier", |
|
"lastEngaged", |
|
"threadCount", |
|
"tokenCount", |
|
"metadata", |
|
] |
|
users_filters = List[Filter[users_filterable_fields]] |
|
|
|
scores_filterable_fields = Literal[ |
|
"id", |
|
"createdAt", |
|
"participant", |
|
"name", |
|
"tags", |
|
"value", |
|
"type", |
|
"comment", |
|
] |
|
scores_orderable_fields = Literal["createdAt"] |
|
scores_filters = List[Filter[scores_filterable_fields]] |
|
scores_order_by = OrderBy[scores_orderable_fields] |
|
|
|
generation_filterable_fields = Literal[ |
|
"id", |
|
"createdAt", |
|
"model", |
|
"duration", |
|
"promptLineage", |
|
"promptVersion", |
|
"tags", |
|
"score", |
|
"participant", |
|
"tokenCount", |
|
"error", |
|
] |
|
generation_orderable_fields = Literal[ |
|
"createdAt", |
|
"tokenCount", |
|
"model", |
|
"provider", |
|
"participant", |
|
"duration", |
|
] |
|
generations_filters = List[Filter[generation_filterable_fields]] |
|
generations_order_by = OrderBy[generation_orderable_fields] |
|
|
|
literal_client = AsyncLiteralClient(api_key=os.getenv("LITERAL_API_KEY_LOGGING")) |
|
|
|
|
|
|
|
def convert_to_dict(user_info): |
|
|
|
if isinstance(user_info, dict): |
|
return user_info |
|
if hasattr(user_info, "__dict__"): |
|
user_info = user_info.__dict__ |
|
return user_info |
|
|
|
|
|
def get_time(): |
|
return datetime.now(timezone.utc).isoformat() |
|
|
|
|
|
async def get_user_details(user_email_id): |
|
user_info = await literal_client.api.get_or_create_user(identifier=user_email_id) |
|
return user_info |
|
|
|
|
|
async def update_user_info(user_info): |
|
|
|
user_info = convert_to_dict(user_info) |
|
await literal_client.api.update_user( |
|
id=user_info["id"], |
|
identifier=user_info["identifier"], |
|
metadata=user_info["metadata"], |
|
) |
|
|
|
|
|
async def check_user_cooldown(user_info, current_time): |
|
|
|
tokens_left = user_info.metadata.get("tokens_left", 0) |
|
if tokens_left > 0 and not user_info.metadata.get("in_cooldown", False): |
|
return False, None |
|
|
|
user_info = convert_to_dict(user_info) |
|
last_message_time_str = user_info["metadata"].get("last_message_time") |
|
|
|
|
|
last_message_time = datetime.fromisoformat(last_message_time_str).replace( |
|
tzinfo=timezone.utc |
|
) |
|
current_time = datetime.fromisoformat(current_time).replace(tzinfo=timezone.utc) |
|
|
|
|
|
elapsed_time = current_time - last_message_time |
|
elapsed_time_in_seconds = elapsed_time.total_seconds() |
|
|
|
|
|
cooldown_end_time = last_message_time + timedelta(seconds=COOLDOWN_TIME) |
|
cooldown_end_time_iso = cooldown_end_time.isoformat() |
|
|
|
|
|
print(f"Cooldown end time (ISO): {cooldown_end_time_iso}") |
|
|
|
|
|
if elapsed_time_in_seconds < COOLDOWN_TIME: |
|
return True, cooldown_end_time_iso |
|
|
|
user_info["metadata"]["in_cooldown"] = False |
|
|
|
await reset_tokens_for_user(user_info) |
|
|
|
return False, None |
|
|
|
|
|
async def reset_tokens_for_user(user_info): |
|
user_info = convert_to_dict(user_info) |
|
last_message_time_str = user_info["metadata"].get("last_message_time") |
|
|
|
last_message_time = datetime.fromisoformat(last_message_time_str).replace( |
|
tzinfo=timezone.utc |
|
) |
|
current_time = datetime.fromisoformat(get_time()).replace(tzinfo=timezone.utc) |
|
|
|
|
|
elapsed_time_in_seconds = (current_time - last_message_time).total_seconds() |
|
|
|
|
|
current_tokens = user_info["metadata"].get("tokens_left_at_last_message", 0) |
|
current_tokens = min(current_tokens, TOKENS_LEFT) |
|
|
|
|
|
max_tokens = user_info["metadata"].get("max_tokens", TOKENS_LEFT) |
|
|
|
|
|
if current_tokens < max_tokens: |
|
|
|
regeneration_rate_per_second = max_tokens / REGEN_TIME |
|
|
|
|
|
tokens_to_regenerate = int( |
|
elapsed_time_in_seconds * regeneration_rate_per_second |
|
) |
|
|
|
|
|
new_token_count = min(current_tokens + tokens_to_regenerate, max_tokens) |
|
|
|
print( |
|
f"\n\n Adding {tokens_to_regenerate} tokens to the user, Time elapsed: {elapsed_time_in_seconds} seconds, Tokens after regeneration: {new_token_count}, Tokens before: {current_tokens} \n\n" |
|
) |
|
|
|
|
|
user_info["metadata"]["tokens_left"] = new_token_count |
|
|
|
await update_user_info(user_info) |
|
|
|
|
|
async def get_thread_step_info(thread_id): |
|
step = await literal_client.api.get_step(thread_id) |
|
return step |
|
|
|
|
|
def get_num_tokens(text, model): |
|
encoding = tiktoken.encoding_for_model(model) |
|
tokens = encoding.encode(text) |
|
return len(tokens) |
|
|