geekyrakshit commited on
Commit
cfcefce
·
1 Parent(s): 7b10546

update: chat app + llama guard guardrail

Browse files
application_pages/chat_app.py CHANGED
@@ -29,6 +29,8 @@ def initialize_session_state():
29
  st.session_state.test_guardrails = False
30
  if "llm_model" not in st.session_state:
31
  st.session_state.llm_model = None
 
 
32
 
33
 
34
  def initialize_guardrails():
@@ -89,6 +91,23 @@ def initialize_guardrails():
89
  guardrail_name,
90
  )(should_anonymize=True)
91
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  else:
93
  st.session_state.guardrails.append(
94
  getattr(
 
29
  st.session_state.test_guardrails = False
30
  if "llm_model" not in st.session_state:
31
  st.session_state.llm_model = None
32
+ if "llama_guard_checkpoint_name" not in st.session_state:
33
+ st.session_state.llama_guard_checkpoint_name = ""
34
 
35
 
36
  def initialize_guardrails():
 
91
  guardrail_name,
92
  )(should_anonymize=True)
93
  )
94
+ elif guardrail_name == "PromptInjectionLlamaGuardrail":
95
+ llama_guard_checkpoint_name = st.sidebar.text_input(
96
+ "Checkpoint Name", value=""
97
+ )
98
+ st.session_state.llama_guard_checkpoint_name = llama_guard_checkpoint_name
99
+ st.session_state.guardrails.append(
100
+ getattr(
101
+ importlib.import_module("guardrails_genie.guardrails"),
102
+ guardrail_name,
103
+ )(
104
+ checkpoint=(
105
+ None
106
+ if st.session_state.llama_guard_checkpoint_name == ""
107
+ else st.session_state.llama_guard_checkpoint_name
108
+ )
109
+ )
110
+ )
111
  else:
112
  st.session_state.guardrails.append(
113
  getattr(
guardrails_genie/guardrails/injection/llama_prompt_guardrail.py CHANGED
@@ -76,7 +76,7 @@ class PromptInjectionLlamaGuardrail(Guardrail):
76
  if self.checkpoint is None:
77
  self._model = AutoModelForSequenceClassification.from_pretrained(
78
  self.model_name
79
- ).to(self.device)
80
  else:
81
  api = wandb.Api()
82
  artifact = api.artifact(self.checkpoint.removeprefix("wandb://"))
 
76
  if self.checkpoint is None:
77
  self._model = AutoModelForSequenceClassification.from_pretrained(
78
  self.model_name
79
+ )
80
  else:
81
  api = wandb.Api()
82
  artifact = api.artifact(self.checkpoint.removeprefix("wandb://"))