enzostvs HF staff commited on
Commit
1f31a76
·
1 Parent(s): 44e7a8c

add base_model filter

Browse files
src/lib/components/fields/Input.svelte CHANGED
@@ -2,6 +2,7 @@
2
  export let placeholder: string = "Search";
3
  export let value: string = "";
4
  export let prefix: string = "";
 
5
  export let onChange: (value: string) => void = () => {};
6
 
7
  const handleChange = (event: any) => {
@@ -10,7 +11,7 @@
10
  }
11
  </script>
12
 
13
- <div class="bg-neutral-900 border border-neutral-800 rounded-lg text-neutral-200 text-base flex items-center justify-start overflow-hidden">
14
  {#if prefix}
15
  <div class="flex items-center justify-between bg-neutral-800/50 px-3 border-r border-neutral-800 py-4">
16
  <p class="text-xs uppercase text-neutral-100 font-semibold">{prefix}</p>
 
2
  export let placeholder: string = "Search";
3
  export let value: string = "";
4
  export let prefix: string = "";
5
+ export let className: string = "";
6
  export let onChange: (value: string) => void = () => {};
7
 
8
  const handleChange = (event: any) => {
 
11
  }
12
  </script>
13
 
14
+ <div class="bg-neutral-900 border border-neutral-800 rounded-lg text-neutral-200 text-base flex items-center justify-start overflow-hidden {className}">
15
  {#if prefix}
16
  <div class="flex items-center justify-between bg-neutral-800/50 px-3 border-r border-neutral-800 py-4">
17
  <p class="text-xs uppercase text-neutral-100 font-semibold">{prefix}</p>
src/routes/api/models/+server.ts CHANGED
@@ -20,6 +20,7 @@ export async function GET(request : RequestEvent) {
20
  const filter = request.url.searchParams.get('filter') || 'hotest'
21
  const search = request.url.searchParams.get('search') || ''
22
  const limit = parseInt(request.url.searchParams.get('limit') || '20')
 
23
 
24
  const orderBy: Record<string, string> = {}
25
  if (filter === 'hotest') {
@@ -30,6 +31,21 @@ export async function GET(request : RequestEvent) {
30
  orderBy['createdAt'] = 'desc'
31
  }
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  const only_not_public = filter === 'staff_only';
34
 
35
  const cards = await prisma.model.findMany({
@@ -37,6 +53,7 @@ export async function GET(request : RequestEvent) {
37
  ...(
38
  !IS_ADMIN ? { isPublic: true } : only_not_public ? { isPublic: false } : {}
39
  ),
 
40
  OR: [
41
  { id: { contains: search } },
42
  ]
 
20
  const filter = request.url.searchParams.get('filter') || 'hotest'
21
  const search = request.url.searchParams.get('search') || ''
22
  const limit = parseInt(request.url.searchParams.get('limit') || '20')
23
+ const base_model = request.url.searchParams.get('base_model') || undefined
24
 
25
  const orderBy: Record<string, string> = {}
26
  if (filter === 'hotest') {
 
31
  orderBy['createdAt'] = 'desc'
32
  }
33
 
34
+ let base_model_mapped: string[] | undefined = undefined;
35
+ if (base_model) {
36
+ switch (base_model) {
37
+ case "sd3":
38
+ base_model_mapped = ['stabilityai/stable-diffusion-3-medium-diffusers']
39
+ break;
40
+ case "sdxl":
41
+ base_model_mapped = ['stabilityai/stabilityai/stable-diffusion-xl-base-1.0']
42
+ break;
43
+ case "sd1":
44
+ base_model_mapped = ['CompVis/stable-diffusion-v1-4', 'runwayml/stable-diffusion-v1-5']
45
+ break;
46
+ }
47
+ }
48
+
49
  const only_not_public = filter === 'staff_only';
50
 
51
  const cards = await prisma.model.findMany({
 
53
  ...(
54
  !IS_ADMIN ? { isPublic: true } : only_not_public ? { isPublic: false } : {}
55
  ),
56
+ ...(base_model_mapped ? { base_model: { in: base_model_mapped } } : {}),
57
  OR: [
58
  { id: { contains: search } },
59
  ]
src/routes/models/+layout.svelte CHANGED
@@ -26,6 +26,7 @@
26
  let form: Record<string, string> = {
27
  filter: $page.url.searchParams.get('filter') ?? "hotest",
28
  search: $page.url.searchParams.get('search') ?? "",
 
29
  page: "0"
30
  }
31
  let submitModelDialog = false;
@@ -47,6 +48,14 @@
47
  await goto(`?${$page.url.searchParams.toString()}`);
48
  refetch(false);
49
  }
 
 
 
 
 
 
 
 
50
  let timeout: any;
51
  const handleChangeSearch = async (search: string) => {
52
  clearTimeout(timeout);
@@ -126,8 +135,27 @@
126
  </Button>
127
  </div>
128
  </div>
129
- <div class="mt-5 max-w-sm">
130
- <Input value={form.search} placeholder="Filter by model name" onChange={handleChangeSearch} />
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  </div>
132
  <div class="mx-auto grid grid-cols-1 sm:grid-cols-2 md:grid-cols-3 2xl:grid-cols-4 gap-5 mt-8 lg:mt-10">
133
  {#each data.models as card}
 
26
  let form: Record<string, string> = {
27
  filter: $page.url.searchParams.get('filter') ?? "hotest",
28
  search: $page.url.searchParams.get('search') ?? "",
29
+ base_model: $page.url.searchParams.get('base_model') ?? "",
30
  page: "0"
31
  }
32
  let submitModelDialog = false;
 
48
  await goto(`?${$page.url.searchParams.toString()}`);
49
  refetch(false);
50
  }
51
+ const handleChangeBaseModel = async (event: any) => {
52
+ const base_model = event.target.value
53
+ form.base_model = base_model;
54
+ $page.url.searchParams.set('base_model', base_model);
55
+ await goto(`?${$page.url.searchParams.toString()}`);
56
+ refetch(false);
57
+ }
58
+
59
  let timeout: any;
60
  const handleChangeSearch = async (search: string) => {
61
  clearTimeout(timeout);
 
135
  </Button>
136
  </div>
137
  </div>
138
+ <div class="mt-5 max-w-sm flex items-center justify-start gap-5 w-full">
139
+ <Input value={form.search} className="lg:min-w-[300px]" placeholder="Filter by model name" onChange={handleChangeSearch} />
140
+ <div class="flex flex-col items-start justify-center gap-1.5">
141
+ <p class="text-xs text-white/60 whitespace-nowrap">
142
+ Filter by
143
+ </p>
144
+ <select value={form.base_model} class="text-white bg-transparent outline-none cursor-pointer" on:change={handleChangeBaseModel}>
145
+ <option value="">
146
+ All models
147
+ </option>
148
+ <option value="sd3">
149
+ Stable Diffusion 3
150
+ </option>
151
+ <option value="sdxl">
152
+ Stable Diffusion XL
153
+ </option>
154
+ <option value="sd1">
155
+ Stable Diffusion 1
156
+ </option>
157
+ </select>
158
+ </div>
159
  </div>
160
  <div class="mx-auto grid grid-cols-1 sm:grid-cols-2 md:grid-cols-3 2xl:grid-cols-4 gap-5 mt-8 lg:mt-10">
161
  {#each data.models as card}