ZeyuXie commited on
Commit
cb0c99a
·
verified ·
1 Parent(s): db42a2b

Upload 5 files

Browse files
Files changed (4) hide show
  1. app.py +4 -2
  2. inference.py +18 -7
  3. llm_preprocess.py +1 -1
  4. requirements.txt +29 -29
app.py CHANGED
@@ -52,13 +52,15 @@ with gr.Blocks() as demo:
52
  with gr.Row():
53
  gr.Markdown("## PicoAudio")
54
  with gr.Row():
55
- description_text = f"18 events: {', '.join(event_list)}"
56
  gr.Markdown(description_text)
57
 
58
  with gr.Row():
59
  gr.Markdown("## Step1")
60
  with gr.Row():
61
- preprocess_description_text = f"preprocess: free-text to timestamp caption via LLM"
 
 
62
  gr.Markdown(preprocess_description_text)
63
  with gr.Row():
64
  with gr.Column():
 
52
  with gr.Row():
53
  gr.Markdown("## PicoAudio")
54
  with gr.Row():
55
+ description_text = f"Support 18 events: {', '.join(event_list)}"
56
  gr.Markdown(description_text)
57
 
58
  with gr.Row():
59
  gr.Markdown("## Step1")
60
  with gr.Row():
61
+ preprocess_description_text = f"preprocess: free-text to timestamp caption via LLM. "+\
62
+ "This demo uses Gemini as the preprocessor. If any errors occur, please try a few more times. "+\
63
+ "We also provide the GPT version consistent with the paper in the file 'File/llc_reprocessing.py'. You can use your own api_key to modify and run 'File/llc_reference. py' for local inference."
64
  gr.Markdown(preprocess_description_text)
65
  with gr.Row():
66
  with gr.Column():
inference.py CHANGED
@@ -9,7 +9,7 @@ import numpy as np
9
  import torch
10
  from diffusers import DDPMScheduler
11
  from pico_model import PicoDiffusion, build_pretrained_models
12
-
13
  class dotdict(dict):
14
  """dot.notation access to dictionary attributes"""
15
  __getattr__ = dict.get
@@ -19,8 +19,14 @@ class dotdict(dict):
19
  def parse_args():
20
  parser = argparse.ArgumentParser(description="Inference for text to audio generation task.")
21
  parser.add_argument(
22
- "--text", '-t', type=str, default="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.",
23
- help="Path for experiment."
 
 
 
 
 
 
24
  )
