LibSVM 3.12的源码分析Svm-train.c

共涉及3个文件: Svm-train.c, Svm.cpp, Svm.h. 建议使用Source Insight软件对这3个文件建立工程. 方便代码阅读. 下面从Svm-train.c文件中的main()函数切入.

  1. int main(int argc, class="keyword">char **argv)
  2. {
  3. char input_file_name[1024]; //训练样本文件名

  4. char model_file_name[1024]; //输出模型的文件名
  5. const char *error_msg;
  6. parse_command_line(argc, argv, input_file_name, model_file_name); //解析运行程序时,命令行输入的参数

  7. read_problem(input_file_name); //读入训练样本,存入到struct svm_problem prob结构体中

  8. error_msg = svm_check_parameter(&prob,&param); //检查训练样本数据格式是否正确
  9. if(error_msg)
  10. {
  11. fprintf(stderr,“ERROR: %s\n”,error_msg);
  12. exit(1);
  13. }
  14. if(cross_validation)
  15. {
  16. do_cross_validation(); //根据设置进行交叉验证训练
  17. }
  18. else
  19. {
  20. model = svm_train(&prob,&param); //根据问题数据(&prob)和参数(&param)训练模型

  21. if(svm_save_model(model_file_name,model))//保存模型到输出

    文件中

  22. {
  23. fprintf(stderr, “can’t save model to file %s\n”, model_file_name);
  24. exit(1);
  25. }
  26. svm_free_and_destroy_model(&model); //释放模型结构空间
  27. }
  28. svm_destroy_param(&param); //释放使用的其他结构空间
  29. free(prob.y);
  30. free(prob.x);
  31. free(x_space);
  32. free(line);
  33. return 0;
  34. }

下面分析一下main()函数中调用的主要函数程序, 命令行参数解析函数parse_command_line()代码及其注释如下:

  1. void parse_command_line(int argc, class="keyword">char **argv, char *input_file_name, char

    *model_file_name)

  2. {
  3. int i;
  4. void (*print_func)(const class="keyword">char*) = NULL; // default printing to stdout
  5. // default values
  6. param.svm_type = C_SVC;
  7. param.kernel_type = RBF;
  8. param.degree = 3;
  9. param.gamma = 0; // 1/num_features
  10. param.coef0 = 0;
  11. param.nu = 0.5;
  12. param.cache_size = 100;
  13. param.C = 1;
  14. param.eps = 1e-3;
  15. param.p = 0.1;
  16. param.shrinking = 1;
  17. param.probability = 0;
  18. param.nr_weight = 0;
  19. param.weight_label = NULL;
  20. param.weight = NULL;
  21. cross_validation = 0;
  22. // parse options
  23. for(i=1;i<argc;i++) //argc中存放的是命令行程序运行时的参数

    个数

  24. {
  25. if(argv[i][0] != ‘-’) break; class="comment">//开头处是否为参数类型标识,若不是跳出循环
  26. if(++i>=argc) //判断参数类型后是否有其他参数,如样本文件名

  27. exit_with_help(); //如果没有则退出并打印帮助提示
  28. switch(argv[i-1][1]) //根据参数标识,转换参数值为正确类型或相应设置

  29. {
  30. case ‘s’:
  31. param.svm_type = atoi(argv[i]);
  32. break;
  33. case ‘t’:
  34. param.kernel_type = atoi(argv[i]);
  35. break;
  36. case ‘d’:
  37. param.degree = atoi(argv[i]);
  38. break;
  39. case ‘g’:
  40. param.gamma = atof(argv[i]);
  41. break;
  42. case ‘r’:
  43. param.coef0 = atof(argv[i]);
  44. break;
  45. case ‘n’:
  46. param.nu = atof(argv[i]);
  47. break;
  48. case ‘m’:
  49. param.cache_size = atof(argv[i]);
  50. break;
  51. case ‘c’:
  52. param.C = atof(argv[i]);
  53. break;
  54. case ‘e’:
  55. param.eps = atof(argv[i]);
  56. break;
  57. case ‘p’:
  58. param.p = atof(argv[i]);
  59. break;
  60. case ‘h’:
  61. param.shrinking = atoi(argv[i]);
  62. break;
  63. case ‘b’:
  64. param.probability = atoi(argv[i]);
  65. break;
  66. case ‘q’:
  67. print_func = &print_null;
  68. i–;
  69. break;
  70. case ‘v’: //设置交叉验证的参数标识
  71. cross_validation = 1;
  72. nr_fold = atoi(argv[i]);
  73. if(nr_fold < 2)
  74. {
  75. fprintf(stderr,“n-fold cross validation: n must >= 2\n”);
  76. exit_with_help();
  77. }
  78. break;
  79. case ‘w’:
  80. ++param.nr_weight;
  81. param.weight_label = (int *)realloc(param.weight_label,class="keyword">sizeof(int)*param.nr_weight);
  82. param.weight = (double *)realloc(param.weight,class="keyword">sizeof(double)*param.nr_weight);
  83. param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]);
  84. param.weight[param.nr_weight-1] = atof(argv[i]);
  85. break;
  86. default:
  87. fprintf(stderr,“Unknown option: -%c\n”, argv[i-1][1]);
  88. exit_with_help();
  89. }
  90. }
  91. svm_set_print_string_function(print_func);
  92. // determine filenames
  93. if(i>=argc)
  94. exit_with_help();
  95. strcpy(input_file_name, argv[i]); //将命令行中的训练文件名,赋值给main中的字符数组.

  96. if(i<argc-1) //如果自定义了输出模型名,则赋值给变量,否则使用默认命名方式

    构造文件名

  97. strcpy(model_file_name,argv[i+1]);
  98. else
  99. {
  100. char *p = strrchr(argv[i],’/');
  101. if(p==NULL)
  102. p = argv[i];
  103. else
  104. ++p;
  105. sprintf(model_file_name,“%s.model”,p);
  106. }
  107. }

发表评论

电子邮件地址不会被公开。 必填项已用 * 标注

*

您可以使用这些 HTML 标签和属性: <a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <strike> <strong>