Commit
•
ad02fa3
1
Parent(s):
e5eb656
✨ Save messages in backend (#31)
Browse files- .eslintrc.cjs +3 -0
- src/lib/buildPrompt.ts +25 -0
- src/lib/utils/streamToAsyncIterable.ts +15 -0
- src/lib/utils/sum.ts +3 -0
- src/routes/+page.svelte +1 -1
- src/routes/api/conversation/+server.ts +0 -19
- src/routes/conversation/[id]/+page.svelte +2 -20
- src/routes/conversation/[id]/+server.ts +110 -0
.eslintrc.cjs
CHANGED
@@ -12,6 +12,9 @@ module.exports = {
|
|
12 |
sourceType: 'module',
|
13 |
ecmaVersion: 2020
|
14 |
},
|
|
|
|
|
|
|
15 |
env: {
|
16 |
browser: true,
|
17 |
es2017: true,
|
|
|
12 |
sourceType: 'module',
|
13 |
ecmaVersion: 2020
|
14 |
},
|
15 |
+
rules: {
|
16 |
+
'no-shadow': ['error']
|
17 |
+
},
|
18 |
env: {
|
19 |
browser: true,
|
20 |
es2017: true,
|
src/lib/buildPrompt.ts
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import {
|
2 |
+
PUBLIC_ASSISTANT_MESSAGE_TOKEN,
|
3 |
+
PUBLIC_SEP_TOKEN,
|
4 |
+
PUBLIC_USER_MESSAGE_TOKEN
|
5 |
+
} from '$env/static/public';
|
6 |
+
import type { Message } from './types/Message';
|
7 |
+
|
8 |
+
/**
|
9 |
+
* Convert [{user: "assistant", content: "hi"}, {user: "user", content: "hello"}] to:
|
10 |
+
*
|
11 |
+
* <|assistant|>hi<|endoftext|><|prompter|>hello<|endoftext|><|assistant|>
|
12 |
+
*/
|
13 |
+
export function buildPrompt(messages: Message[]): string {
|
14 |
+
return (
|
15 |
+
messages
|
16 |
+
.map(
|
17 |
+
(m) =>
|
18 |
+
(m.from === 'user'
|
19 |
+
? PUBLIC_USER_MESSAGE_TOKEN + m.content
|
20 |
+
: PUBLIC_ASSISTANT_MESSAGE_TOKEN + m.content) +
|
21 |
+
(m.content.endsWith(PUBLIC_SEP_TOKEN) ? '' : PUBLIC_SEP_TOKEN)
|
22 |
+
)
|
23 |
+
.join('') + PUBLIC_ASSISTANT_MESSAGE_TOKEN
|
24 |
+
);
|
25 |
+
}
|
src/lib/utils/streamToAsyncIterable.ts
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Statements/for-await...of#iterating_over_async_generators
|
2 |
+
export async function* streamToAsyncIterable(
|
3 |
+
stream: ReadableStream<Uint8Array>
|
4 |
+
): AsyncIterableIterator<Uint8Array> {
|
5 |
+
const reader = stream.getReader();
|
6 |
+
try {
|
7 |
+
while (true) {
|
8 |
+
const { done, value } = await reader.read();
|
9 |
+
if (done) return;
|
10 |
+
yield value;
|
11 |
+
}
|
12 |
+
} finally {
|
13 |
+
reader.releaseLock();
|
14 |
+
}
|
15 |
+
}
|
src/lib/utils/sum.ts
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
export function sum(nums: number[]): number {
|
2 |
+
return nums.reduce((a, b) => a + b, 0);
|
3 |
+
}
|
src/routes/+page.svelte
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
<script lang="ts">
|
2 |
-
import { goto
|
3 |
import ChatWindow from '$lib/components/chat/ChatWindow.svelte';
|
4 |
import { pendingMessage } from '$lib/stores/pendingMessage';
|
5 |
|
|
|
1 |
<script lang="ts">
|
2 |
+
import { goto } from '$app/navigation';
|
3 |
import ChatWindow from '$lib/components/chat/ChatWindow.svelte';
|
4 |
import { pendingMessage } from '$lib/stores/pendingMessage';
|
5 |
|
src/routes/api/conversation/+server.ts
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
import { HF_TOKEN } from '$env/static/private';
|
2 |
-
import { PUBLIC_MODEL_ENDPOINT } from '$env/static/public';
|
3 |
-
|
4 |
-
export async function POST({ request, fetch }) {
|
5 |
-
const resp = await fetch(PUBLIC_MODEL_ENDPOINT, {
|
6 |
-
headers: {
|
7 |
-
'Content-Type': request.headers.get('Content-Type') ?? 'application/json',
|
8 |
-
Authorization: `Basic ${HF_TOKEN}`
|
9 |
-
},
|
10 |
-
method: 'POST',
|
11 |
-
body: await request.text()
|
12 |
-
});
|
13 |
-
|
14 |
-
return new Response(resp.body, {
|
15 |
-
headers: Object.fromEntries(resp.headers.entries()),
|
16 |
-
status: resp.status,
|
17 |
-
statusText: resp.statusText
|
18 |
-
});
|
19 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/routes/conversation/[id]/+page.svelte
CHANGED
@@ -4,23 +4,14 @@
|
|
4 |
import { onMount } from 'svelte';
|
5 |
import type { PageData } from './$types';
|
6 |
import { page } from '$app/stores';
|
7 |
-
import {
|
8 |
-
PUBLIC_ASSISTANT_MESSAGE_TOKEN,
|
9 |
-
PUBLIC_SEP_TOKEN,
|
10 |
-
PUBLIC_USER_MESSAGE_TOKEN
|
11 |
-
} from '$env/static/public';
|
12 |
import { HfInference } from '@huggingface/inference';
|
13 |
|
14 |
export let data: PageData;
|
15 |
|
16 |
$: messages = data.messages;
|
17 |
|
18 |
-
const userToken = PUBLIC_USER_MESSAGE_TOKEN;
|
19 |
-
const assistantToken = PUBLIC_ASSISTANT_MESSAGE_TOKEN;
|
20 |
-
const sepToken = PUBLIC_SEP_TOKEN;
|
21 |
-
|
22 |
const hf = new HfInference();
|
23 |
-
const model = hf.endpoint(
|
24 |
|
25 |
let loading = false;
|
26 |
|
@@ -76,16 +67,7 @@
|
|
76 |
|
77 |
messages = [...messages, { from: 'user', content: message }];
|
78 |
|
79 |
-
|
80 |
-
messages
|
81 |
-
.map(
|
82 |
-
(m) =>
|
83 |
-
(m.from === 'user' ? userToken + m.content : assistantToken + m.content) +
|
84 |
-
(m.content.endsWith(sepToken) ? '' : sepToken)
|
85 |
-
)
|
86 |
-
.join('') + assistantToken;
|
87 |
-
|
88 |
-
await getTextGenerationStream(inputs);
|
89 |
} finally {
|
90 |
loading = false;
|
91 |
}
|
|
|
4 |
import { onMount } from 'svelte';
|
5 |
import type { PageData } from './$types';
|
6 |
import { page } from '$app/stores';
|
|
|
|
|
|
|
|
|
|
|
7 |
import { HfInference } from '@huggingface/inference';
|
8 |
|
9 |
export let data: PageData;
|
10 |
|
11 |
$: messages = data.messages;
|
12 |
|
|
|
|
|
|
|
|
|
13 |
const hf = new HfInference();
|
14 |
+
const model = hf.endpoint($page.url.href);
|
15 |
|
16 |
let loading = false;
|
17 |
|
|
|
67 |
|
68 |
messages = [...messages, { from: 'user', content: message }];
|
69 |
|
70 |
+
await getTextGenerationStream(message);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
} finally {
|
72 |
loading = false;
|
73 |
}
|
src/routes/conversation/[id]/+server.ts
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { HF_TOKEN } from '$env/static/private';
|
2 |
+
import { PUBLIC_MODEL_ENDPOINT } from '$env/static/public';
|
3 |
+
import { buildPrompt } from '$lib/buildPrompt.js';
|
4 |
+
import { collections } from '$lib/server/database.js';
|
5 |
+
import type { Message } from '$lib/types/Message.js';
|
6 |
+
import { streamToAsyncIterable } from '$lib/utils/streamToAsyncIterable';
|
7 |
+
import { sum } from '$lib/utils/sum';
|
8 |
+
import { error } from '@sveltejs/kit';
|
9 |
+
import { ObjectId } from 'mongodb';
|
10 |
+
|
11 |
+
export async function POST({ request, fetch, locals, params }) {
|
12 |
+
// todo: add validation on params.id
|
13 |
+
const convId = new ObjectId(params.id);
|
14 |
+
|
15 |
+
const conv = await collections.conversations.findOne({
|
16 |
+
_id: convId,
|
17 |
+
sessionId: locals.sessionId
|
18 |
+
});
|
19 |
+
|
20 |
+
if (!conv) {
|
21 |
+
throw error(404, 'Conversation not found');
|
22 |
+
}
|
23 |
+
|
24 |
+
// Todo: validate prompt with zod? or aktype
|
25 |
+
const json = await request.json();
|
26 |
+
|
27 |
+
const messages = [...conv.messages, { from: 'user', content: json.inputs }] satisfies Message[];
|
28 |
+
|
29 |
+
json.inputs = buildPrompt(messages);
|
30 |
+
|
31 |
+
const resp = await fetch(PUBLIC_MODEL_ENDPOINT, {
|
32 |
+
headers: {
|
33 |
+
'Content-Type': request.headers.get('Content-Type') ?? 'application/json',
|
34 |
+
Authorization: `Basic ${HF_TOKEN}`
|
35 |
+
},
|
36 |
+
method: 'POST',
|
37 |
+
body: JSON.stringify(json)
|
38 |
+
});
|
39 |
+
|
40 |
+
const [stream1, stream2] = resp.body!.tee();
|
41 |
+
|
42 |
+
async function saveMessage() {
|
43 |
+
const generated_text = await parseGeneratedText(stream2);
|
44 |
+
|
45 |
+
messages.push({ from: 'assistant', content: generated_text });
|
46 |
+
|
47 |
+
console.log('updating conversation', convId, messages);
|
48 |
+
|
49 |
+
await collections.conversations.updateOne(
|
50 |
+
{
|
51 |
+
_id: convId
|
52 |
+
},
|
53 |
+
{
|
54 |
+
$set: {
|
55 |
+
messages,
|
56 |
+
updatedAt: new Date()
|
57 |
+
}
|
58 |
+
}
|
59 |
+
);
|
60 |
+
}
|
61 |
+
|
62 |
+
saveMessage().catch(console.error);
|
63 |
+
|
64 |
+
// Todo: maybe we should wait for the message to be saved before ending the response - in case of errors
|
65 |
+
return new Response(stream1, {
|
66 |
+
headers: Object.fromEntries(resp.headers.entries()),
|
67 |
+
status: resp.status,
|
68 |
+
statusText: resp.statusText
|
69 |
+
});
|
70 |
+
}
|
71 |
+
|
72 |
+
async function parseGeneratedText(stream: ReadableStream): Promise<string> {
|
73 |
+
const inputs: Uint8Array[] = [];
|
74 |
+
for await (const input of streamToAsyncIterable(stream)) {
|
75 |
+
inputs.push(input);
|
76 |
+
}
|
77 |
+
|
78 |
+
// Merge inputs into a single Uint8Array
|
79 |
+
const completeInput = new Uint8Array(sum(inputs.map((input) => input.length)));
|
80 |
+
let offset = 0;
|
81 |
+
for (const input of inputs) {
|
82 |
+
completeInput.set(input, offset);
|
83 |
+
offset += input.length;
|
84 |
+
}
|
85 |
+
|
86 |
+
// Get last line starting with "data:" and parse it as JSON to get the generated text
|
87 |
+
const message = new TextDecoder().decode(completeInput);
|
88 |
+
|
89 |
+
let lastIndex = message.lastIndexOf('\ndata:');
|
90 |
+
if (lastIndex === -1) {
|
91 |
+
lastIndex = message.indexOf('data');
|
92 |
+
}
|
93 |
+
|
94 |
+
if (lastIndex === -1) {
|
95 |
+
console.error('Could not parse in last message');
|
96 |
+
}
|
97 |
+
|
98 |
+
let lastMessage = message.slice(lastIndex).trim().slice('data:'.length);
|
99 |
+
if (lastMessage.includes('\n')) {
|
100 |
+
lastMessage = lastMessage.slice(0, lastMessage.indexOf('\n'));
|
101 |
+
}
|
102 |
+
|
103 |
+
const res = JSON.parse(lastMessage).generated_text;
|
104 |
+
|
105 |
+
if (typeof res !== 'string') {
|
106 |
+
throw new Error('Could not parse generated text');
|
107 |
+
}
|
108 |
+
|
109 |
+
return res;
|
110 |
+
}
|