coyotte508 HF staff victor HF staff commited on
Commit
4a6603b
1 Parent(s): e518a94

✨ Retry messages (#121)

Browse files

Co-authored-by: Victor Mustar <[email protected]>

src/lib/components/chat/ChatMessage.svelte CHANGED
@@ -1,11 +1,12 @@
1
  <script lang="ts">
2
  import { marked } from "marked";
3
  import type { Message } from "$lib/types/Message";
4
- import { afterUpdate } from "svelte";
5
  import { deepestChild } from "$lib/utils/deepestChild";
6
 
7
  import CodeBlock from "../CodeBlock.svelte";
8
  import IconLoading from "../icons/IconLoading.svelte";
 
9
 
10
  function sanitizeMd(md: string) {
11
  return md
@@ -25,6 +26,8 @@
25
  export let message: Message;
26
  export let loading: boolean = false;
27
 
 
 
28
  let contentEl: HTMLElement;
29
  let loadingEl: any;
30
  let pendingTimeout: NodeJS.Timeout;
@@ -69,7 +72,7 @@
69
  <img
70
  alt=""
71
  src="https://huggingface.co/avatars/2edb18bd0206c16b433841a47f53fa8e.svg"
72
- class="mt-5 h-3 w-3 flex-none rounded-full shadow-lg"
73
  />
74
  <div
75
  class="relative min-h-[calc(2rem+theme(spacing[3.5])*2)] min-w-[100px] rounded-2xl border border-gray-100 bg-gradient-to-br from-gray-50 px-5 py-3.5 text-gray-600 prose-pre:my-2 dark:border-gray-800 dark:from-gray-800/40 dark:text-gray-300"
@@ -93,10 +96,20 @@
93
  </div>
94
  {/if}
95
  {#if message.from === "user"}
96
- <div class="flex items-start justify-start gap-4 max-sm:text-sm">
97
  <div class="mt-5 h-3 w-3 flex-none rounded-full" />
98
  <div class="whitespace-break-spaces rounded-2xl px-5 py-3.5 text-gray-500 dark:text-gray-400">
99
  {message.content.trim()}
100
  </div>
 
 
 
 
 
 
 
 
 
 
101
  </div>
102
  {/if}
 
1
  <script lang="ts">
2
  import { marked } from "marked";
3
  import type { Message } from "$lib/types/Message";
4
+ import { afterUpdate, createEventDispatcher } from "svelte";
5
  import { deepestChild } from "$lib/utils/deepestChild";
6
 
7
  import CodeBlock from "../CodeBlock.svelte";
8
  import IconLoading from "../icons/IconLoading.svelte";
9
+ import CarbonRotate360 from "~icons/carbon/rotate-360";
10
 
11
  function sanitizeMd(md: string) {
12
  return md
 
26
  export let message: Message;
27
  export let loading: boolean = false;
28
 
29
+ const dispatch = createEventDispatcher<{ retry: void }>();
30
+
31
  let contentEl: HTMLElement;
32
  let loadingEl: any;
33
  let pendingTimeout: NodeJS.Timeout;
 
72
  <img
73
  alt=""
74
  src="https://huggingface.co/avatars/2edb18bd0206c16b433841a47f53fa8e.svg"
75
+ class="mt-5 h-3 w-3 flex-none select-none rounded-full shadow-lg"
76
  />
77
  <div
78
  class="relative min-h-[calc(2rem+theme(spacing[3.5])*2)] min-w-[100px] rounded-2xl border border-gray-100 bg-gradient-to-br from-gray-50 px-5 py-3.5 text-gray-600 prose-pre:my-2 dark:border-gray-800 dark:from-gray-800/40 dark:text-gray-300"
 
96
  </div>
97
  {/if}
98
  {#if message.from === "user"}
99
+ <div class="group relative flex items-start justify-start gap-4 max-sm:text-sm">
100
  <div class="mt-5 h-3 w-3 flex-none rounded-full" />
101
  <div class="whitespace-break-spaces rounded-2xl px-5 py-3.5 text-gray-500 dark:text-gray-400">
102
  {message.content.trim()}
103
  </div>
104
+ {#if !loading && message.id}
105
+ <button
106
+ class="absolute right-0 top-3.5 cursor-pointer rounded-lg border border-gray-100 p-1 text-xs text-gray-400 group-hover:block hover:text-gray-500 dark:border-gray-800 dark:text-gray-400 dark:hover:text-gray-300 md:hidden lg:-right-2"
107
+ title="Retry"
108
+ type="button"
109
+ on:click={() => dispatch("retry")}
110
+ >
111
+ <CarbonRotate360 />
112
+ </button>
113
+ {/if}
114
  </div>
115
  {/if}
src/lib/components/chat/ChatMessages.svelte CHANGED
@@ -2,10 +2,13 @@
2
  import type { Message } from "$lib/types/Message";
3
  import { snapScrollToBottom } from "$lib/actions/snapScrollToBottom";
4
  import ScrollToBottomBtn from "$lib/components/ScrollToBottomBtn.svelte";
5
- import { tick } from "svelte";
6
 
7
  import ChatIntroduction from "./ChatIntroduction.svelte";
8
  import ChatMessage from "./ChatMessage.svelte";
 
 
 
9
 
10
  export let messages: Message[];
11
  export let loading: boolean;
@@ -31,12 +34,16 @@
31
  >
32
  <div class="mx-auto flex h-full max-w-3xl flex-col gap-5 px-5 pt-6 sm:gap-8 xl:max-w-4xl">
33
  {#each messages as message, i}
34
- <ChatMessage loading={loading && i === messages.length - 1} {message} />
 
 
 
 
35
  {:else}
36
  <ChatIntroduction on:message />
37
  {/each}
38
  {#if pending}
39
- <ChatMessage message={{ from: "assistant", content: "" }} />
40
  {/if}
41
  <div class="h-32 flex-none" />
42
  </div>
 
2
  import type { Message } from "$lib/types/Message";
3
  import { snapScrollToBottom } from "$lib/actions/snapScrollToBottom";
4
  import ScrollToBottomBtn from "$lib/components/ScrollToBottomBtn.svelte";
5
+ import { createEventDispatcher, tick } from "svelte";
6
 
7
  import ChatIntroduction from "./ChatIntroduction.svelte";
8
  import ChatMessage from "./ChatMessage.svelte";
9
+ import { randomUUID } from "$lib/utils/randomUuid";
10
+
11
+ const dispatch = createEventDispatcher<{ retry: { id: Message["id"]; content: string } }>();
12
 
13
  export let messages: Message[];
14
  export let loading: boolean;
 
34
  >
35
  <div class="mx-auto flex h-full max-w-3xl flex-col gap-5 px-5 pt-6 sm:gap-8 xl:max-w-4xl">
36
  {#each messages as message, i}
37
+ <ChatMessage
38
+ loading={loading && i === messages.length - 1}
39
+ {message}
40
+ on:retry={() => dispatch("retry", { id: message.id, content: message.content })}
41
+ />
42
  {:else}
43
  <ChatIntroduction on:message />
44
  {/each}
45
  {#if pending}
46
+ <ChatMessage message={{ from: "assistant", content: "", id: randomUUID() }} />
47
  {/if}
48
  <div class="h-32 flex-none" />
49
  </div>
src/lib/components/chat/ChatWindow.svelte CHANGED
@@ -17,7 +17,12 @@
17
 
18
  let message: string;
19
 
20
- const dispatch = createEventDispatcher<{ message: string; share: void; stop: void }>();
 
 
 
 
 
21
 
22
  const handleSubmit = () => {
23
  if (loading) return;
@@ -27,7 +32,15 @@
27
  </script>
28
 
29
  <div class="relative min-h-0 min-w-0">
30
- <ChatMessages {loading} {pending} {messages} on:message />
 
 
 
 
 
 
 
 
31
  <div
32
  class="dark:via-gray-80 pointer-events-none absolute inset-x-0 bottom-0 z-0 mx-auto flex w-full max-w-3xl flex-col items-center justify-center bg-gradient-to-t from-white via-white/80 to-white/0 px-3.5 py-4 dark:border-gray-800 dark:from-gray-900 dark:to-gray-900/0 max-md:border-t max-md:bg-white max-md:dark:bg-gray-900 sm:px-5 md:py-8 xl:max-w-4xl [&>*]:pointer-events-auto"
33
  >
 
17
 
18
  let message: string;
19
 
20
+ const dispatch = createEventDispatcher<{
21
+ message: string;
22
+ share: void;
23
+ stop: void;
24
+ retry: { id: Message["id"]; content: string };
25
+ }>();
26
 
27
  const handleSubmit = () => {
28
  if (loading) return;
 
32
  </script>
33
 
34
  <div class="relative min-h-0 min-w-0">
35
+ <ChatMessages
36
+ {loading}
37
+ {pending}
38
+ {messages}
39
+ on:message
40
+ on:retry={(ev) => {
41
+ if (!loading) dispatch("retry", ev.detail);
42
+ }}
43
+ />
44
  <div
45
  class="dark:via-gray-80 pointer-events-none absolute inset-x-0 bottom-0 z-0 mx-auto flex w-full max-w-3xl flex-col items-center justify-center bg-gradient-to-t from-white via-white/80 to-white/0 px-3.5 py-4 dark:border-gray-800 dark:from-gray-900 dark:to-gray-900/0 max-md:border-t max-md:bg-white max-md:dark:bg-gray-900 sm:px-5 md:py-8 xl:max-w-4xl [&>*]:pointer-events-auto"
46
  >
src/lib/stores/pendingMessageIdToRetry.ts ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import type { Message } from "$lib/types/Message";
2
+ import { writable } from "svelte/store";
3
+
4
+ export const pendingMessageIdToRetry = writable<Message["id"] | null>(null);
src/lib/types/Message.ts CHANGED
@@ -1,4 +1,5 @@
1
  export interface Message {
2
  from: "user" | "assistant";
 
3
  content: string;
4
  }
 
1
  export interface Message {
2
  from: "user" | "assistant";
3
+ id: ReturnType<typeof crypto.randomUUID>;
4
  content: string;
5
  }
src/lib/utils/randomUuid.ts ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ type UUID = ReturnType<typeof crypto.randomUUID>;
2
+
3
+ export function randomUUID(): UUID {
4
+ // Only on old safari / ios
5
+ if (!("randomUUID" in crypto)) {
6
+ return "10000000-1000-4000-8000-100000000000".replace(/[018]/g, (c) =>
7
+ (
8
+ Number(c) ^
9
+ (crypto.getRandomValues(new Uint8Array(1))[0] & (15 >> (Number(c) / 4)))
10
+ ).toString(16)
11
+ ) as UUID;
12
+ }
13
+ return crypto.randomUUID();
14
+ }
src/routes/conversation/[id]/+page.svelte CHANGED
@@ -1,6 +1,7 @@
1
  <script lang="ts">
2
  import ChatWindow from "$lib/components/chat/ChatWindow.svelte";
3
  import { pendingMessage } from "$lib/stores/pendingMessage";
 
4
  import { onMount } from "svelte";
5
  import { page } from "$app/stores";
6
  import { textGenerationStream } from "@huggingface/inference";
@@ -10,6 +11,7 @@
10
  import { shareConversation } from "$lib/shareConversation";
11
  import { UrlDependency } from "$lib/types/UrlDependency";
12
  import { error } from "$lib/stores/errors";
 
13
 
14
  export let data;
15
 
@@ -26,7 +28,7 @@
26
  let loading = false;
27
  let pending = false;
28
 
29
- async function getTextGenerationStream(inputs: string) {
30
  let conversationId = $page.params.id;
31
 
32
  const response = textGenerationStream(
@@ -48,6 +50,8 @@
48
  },
49
  },
50
  {
 
 
51
  use_cache: false,
52
  }
53
  );
@@ -89,7 +93,11 @@
89
 
90
  if (lastMessage?.from !== "assistant") {
91
  // First token has a space at the beginning, trim it
92
- messages = [...messages, { from: "assistant", content: data.token.text.trimStart() }];
 
 
 
 
93
  } else {
94
  lastMessage.content += data.token.text;
95
  messages = [...messages];
@@ -104,7 +112,7 @@
104
  });
105
  }
106
 
107
- async function writeMessage(message: string) {
108
  if (!message.trim()) return;
109
 
110
  try {
@@ -112,9 +120,18 @@
112
  loading = true;
113
  pending = true;
114
 
115
- messages = [...messages, { from: "user", content: message }];
 
 
 
 
 
 
 
 
 
116
 
117
- await getTextGenerationStream(message);
118
 
119
  if (messages.filter((m) => m.from === "user").length === 1) {
120
  summarizeTitle($page.params.id)
@@ -135,9 +152,11 @@
135
  onMount(async () => {
136
  if ($pendingMessage) {
137
  const val = $pendingMessage;
 
138
  $pendingMessage = "";
 
139
 
140
- writeMessage(val);
141
  }
142
  });
143
 
@@ -153,6 +172,7 @@
153
  {pending}
154
  {messages}
155
  on:message={(message) => writeMessage(message.detail)}
 
156
  on:share={() => shareConversation($page.params.id, data.title)}
157
  on:stop={() => (isAborted = true)}
158
  />
 
1
  <script lang="ts">
2
  import ChatWindow from "$lib/components/chat/ChatWindow.svelte";
3
  import { pendingMessage } from "$lib/stores/pendingMessage";
4
+ import { pendingMessageIdToRetry } from "$lib/stores/pendingMessageIdToRetry";
5
  import { onMount } from "svelte";
6
  import { page } from "$app/stores";
7
  import { textGenerationStream } from "@huggingface/inference";
 
11
  import { shareConversation } from "$lib/shareConversation";
12
  import { UrlDependency } from "$lib/types/UrlDependency";
13
  import { error } from "$lib/stores/errors";
14
+ import { randomUUID } from "$lib/utils/randomUuid";
15
 
16
  export let data;
17
 
 
28
  let loading = false;
29
  let pending = false;
30
 
31
+ async function getTextGenerationStream(inputs: string, messageId: string, isRetry = false) {
32
  let conversationId = $page.params.id;
33
 
34
  const response = textGenerationStream(
 
50
  },
51
  },
52
  {
53
+ id: messageId,
54
+ is_retry: isRetry,
55
  use_cache: false,
56
  }
57
  );
 
93
 
94
  if (lastMessage?.from !== "assistant") {
95
  // First token has a space at the beginning, trim it
96
+ messages = [
97
+ ...messages,
98
+ // id doesn't match the backend id but it's not important for assistant messages
99
+ { from: "assistant", content: data.token.text.trimStart(), id: randomUUID() },
100
+ ];
101
  } else {
102
  lastMessage.content += data.token.text;
103
  messages = [...messages];
 
112
  });
113
  }
114
 
115
+ async function writeMessage(message: string, messageId = crypto.randomUUID()) {
116
  if (!message.trim()) return;
117
 
118
  try {
 
120
  loading = true;
121
  pending = true;
122
 
123
+ let retryMessageIndex = messages.findIndex((msg) => msg.id === messageId);
124
+ const isRetry = retryMessageIndex !== -1;
125
+ if (!isRetry) {
126
+ retryMessageIndex = messages.length;
127
+ }
128
+
129
+ messages = [
130
+ ...messages.slice(0, retryMessageIndex),
131
+ { from: "user", content: message, id: messageId },
132
+ ];
133
 
134
+ await getTextGenerationStream(message, messageId, isRetry);
135
 
136
  if (messages.filter((m) => m.from === "user").length === 1) {
137
  summarizeTitle($page.params.id)
 
152
  onMount(async () => {
153
  if ($pendingMessage) {
154
  const val = $pendingMessage;
155
+ const messageId = $pendingMessageIdToRetry || undefined;
156
  $pendingMessage = "";
157
+ $pendingMessageIdToRetry = null;
158
 
159
+ writeMessage(val, messageId);
160
  }
161
  });
162
 
 
172
  {pending}
173
  {messages}
174
  on:message={(message) => writeMessage(message.detail)}
175
+ on:retry={(message) => writeMessage(message.detail.content, message.detail.id)}
176
  on:share={() => shareConversation($page.params.id, data.title)}
177
  on:stop={() => (isAborted = true)}
178
  />
src/routes/conversation/[id]/+server.ts CHANGED
@@ -27,10 +27,43 @@ export async function POST({ request, fetch, locals, params }) {
27
  throw error(404, "Conversation not found");
28
  }
29
 
30
- // Todo: validate prompt with zod? or aktype
31
  const json = await request.json();
32
-
33
- const messages = [...conv.messages, { from: "user", content: json.inputs }] satisfies Message[];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  const prompt = buildPrompt(messages);
35
 
36
  const randomEndpoint = modelEndpoint();
@@ -62,7 +95,7 @@ export async function POST({ request, fetch, locals, params }) {
62
 
63
  generated_text = trimSuffix(trimPrefix(generated_text, "<|startoftext|>"), PUBLIC_SEP_TOKEN);
64
 
65
- messages.push({ from: "assistant", content: generated_text });
66
 
67
  await collections.conversations.updateOne(
68
  {
 
27
  throw error(404, "Conversation not found");
28
  }
29
 
 
30
  const json = await request.json();
31
+ const {
32
+ inputs: newPrompt,
33
+ options: { id: messageId, is_retry },
34
+ } = z
35
+ .object({
36
+ inputs: z.string().trim().min(1),
37
+ options: z.object({
38
+ id: z.optional(z.string().uuid()),
39
+ is_retry: z.optional(z.boolean()),
40
+ }),
41
+ })
42
+ .parse(json);
43
+
44
+ const messages = (() => {
45
+ if (is_retry && messageId) {
46
+ let retryMessageIdx = conv.messages.findIndex((message) => message.id === messageId);
47
+ if (retryMessageIdx === -1) {
48
+ retryMessageIdx = conv.messages.length;
49
+ }
50
+ return [
51
+ ...conv.messages.slice(0, retryMessageIdx),
52
+ { content: newPrompt, from: "user", id: messageId as Message["id"] },
53
+ ];
54
+ }
55
+ return [
56
+ ...conv.messages,
57
+ { content: newPrompt, from: "user", id: (messageId as Message["id"]) || crypto.randomUUID() },
58
+ ];
59
+ })() satisfies Message[];
60
+
61
+ // Todo: on-the-fly migration, remove later
62
+ for (const message of messages) {
63
+ if (!message.id) {
64
+ message.id = crypto.randomUUID();
65
+ }
66
+ }
67
  const prompt = buildPrompt(messages);
68
 
69
  const randomEndpoint = modelEndpoint();
 
95
 
96
  generated_text = trimSuffix(trimPrefix(generated_text, "<|startoftext|>"), PUBLIC_SEP_TOKEN);
97
 
98
+ messages.push({ from: "assistant", content: generated_text, id: crypto.randomUUID() });
99
 
100
  await collections.conversations.updateOne(
101
  {
src/routes/r/[id]/+page.svelte CHANGED
@@ -5,13 +5,14 @@
5
  import ChatWindow from "$lib/components/chat/ChatWindow.svelte";
6
  import { ERROR_MESSAGES, error } from "$lib/stores/errors";
7
  import { pendingMessage } from "$lib/stores/pendingMessage";
 
8
  import type { PageData } from "./$types";
9
 
10
  export let data: PageData;
11
 
12
  let loading = false;
13
 
14
- async function createConversation(message: string) {
15
  try {
16
  loading = true;
17
  const res = await fetch(`${base}/conversation`, {
@@ -32,16 +33,11 @@
32
 
33
  const { conversationId } = await res.json();
34
 
35
- // Ugly hack to use a store as temp storage, feel free to improve ^^
36
- pendingMessage.set(message);
37
-
38
- // invalidateAll to update list of conversations
39
- await goto(`${base}/conversation/${conversationId}`, { invalidateAll: true });
40
  } catch (err) {
41
  error.set(ERROR_MESSAGES.default);
42
  console.error(String(err));
43
- } finally {
44
- loading = false;
45
  }
46
  }
47
 
@@ -65,8 +61,22 @@
65
  </svelte:head>
66
 
67
  <ChatWindow
68
- on:message={(ev) => createConversation(ev.detail)}
 
 
 
 
 
 
69
  on:share={shareConversation}
 
 
 
 
 
 
 
 
70
  messages={data.messages}
71
  {loading}
72
  />
 
5
  import ChatWindow from "$lib/components/chat/ChatWindow.svelte";
6
  import { ERROR_MESSAGES, error } from "$lib/stores/errors";
7
  import { pendingMessage } from "$lib/stores/pendingMessage";
8
+ import { pendingMessageIdToRetry } from "$lib/stores/pendingMessageIdToRetry";
9
  import type { PageData } from "./$types";
10
 
11
  export let data: PageData;
12
 
13
  let loading = false;
14
 
15
+ async function createConversation() {
16
  try {
17
  loading = true;
18
  const res = await fetch(`${base}/conversation`, {
 
33
 
34
  const { conversationId } = await res.json();
35
 
36
+ return conversationId;
 
 
 
 
37
  } catch (err) {
38
  error.set(ERROR_MESSAGES.default);
39
  console.error(String(err));
40
+ throw err;
 
41
  }
42
  }
43
 
 
61
  </svelte:head>
62
 
63
  <ChatWindow
64
+ on:message={(ev) =>
65
+ createConversation()
66
+ .then((convId) => {
67
+ $pendingMessage = ev.detail;
68
+ return goto(`${base}/conversation/${convId}`, { invalidateAll: true });
69
+ })
70
+ .finally(() => (loading = false))}
71
  on:share={shareConversation}
72
+ on:retry={(ev) =>
73
+ createConversation()
74
+ .then((convId) => {
75
+ $pendingMessageIdToRetry = ev.detail.id;
76
+ $pendingMessage = ev.detail.content;
77
+ return goto(`${base}/conversation/${convId}`, { invalidateAll: true });
78
+ })
79
+ .finally(() => (loading = false))}
80
  messages={data.messages}
81
  {loading}
82
  />