yue-jobs / wrapper.py
multimodalart's picture
Update wrapper.py
d28357f verified
raw
history blame
4.4 kB
import argparse
import tempfile
import os
import subprocess
import sys
def main():
parser = argparse.ArgumentParser(description='Run YuE model with direct input')
parser.add_argument('--genre', type=str, required=True, help='Genre tags for the music')
parser.add_argument('--lyrics', type=str, required=True, help='Lyrics for the music')
parser.add_argument('--run_n_segments', type=int, default=2, help='Number of segments to process')
parser.add_argument('--stage2_batch_size', type=int, default=4, help='Batch size for stage 2')
parser.add_argument('--max_new_tokens', type=int, default=3000, help='Maximum number of new tokens')
parser.add_argument('--cuda_idx', type=int, default=0, help='CUDA device index')
args = parser.parse_args()
print("\n=== Starting YuE Inference ===")
current_dir = os.path.dirname(os.path.abspath(__file__))
inference_dir = os.path.dirname(os.path.abspath(__file__))
print(f"\nCurrent directory: {os.getcwd()}")
print(f"Script directory: {current_dir}")
print(f"Inference directory: {inference_dir}")
# Create temporary files for genre and lyrics
with tempfile.NamedTemporaryFile(mode='w', delete=False) as genre_file:
genre_file.write(args.genre)
genre_path = genre_file.name
print(f"\nCreated genre file at: {genre_path}")
print(f"Genre content: {args.genre}")
with tempfile.NamedTemporaryFile(mode='w', delete=False) as lyrics_file:
lyrics_file.write(args.lyrics)
lyrics_path = lyrics_file.name
print(f"\nCreated lyrics file at: {lyrics_path}")
print(f"Lyrics content: {args.lyrics}")
output_dir = '/home/user/app/output'
try:
# Go to inference directory for running the script
os.chdir(inference_dir)
print(f"\nChanged working directory to: {os.getcwd()}")
infer_script = 'infer.py'
print(f"\nInference script path: {infer_script}")
print(f"Script exists: {os.path.exists(infer_script)}")
tokenizer_path = './mm_tokenizer_v0.2_hf/tokenizer.model'
print(f"\nChecking tokenizer at: {tokenizer_path}")
print(f"Tokenizer exists: {os.path.exists(tokenizer_path)}")
if os.path.exists('./mm_tokenizer_v0.2_hf'):
print("Tokenizer directory contents:")
print(os.listdir('./mm_tokenizer_v0.2_hf'))
else:
print("WARNING: Tokenizer directory not found!")
print("\nExecuting inference command...")
command = [
'python', infer_script,
'--stage1_model', 'm-a-p/YuE-s1-7B-anneal-en-cot',
'--stage2_model', 'm-a-p/YuE-s2-1B-general',
'--genre_txt', genre_path,
'--lyrics_txt', lyrics_path,
'--run_n_segments', str(args.run_n_segments),
'--stage2_batch_size', str(args.stage2_batch_size),
'--output_dir', output_dir,
'--cuda_idx', str(args.cuda_idx),
'--max_new_tokens', str(args.max_new_tokens)
]
print(f"Command: {' '.join(command)}")
result = subprocess.run(command,
check=True,
capture_output=True,
text=True)
print("\nInference completed successfully!")
print("\nStdout:")
print(result.stdout)
if result.stderr:
print("\nStderr:")
print(result.stderr)
print(f"\nOutput directory: {output_dir}")
if os.path.exists(output_dir):
print("Generated files:")
for file in os.listdir(output_dir):
file_path = os.path.join(output_dir, file)
print(f"- {file_path} ({os.path.getsize(file_path)} bytes)")
else:
print("WARNING: Output directory does not exist!")
except subprocess.CalledProcessError as e:
print("\nError running inference script:")
print(f"Exit code: {e.returncode}")
print("\nStdout:")
print(e.stdout)
print("\nStderr:")
print(e.stderr)
raise
finally:
# Clean up temporary files
print("\nCleaning up temporary files...")
os.unlink(genre_path)
os.unlink(lyrics_path)
if __name__ == '__main__':
main()