有幸参与了2020华为软挑初赛跟复赛,博主cpp菜鸡一枚,因此全程都是Java语言参与,Java最好成绩初赛2.8,在粤港澳勉强狗进32强,队名:佛系炮灰。
初赛赛题如下:
通过金融风控的资金流水分析,可有效识别循环转账,辅助公安挖掘洗钱组织,帮助银行预防信用卡诈骗。基于给定的资金流水,检测并输出指定约束条件的所有循环转账,结果准确,用时最短者胜。简单来讲,就是创建图,找环路。
官方提供了输入文件与参考答案,数据规模较小,博主菜鸡,Java跑官方样例也需要0.06s。
输入为包含资金流水的文本文件,每一行代表一次资金交易记录,包含本端账号ID, 对端账号ID, 转账金额,用逗号隔开。
输出信息为一个文件,包含满足限制条件下的循环转账个数与所有满足限制条件的循环转账路径详情。
知乎杭厦赛区某大佬开源了cpp版本,因此起初复现了一版Java单线程的,并在此基础上做了改动,最好成绩3.4左右,比赛全程各种被cpp吊锤,使用Java做算法大赛果然是毫无体验。如果有需要单程程代码的,可以私聊我,以下分享全部是Java多线程版本思路:
1:BufferredReader以行为单位读文件,同时创建图(这里有尝试过使用Java的MappedByteBuffer一次性读入,并以字节为单位处理两端账号ID,忽略转账金额,该方案线下提升速度在25%左右,线上提升不明显)
2:使用拓扑排序,去掉不会构成环的ID(只去掉一层,同时需要更新边表节点列表)
3:使用DFS加剪枝,忽略小于起始ID的账号。DFS只搜索到第6层,利用记忆化搜索保留第6层到起始账号ID的中间节点,同时在深度为5时,使用记忆化搜素结构直接输出长度为6的环。并非开源方案中的path2结构,该结构是保存所有符合成环的中间点。Java版本则是只提前搜索所有点到本端起始ID的中间点,该思路在线下与线上性能提升较大。
4:对边表节点列表进行排序,记忆化搜索结构进行排序,使用Stringbuilder分开储存3-7的环,最后进行输出
5:多线程对任务进行切分。这里使用线程池管理线程,没有用到大佬分享的原子类等多线程骚操作(其实是不会,手动狗头)理论来讲,开4线程是最优的,实测则是3线程的成绩较好,可能是博主比较菜,参数没调整好的原因。多线程任务一定不能四等分,因此ID越小,线程的任务越重,多线程的优势反而不明显。这里可以参考知乎某大佬的调参方案 https://zhuanlan.zhihu.com/p/136785097
6:使用Printwriter或BufferWriter输出
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public class Main {
public static int threadNum = 3;
public static void main(String[] args) {
long start = System.currentTimeMillis();
GraghSerach gs = new GraghSerach();
try {
gs.initGragh();
} catch (NumberFormatException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
int sampleNum = gs.n;
ExecutorService exe = Executors.newFixedThreadPool(threadNum);
SerachCricleThread r1 = new SerachCricleThread(gs, 0, sampleNum/15);
SerachCricleThread r2 = new SerachCricleThread(gs, sampleNum/15, sampleNum/8);
SerachCricleThread r3 = new SerachCricleThread(gs, sampleNum/8, sampleNum);
exe.execute(r1);
exe.execute(r2);
exe.execute(r3);
exe.shutdown();
//通过线程池管理线程是否执行完毕
while (true) {
if (exe.isTerminated()) {
break;
}
}
StringBuilder result = new StringBuilder(7*7*2800000);
StringBuilder[] result_r1 = r1.getList();
StringBuilder[] result_r2 = r2.getList();
StringBuilder[] result_r3 = r3.getList();
int[] circleCount_r1 = r1.getCount();
int[] circleCount_r2 = r2.getCount();
int[] circleCount_r3 = r3.getCount();
int circleCount =circleCount_r1[0]+ circleCount_r2[0]+circleCount_r3[0];
result.append(Integer.toString(circleCount) + "\n");
for (int i = 3; i < 8; i++) {
result.append(result_r1[i].toString());
result.append(result_r2[i].toString());
result.append(result_r3[i].toString());
}
long wrtime = System.currentTimeMillis();
gs.savePredictResult(result);
long alltime = System.currentTimeMillis();
System.out.println("All Time(s): " + (alltime - start) * 1.0 / 1000 + "s");
System.out.println("wr Time(s): " + (alltime - wrtime) * 1.0 / 1000 + "s");
}
}
class SerachCricleThread implements Runnable {
GraghSerach gs;
int start;//线程切分起始参数
int end;//线程切分终止参数
boolean visi[];//节点是否访问数组
StringBuilder[] sb = new StringBuilder[8];
int[] circleCount = {0};//环个数数组
int[] reachable;//第五层或第六层点到起始点是否可达标志位数组
Map<Integer, List<Integer>> mas = null;
public SerachCricleThread(GraghSerach gs, int start, int end) {
this.gs = gs;
this.start = start;
this.end = end;
}
public StringBuilder[] getList() {
return sb;
}
public int[] getCount() {
return circleCount;
}
@Override
public void run() {
List<Integer> trace = new ArrayList<>();
int n = gs.n;
visi = new boolean[n];
reachable = new int[n];
sb[3] = new StringBuilder(3*3*200000);
sb[4] = new StringBuilder(4*4*200000);
sb[5] = new StringBuilder(5*5*500000);
sb[6] = new StringBuilder(6*6*500000);
sb[7] = new StringBuilder(7*7*1000000);
for (int i = start; i < end; i++) {
for (int reach = i+1; reach < n; reach++) {
reachable[reach] = -1;
}
mas = new HashMap<Integer, List<Integer>>();//mas保存了所有点到当前访问点的中间节点,使用list储存
int[] headList = gs.InG[i];
for(int j=0,len_1 = gs.inDegress[i];j<len_1;j++) {
int headNextNode = headList[j];
if (headNextNode > i) {
int[] headNextList = gs.InG[headNextNode];
for (int k = 0,len_2 = gs.inDegress[headNextNode]; k<len_2;k++) {
int temp = headNextList[k];
reachable[temp] = 0;
if (mas.get(temp)==null) {
mas.put(temp, new ArrayList<Integer>());
}
mas.get(temp).add(headNextNode);
}
}
}
gs.decDfs(i, i, 1, visi, trace, sb, reachable,mas,circleCount);
}
}
}
//图类
class GraghSerach {
Set<Integer> nodeSet = new HashSet<>(150000);
HashMap<Integer, Integer> idMap = new HashMap<Integer, Integer>();//内存映射
int[] outDegress = null;
int[] inDegress = null;
List<Integer> trace;//当前dfs访问路径
String fileNames = "/data/test_data.txt";
String output = "/projects/student/result.txt";
int[][] outG = null;//出度边表点列表数组
int[][] InG = null;//入度边表点列表数组
int out_max = 50;
int in_max = 50;
//有效节点的总个数
int n = 0;
boolean[] visi = null;
BufferedReader reader = null;
int[] nodeArr = null;
String line = "";
List<Integer> nodeList = new ArrayList<Integer>();
int[] reachable;
String[] nodeLF = null;//对节点ID与“,”或“/n”拼接,输出环路径时可以直接通过该结构获取
String[] nodeRF = null;//对节点ID与“,”或“/n”拼接,输出环路径时可以直接通过该结构获取
public GraghSerach() {
}
public void initGragh() throws NumberFormatException, IOException {
Map<Integer, ArrayList<Integer>> outList = new HashMap<Integer, ArrayList<Integer>>();//key为结点,value为出度列表)
Map<Integer, ArrayList<Integer>> inputList = new HashMap<Integer, ArrayList<Integer>>();//key为结点,value为入度列表)
Map<Integer, Integer> outputValue = new HashMap<>() ; //key为结点,value为出度个数
Map<Integer, Integer> inputValue = new HashMap<>() ; //key为结点,value为入度个数
try {
//BufferedReader读文件
reader = new BufferedReader(new FileReader(fileNames));
while ((line = reader.readLine()) != null) {
String item[] = line.split(",");
insertEdge(Integer.parseInt(item[0]),Integer.parseInt(item[1]),outList,inputList,outputValue,inputValue);
}
} catch (FileNotFoundException exception) {
System.err.println(fileNames + " File Not Found");
}
Iterator<Integer> iter = outputValue.keySet().iterator();
while(iter.hasNext()){
int tempKey= iter.next();
Integer tempValue = inputValue.get(tempKey);
if(tempValue == null) {
iter.remove();
}else {
nodeSet.add(tempKey);
n++;
}
}
//大佬测试了最大的入度为不大于255
InG = new int[n][in_max];
inDegress = new int[n];
//大佬测试了最大的出度为不大于50
outG = new int[n][out_max];
outDegress = new int[n];
//定义有效的结点数组
nodeArr = new int[n];
nodeLF = new String[n];
nodeRF= new String[n];
reachable = new int[n];
int count_i=0;
for(iter = nodeSet.iterator(); iter.hasNext(); ) {
int element = (int) iter.next();
idMap.put(element, count_i);
nodeLF[count_i] = Integer.toString(element) + "\n";
nodeRF[count_i] = Integer.toString(element)+ ",";
nodeArr[count_i++] = element;
}
for (int i = 0; i < n; i++) {
int element = nodeArr[i];
int outValue_count =0;
int inValue_count =0;
ArrayList<Integer> mList = outList.get(element);
for (int j = 0,len_m = mList.size(); j < len_m; j++) {
int outElement = mList.get(j);
Integer outElement_index = idMap.get(outElement);
if(outElement_index != null) {//判断出度map中是否有当前key
outG[i][outValue_count] = outElement_index;
outValue_count++;
}
}
outDegress[i] = outValue_count;
Arrays.sort(outG[i],0,outDegress[i]);
ArrayList<Integer> inList = inputList.get(element);
for (int k = 0,len_in = inList.size(); k < len_in; k++) {
int inElement = inList.get(k);
Integer inElement_index = idMap.get(inElement);
if(inElement_index != null) {//判断出度map中是否有当前key
InG[i][inValue_count] = inElement_index;
inValue_count++;
}
}
inDegress[i] = inValue_count;
}
}
//创建图的同时,需要更新边表节点列表
public void insertEdge(int a,int b,Map<Integer, ArrayList<Integer>> outList,Map<Integer, ArrayList<Integer>> inputList,Map<Integer, Integer> outputValue,Map<Integer, Integer> inputValue){
//出度
if(outList.containsKey(a)) {
outList.get(a).add(b);
int value = outputValue.get(a).intValue() + 1;
outputValue.put(a, Integer.valueOf(value));
}else {
ArrayList<Integer> temp = new ArrayList<>();
temp.add(b);
outList.put(a, temp);
outputValue.put(a, 1);
}
//入度
if(inputList.containsKey(b)) {
inputList.get(b).add(a);
inputValue.put(b, ((Integer)inputValue.get(b)).intValue() + 1);
}else {
ArrayList<Integer> temp = new ArrayList<>();
temp.add(a);
inputList.put(b, temp);
inputValue.put(b, 1);
}
}
public void decDfs(int head, int current, int depth, boolean[] visi, List<Integer> trace, StringBuilder[] sb, int[] reachable,Map<Integer, List<Integer>> mas,int[] countArr) {
visi[current] = true;
trace.add(current);
List<Integer> circle = null;
int[] list = outG[current];
//第六层,直接通过path2寻找第七个点,同时直接输出环
if (reachable[current] == 0 && depth == 6) {
List<Integer> path2List = mas.get(current);
Collections.sort(path2List);
for (int i = 0; i < path2List.size(); i++) {
int lastNode = (int) path2List.get(i);
if (!visi[lastNode]) {
for(int key = 0;key<trace.size();key++) {
sb[7].append(nodeRF[trace.get(key)]);
}
sb[7].append(nodeLF[lastNode]);
countArr[0] = countArr[0]+1;
}
}
}
if (depth < 6) {
for (int i = 0,len = outDegress[current]; i < len; i++) {
int gcur = list[i];
if (gcur == head && depth >= 3 && depth < 6) {
for(int key = 0;key<trace.size()-1;key++) {
sb[depth].append(nodeRF[trace.get(key)]);
}
sb[depth].append(nodeLF[trace.get(trace.size()-1)]);
countArr[0] = countArr[0]+1;
}
if (gcur > head && depth < 6 && !visi[gcur]) {
decDfs(head, gcur, depth + 1, visi, trace, sb, reachable,mas,countArr);
}
}
//这里debug调试时,由于dfs到第6层时,并不会判断第7层ID是否与起始ID相同,因此会漏掉长度为6的环,因此使用记忆化搜索结构,通过5+1拼接
if (reachable[current] == 0 && (depth == 5)) {
List<Integer> path2List = mas.get(current);
Collections.sort(path2List);
for (int i1 = 0; i1 < path2List.size(); i1++) {
int lastNode = (int) path2List.get(i1);
if (!visi[lastNode]) {
for(int key = 0;key<trace.size();key++) {
sb[6].append(nodeRF[trace.get(key)]);
}
sb[6].append(nodeLF[lastNode]);
countArr[0] = countArr[0]+1;
}
}
}
}
visi[current] = false;
trace.remove(trace.size() - 1);
}
//文件输出
public void savePredictResult(StringBuilder result) {
File file = new File(output);
try(PrintWriter writer = new PrintWriter(new FileWriter(file))){
writer.print(result.toString());
}
catch (IOException e){
throw new RuntimeException();
}
}
}
最后,感谢这次比赛遇到的每一个人,特别是我家那条狗