GenMM / app.py
wyysf's picture
add header with title (#1)
88aa0dc
raw
history blame
No virus
2.8 kB
import json
import time
import uvicorn
from pathlib import Path
from fastapi import FastAPI
from fastapi.responses import RedirectResponse
from fastapi.staticfiles import StaticFiles
from dataset.tracks_motion import TracksMotion
from GPS import GPS
import gradio as gr
def _synthesis(synthesis_setting, motion_data):
model = GPS(
init_mode=f"random_synthesis/{synthesis_setting['frames']}",
noise_sigma=synthesis_setting["noise_sigma"],
coarse_ratio=0.2,
pyr_factor=synthesis_setting["pyr_factor"],
num_stages_limit=-1,
silent=True,
device="cpu",
)
synthesized_motion = model.run(
motion_data,
mode="match_and_blend",
ext={
"criteria": {
"type": "PatchCoherentLoss",
"patch_size": synthesis_setting["patch_size"],
"stride": synthesis_setting["stride"]
if "stride" in synthesis_setting.keys()
else 1,
"loop": synthesis_setting["loop"],
"coherent_alpha": synthesis_setting["alpha"]
if synthesis_setting["completeness"]
else None,
},
"optimizer": "match_and_blend",
"num_itrs": synthesis_setting["num_steps"],
},
)
return synthesized_motion
def synthesis(data):
data = json.loads(data)
# create track object
data["setting"]["coarse_ratio"] = -1
motion_data = TracksMotion(data["tracks"], scale=data["scale"])
start = time.time()
synthesized_motion = _synthesis(data["setting"], [motion_data])
end = time.time()
data["time"] = end - start
data["tracks"] = motion_data.parse(synthesized_motion)
return data
intro = """
<h1 style="text-align: center;">
Example-based Motion Synthesis via Generative Motion Matching
</h1>
<h3 style="text-align: center; margin-bottom: 7px;">
<a href="http://weiyuli.xyz/GenMM" target="_blank">Project Page</a> | <a href="https://huggingface.co/papers/2306.00378" target="_blank">Paper</a> | <a href="https://github.com/wyysf-98/GenMM" target="_blank">Code</a>
</h3>
"""
with gr.Blocks() as demo:
gr.HTML(intro)
gr.HTML(
"""<iframe src="/GenMM_demo/" width="100%" height="700px" style="border:none;">"""
)
json_in = gr.JSON(visible=False)
json_out = gr.JSON(visible=False)
btn = gr.Button("Synthesize", visible=False)
btn.click(synthesis, inputs=[json_in], outputs=[json_out], api_name="predict")
app = FastAPI()
static_dir = Path("./GenMM_demo")
app.mount("/GenMM_demo", StaticFiles(directory=static_dir, html=True), name="static")
app = gr.mount_gradio_app(app, demo, path="/")
# serve the app
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860)