Verah commited on
Commit
8641e25
·
verified ·
1 Parent(s): 45173cb

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +41 -1
README.md CHANGED
@@ -13,12 +13,15 @@ denoise_util.py includes all definitions required to use Cascaded Gaze networks
13
 
14
  **v1**
15
  - ~ 132M params, trained on 256 * 256 RGB patches for intermediate jpg & webp compression artefact removal. It's been trained on about 700k samples (photographs only) at a precision of bf16. Also capable of removing ISO-like noise and gaussian noise.
 
16
 
17
  **Loading v1**
18
  ``` python
19
  from denoise_util import CascadedGaze
20
  from safetensors.torch import load_file
21
 
 
 
22
  img_channel = 3
23
  width = 60
24
  enc_blks = [2, 2, 4, 6]
@@ -31,7 +34,44 @@ model = CascadedGaze(img_channel=img_channel,width=width, middle_blk_num=middle_
31
 
32
  state_dict = load_file("models/v1.safetensors")
33
  model.load_state_dict(state_dict)
 
34
  model.requires_grad_(False)
35
  model.eval()
36
  ```
37
- I recommend inputing tensors of [B,3,256,256], with values of floats scaled to 0 - 1.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  **v1**
15
  - ~ 132M params, trained on 256 * 256 RGB patches for intermediate jpg & webp compression artefact removal. It's been trained on about 700k samples (photographs only) at a precision of bf16. Also capable of removing ISO-like noise and gaussian noise.
16
+ - I recommend inputing tensors of [B,3,256,256], with values of floats scaled to 0 - 1.
17
 
18
  **Loading v1**
19
  ``` python
20
  from denoise_util import CascadedGaze
21
  from safetensors.torch import load_file
22
 
23
+ device = "cuda"
24
+
25
  img_channel = 3
26
  width = 60
27
  enc_blks = [2, 2, 4, 6]
 
34
 
35
  state_dict = load_file("models/v1.safetensors")
36
  model.load_state_dict(state_dict)
37
+ model = model.to(device)
38
  model.requires_grad_(False)
39
  model.eval()
40
  ```
41
+
42
+ **Usage**
43
+ - Using https://github.com/ProGamerGov/blended-tiling to handle converting images of arbitrary sizes into 256*256 tiles then back again.
44
+ - You'll need to make ammendments to prevent the batches from being too large for your device.
45
+ - presumes the model was already loaded with code above.
46
+
47
+ ```python
48
+ import torch
49
+ from PIL import Image
50
+ import torchvision
51
+ from blended_tiling import TilingModule
52
+
53
+ def toimg(tensor):
54
+ tensor = torch.clamp(tensor, 0.0, 1.0)
55
+ tensor = tensor * 255
56
+ tensor = tensor.byte()
57
+ return torchvision.transforms.functional.to_pil_image(tensor)
58
+
59
+ # nb: if rgba inputs are anticipated, this won't be sufficient.
60
+ pil_image = Image.open("input.jpg").convert("RGB")
61
+
62
+ tiling_module = TilingModule(
63
+ tile_size=[256, 256],
64
+ tile_overlap=[0.1, 0.1], # you can configure this to taste
65
+ base_size=pil_image.size,
66
+ )
67
+
68
+ tensor = torchvision.transforms.functional.to_tensor(pil_image)
69
+ tensor = torch.unsqueeze(tensor,0)
70
+ tiles = tiling_module.split_into_tiles(tensor)
71
+ tiles = tiles.to(device)
72
+ result = model(tiles).cpu()
73
+ result = tiling_module.rebuild_with_masks(result).squeeze()
74
+
75
+ pil_result = toimg(result)
76
+ ```
77
+