|
|
|
import random |
|
import time |
|
|
|
from typing import List, Dict, Tuple |
|
|
|
from .shared import DreamImage |
|
|
|
|
|
class RGBPalette: |
|
ID = "RGB_PALETTE" |
|
|
|
def __init__(self, colors: List[tuple[int, int, int]] = None, image: DreamImage = None): |
|
self._colors = [] |
|
|
|
def _fix_tuple(t): |
|
if len(t) < 3: |
|
return (t[0], t[0], t[0]) |
|
else: |
|
return t |
|
|
|
if image: |
|
for p, _, _ in image: |
|
self._colors.append(_fix_tuple(p)) |
|
if colors: |
|
for c in colors: |
|
self._colors.append(_fix_tuple(c)) |
|
|
|
def _calculate_channel_contrast(self, c): |
|
hist = list(map(lambda _: 0, range(16))) |
|
for pixel in self._colors: |
|
hist[pixel[c] // 16] += 1 |
|
s = 0 |
|
max_possible = (15 - 0) * (len(self) // 2) * (len(self) // 2) |
|
for i in range(16): |
|
for j in range(i): |
|
if i != j: |
|
s += abs(i - j) * hist[i] * hist[j] |
|
return s / max_possible |
|
|
|
def _calculate_combined_contrast(self): |
|
s = 0 |
|
for c in range(3): |
|
s += self._calculate_channel_contrast(c) |
|
return s / 3 |
|
|
|
def analyze(self): |
|
total_red = 0 |
|
total_blue = 0 |
|
total_green = 0 |
|
for pixel in self: |
|
total_red += pixel[0] |
|
total_green += pixel[1] |
|
total_blue += pixel[2] |
|
n = len(self._colors) |
|
r = float(total_red) / (255 * n) |
|
g = float(total_green) / (255 * n) |
|
b = float(total_blue) / (255 * n) |
|
return ((r + g + b) / 3.0, self._calculate_combined_contrast(), r, g, b) |
|
|
|
def __len__(self): |
|
return len(self._colors) |
|
|
|
def __iter__(self): |
|
return iter(self._colors) |
|
|
|
def random_iteration(self, seed=None): |
|
s = seed if seed is not None else int(time.time() * 1000) |
|
n = len(self._colors) - 1 |
|
c = self._colors |
|
|
|
class _ColorIterator: |
|
def __init__(self): |
|
self._r = random.Random() |
|
self._r.seed(s) |
|
self._n = n |
|
self._c = c |
|
|
|
def __next__(self): |
|
return self._c[self._r.randint(0, self._n)] |
|
|
|
return _ColorIterator() |
|
|
|
|
|
class PartialPrompt: |
|
ID = "PARTIAL_PROMPT" |
|
|
|
def __init__(self): |
|
self._data = {} |
|
|
|
def add(self, text: str, weight: float): |
|
output = PartialPrompt() |
|
output._data = dict(self._data) |
|
for parts in text.split(","): |
|
parts = parts.strip() |
|
if " " in parts: |
|
output._data["(" + parts + ")"] = weight |
|
else: |
|
output._data[parts] = weight |
|
return output |
|
|
|
def is_empty(self): |
|
return not self._data |
|
|
|
def abs_sum(self): |
|
if not self._data: |
|
return 0.0 |
|
return sum(map(abs, self._data.values())) |
|
|
|
def abs_max(self): |
|
if not self._data: |
|
return 0.0 |
|
return max(map(abs, self._data.values())) |
|
|
|
def scaled_by(self, f: float): |
|
new_data = PartialPrompt() |
|
new_data._data = dict(self._data) |
|
for text, weight in new_data._data.items(): |
|
new_data._data[text] = weight * f |
|
return new_data |
|
|
|
def finalize(self, clamp: float): |
|
items = self._data.items() |
|
items = sorted(items, key=lambda pair: (pair[1], pair[0])) |
|
pos = list() |
|
neg = list() |
|
for text, w in sorted(items, key=lambda pair: (-pair[1], pair[0])): |
|
if w >= 0.0001: |
|
pos.append("({}:{:.3f})".format(text, min(clamp, w))) |
|
for text, w in sorted(items, key=lambda pair: (pair[1], pair[0])): |
|
if w <= -0.0001: |
|
neg.append("({}:{:.3f})".format(text, min(clamp, -w))) |
|
return ", ".join(pos), ", ".join(neg) |
|
|
|
|
|
class LogEntry: |
|
ID = "LOG_ENTRY" |
|
|
|
@classmethod |
|
def new(cls, text): |
|
return LogEntry([(time.time(), text)]) |
|
|
|
def __init__(self, data: List[Tuple[float, str]] = None): |
|
if data is None: |
|
self._data = list() |
|
else: |
|
self._data = list(data) |
|
|
|
def add(self, text: str): |
|
new_data = list(self._data) |
|
new_data.append((time.time(), text)) |
|
return LogEntry(new_data) |
|
|
|
def merge(self, log_entry): |
|
new_data = list(self._data) |
|
new_data.extend(log_entry._data) |
|
return LogEntry(new_data) |
|
|
|
def get_filtered_entries(self, t: float): |
|
for d in sorted(self._data): |
|
if d[0] > t: |
|
yield d |
|
|
|
|
|
class FrameCounter: |
|
ID = "FRAME_COUNTER" |
|
|
|
def __init__(self, current_frame=0, total_frames=1, frames_per_second=25.0): |
|
self.current_frame = max(0, current_frame) |
|
self.total_frames = max(total_frames, 1) |
|
self.frames_per_second = float(max(1.0, frames_per_second)) |
|
|
|
def incremented(self, amount: int): |
|
return FrameCounter(self.current_frame + amount, self.total_frames, self.frames_per_second) |
|
|
|
@property |
|
def is_first_frame(self): |
|
return self.current_frame == 0 |
|
|
|
@property |
|
def is_final_frame(self): |
|
return (self.current_frame + 1) == self.total_frames |
|
|
|
@property |
|
def is_after_last_frame(self): |
|
return self.current_frame >= self.total_frames |
|
|
|
@property |
|
def current_time_in_seconds(self): |
|
return float(self.current_frame) / self.frames_per_second |
|
|
|
@property |
|
def total_time_in_seconds(self): |
|
return float(self.total_frames) / self.frames_per_second |
|
|
|
@property |
|
def remaining_time_in_seconds(self): |
|
return self.total_time_in_seconds - self.current_time_in_seconds |
|
|
|
@property |
|
def progress(self): |
|
return float(self.current_frame) / (max(2, self.total_frames) - 1) |
|
|
|
|
|
class AnimationSequence: |
|
ID = "ANIMATION_SEQUENCE" |
|
|
|
def __init__(self, frame_counter: FrameCounter, frames: Dict[int, List[str]] = None): |
|
self.frames = frames |
|
self.fps = frame_counter.frames_per_second |
|
self.frame_counter = frame_counter |
|
if self.is_defined: |
|
self.keys_in_order = sorted(frames.keys()) |
|
self.num_batches = min(map(len, self.frames.values())) |
|
else: |
|
self.keys_in_order = [] |
|
self.num_batches = 0 |
|
|
|
@property |
|
def batches(self): |
|
return range(self.num_batches) |
|
|
|
def get_image_files_of_batch(self, batch_num): |
|
for key in self.keys_in_order: |
|
yield self.frames[key][batch_num] |
|
|
|
@property |
|
def is_defined(self): |
|
if self.frames: |
|
return True |
|
else: |
|
return False |
|
|
|
|
|
class SharedTypes: |
|
frame_counter = {"frame_counter": (FrameCounter.ID,)} |
|
sequence = {"sequence": (AnimationSequence.ID,)} |
|
palette = {"palette": (RGBPalette.ID,)} |
|
|