ppsingh commited on
Commit
6d574b3
·
verified ·
1 Parent(s): 6f62fcd

Update auditqa/retriever.py

Browse files
Files changed (1) hide show
  1. auditqa/retriever.py +9 -9
auditqa/retriever.py CHANGED
@@ -9,20 +9,20 @@ model_config = getconfig("model_params.cfg")
9
  def create_filter(reports:list = [],sources:str =None,
10
  subtype:str =None,year:str =None):
11
  if len(reports) == 0:
12
- print("defining filter for sources:{},subtype:{},year:{}".format(sources,subtype,year))
13
  filter=rest.Filter(
14
  must=[rest.FieldCondition(
15
  key="metadata.source",
16
  match=rest.MatchValue(value=sources)
17
  ),
18
  rest.FieldCondition(
19
- key="metadata.subtype",
20
  match=rest.MatchValue(value=subtype)
21
  ),
22
- rest.FieldCondition(
23
- key="metadata.year",
24
- match=rest.MatchAny(any=year)
25
- ),])
26
  else:
27
  print("defining filter for allreports:",reports)
28
  filter=rest.Filter(
@@ -35,13 +35,13 @@ def create_filter(reports:list = [],sources:str =None,
35
  return filter
36
 
37
 
38
- def get_context(vectorstore,query,reports,sources,subtype,year):
39
  # create metadata filter
40
- filter = create_filter(reports=reports,sources=sources,subtype=subtype,year=year)
41
 
42
  # getting context
43
  retriever = vectorstore.as_retriever(search_type="similarity_score_threshold",
44
- search_kwargs={"score_threshold": 0.6,
45
  "k": int(model_config.get('retriever','TOP_K')),
46
  "filter":filter})
47
  # re-ranking the retrieved results
 
9
  def create_filter(reports:list = [],sources:str =None,
10
  subtype:str =None,year:str =None):
11
  if len(reports) == 0:
12
+ print("defining filter for sources:{},subtype:{},year:{}".format(sources,subtype))
13
  filter=rest.Filter(
14
  must=[rest.FieldCondition(
15
  key="metadata.source",
16
  match=rest.MatchValue(value=sources)
17
  ),
18
  rest.FieldCondition(
19
+ key="metadata.filename",
20
  match=rest.MatchValue(value=subtype)
21
  ),
22
+ #rest.FieldCondition(
23
+ # key="metadata.year",
24
+ # match=rest.MatchAny(any=year)
25
+ ])
26
  else:
27
  print("defining filter for allreports:",reports)
28
  filter=rest.Filter(
 
35
  return filter
36
 
37
 
38
+ def get_context(vectorstore,query,reports,sources,subtype):
39
  # create metadata filter
40
+ filter = create_filter(reports=reports,sources=sources,subtype=subtype)
41
 
42
  # getting context
43
  retriever = vectorstore.as_retriever(search_type="similarity_score_threshold",
44
+ search_kwargs={"score_threshold": 0.4,
45
  "k": int(model_config.get('retriever','TOP_K')),
46
  "filter":filter})
47
  # re-ranking the retrieved results