|
import gradio as gr |
|
import os |
|
from huggingface_hub import InferenceClient |
|
from huggingface_hub.inference._generated.types.chat_completion import ChatCompletionStreamOutput |
|
|
|
MODEL = "nomiChroma3.1" |
|
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") |
|
|
|
def respond( |
|
message: str, |
|
chat_history: list[tuple[str, str]], |
|
max_tokens: int, |
|
temperature: float, |
|
top_p: float, |
|
) -> tuple[list[tuple[str, str]], str]: |
|
""" |
|
Generate a response and update chat history. |
|
Returns tuple of (new_history, None) to clear input box. |
|
""" |
|
system_message = "You are a maritime legal assistant with expertise strictly in Indian maritime law. Provide detailed legal advice and information within word limit based on Indian maritime legal principles and regulations." |
|
|
|
messages = [{"role": "system", "content": system_message}] |
|
for user_msg, assistant_msg in chat_history: |
|
messages.extend([ |
|
{"role": "user", "content": user_msg}, |
|
{"role": "assistant", "content": assistant_msg} |
|
]) |
|
messages.append({"role": "user", "content": message}) |
|
|
|
chat_history = chat_history + [(message, None)] |
|
response = "" |
|
|
|
try: |
|
for chunk in client.chat_completion( |
|
messages, |
|
max_tokens=max_tokens, |
|
stream=True, |
|
temperature=temperature, |
|
top_p=top_p, |
|
): |
|
try: |
|
if isinstance(chunk, ChatCompletionStreamOutput): |
|
content = chunk.choices[0].delta.content |
|
if content: |
|
response += content |
|
chat_history[-1] = (message, response) |
|
yield chat_history, "" |
|
if chunk.choices[0].finish_reason == 'stop': |
|
break |
|
elif isinstance(chunk, dict): |
|
content = chunk.get('choices', [{}])[0].get('delta', {}).get('content') |
|
if content: |
|
response += content |
|
chat_history[-1] = (message, response) |
|
yield chat_history, "" |
|
if chunk.get('choices', [{}])[0].get('finish_reason') == 'stop': |
|
break |
|
elif isinstance(chunk, str) and chunk.strip(): |
|
response += chunk |
|
chat_history[-1] = (message, response) |
|
yield chat_history, "" |
|
|
|
except Exception as e: |
|
print(f"Error processing chunk: {e}") |
|
continue |
|
|
|
if not response: |
|
chat_history[-1] = (message, "I apologize, but I couldn't generate a response. Please try again.") |
|
|
|
yield chat_history, "" |
|
|
|
except Exception as e: |
|
error_msg = f"An error occurred: {str(e)}" |
|
chat_history[-1] = (message, error_msg) |
|
yield chat_history, "" |
|
|
|
|
|
|
|
|
|
custom_css = """ |
|
/* Global styles */ |
|
.gradio-container { |
|
background-color: #1a365d !important; |
|
font-family: 'Inter', -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Oxygen-Sans, Ubuntu, Cantarell, "Helvetica Neue", sans-serif !important; |
|
} |
|
/* Header styling */ |
|
.header-container { |
|
text-align: center; |
|
padding: 1rem 0; |
|
margin-bottom: 1rem; |
|
border-bottom: 2px solid rgba(255, 255, 255, 0.1); |
|
} |
|
.header-title { |
|
color: #ffffff; |
|
font-size: 2rem; |
|
margin-bottom: 0.3rem; |
|
font-family: inherit; |
|
} |
|
.header-subtitle { |
|
color: #e6f3ff; |
|
font-size: 1rem; |
|
margin-bottom: 0.2rem; |
|
font-family: inherit; |
|
} |
|
/* Sidebar styling */ |
|
.sidebar { |
|
background: #e6f3ff !important; |
|
border-radius: 8px !important; |
|
padding: 15px !important; |
|
border: 1px solid rgba(176, 226, 255, 0.2) !important; |
|
height: fit-content !important; |
|
} |
|
.sidebar-title { |
|
color: #1a365d !important; |
|
font-size: 1.1rem !important; |
|
margin-bottom: 0.8rem !important; |
|
padding-bottom: 0.4rem !important; |
|
border-bottom: 2px solid rgba(26, 54, 93, 0.1) !important; |
|
font-family: inherit !important; |
|
} |
|
/* Example queries styling */ |
|
.example-queries { |
|
margin-bottom: 1.5rem !important; |
|
} |
|
.example-query-button { |
|
background-color: #cce7ff !important; |
|
color: #1a365d !important; |
|
border: none !important; |
|
margin: 3px 0 !important; |
|
padding: 6px 10px !important; |
|
border-radius: 4px !important; |
|
text-align: left !important; |
|
width: 100% !important; |
|
cursor: pointer !important; |
|
transition: background-color 0.3s ease !important; |
|
font-size: 0.9rem !important; |
|
font-family: inherit !important; |
|
} |
|
.example-query-button:hover { |
|
background-color: #b0e2ff !important; |
|
} |
|
/* Chat container */ |
|
.chat-container { |
|
background: #e6f3ff !important; |
|
border-radius: 8px !important; |
|
padding: 15px !important; |
|
height: 300px !important; |
|
overflow-y: auto !important; |
|
border: 1px solid rgba(176, 226, 255, 0.2) !important; |
|
backdrop-filter: blur(10px) !important; |
|
font-family: inherit !important; |
|
} |
|
/* Message styling */ |
|
.message.user, .message.bot { |
|
padding: 8px 12px !important; |
|
margin: 6px 0 !important; |
|
border-radius: 6px !important; |
|
color: #1a365d !important; |
|
font-size: 0.9rem !important; |
|
font-family: inherit !important; |
|
line-height: 1.5 !important; |
|
} |
|
.message.user { |
|
background-color: #cce7ff !important; |
|
} |
|
.message.bot { |
|
background-color: #e6f3ff !important; |
|
} |
|
/* Chat input styling */ |
|
textarea { |
|
background-color: #e6f3ff !important; |
|
border: 1px solid rgba(176, 226, 255, 0.3) !important; |
|
border-radius: 6px !important; |
|
padding: 8px !important; |
|
color: #1a365d !important; |
|
font-size: 0.9rem !important; |
|
font-family: inherit !important; |
|
} |
|
/* Button styling */ |
|
.gr-button { |
|
background-color: #cce7ff !important; |
|
color: #1a365d !important; |
|
border: none !important; |
|
padding: 6px 12px !important; |
|
font-size: 0.9rem !important; |
|
font-family: inherit !important; |
|
border-radius: 4px !important; |
|
} |
|
.gr-button:hover { |
|
background-color: #1a365d !important; |
|
color: #ffffff !important; |
|
} |
|
/* Markdown text styling */ |
|
.prose { |
|
font-family: inherit !important; |
|
} |
|
/* All text elements */ |
|
p, span, div { |
|
font-family: inherit !important; |
|
} |
|
""" |
|
|
|
def handle_example_click(example_query: str): |
|
"""Handle example query click by returning the query and empty chat history""" |
|
return example_query, [] |
|
|
|
|
|
with gr.Blocks(css=custom_css, theme=gr.themes.Base()) as demo: |
|
|
|
gr.HTML(""" |
|
<div class="header-container"> |
|
<h1 class="header-title">Maritime Legal Compliance</h1> |
|
<p class="header-subtitle">AI-powered assistance for Indian maritime law queries</p> |
|
<p class="header-subtitle">This chatbot uses Fine-tuned LLAMA-3.1 model personalised specifically to provide assistance with Indian maritime legal queries.</p> |
|
</div> |
|
""") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=1, elem_classes="sidebar"): |
|
gr.Markdown("### Example Queries", elem_classes="sidebar-title") |
|
|
|
example_queries = [ |
|
"What are the key regulations governing ports in India?", |
|
"Explain the concept of cabotage in Indian maritime law.", |
|
"What are the legal requirements for registering a vessel in India?", |
|
"What are the environmental regulations for ships in Indian waters?", |
|
"Explain the Maritime Labour Convention implementation in India.", |
|
"What are the rules for coastal cargo transportation in India?" |
|
] |
|
|
|
with gr.Column(elem_classes="example-queries"): |
|
example_buttons = [gr.Button(query, elem_classes="example-query-button") for query in example_queries] |
|
|
|
gr.Markdown("### Configuration", elem_classes="sidebar-title") |
|
max_tokens = gr.Slider( |
|
minimum=1, |
|
maximum=2048, |
|
value=512, |
|
step=1, |
|
label="Response Length" |
|
) |
|
temperature = gr.Slider( |
|
minimum=0.1, |
|
maximum=4.0, |
|
value=0.7, |
|
step=0.1, |
|
label="Creativity Level" |
|
) |
|
top_p = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.95, |
|
step=0.05, |
|
label="Response Focus" |
|
) |
|
|
|
|
|
with gr.Column(scale=3): |
|
chatbot = gr.Chatbot(height=300, elem_classes="chat-container") |
|
msg = gr.Textbox( |
|
show_label=False, |
|
placeholder="Type your maritime law query here...", |
|
container=False |
|
) |
|
with gr.Row(): |
|
submit = gr.Button("Send", variant="primary") |
|
clear = gr.Button("Clear") |
|
|
|
|
|
msg.submit( |
|
fn=respond, |
|
inputs=[msg, chatbot, max_tokens, temperature, top_p], |
|
outputs=[chatbot, msg] |
|
) |
|
|
|
submit.click( |
|
fn=respond, |
|
inputs=[msg, chatbot, max_tokens, temperature, top_p], |
|
outputs=[chatbot, msg] |
|
) |
|
|
|
clear.click( |
|
fn=lambda: ([], ""), |
|
inputs=None, |
|
outputs=[chatbot, msg], |
|
queue=False |
|
) |
|
|
|
|
|
for button in example_buttons: |
|
|
|
button.click( |
|
fn=handle_example_click, |
|
inputs=[button], |
|
outputs=[msg, chatbot], |
|
queue=False |
|
).then( |
|
fn=respond, |
|
inputs=[msg, chatbot, max_tokens, temperature, top_p], |
|
outputs=[chatbot, msg] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |