|
from copy import deepcopy |
|
from typing import Optional |
|
|
|
from inseq import load_model |
|
from inseq.commands.attribute_context.attribute_context_args import AttributeContextArgs |
|
from inseq.commands.attribute_context.attribute_context_helpers import ( |
|
AttributeContextOutput, |
|
filter_rank_tokens, |
|
get_filtered_tokens, |
|
) |
|
from inseq.models import HuggingfaceModel |
|
|
|
|
|
def get_formatted_attribute_context_results( |
|
model: HuggingfaceModel, |
|
args: AttributeContextArgs, |
|
output: AttributeContextOutput, |
|
) -> str: |
|
"""Format the results of the context attribution process.""" |
|
|
|
def format_context_comment( |
|
model: HuggingfaceModel, |
|
has_other_context: bool, |
|
special_tokens_to_keep: list[str], |
|
context: str, |
|
context_scores: list[float], |
|
other_context_scores: Optional[list[float]] = None, |
|
is_target: bool = False, |
|
) -> str: |
|
context_tokens = get_filtered_tokens( |
|
context, |
|
model, |
|
special_tokens_to_keep, |
|
replace_special_characters=True, |
|
is_target=is_target, |
|
) |
|
context_token_tuples = [(t, None) for t in context_tokens] |
|
scores = context_scores |
|
if has_other_context: |
|
scores += other_context_scores |
|
context_ranked_tokens, _ = filter_rank_tokens( |
|
tokens=context_tokens, |
|
scores=scores, |
|
std_threshold=args.attribution_std_threshold, |
|
topk=args.attribution_topk, |
|
) |
|
for idx, _, tok in context_ranked_tokens: |
|
context_token_tuples[idx] = (tok, "Influential context") |
|
return context_token_tuples |
|
|
|
out = [] |
|
output_current_tokens = get_filtered_tokens( |
|
output.output_current, |
|
model, |
|
args.special_tokens_to_keep, |
|
replace_special_characters=True, |
|
is_target=True, |
|
) |
|
for example_idx, cci_out in enumerate(output.cci_scores, start=1): |
|
curr_output_tokens = [(t, None) for t in output_current_tokens] |
|
cti_idx = cci_out.cti_idx |
|
curr_output_tokens[cti_idx] = ( |
|
curr_output_tokens[cti_idx][0], |
|
"Context sensitive", |
|
) |
|
if args.has_input_context: |
|
input_context_tokens = format_context_comment( |
|
model, |
|
args.has_output_context, |
|
args.special_tokens_to_keep, |
|
output.input_context, |
|
cci_out.input_context_scores, |
|
cci_out.output_context_scores, |
|
) |
|
if args.has_output_context: |
|
output_context_tokens = format_context_comment( |
|
model, |
|
args.has_input_context, |
|
args.special_tokens_to_keep, |
|
output.output_context, |
|
cci_out.output_context_scores, |
|
cci_out.input_context_scores, |
|
is_target=True, |
|
context_type="Output", |
|
) |
|
out += [ |
|
("\n\n" if example_idx > 1 else "", None), |
|
( |
|
f"#{example_idx}.\nGenerated output:\t", |
|
None, |
|
), |
|
] |
|
out += curr_output_tokens |
|
if args.has_input_context: |
|
out += [("\nInput context:\t", None)] |
|
out += input_context_tokens |
|
if args.has_output_context: |
|
out += [("\\Output context:\t", None)] |
|
out += output_context_tokens |
|
return out |
|
|
|
|
|
def get_tuples_from_output(output: AttributeContextOutput): |
|
model = load_model( |
|
output.info.model_name_or_path, |
|
output.info.attribution_method, |
|
model_kwargs=deepcopy(output.info.model_kwargs), |
|
tokenizer_kwargs=deepcopy(output.info.tokenizer_kwargs), |
|
) |
|
return get_formatted_attribute_context_results(model, output.info, output) |
|
|