|||
我在Brown University CS242 homework 5上花了大概20天(当然是业余时间)左右,虽然结果仍然遗留了一个小尾巴,但对Mean field variational inferece在MRF/CRF上如何应用,算有一个清晰深入的认识,同时也学到一些小trick,比如如何解决计算exp(x)和log(x)时内存溢出的问题,同时发现原来CRF/MRF在多标签分类问题上还可以这样玩。
总之收获很多,名校的homework是非常值得一做的。
然后回过头再看PairwiseMarkov随机场求解代码,很快发现虽然参数估计的公式没有错,但代码实现上有bug,于是赶紧修改过来。
运行时候发现总出现内存溢出警报,发现是exp(x)时候,当x是800+时,计算结果infinite,后来使用CS242的小trick,logsum的实现方式如下:
import math
def logsum(x):
m = max(x);
x = x-m;
res = m + np.log(sum(np.exp(x)));
return res
在将概率规范化时候如下操作
xi_prob_tmp = np.exp(xi_prob_tmp - self.logsum(xi_prob_tmp))
完美解决内存溢出问题,爽到飞起来。
然后在4000个左右节点的MarkovNet上运行Mean Field Variational Inference,收敛速度非常快,基本上10次迭代内就能得到一个稳定结果,如下:
下面该考虑Mean Field Variational Inference和SGD结合,去学习模型参数了。
Archiver|手机版|科学网 ( 京ICP备07017567号-12 )
GMT+8, 2024-11-22 07:07
Powered by ScienceNet.cn
Copyright © 2007- 中国科学报社