需求
在很多应用中,需要比较两个序列的相似性。然而比如语音、手势,每个音的长短不一致,即使两个序列有很强的相似性,但是在每个特征时间坐标也不是对齐的。这样就需要对其进行伸缩,动态时间规整就是通过找到这两个波形对齐的点,来计算两个序列之间的距离(相似度)。时间序列在处理手势识别数据时,需要用到动态时间规整。
动态时间规整 DTW
比较两个序列Q,C
Q = q1, q2,…,qi,…, qn ;
C = c1, c2,…, cj,…, cm ;
构造一个 n x m 的矩阵网格
矩阵中元素$(i,j)$表示 $q_i$和$c_j$两点的距离$d(q_i,c_j)$(距离越小相似度越高)。一般使用欧式距离$d(q_i,c_j)=(q_i-c_j)^2$, 每个元素$(i,j)$也表示$q_i$和$c_j$进行对齐。
DTW算法可以归结为寻找此网格中包含若干格点的路径,包含的格点即为两个序列进行计算时需要对齐的点。
将这条路径定义为warping path规整路径,用W表示$ w_k=(i,j)_k $
$ W = w_1,w_2,\cdots,w_k,\cdots,W_K\max(m,n)\leq K < m+n-1 $
这条路径需要满足以下条件:
-
(1) 第一个序列中的每个索引都必须与另一个序列中的一个或多个索引匹配,反之亦然
-
(2) 边界条件:$ w_1=(1,1) $和$ w_K=(m,n) $
-
(3) 连续性:如果 $$ w_{k-1}= (a’, b’) $$,那么对于路径的下一个点$$ w_k=(a, b) $$需要满足 $$ (a-a’) <=1 $$ 和 $$ (b-b’) <=1 $$。
-
(4) 单调性:如果$ w_{k-1}= (a’, b’) $,那么对于路径的下一个点$ w_k=(a, b) $需要满足$ 0\leq(a-a’) $和$ 0 \leq (b-b’) $。
如果路径已经通过格点 $ (i,j) $, 为满足连续性和单调性约束,下一个格点只有三个方向,$ (i+1,j), (i,j+1), (i+1,j+1) $。
$ DTW(Q,C)=min$ ${$ $ {\sqrt[]{\sum^{K}_{k = 1}{w_k}}/K} $
这个距离需要满足以上各种约束,并保证代价(Q,C的距离)最小。这个距离不能保证三角不等式成立。
算法
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
| public class DTW { public DTW(){ } public double putDataAndGetRes(){ double[] ts_a = {1,5,8,10,56,21,32,8};
double[] ts_b = {1,5,8,10,23,56,21,32,8};
double[] ts_c = {1,3,6,9,16,29,31,32,33};
return getResult(ts_a,ts_c); }
public double getResult(double[] s, double[] t){ int n=s.length+1; int m=t.length+1; double[][] DTW=new double[n][m];
for(int i=0;i<n;i++){ for(int j=0;j<m;j++){ DTW[i][j]=Double.POSITIVE_INFINITY; } } DTW[0][0]=0; for(int i=1;i<n;i++){ for(int j=1;j<m;j++){ double cost=dist(s[i-1],t[j-1]); DTW[i][j]=cost+Math.min(DTW[i-1][j],Math.min(DTW[i][j-1],DTW[i-1][j-1])); } }
return DTW[n-1][m-1]; } public double dist(double a,double b){ return Math.abs(a-b); }
}
|
上述的算法没有考虑Wraping Path的长度不同·
按照公式考虑Wraping Path的长度因素。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
| public class DTW { public DTW(){ } public double putDataAndGetRes(){ double[] ts_a = {1,5,8,10,56,21,32,8};
double[] ts_b = {1,5,8,10,23,56,21,32,8};
double[] ts_c = {1,3,6,9,16,29,31,32,33};
return getResult2(ts_a,ts_b); }
public double getResult(double[] s, double[] t){ int n=s.length+1; int m=t.length+1; double[][] DTW=new double[n][m];
for(int i=0;i<n;i++){ for(int j=0;j<m;j++){ DTW[i][j]=Double.POSITIVE_INFINITY; } } DTW[0][0]=0; for(int i=1;i<n;i++){ for(int j=1;j<m;j++){ double cost=dist(s[i-1],t[j-1]); DTW[i][j]=cost+Math.min(DTW[i-1][j],Math.min(DTW[i][j-1],DTW[i-1][j-1])); } }
return DTW[n-1][m-1]; } public double getResult2(double[] s, double[] t){ int n=s.length+1; int m=t.length+1; double[][] DTW=new double[n][m]; double[][] count=new double[n][m]; for(int i=0;i<n;i++){ for(int j=0;j<m;j++){ DTW[i][j]=Double.POSITIVE_INFINITY; count[i][j]=0; } } DTW[0][0]=0; for(int i=1;i<n;i++){ for(int j=1;j<m;j++){ double cost=dist(s[i-1],t[j-1]); double min=Math.min(DTW[i-1][j],Math.min(DTW[i][j-1],DTW[i-1][j-1])); DTW[i][j]=cost+min; if(DTW[i-1][j]==min){ count[i][j]=count[i-1][j]+1; }else if(DTW[i-1][j-1]==min){ count[i][j]=count[i-1][j-1]+1; }else{ count[i][j]=count[i][j-1]+1; } } }
return Math.sqrt(DTW[n-1][m-1])/count[n-1][m-1]; } public double dist(double a,double b){ return Math.abs(a-b); }
}
|
参考: