pecore / utils.py
gsarti's picture
Updated
cf3d1b1
raw
history blame
3.81 kB
from copy import deepcopy
from typing import Optional
from inseq import load_model
from inseq.commands.attribute_context.attribute_context_args import AttributeContextArgs
from inseq.commands.attribute_context.attribute_context_helpers import (
AttributeContextOutput,
filter_rank_tokens,
get_filtered_tokens,
)
from inseq.models import HuggingfaceModel
def get_formatted_attribute_context_results(
model: HuggingfaceModel,
args: AttributeContextArgs,
output: AttributeContextOutput,
) -> str:
"""Format the results of the context attribution process."""
def format_context_comment(
model: HuggingfaceModel,
has_other_context: bool,
special_tokens_to_keep: list[str],
context: str,
context_scores: list[float],
other_context_scores: Optional[list[float]] = None,
is_target: bool = False,
) -> str:
context_tokens = get_filtered_tokens(
context,
model,
special_tokens_to_keep,
replace_special_characters=True,
is_target=is_target,
)
context_token_tuples = [(t, None) for t in context_tokens]
scores = context_scores
if has_other_context:
scores += other_context_scores
context_ranked_tokens, _ = filter_rank_tokens(
tokens=context_tokens,
scores=scores,
std_threshold=args.attribution_std_threshold,
topk=args.attribution_topk,
)
for idx, _, tok in context_ranked_tokens:
context_token_tuples[idx] = (tok, "Influential context")
return context_token_tuples
out = []
output_current_tokens = get_filtered_tokens(
output.output_current,
model,
args.special_tokens_to_keep,
replace_special_characters=True,
is_target=True,
)
for example_idx, cci_out in enumerate(output.cci_scores, start=1):
curr_output_tokens = [(t, None) for t in output_current_tokens]
cti_idx = cci_out.cti_idx
curr_output_tokens[cti_idx] = (
curr_output_tokens[cti_idx][0],
"Context sensitive",
)
if args.has_input_context:
input_context_tokens = format_context_comment(
model,
args.has_output_context,
args.special_tokens_to_keep,
output.input_context,
cci_out.input_context_scores,
cci_out.output_context_scores,
)
if args.has_output_context:
output_context_tokens = format_context_comment(
model,
args.has_input_context,
args.special_tokens_to_keep,
output.output_context,
cci_out.output_context_scores,
cci_out.input_context_scores,
is_target=True,
context_type="Output",
)
out += [
("\n\n" if example_idx > 1 else "", None),
(
f"#{example_idx}.\nGenerated output:\t",
None,
),
]
out += curr_output_tokens
if args.has_input_context:
out += [("\nInput context:\t", None)]
out += input_context_tokens
if args.has_output_context:
out += [("\\Output context:\t", None)]
out += output_context_tokens
return out
def get_tuples_from_output(output: AttributeContextOutput):
model = load_model(
output.info.model_name_or_path,
output.info.attribution_method,
model_kwargs=deepcopy(output.info.model_kwargs),
tokenizer_kwargs=deepcopy(output.info.tokenizer_kwargs),
)
return get_formatted_attribute_context_results(model, output.info, output)