asysbang

 找回密码
 立即注册
查看: 304|回复: 0
打印 上一主题 下一主题

yolov5-android-tflite

[复制链接]

520

主题

2

好友

6551

积分

管理员

Rank: 80Rank: 80Rank: 80Rank: 80Rank: 80

最佳新人 活跃会员 热心会员 推广达人 宣传达人 灌水之王 突出贡献 优秀版主 荣誉管理 论坛元老

跳转到指定楼层
楼主
发表于 2025-7-29 14:35:27 |只看该作者 |正序浏览
private void onDetectClickedA() {
        int INPUT_SIZE = 416;
        Bitmap bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.zidane);
        Bitmap scaledBitmap = Bitmap.createScaledBitmap(bitmap, 416, 416, false);
        ByteBuffer imgData = ByteBuffer.allocateDirect(1 * INPUT_SIZE * INPUT_SIZE * 3 * 4);
        ;
        imgData.order(ByteOrder.nativeOrder());
        int[] intValues = new int[INPUT_SIZE * INPUT_SIZE];
        float IMAGE_MEAN = 0;

        float IMAGE_STD = 255.0f;
        scaledBitmap.getPixels(intValues, 0, scaledBitmap.getWidth(), 0, 0, scaledBitmap.getWidth(), scaledBitmap.getHeight());
        imgData.rewind();
        for (int i = 0; i < INPUT_SIZE; ++i) {
            for (int j = 0; j < INPUT_SIZE; ++j) {
                int pixelValue = intValues[i * INPUT_SIZE + j];
                imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
                imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
                imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
            }
        }

        Map<Integer, Object> outputMap = new HashMap<>();
        ByteBuffer outData = ByteBuffer.allocateDirect(10647 * 85 * 4);
        outData.rewind();
        outputMap.put(0, outData);

        Object[] inputArray = {imgData};
        Log.e("====", "====imgData =" + imgData);
        Log.e("====", "====inputArray =" + inputArray.length);

        Log.e("====", "====outData =" + outData);
        mTFLite.getmInterpreter().runForMultipleInputsOutputs(inputArray, outputMap);


        ByteBuffer byteBuffer = (ByteBuffer) outputMap.get(0);
        byteBuffer.rewind();

        ArrayList<Classifier.Recognition> detections = new ArrayList<Classifier.Recognition>();

        float[][][] out = new float[1][10647][85];
        Log.d("YoloV5Classifier", "out[0] detect start");

        for (int i = 0; i < 10647; ++i) {
            for (int j = 0; j < 85; ++j) {
                out[0][j] = byteBuffer.getFloat();
            }
            // Denormalize xywh
            for (int j = 0; j < 4; ++j) {
                out[0][j] *= 416;
            }
        }

        for (int i = 0; i < 10647; ++i) {
            final int offset = 0;
            final float confidence = out[0][4];
            int detectedClass = -1;
            float maxClass = 0;

            final float[] classes = new float[80];
            for (int c = 0; c < 80; ++c) {
                classes[c] = out[0][5 + c];
            }

            for (int c = 0; c < 80; ++c) {
                if (classes[c] > maxClass) {
                    detectedClass = c;
                    maxClass = classes[c];
                }
            }

            final float confidenceInClass = maxClass * confidence;
            Log.e("====", "====maxClass =" + maxClass);
            Log.e("====", "====confidence =" + confidence);
            Log.e("====", "====confidenceInClass =" + confidenceInClass);
            if (confidenceInClass > 0.3f) {
                final float xPos = out[0][0];
                final float yPos = out[0][1];

                final float w = out[0][2];
                final float h = out[0][3];
                Log.d("====YoloV5Classifier",
                        Float.toString(xPos) + ',' + yPos + ',' + w + ',' + h);

                final RectF rect =
                        new RectF(
                                Math.max(0, xPos - w / 2),
                                Math.max(0, yPos - h / 2),
                                Math.min(bitmap.getWidth() - 1, xPos + w / 2),
                                Math.min(bitmap.getHeight() - 1, yPos + h / 2));

                Log.e("====", "====rect =" + rect);

//                detections.add(new Classifier.Recognition("" + offset, labels.get(detectedClass),
//                        confidenceInClass, rect, detectedClass));
            }
        }

        Log.d("YoloV5Classifier", "===========out[0] detect end");

    }


//non maximum suppression
protected ArrayList<Recognition> nms(ArrayList<Recognition> list) {
    ArrayList<Recognition> nmsList = new ArrayList<Recognition>();

    for (int k = 0; k < labels.size(); k++) {
        //1.find max confidence per class
        PriorityQueue<Recognition> pq =
                new PriorityQueue<Recognition>(
                        50,
                        new Comparator<Recognition>() {
                            @Override
                            public int compare(final Recognition lhs, final Recognition rhs) {
                                // Intentionally reversed to put high confidence at the head of the queue.
                                return Float.compare(rhs.getConfidence(), lhs.getConfidence());
                            }
                        });

        for (int i = 0; i < list.size(); ++i) {
            if (list.get(i).getDetectedClass() == k) {
                pq.add(list.get(i));
            }
        }

        //2.do non maximum suppression
        while (pq.size() > 0) {
            //insert detection with max confidence
            Recognition[] a = new Recognition[pq.size()];
            Recognition[] detections = pq.toArray(a);
            Recognition max = detections[0];
            nmsList.add(max);
            pq.clear();

            for (int j = 1; j < detections.length; j++) {
                Recognition detection = detections[j];
                RectF b = detection.getLocation();
                if (box_iou(max.getLocation(), b) < mNmsThresh) {
                    pq.add(detection);
                }
            }
        }
    }
    return nmsList;
}

protected float mNmsThresh = 0.6f;

protected float box_iou(RectF a, RectF b) {
    return box_intersection(a, b) / box_union(a, b);
}

protected float box_intersection(RectF a, RectF b) {
    float w = overlap((a.left + a.right) / 2, a.right - a.left,
            (b.left + b.right) / 2, b.right - b.left);
    float h = overlap((a.top + a.bottom) / 2, a.bottom - a.top,
            (b.top + b.bottom) / 2, b.bottom - b.top);
    if (w < 0 || h < 0) return 0;
    float area = w * h;
    return area;
}

protected float box_union(RectF a, RectF b) {
    float i = box_intersection(a, b);
    float u = (a.right - a.left) * (a.bottom - a.top) + (b.right - b.left) * (b.bottom - b.top) - i;
    return u;
}

protected float overlap(float x1, float w1, float x2, float w2) {
    float l1 = x1 - w1 / 2;
    float l2 = x2 - w2 / 2;
    float left = l1 > l2 ? l1 : l2;
    float r1 = x1 + w1 / 2;
    float r2 = x2 + w2 / 2;
    float right = r1 < r2 ? r1 : r2;
    return right - left;
}



1





回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

Archiver|手机版|aSys-帮 ( 京ICP备13033689号 )

GMT+8, 2026-1-28 14:35 , Processed in 0.047988 second(s), 21 queries .

Powered by Discuz! X2.5

© 2001-2012 Comsenz Inc.

回顶部