import torch from torch.nn import functional as F from PIL import Image ### from def top_k_top_p_filtering( logits, top_k: int = 0, top_p: float = 1.0, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1, ): """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (batch size, vocabulary size) if top_k > 0: keep only top k tokens with highest probability (top-k filtering). if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. ( Make sure we keep at least min_tokens_to_keep per batch example in the output From: """ logits[:,:256000]=filter_value if top_k > 0: top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold (token with 0 are kept) sorted_indices_to_remove = cumulative_probs > top_p if min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = filter_value # import pdb;pdb.set_trace() return logits def sample(logits, temperature: float=1.0, top_k: int=0, top_p: float=1.0, sample_logits=True): logits = logits[:, -1, :] / max(temperature, 1e-5) if top_k > 0 or top_p < 1.0: logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) probs = F.softmax(logits, dim=-1) if sample_logits: idx = torch.multinomial(probs, num_samples=1) else: _, idx = torch.topk(probs, k=1, dim=-1) return idx, probs def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result =, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result =, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result def tokenizer_image_token(prompt, tokenizer, image_token_index=-200, return_tensors=None): prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] def insert_separator(X, sep): return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] input_ids = [] offset = 0 if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: offset = 1 input_ids.append(prompt_chunks[0][0]) for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): input_ids.extend(x[offset:]) if return_tensors is not None: if return_tensors == 'pt': return torch.tensor(input_ids, dtype=torch.long) raise ValueError(f'Unsupported tensor type: {return_tensors}') return input_ids