#!/usr/bin/ruby

# Constants
MP64  = 0x01ffffff01ffffff                  # Polynomial key masks
MP128 = 0x01ffffff01ffffff01ffffff01ffffff
M32   = 0xffffffff                          # Bit masks
M64   = 0xffffffffffffffff
P36   = 0xffffffffb                         # Prime numbers
P64   = 0xffffffffffffffc5
P128  = 0xffffffffffffffffffffffffffffff61
T64   = 0xffffffff00000000                  # Polynomial test values
T128  = 0xffffffff000000000000000000000000

class Umac

  require 'aes_alg'

	def kdf(kdfCipher,index, numbytes)
		s = ''; num_blocks = (numbytes+15)/16
		pt = 0.chr * 7 << index.chr << 0.chr * 8
		num_blocks.times do
      pt[15] += 1
      s << kdfCipher.cipher_encrypt(pt)
		end
		s[0,numbytes]
	end

  def initialize(umacKey, tagLength = 64)
    @l2Key = []; @l3Key = []; @l1Out = []; @l3Out = []
    @iters = [1, (tagLength/32.0).ceil].max
    kdfCipher = AesAlg::new(umacKey.length*8, 'ECB', umacKey)
    @pdfCipher = AesAlg::new(umacKey.length*8, 'ECB', kdf(kdfCipher,0,umacKey.length))
    @l1Key = kdf(kdfCipher, 1, 1024 + (@iters - 1) * 16).unpack("N*")
    tmp2 = kdf(kdfCipher, 2, @iters * 24).unpack("N*")
    tmp3 = kdf(kdfCipher, 3, @iters * 64).unpack("N*")
    tmp4 = kdf(kdfCipher, 4, @iters * 4 ).unpack("N*")
    @iters.times do |i|
      x64  = (tmp2[i*6] << 32) + tmp2[i*6+1]
      x128 = (tmp2[i*6+2] << 96) + (tmp2[i*6+3] << 64) + (tmp2[i*6+4] << 32) + tmp2[i*6+5]
      @l2Key << [x64 & MP64, x128 & MP128]
      8.times { |j| tmp3[i*8+j] = ((tmp3[i*16+j*2] << 32)+(tmp3[i*16+j*2+1])) % P36 }
      @l3Key << [tmp3[i*8,8], tmp4[i]]
      @l1Out << []
    end
  end

	def nh(k,data,bitlength)
		a = 0; final = data.length - 8
		0.step(final, 8) do |i|
			a += (((data[i  ] + @l1Key[i+k  ]) & M32) *
				    ((data[i+4] + @l1Key[i+k+4]) & M32))
			a += (((data[i+1] + @l1Key[i+k+1]) & M32) *
				    ((data[i+5] + @l1Key[i+k+5]) & M32))
			a += (((data[i+2] + @l1Key[i+k+2]) & M32) *
				    ((data[i+6] + @l1Key[i+k+6]) & M32))
			a += (((data[i+3] + @l1Key[i+k+3]) & M32) *
				    ((data[i+7] + @l1Key[i+k+7]) & M32))
		end
		(a+bitlength) & M64  # mod 2^^64
  end	

	def uhashUpdate(inString)
		data = inString.unpack('V*') # To big-endian, 32-bits
		@iters.times { |i| @l1Out[i] << nh(i * 4, data, 8192) }
  end
  
	def uhashFinal(inString, bitlength)
		# Pad to 32-byte multiple and unpack to tuple of 32-bit values
		toAppend = (inString.length == 0 ? 32 : (32 - (inString.length % 32)) % 32)
		data = (inString + (0.chr * toAppend)).unpack('V*')
		# Do three-level hash, iter times
		@iters.times do |i|
			# L1 Hash
			@l1Out[i] << nh(i * 4, data, bitlength)
			# L2 Hash
			if @l1Out[i].length == 1
				l2Out = @l1Out[i][0]
			else
			  l2Out = 1
		    hiPoly = (@l1Out.length > 2**14 ? @l1Out[i].slice!(2**14..-1) : nil)
			  @l1Out[i].each do |x|
			    if x >= T64
			      l2Out = (l2Out * @l2Key[i][0] + (P64-1)) % P64
			      l2Out = (l2Out * @l2Key[i][0] + (x-59)) % P64
		      else
			      l2Out = (l2Out * @l2Key[i][0] + x) % P64
		      end
			  end
				if hiPoly
					hiPoly << 0x8000000000000000
					hiPoly << 0 if hiPoly.length % 2 == 1
					l2Out = (l2Out + @l2Key[i][1]) % P128
					0.step(hiPoly.length-2, 2) do |j|
					  x = (hiPoly[j] << 64) + hiPoly[j+1]
  			    if x >= T128
  			      l2Out = (l2Out * @l2Key[i][1] + (P128-1)) % P128
  			      l2Out = (l2Out * @l2Key[i][1] + (x-159)) % P128
  		      else
  			      l2Out = (l2Out * @l2Key[i][1] + x) % P128
  		      end
  		    end
		    end
			end
  		#L3 Hash
  		res = 0
  		7.downto(0) do |j|
  			res += (l2Out & 0xffff) * @l3Key[i][0][j]
  			l2Out >>= 16
  		end
  		@l3Out << (((res % P36) & M32) ^ @l3Key[i][1])
		end
		@l3Out
	end

	def umacUpdate(inString)
		uhashUpdate(inString)
	end
	
	def umacFinal(inString, bitlength, nonce)
		uhashFinal(inString, bitlength)
		# Generate pad
		mask = [nil, 3, 1, 0, 0]
		nlen = nonce.length
		old = nonce[nlen-1]
		idx = old & mask[@iters]
		pt = nonce[0,nlen-1] + (old - idx).chr + 0.chr * (16-nlen)
		pad = @pdfCipher.cipher_encrypt(pt).unpack("N*")
		result = []
		@iters.times { |i| result << (@l3Out[i] ^ pad[@iters*idx+i]) }
		@l1Out.each { |i| i.clear }
		@l3Out.clear
		result
	end

end

u = Umac::new('abcdefghijklmnop')
x = 'a' * 2**10
n = 'bcdefghi'
i = 0
while ( (i+1) * 1024 < x.length ) do
	u.umacUpdate(x[i*1024,1024])
	i += 1
end
tag = u.umacFinal(x[i*1024..-1], 8*(x[i*1024..-1]).length, n)
tag.each { |s| printf("%08X", s) }
putc "\n"
#s = u.kdf(1,24)
#s.each_byte {|c| printf("%02x",c)}; putc("\n")