25
  parser.add_argument(
26
  "--exp_path", '-exp', type=str, default="/hpc_stor03/sjtu_home/zeyu.xie/workspace/controllable_audio_generation/huggingface/ckpts/pico_model",
@@ -43,7 +49,7 @@ def parse_args():
43
  def main():
44
  args = parse_args()
45
  train_args = dotdict(json.loads(open(args.original_args).readlines()[0]))
46
-
47
  seed = args.seed
48
  random.seed(seed)
49
  np.random.seed(seed)
@@ -52,6 +58,11 @@ def main():
52
  torch.backends.cudnn.deterministic = True
53
  torch.backends.cudnn.benchmark = False
54
 
 
 
 
 
 
55
  # Load Models #
56
  print("------Load model")
57
  name = "audioldm-s-full"
@@ -74,11 +85,11 @@ def main():
74
 
75
  print("------Diffusion begin!")
76
  with torch.no_grad():
77
- latents = model.demo_inference(args.text, scheduler, num_steps, guidance, num_samples, disable_progress=True)
78
  mel = vae.decode_first_stage(latents)
79
  wave = vae.decode_to_waveform(mel)
80
- sf.write(f"{output_dir}/{args.text}.wav", wave[0][:audio_len], samplerate=16000, subtype='PCM_16')
81
- print(f"------Write to files to {output_dir}/{args.text}.wav")
82
 
83
  if __name__ == "__main__":
84
  main()
 
9
  import torch
10
  from diffusers import DDPMScheduler
11
  from pico_model import PicoDiffusion, build_pretrained_models
12
+ from llm_preprocess import get_event, preprocess_gemini, preprocess_gpt
13
  class dotdict(dict):
14
  """dot.notation access to dictionary attributes"""
15
  __getattr__ = dict.get
 
19
  def parse_args():
20
  parser = argparse.ArgumentParser(description="Inference for text to audio generation task.")
21
  parser.add_argument(
22
+ "--text", '-t', type=str, default="spraying two times then gunshot three times.",
23
+ help="free-text caption."
24
+ )
25
+ parser.add_argument(
26
+ "--timestamp_caption", '-c', type=str,
27
+ default=None,
28
+ #default="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.",
29
+ help="timestamp caption, formatted as 'event1 at onset1-offset1_onset2-offset2 and event2 at onset1-offset1'."
30
  )
31
  parser.add_argument(
32
  "--exp_path", '-exp', type=str, default="/hpc_stor03/sjtu_home/zeyu.xie/workspace/controllable_audio_generation/huggingface/ckpts/pico_model",
 
49
  def main():
50
  args = parse_args()
51
  train_args = dotdict(json.loads(open(args.original_args).readlines()[0]))
52
+
53
  seed = args.seed
54
  random.seed(seed)
55
  np.random.seed(seed)
 
58
  torch.backends.cudnn.deterministic = True
59
  torch.backends.cudnn.benchmark = False
60
 
61
+ # Step1: preprocess via llm
62
+ if args.timestamp_caption == None:
63
+ #args.timestamp_caption = preprocess_gpt(args.text)
64
+ args.timestamp_caption = preprocess_gemini(args.text)
65
+
66
  # Load Models #
67
  print("------Load model")
68
  name = "audioldm-s-full"
 
85
 
86
  print("------Diffusion begin!")
87
  with torch.no_grad():
88
+ latents = model.demo_inference(args.timestamp_caption, scheduler, num_steps, guidance, num_samples, disable_progress=True)
89
  mel = vae.decode_first_stage(latents)
90
  wave = vae.decode_to_waveform(mel)
91
+ sf.write(f"{output_dir}/{args.timestamp_caption}.wav", wave[0][:audio_len], samplerate=16000, subtype='PCM_16')
92
+ print(f"------Write to files to {output_dir}/{args.timestamp_caption}.wav")
93
 
94
  if __name__ == "__main__":
95
  main()
llm_preprocess.py CHANGED
@@ -85,7 +85,7 @@ def preprocess_gemini(free_text_caption):
85
  def preprocess_gpt(free_text_caption):
86
  preffix_prompt = get_prompt()
87
  from openai import OpenAI
88
- client = OpenAI(api_key="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")
89
  completion_start = client.chat.completions.create(
90
  model="gpt-4-1106-preview",
91
  messages=[{
 
85
  def preprocess_gpt(free_text_caption):
86
  preffix_prompt = get_prompt()
87
  from openai import OpenAI
88
+ client = OpenAI(api_key="")
89
  completion_start = client.chat.completions.create(
90
  model="gpt-4-1106-preview",
91
  messages=[{
requirements.txt CHANGED
@@ -1,30 +1,30 @@
1
- torch==2.0.1
2
- torchaudio==2.0.2
3
- torchvision==0.15.2
4
- transformers==4.37.2
5
- accelerate==0.26.1
6
- datasets==2.16.1
7
- diffusers==0.18.2
8
- einops==0.7.0
9
- h5py==3.10.0
10
- huggingface_hub==0.20.3
11
- importlib_metadata==7.0.1
12
- librosa==0.10.1
13
- matplotlib==3.8.2
14
- numpy==1.23.5
15
- omegaconf==2.0.6
16
- packaging==23.2
17
- pandas==2.2.0
18
- progressbar33==2.4
19
- protobuf==3.20.*
20
- resampy==0.4.2
21
- scikit_image==0.22.0
22
- scikit_learn==1.4.0
23
- scipy==1.12.0
24
- soundfile==0.12.1
25
- ssr_eval==0.0.7
26
- torchlibrosa==0.1.0
27
- tqdm==4.63.1
28
- laion-clap==1.1.4
29
- gradio
30
  google-generativeai
 
1
+ torch==2.0.1
2
+ torchaudio==2.0.2
3
+ torchvision==0.15.2
4
+ transformers==4.37.2
5
+ accelerate==0.26.1
6
+ datasets==2.16.1
7
+ diffusers==0.18.2
8
+ einops==0.7.0
9
+ h5py==3.10.0
10
+ huggingface_hub==0.20.3
11
+ importlib_metadata==7.0.1
12
+ librosa==0.10.1
13
+ matplotlib==3.8.2
14
+ numpy==1.23.5
15
+ omegaconf==2.0.6
16
+ packaging==23.2
17
+ pandas==2.2.0
18
+ progressbar33==2.4
19
+ protobuf==3.20.*
20
+ resampy==0.4.2
21
+ scikit_image==0.22.0
22
+ scikit_learn==1.4.0
23
+ scipy==1.12.0
24
+ soundfile==0.12.1
25
+ ssr_eval==0.0.7
26
+ torchlibrosa==0.1.0
27
+ tqdm==4.63.1
28
+ laion-clap==1.1.4
29
+ gradio
30
  google-generativeai