import os import uuid import argparse argparser = argparse.ArgumentParser() argparser.add_argument("--port", type=int, default=1239, help="Port number for the local server") argparser.add_argument("--cuda_device", type=str, default='0', help="Cuda devices to use. Default is 0") argparser.add_argument("--static_folder", type=str, default='static', help="Folder to store static files") args = argparser.parse_args() os.environ["CUDA_VISIBLE_DEVICES"] = '' import gradio as gr from pathlib import Path import numpy as np import torch from LInK.CAD import create_3d_html import numpy as np from LInK.CurveUtils import uniformize import torch import matplotlib.pyplot as plt from LInK.Solver import solve_rev_vectorized_batch_CPU # turn off gradient computation torch.set_grad_enabled(False) results = np.load('alpha_res.npy',allow_pickle=True) alphabet_test = np.load('alphabet.npy') zs_ = np.load('alpha_z.npy',allow_pickle=True) curves__ = torch.tensor(alphabet_test).float() curves__ = curves__ - curves__.mean(1).unsqueeze(1) max_idx = torch.square(curves__).sum(-1).argmax(dim=1) theta = torch.atan2(curves__[torch.arange(curves__.shape[0]),max_idx,1],curves__[torch.arange(curves__.shape[0]),max_idx,0]).numpy() curves_ = [] for i in range(len(results)): curves_.append(results[i][-1]) curves_ = torch.tensor(curves_).float() curves_ = uniformize(curves_,200) curves_ = curves_ - curves_.mean(1).unsqueeze(1) max_idx = torch.square(curves_).sum(-1).argmax(dim=1) theta2 = torch.atan2(curves_[torch.arange(curves_.shape[0]),max_idx,1],curves_[torch.arange(curves_.shape[0]),max_idx,0]).numpy() alphas = [] letter_heights = [] letter_centers = [] letter_widths = [] for i in range(len(results)): A, x0, node_type, start_theta, end_theta, tr = results[i][0] alpha = theta[i] - theta2[i] if i == 21: alpha -= np.pi/2.5 if i == 7: alpha -= np.pi/3 if i == 4: alpha += np.pi/36 alphas.append(alpha) R = np.array([[np.cos(alpha), -np.sin(alpha)],[np.sin(alpha), np.cos(alpha)]]).squeeze() transformed_curve = (R@results[i][-1].T).T CD,OD,_ = results[i][1] sol = solve_rev_vectorized_batch_CPU(A[None],x0[None],node_type[None],np.linspace(start_theta,end_theta,2000))[0] sol_curve = (R@sol[-1].T).T n_left = len(results) - i n_left_row = 10 - i%10 letter_heights.append(transformed_curve[:,1].max()-transformed_curve[:,1].min()) letter_widths.append(transformed_curve[:,0].max()-transformed_curve[:,0].min()) letter_centers.append([(transformed_curve[:,0].max() + transformed_curve[:,0].min())/2,(transformed_curve[:,1].max() + transformed_curve[:,1].min())/2]) alphas = np.array(alphas) letter_heights = np.array(letter_heights) letter_centers = np.array(letter_centers) letter_widths = np.array(letter_widths) alphabet_dict = {'A':0,'B':1,'C':2,'D':3,'E':4,'F':5,'G':6,'H':7,'I':8,'J':9,'K':10,'L':11,'M':12,'N':13,'O':14,'P':16,'Q':15,'R':17,'S':18,'T':19,'U':20,'V':21,'W':22,'X':23,'Y':24,'Z':25} def create_mech(target_text): target_text = target_text.replace(' ','').upper() target_height = 1. spacing = 0.2 letters = [alphabet_dict[l] for l in target_text] translations = [] scaling = [] transformed_curves = [] mechs = [] total_size = 0 for i,l in enumerate(letters): A, x0, node_type, start_theta, end_theta, tr = results[l][0] alpha = alphas[l] R = np.array([[np.cos(alpha), -np.sin(alpha)],[np.sin(alpha), np.cos(alpha)]]).squeeze() transformed_curve = (R@results[l][-1].T).T s = target_height/letter_heights[l] scaling.append(s) if i>0: trans = [translations[-1][0] + letter_widths[letters[i-1]]/2 * scaling[i-1] + letter_widths[l]/2 * s + spacing , 0] else: trans = [letter_widths[l]/2 * s ,0] translations.append(trans) transformed_curves.append(s*(transformed_curve - letter_centers[l]) + trans) mechs.append([A,s*((R@x0.T).T - letter_centers[l]) + trans,node_type,start_theta+alpha,end_theta+alpha]) total_size += A.shape[0] A_all = np.zeros((total_size,total_size)) x0_all = np.zeros((total_size,2)) node_type_all = np.zeros((total_size,1)) current_count = 0 sols = [] highlights = [] zs = [] for i,m in enumerate(mechs): A, x0, node_type, start_theta, end_theta = m A_all[current_count:current_count+A.shape[0],current_count:current_count+A.shape[0]] = A # x0_all[current_count:current_count+A.shape[0]] = x0 node_type_all[current_count:current_count+A.shape[0]] = node_type if i ==0: highlights.append(A.shape[0]-1) else: highlights.append(A.shape[0]+highlights[-1]) sol = solve_rev_vectorized_batch_CPU(A[None],x0[None],node_type[None],np.linspace(start_theta,end_theta,100))[0] sols.append(sol.transpose(1,0,2)) x0_all[current_count:current_count+A.shape[0]] = sol[:,0,:] current_count += A.shape[0] z = zs_[letters[i]] zs.append(z + zs[-1].max() + 1 if i>0 else z) sols = np.concatenate(sols,axis=1) zs = np.concatenate(zs) uuid_ = str(uuid.uuid4()) create_3d_html(A_all, x0_all, node_type_all, zs, np.concatenate([sols.transpose(1,0,2),sols.transpose(1,0,2)[:,::-1,:]],1), template_path = f'./static/animation.html', save_path=f'./static/{uuid_}.html', highlights=highlights) return gr.HTML(f'',label="3D Plot",elem_classes="plot3d") gr.set_static_paths(paths=[Path(f'./{args.static_folder}')]) with gr.Blocks() as block: with gr.Row(): with gr.Column(): text = gr.Textbox(label="Enter a word (spaces will be ignored)", value='DECODE') btn = gr.Button(value="Create Mechanism", variant="primary") plot_3d = gr.HTML('',label="3D Plot",elem_classes="plot3d") event1 = btn.click(create_mech, inputs=[text], outputs=[plot_3d]) block.launch()