OmPrakashSingh1704 commited on
Commit
6795bf0
1 Parent(s): a122170

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -89
app.py CHANGED
@@ -1,20 +1,15 @@
1
  import streamlit as st
2
- import re,torch
3
- import json,os
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  from datetime import datetime
6
- from huggingface_hub import login
 
7
 
8
  login(token=os.getenv("TOKEN"))
9
 
10
- # Load model and tokenizer
11
- model = AutoModelForCausalLM.from_pretrained(
12
- "google/gemma-2b",
13
- torch_dtype="auto",
14
- device_map="auto",
15
- )
16
 
17
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
18
  if 'recipe' not in st.session_state:
19
  st.session_state.recipe = None
20
 
@@ -59,88 +54,18 @@ def create_detailed_prompt(user_direction, exclusions, serving_size, difficulty)
59
 
60
  def generate_recipe(user_inputs):
61
  with st.spinner('Building the perfect recipe...'):
62
- # provide_recipe_schema = {
63
- # 'type': 'function',
64
- # 'function': {
65
- # 'name': 'provide_recipe',
66
- # 'description': 'Provides a detailed recipe strictly adhering to the user input/specifications, especially ingredient exclusions and the recipe difficulty',
67
- # 'parameters': {
68
- # 'type': 'object',
69
- # 'properties': {
70
- # 'name': {
71
- # 'type': 'string',
72
- # 'description': 'A creative name for the recipe'
73
- # },
74
- # 'description': {
75
- # 'type': 'string',
76
- # 'description': 'a brief one-sentence description of the provided recipe'
77
- # },
78
- # 'ingredients': {
79
- # 'type': 'array',
80
- # 'items': {
81
- # 'type': 'object',
82
- # 'properties': {
83
- # 'name': {
84
- # 'type': 'string',
85
- # 'description': 'Quantity and name of the ingredient'
86
- # }
87
- # }
88
- # }
89
- # },
90
- # 'instructions': {
91
- # 'type': 'array',
92
- # 'items': {
93
- # 'type': 'object',
94
- # 'properties': {
95
- # 'step_number': {
96
- # 'type': 'number',
97
- # 'description': 'The sequence number of this step'
98
- # },
99
- # 'instruction': {
100
- # 'type': 'string',
101
- # 'description': 'Detailed description of what to do in this step'
102
- # }
103
- # }
104
- # }
105
- # }
106
- # },
107
- # 'required': [
108
- # 'name',
109
- # 'description',
110
- # 'ingredients',
111
- # 'instructions'
112
- # ]
113
- # }
114
- # }
115
- # }
116
  prompt = create_detailed_prompt(user_inputs['user_direction'], user_inputs['exclusions'], user_inputs['serving_size'], user_inputs['difficulty'])
117
- # messages = [{"role": "user", "content": prompt}]
118
- # tool_section = "\n".join([f"{tool['function']['name']}({json.dumps(tool['function']['parameters'])})" for tool in [provide_recipe_schema]])
119
- # text = f"{prompt}\n\nTools:\n{tool_section}"
120
 
121
- # Tokenize and move to the correct device
122
- model_inputs = tokenizer(prompt, return_tensors="pt")
123
- # text = tokenizer.apply_chat_template(
124
- # messages,
125
- # tokenize=False,
126
- # add_generation_prompt=True,
127
- # tools=[provide_recipe_schema]
128
- # )
129
-
130
- # Tokenize and move to the correct device
131
- # model_inputs = tokenizer([text], return_tensors="pt")
132
- # torch.cuda.empty_cache()
133
- # with torch.no_grad():
134
- generated_ids = model.generate(
135
- **model_inputs,
136
- # max_new_tokens=512,
137
  )
138
 
139
- # generated_ids = [
140
- # output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
141
- # ]
142
-
143
- st.session_state.recipe = tokenizer.decode(generated_ids[0])
144
  st.session_state.recipe_saved = False
145
 
146
  def clear_inputs():
@@ -197,7 +122,7 @@ st.session_state.exclusions = st.text_area(
197
  placeholder="gluten, dairy, nuts, cilantro",
198
  )
199
 
200
- fancy_exclusions =""
201
 
202
  if st.session_state.selected_difficulty == "Professional":
203
  exclude_fancy = st.checkbox(
 
1
  import streamlit as st
2
+ import re, torch, json, os
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from datetime import datetime
5
+ from huggingface_hub import login, InferenceClient
6
+ import random
7
 
8
  login(token=os.getenv("TOKEN"))
9
 
10
+ # Initialize the inference client for the Mixtral model
11
+ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
 
 
 
 
12
 
 
13
  if 'recipe' not in st.session_state:
14
  st.session_state.recipe = None
15
 
 
54
 
55
  def generate_recipe(user_inputs):
56
  with st.spinner('Building the perfect recipe...'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  prompt = create_detailed_prompt(user_inputs['user_direction'], user_inputs['exclusions'], user_inputs['serving_size'], user_inputs['difficulty'])
 
 
 
58
 
59
+ generate_kwargs = dict(
60
+ temperature=0.9,
61
+ max_new_tokens=1000,
62
+ top_p=0.9,
63
+ repetition_penalty=1.0,
64
+ do_sample=True,
 
 
 
 
 
 
 
 
 
 
65
  )
66
 
67
+ response = client.text_generation(prompt, **generate_kwargs)
68
+ st.session_state.recipe = response['generated_text']
 
 
 
69
  st.session_state.recipe_saved = False
70
 
71
  def clear_inputs():
 
122
  placeholder="gluten, dairy, nuts, cilantro",
123
  )
124
 
125
+ fancy_exclusions = ""
126
 
127
  if st.session_state.selected_difficulty == "Professional":
128
  exclude_fancy = st.checkbox(