desc = """ ### Typed Extraction Information extraction that is automatically generated from a typed specification. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/pal.ipynb) (Novel to MiniChain) """ # $ from minichain import prompt, show, OpenAI, transform from dataclasses import dataclass, is_dataclass, fields from typing import List, Type, Dict, Any, get_origin, get_args from enum import Enum from jinja2 import select_autoescape, FileSystemLoader, Environment import json def enum(x: Type[Enum]) -> Dict[str, int]: d = {e.name: e.value for e in x} return d def walk(x: Any) -> Any: if issubclass(x if get_origin(x) is None else get_origin(x), List): return {"_t_": "list", "t": walk(get_args(x)[0])} if issubclass(x, Enum): return enum(x) if is_dataclass(x): return {y.name: walk(y.type) for y in fields(x)} return x.__name__ def type_to_prompt(out: type) -> str: tmp = env.get_template("type_prompt.pmpt.tpl") d = walk(out) return tmp.render({"typ": d}) env = Environment( loader=FileSystemLoader("."), autoescape=select_autoescape(), extensions=["jinja2_highlight.HighlightExtension"], ) # Data specification # + class StatType(Enum): POINTS = 1 REBOUNDS = 2 ASSISTS = 3 @dataclass class Stat: value: int stat: StatType @dataclass class Player: player: str stats: List[Stat] # - @prompt(OpenAI(), template_file="stats.pmpt.tpl") def stats(model, passage): return model.stream(dict(passage=passage, typ=type_to_prompt(Player))) @transform() def to_data(s:str): return [Player(**j) for j in json.loads(s)] # $ article = open("sixers.txt").read() gradio = show(lambda passage: to_data(stats(passage)), examples=[article], subprompts=[stats], out_type="json", description=desc, code=open("stats.py", "r").read().split("$")[1].strip().strip("#").strip(), ) if __name__ == "__main__": gradio.queue().launch() # ExtractionPrompt().show({"passage": "Harden had 10 rebounds."}, # '[{"player": "Harden", "stats": {"value": 10, "stat": 2}}]') # # View the run log. # minichain.show_log("bash.log")