File size: 4,187 Bytes
7b856a8
 
 
4e3dc76
f5ec828
872b692
f5ec828
 
 
 
872b692
f5ec828
 
 
 
 
872b692
4e3dc76
f5ec828
4e3dc76
 
 
 
 
7b856a8
4e3dc76
 
 
f5ec828
4e3dc76
 
 
7b856a8
 
 
 
f5ec828
872b692
f5ec828
7b856a8
 
 
 
4e3dc76
 
7b856a8
 
 
 
872b692
7b856a8
 
 
 
 
 
 
 
f5ec828
872b692
f5ec828
 
7b856a8
 
 
 
 
 
 
872b692
4e3dc76
 
 
 
 
 
 
f5ec828
872b692
7b856a8
4e3dc76
7b856a8
4e3dc76
7b856a8
4e3dc76
 
 
 
 
7b856a8
 
 
 
f5ec828
872b692
7b856a8
4e3dc76
 
 
 
 
 
 
 
 
 
 
 
f5ec828
872b692
f5ec828
7b856a8
 
4e3dc76
 
 
7b856a8
 
 
f5ec828
 
872b692
f5ec828
 
 
7b856a8
 
 
f5ec828
872b692
f5ec828
7b856a8
 
4e3dc76
 
 
 
 
7b856a8
 
 
 
f5ec828
872b692
7b856a8
f5ec828
7b856a8
f5ec828
7b856a8
4e3dc76
 
 
 
 
 
 
 
7b856a8
 
 
 
f5ec828
 
 
7b856a8
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f7e3a8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q git+https://github.com/srush/MiniChain\n",
    "!git clone https://github.com/srush/MiniChain; cp -fr MiniChain/examples/* . "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49443595",
   "metadata": {
    "lines_to_next_cell": 2,
    "tags": [
     "hide_inp"
    ]
   },
   "outputs": [],
   "source": [
    "desc = \"\"\"\n",
    "### Question Answering with Retrieval\n",
    "\n",
    "Chain that answers questions with embeedding based retrieval. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/qa.ipynb)\n",
    "\n",
    "(Adapted from [OpenAI Notebook](https://github.com/openai/openai-cookbook/blob/main/examples/Question_answering_using_embeddings.ipynb).)\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5183ea7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "import numpy as np\n",
    "from minichain import prompt, show, OpenAIEmbed, OpenAI\n",
    "from manifest import Manifest"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2bf59f0d",
   "metadata": {},
   "source": [
    "We use Hugging Face Datasets as the database by assigning\n",
    "a FAISS index."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f371a85e",
   "metadata": {},
   "outputs": [],
   "source": [
    "olympics = datasets.load_from_disk(\"olympics.data\")\n",
    "olympics.add_faiss_index(\"embeddings\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1099002",
   "metadata": {},
   "source": [
    "Fast KNN retieval prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6881ae0e",
   "metadata": {
    "lines_to_next_cell": 1
   },
   "outputs": [],
   "source": [
    "@prompt(OpenAIEmbed())\n",
    "def get_neighbors(model, inp, k):\n",
    "    embedding = model(inp)\n",
    "    res = olympics.get_nearest_examples(\"embeddings\", np.array(embedding), k)\n",
    "    return res.examples[\"content\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59cc1355",
   "metadata": {
    "lines_to_next_cell": 1
   },
   "outputs": [],
   "source": [
    "@prompt(OpenAI(),\n",
    "        template_file=\"qa.pmpt.tpl\")\n",
    "def get_result(model, query, neighbors):\n",
    "    return model(dict(question=query, docs=neighbors))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb2f1101",
   "metadata": {},
   "outputs": [],
   "source": [
    "def qa(query):\n",
    "    n = get_neighbors(query, 3)\n",
    "    return get_result(query, n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f70bac7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abdfcd87",
   "metadata": {},
   "outputs": [],
   "source": [
    "questions = [\"Who won the 2020 Summer Olympics men's high jump?\",\n",
    "             \"Why was the 2020 Summer Olympics originally postponed?\",\n",
    "             \"In the 2020 Summer Olympics, how many gold medals did the country which won the most medals win?\",\n",
    "             \"What is the total number of medals won by France?\",\n",
    "             \"What is the tallest mountain in the world?\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddce3ec3",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "gradio = show(qa,\n",
    "              examples=questions,\n",
    "              subprompts=[get_neighbors, get_result],\n",
    "              description=desc,\n",
    "              )\n",
    "if __name__ == \"__main__\":\n",
    "    gradio.launch()"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "tags,-all",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}