Spaces:
Running
Running
import os | |
import uuid | |
import argparse | |
argparser = argparse.ArgumentParser() | |
argparser.add_argument("--port", type=int, default=1239, help="Port number for the local server") | |
argparser.add_argument("--cuda_device", type=str, default='0', help="Cuda devices to use. Default is 0") | |
argparser.add_argument("--static_folder", type=str, default='static', help="Folder to store static files") | |
args = argparser.parse_args() | |
os.environ["CUDA_VISIBLE_DEVICES"] = '' | |
import gradio as gr | |
from pathlib import Path | |
import numpy as np | |
import torch | |
from LInK.CAD import create_3d_html | |
import numpy as np | |
from LInK.CurveUtils import uniformize | |
import torch | |
import matplotlib.pyplot as plt | |
from LInK.Solver import solve_rev_vectorized_batch_CPU | |
# turn off gradient computation | |
torch.set_grad_enabled(False) | |
results = np.load('alpha_res.npy',allow_pickle=True) | |
alphabet_test = np.load('alphabet.npy') | |
zs_ = np.load('alpha_z.npy',allow_pickle=True) | |
curves__ = torch.tensor(alphabet_test).float() | |
curves__ = curves__ - curves__.mean(1).unsqueeze(1) | |
max_idx = torch.square(curves__).sum(-1).argmax(dim=1) | |
theta = torch.atan2(curves__[torch.arange(curves__.shape[0]),max_idx,1],curves__[torch.arange(curves__.shape[0]),max_idx,0]).numpy() | |
curves_ = [] | |
for i in range(len(results)): | |
curves_.append(results[i][-1]) | |
curves_ = torch.tensor(curves_).float() | |
curves_ = uniformize(curves_,200) | |
curves_ = curves_ - curves_.mean(1).unsqueeze(1) | |
max_idx = torch.square(curves_).sum(-1).argmax(dim=1) | |
theta2 = torch.atan2(curves_[torch.arange(curves_.shape[0]),max_idx,1],curves_[torch.arange(curves_.shape[0]),max_idx,0]).numpy() | |
alphas = [] | |
letter_heights = [] | |
letter_centers = [] | |
letter_widths = [] | |
for i in range(len(results)): | |
A, x0, node_type, start_theta, end_theta, tr = results[i][0] | |
alpha = theta[i] - theta2[i] | |
if i == 21: | |
alpha -= np.pi/2.5 | |
if i == 7: | |
alpha -= np.pi/3 | |
if i == 4: | |
alpha += np.pi/36 | |
alphas.append(alpha) | |
R = np.array([[np.cos(alpha), -np.sin(alpha)],[np.sin(alpha), np.cos(alpha)]]).squeeze() | |
transformed_curve = (R@results[i][-1].T).T | |
CD,OD,_ = results[i][1] | |
sol = solve_rev_vectorized_batch_CPU(A[None],x0[None],node_type[None],np.linspace(start_theta,end_theta,2000))[0] | |
sol_curve = (R@sol[-1].T).T | |
n_left = len(results) - i | |
n_left_row = 10 - i%10 | |
letter_heights.append(transformed_curve[:,1].max()-transformed_curve[:,1].min()) | |
letter_widths.append(transformed_curve[:,0].max()-transformed_curve[:,0].min()) | |
letter_centers.append([(transformed_curve[:,0].max() + transformed_curve[:,0].min())/2,(transformed_curve[:,1].max() + transformed_curve[:,1].min())/2]) | |
alphas = np.array(alphas) | |
letter_heights = np.array(letter_heights) | |
letter_centers = np.array(letter_centers) | |
letter_widths = np.array(letter_widths) | |
alphabet_dict = {'A':0,'B':1,'C':2,'D':3,'E':4,'F':5,'G':6,'H':7,'I':8,'J':9,'K':10,'L':11,'M':12,'N':13,'O':14,'P':16,'Q':15,'R':17,'S':18,'T':19,'U':20,'V':21,'W':22,'X':23,'Y':24,'Z':25} | |
def create_mech(target_text): | |
target_text = target_text.replace(' ','').upper() | |
target_height = 1. | |
spacing = 0.2 | |
letters = [alphabet_dict[l] for l in target_text] | |
translations = [] | |
scaling = [] | |
transformed_curves = [] | |
mechs = [] | |
total_size = 0 | |
for i,l in enumerate(letters): | |
A, x0, node_type, start_theta, end_theta, tr = results[l][0] | |
alpha = alphas[l] | |
R = np.array([[np.cos(alpha), -np.sin(alpha)],[np.sin(alpha), np.cos(alpha)]]).squeeze() | |
transformed_curve = (R@results[l][-1].T).T | |
s = target_height/letter_heights[l] | |
scaling.append(s) | |
if i>0: | |
trans = [translations[-1][0] + letter_widths[letters[i-1]]/2 * scaling[i-1] + letter_widths[l]/2 * s + spacing , 0] | |
else: | |
trans = [letter_widths[l]/2 * s ,0] | |
translations.append(trans) | |
transformed_curves.append(s*(transformed_curve - letter_centers[l]) + trans) | |
mechs.append([A,s*((R@x0.T).T - letter_centers[l]) + trans,node_type,start_theta+alpha,end_theta+alpha]) | |
total_size += A.shape[0] | |
A_all = np.zeros((total_size,total_size)) | |
x0_all = np.zeros((total_size,2)) | |
node_type_all = np.zeros((total_size,1)) | |
current_count = 0 | |
sols = [] | |
highlights = [] | |
zs = [] | |
for i,m in enumerate(mechs): | |
A, x0, node_type, start_theta, end_theta = m | |
A_all[current_count:current_count+A.shape[0],current_count:current_count+A.shape[0]] = A | |
# x0_all[current_count:current_count+A.shape[0]] = x0 | |
node_type_all[current_count:current_count+A.shape[0]] = node_type | |
if i ==0: | |
highlights.append(A.shape[0]-1) | |
else: | |
highlights.append(A.shape[0]+highlights[-1]) | |
sol = solve_rev_vectorized_batch_CPU(A[None],x0[None],node_type[None],np.linspace(start_theta,end_theta,100))[0] | |
sols.append(sol.transpose(1,0,2)) | |
x0_all[current_count:current_count+A.shape[0]] = sol[:,0,:] | |
current_count += A.shape[0] | |
z = zs_[letters[i]] | |
zs.append(z + zs[-1].max() + 1 if i>0 else z) | |
sols = np.concatenate(sols,axis=1) | |
zs = np.concatenate(zs) | |
uuid_ = str(uuid.uuid4()) | |
create_3d_html(A_all, x0_all, node_type_all, zs, np.concatenate([sols.transpose(1,0,2),sols.transpose(1,0,2)[:,::-1,:]],1), template_path = f'./static/animation.html', save_path=f'./static/{uuid_}.html', highlights=highlights) | |
return gr.HTML(f'<iframe width="100%" height="800px" src="file=static/{uuid_}.html"></iframe>',label="3D Plot",elem_classes="plot3d") | |
gr.set_static_paths(paths=[Path(f'./{args.static_folder}')]) | |
with gr.Blocks() as block: | |
with gr.Row(): | |
with gr.Column(): | |
text = gr.Textbox(label="Enter a word (spaces will be ignored)", value='DECODE') | |
btn = gr.Button(value="Create Mechanism", variant="primary") | |
plot_3d = gr.HTML('<iframe width="100%" height="800px" src="file=static/filler.html"></iframe>',label="3D Plot",elem_classes="plot3d") | |
event1 = btn.click(create_mech, inputs=[text], outputs=[plot_3d]) | |
block.launch() | |