Issues Generating Paired Sequences

#1
by ollieturnbull - opened

Hi,
I'm trying to generate paired sequences using the sampling parameters described in the paper, but the generation stops after the heavy chain is generated (i.e. there is an <|endoftext|>). How can I ensure the generation of full sequences?
Thanks,
Ollie

Generation code:

sample = model.generate(
             tokenizer("Ḣ", return_tensors='pt')['input_ids'][:, :-1],
             max_new_tokens=256,
             top_p=0.95,
             temperature = 1,
             do_sample=True, )
Alchemab Therapeutics org

Thanks Ollie - apologies if there was any confusion from the paper (which was benchmarked using FAbCon-large). For this, you can either append the light chain token type after the heavy chain generations, or use a LogitsProcessor. Hope that helps!

ideasbyjin changed discussion status to closed

Here’s the code I used. Hope it helps!

from transformers import PreTrainedTokenizerFast, FalconForCausalLM
import torch

# Load tokenizer and model
tokenizer = PreTrainedTokenizerFast.from_pretrained("alchemab/fabcon-medium")
model = FalconForCausalLM.from_pretrained("alchemab/fabcon-medium")

# Generate heavy chain sequence
heavy_input = tokenizer("Ḣ", return_tensors='pt')['input_ids']
heavy_output = model.generate(
    heavy_input,
    max_new_tokens=256,
    do_sample=True,
    top_k=50,
    temperature=1.0,
    pad_token_id=tokenizer.eos_token_id
)

# Append light chain token and generate light chain sequence
light_input = torch.cat([heavy_output, tokenizer("Ḷ", return_tensors='pt')['input_ids']], dim=-1)
full_output = model.generate(
    light_input,
    max_new_tokens=256,
    do_sample=True,
    top_k=50,
    temperature=1.0,
    pad_token_id=tokenizer.eos_token_id
)

# Decode the generated sequence
decoded_seq = tokenizer.decode(full_output[0], skip_special_tokens=False)

print(decoded_seq)

Sign up or log in to comment