Abstract
On this article, I'll write naive kNN algorithm with Julia. Recently, I started to use Julia and need to practice. kNN is relatively simple algorithm and nice for practice. So, here, I'll write simple kNN with Julia.kNN
About kNN itself, please check the following articles.Simply, on kNN, we calculate the distance between target point and train data points. And the nearest k point's labels are used for prediction in majority rule.
Code
On this article, I'll write kNN just as practice. The code is not for library. So I don't do proper dealing with error, exception and so on. And sometimes, as Julia code, some part are probably inappropriate. I'll just write the kNN algorithm which works.At first, we need to read necessary packages and make KNN type.
using DataFrames, CSV, DataStructures
type KNN
    x::DataFrames.DataFrame
    y::DataFrames.DataFrame
endIn kNN, we will calculate the distances between two points. The function calcDist() is for this. Probably, Julia has the function for distance calculation and actually the following function doesn't do some necessary step like checking the lengths of two points array. But don't care.
function calcDist(sourcePoint, destPoint)
    sum = 0
    for i in 1:length(sourcePoint)
        sum += (destPoint[i] - sourcePoint[i]) ^ 2
    end
    dist = sqrt(sum)
    return dist
endSimple kNN is based on majority rule. The following function is to get the frequency of labels and the top frequency's label.
function extractTop(targetCandidates)
    targetFrequency = counter(targetCandidates)
    normValue = 0
    normKey = "hoge"
    for key in keys(targetFrequency)
        if targetFrequency[key] > normValue
            normKey = key
            normValue = targetFrequency[key]
        end
    end
    return normKey
endWe will split the data into train and test ones. This function is for this.
function splitTrainTest(data, at = 0.7)
    n = nrow(data)
    ind = shuffle(1:n)
    train_ind = view(ind, 1:floor(Int, at*n))
    test_ind = view(ind, (floor(Int, at*n)+1):n)
    return data[train_ind,:], data[test_ind,:]
endUsually kNN algorithm doesn't need special work on fit() phase, meaning there is no training phase. So, main function is this predict().
Here, the distances between target points and train data points are calculated. By the nearest k point’s label, we can decide the target's label on majority rule.
The function, sortperm(), doesn't have intuitive name. This function is to sort the data and return the index of sorted data.
function predict(data::KNN, testData::DataFrames.DataFrame, k=5)
    predictedLabels = []
    for i in 1:size(testData, 1)
        sourcePoint = Array(testData[i,:])
        distances = []
        for j in 1:size(data.x, 1)
            destPoint = Array(data.x[j,:])
            distance = calcDist(sourcePoint, destPoint)
            push!(distances, distance)
        end
        sortedIndex = sortperm(distances)
        targetCandidates = Array(data.y)[sortedIndex[1:k]]
        predictedLabel = extractTop(targetCandidates)
        push!(predictedLabels, predictedLabel)
    end
    return predictedLabels
end- read data
- split data into train and test
- predict the label of test data by kNN
- evaluation
function main()
    # read data
    df = readtable("iris.csv", header=true)
    # split data into train and test
    trainData, testData = splitTrainTest(df)
    xTrain = trainData[:, [:SepalLength, :SepalWidth, :PetalLength, :PetalWidth]]
    yTrain = trainData[:, [:Species]]
    xTest = testData[:, [:SepalLength, :SepalWidth, :PetalLength, :PetalWidth]]
    yTest = testData[:, [:Species]]
    # predict the label of test data by kNN
    knn = KNN(xTrain, yTrain)
    predicted = predict(knn, xTest)
    # evaluation
    accurate = 0
    yTestArray = Array(yTest)
    for i in 1:length(predicted)
        if yTestArray[i] == predicted[i]
            accurate += 1
        end
    end
    println(accurate/length(predicted))
endWhen I execute the code, the output was 0.9555555555555556. Actually, on this code, many cases are not covered. But as a simple practice, it works.
Whole Code
using DataFrames, CSV, DataStructures
type KNN
    x::DataFrames.DataFrame
    y::DataFrames.DataFrame
end
function predict(data::KNN, testData::DataFrames.DataFrame, k=5)
    predictedLabels = []
    for i in 1:size(testData, 1)
        sourcePoint = Array(testData[i,:])
        distances = []
        for j in 1:size(data.x, 1)
            destPoint = Array(data.x[j,:])
            distance = calcDist(sourcePoint, destPoint)
            push!(distances, distance)
        end
        sortedIndex = sortperm(distances)
        targetCandidates = Array(data.y)[sortedIndex[1:k]]
        predictedLabel = extractTop(targetCandidates)
        push!(predictedLabels, predictedLabel)
    end
    return predictedLabels
end
function calcDist(sourcePoint, destPoint)
    sum = 0
    for i in 1:length(sourcePoint)
        sum += (destPoint[i] - sourcePoint[i]) ^ 2
    end
    dist = sqrt(sum)
    return dist
end
function extractTop(targetCandidates)
    targetFrequency = counter(targetCandidates)
    normValue = 0
    normKey = "hoge"
    for key in keys(targetFrequency)
        if targetFrequency[key] > normValue
            normKey = key
            normValue = targetFrequency[key]
        end
    end
    return normKey
end
function splitTrainTest(data, at = 0.7)
    n = nrow(data)
    ind = shuffle(1:n)
    train_ind = view(ind, 1:floor(Int, at*n))
    test_ind = view(ind, (floor(Int, at*n)+1):n)
    return data[train_ind,:], data[test_ind,:]
end
function main()
    df = readtable("iris.csv", header=true)
    trainData, testData = splitTrainTest(df)
    xTrain = trainData[:, [:SepalLength, :SepalWidth, :PetalLength, :PetalWidth]]
    yTrain = trainData[:, [:Species]]
    xTest = testData[:, [:SepalLength, :SepalWidth, :PetalLength, :PetalWidth]]
    yTest = testData[:, [:Species]]
    knn = KNN(xTrain, yTrain)
    predicted = predict(knn, xTest)
    accurate = 0
    yTestArray = Array(yTest)
    for i in 1:length(predicted)
        if yTestArray[i] == predicted[i]
            accurate += 1
        end
    end
    println(accurate/length(predicted))
end
main()