[docs]classSMPLWrapper:"""A wrapper for the SMPL model """def__init__(self,smpl_root:str,gender:str,shape_params:np.ndarray,device:torch.device=None):self.device=torch.device(deviceifdeviceisnotNoneelse"cpu")self.smpl_root=smpl_rootself.shape_params=self._preprocess_param(shape_params)self.smpl=smplx.create(self.smpl_root,model_type='smpl',gender=gender,betas=self.shape_params).to(self.device)self.faces=np.array(self.smpl.faces,dtype=np.int64)@staticmethoddef_center_output(smpl_model,params,smpl_output):"""Center the SMPL model local coordinate system around the root joint"""if'transl'inparamsandparams['transl']isnotNone:transl=params['transl']else:transl=Noneapply_trans=translisnotNoneorhasattr(smpl_model,'transl')iftranslisNoneandhasattr(smpl_model,'transl'):transl=smpl_model.transldiff=-smpl_output.joints[:,0,:]ifapply_trans:diff=diff+translsmpl_output.joints=smpl_output.joints+diff.view(-1,1,3)smpl_output.vertices=smpl_output.vertices+diff.view(-1,1,3)returnsmpl_outputdef_preprocess_param(self,param:np.ndarray)->torch.Tensor:"""Prepare the parameters for SMPL """ifnotisinstance(param,torch.Tensor):param=torch.tensor(param,dtype=torch.float32)param=param.to(self.device).reshape(1,-1)returnparam
[docs]defget_smpl(self,pose_params:np.ndarray,translation_params:np.ndarray)->np.ndarray:"""Get the SMPL mesh vertices from the target pose and global translation Args: pose_params: Pose parameters vector of shape (72) translation_params: Global translation vector of shape (3) Returns: np.ndarray: vertices of SMPL model """global_orient=self._preprocess_param(pose_params[:3])body_pose=self._preprocess_param(pose_params[3:])translation_params=self._preprocess_param(translation_params)smpl_params=dict(global_orient=global_orient,body_pose=body_pose,transl=translation_params)withtorch.no_grad():output=self.smpl(**smpl_params)output=self._center_output(self.smpl,smpl_params,output)returnoutput.vertices.squeeze(0).cpu().numpy()