Reimplement some built-ins in terms of the standard library.

As a result:

  * functions such as `nth` should work correctly on iterators;
  * `nth` will raise `IndexError` (in a fashion consistent with `get`)
    when the index is out of bounds;
  * `take`, etc. will raise `ValueError` instead of returning
    an ambiguous value if the index is negative;
  * `map`, `zip`, `range`, `input`, `filter` work the same way (Py3k one)
    on both Python 2 and 3 (see #523 and #331).
This commit is contained in:
pyos 2014-04-29 18:01:14 +04:00
parent bdd8e3c82e
commit 8e4b21103c
4 changed files with 65 additions and 87 deletions

View File

@ -495,8 +495,8 @@ nth
Usage: ``(nth coll n)``
Return the `nth` item in a collection, counting from 0. Unlike
``get``, ``nth`` works on both iterators and iterables. Returns ``None``
if the `n` is outside the range of `coll`.
``get``, ``nth`` works on both iterators and iterables. Raises ``IndexError``
if the `n` is outside the range of ``coll`` or ``ValueError`` if it's negative.
.. code-block:: hy
@ -506,8 +506,10 @@ if the `n` is outside the range of `coll`.
=> (nth [1 2 4 7] 3)
7
=> (none? (nth [1 2 4 7] 5))
True
=> (nth [1 2 4 7] 5)
Traceback (most recent call last):
...
IndexError: 5
=> (nth (take 3 (drop 2 [1 2 3 4 5 6])) 2))
5
@ -764,6 +766,7 @@ drop
Usage: ``(drop n coll)``
Return an iterator, skipping the first ``n`` members of ``coll``
Raises ``ValueError`` if ``n`` is negative.
.. code-block:: hy
@ -924,6 +927,7 @@ take
Usage: ``(take n coll)``
Return an iterator containing the first ``n`` members of ``coll``.
Raises ``ValueError`` if ``n`` is negative.
.. code-block:: hy

View File

@ -23,7 +23,9 @@
;;;; to make functional programming slightly easier.
;;;;
(import itertools)
(import functools)
(import collections)
(import [hy._compat [long-type]]) ; long for python2, int for python3
(import [hy.models.cons [HyCons]])
@ -49,15 +51,6 @@
(and (instance? (type :foo) k)
(.startswith k (get :foo 0))))
(defn cycle [coll]
"Yield an infinite repetition of the items in coll"
(setv seen [])
(for* [x coll]
(yield x)
(.append seen x))
(while seen
(for* [x seen]
(yield x))))
(defn dec [n]
"Decrement n by 1"
@ -87,22 +80,36 @@
(yield val)
(.add seen val))))))
(if-python2
(do
(setv filterfalse itertools.ifilterfalse)
(setv zip_longest itertools.izip_longest)
(setv filter itertools.ifilter)
(setv map itertools.imap)
(setv zip itertools.izip)
(setv range xrange)
(setv input raw_input))
(do
(setv reduce functools.reduce)
(setv filterfalse itertools.filterfalse)
(setv zip_longest itertools.zip_longest)
; Someone can import these directly from `hy.core.language`;
; we'll make some duplicates.
(setv filter filter)
(setv map map)
(setv zip zip)
(setv range range)
(setv input input)))
(setv cycle itertools.cycle)
(setv repeat itertools.repeat)
(setv drop-while itertools.dropwhile)
(setv take-while itertools.takewhile)
(setv zipwith map)
(defn drop [count coll]
"Drop `count` elements from `coll` and yield back the rest"
(let [[citer (iter coll)]]
(try (for* [i (range count)]
(next citer))
(catch [StopIteration]))
citer))
(defn drop-while [pred coll]
"Drop all elements of `coll` until `pred` is False"
(let [[citer (iter coll)]]
(for* [val citer]
(if (not (pred val))
(do (yield val) (break))))
(for* [val citer]
(yield val))))
(itertools.islice coll count nil))
(defn empty? [coll]
"Return True if `coll` is empty"
@ -126,13 +133,6 @@
(if (not (hasattr tree attr))
(setattr tree attr 1))))
(defn filter [pred coll]
"Return all elements from `coll` that pass `pred`"
(let [[citer (iter coll)]]
(for* [val citer]
(if (pred val)
(yield val)))))
(defn flatten [coll]
"Return a single flat list expanding all members of coll"
(if (coll? coll)
@ -174,7 +174,7 @@
(defn first [coll]
"Return first item from `coll`"
(get coll 0))
(nth coll 0))
(defn identity [x]
"Returns the argument unchanged"
@ -199,14 +199,13 @@
(defn integer-char? [x]
"Return True if char `x` parses as an integer"
(try
(integer? (int x))
(catch [e ValueError] False)
(catch [e TypeError] False)))
(integer? (int x))
(catch [e ValueError] False)
(catch [e TypeError] False)))
(defn iterable? [x]
"Return true if x is iterable"
(try (do (iter x) true)
(catch [Exception] false)))
(isinstance x collections.Iterable))
(defn iterate [f x]
(setv val x)
@ -216,8 +215,7 @@
(defn iterator? [x]
"Return true if x is an iterator"
(try (= x (iter x))
(catch [TypeError] false)))
(isinstance x collections.Iterator))
(defn list* [hd &rest tl]
"Return a dotted list construed from the elements of the argument"
@ -258,13 +256,9 @@
(defn nth [coll index]
"Return nth item in collection or sequence, counting from 0"
(if (not (neg? index))
(if (iterable? coll)
(try (get (list (take 1 (drop index coll))) 0)
(catch [IndexError] None))
(try (get coll index)
(catch [IndexError] None)))
None))
(try
(next (drop index coll))
(catch [e StopIteration] (raise (IndexError index)))))
(defn odd? [n]
"Return true if n is an odd number"
@ -285,14 +279,7 @@
(defn rest [coll]
"Get all the elements of a coll, except the first."
(slice coll 1))
(defn repeat [x &optional n]
"Yield x forever or optionally n times"
(if (none? n)
(setv dispatch (fn [] (while true (yield x))))
(setv dispatch (fn [] (for* [_ (range n)] (yield x)))))
(dispatch))
(drop 1 coll))
(defn repeatedly [func]
"Yield result of running func repeatedly"
@ -301,7 +288,7 @@
(defn second [coll]
"Return second item from `coll`"
(get coll 1))
(nth coll 1))
(defn some [pred coll]
"Return true if (pred x) is logical true for any x in coll, else false"
@ -322,9 +309,7 @@
(defn take [count coll]
"Take `count` elements from `coll`, or the whole set if the total
number of entries in `coll` is less than `count`."
(let [[citer (iter coll)]]
(for* [_ (range count)]
(yield (next citer)))))
(itertools.islice coll nil count))
(defn take-nth [n coll]
"Return every nth member of coll
@ -337,29 +322,15 @@
(next citer))))
(raise (ValueError "n must be positive"))))
(defn take-while [pred coll]
"Take all elements while `pred` is true"
(let [[citer (iter coll)]]
(for* [val citer]
(if (pred val)
(yield val)
(break)))))
(defn zero? [n]
"Return true if n is 0"
(_numeric_check n)
(= n 0))
(defn zipwith [func &rest lists]
"Zip the contents of several lists and map a function to the result"
(do
(import functools)
(map (functools.partial (fn [f args] (apply f args)) func) (apply zip lists))))
(def *exports* '[calling-module-name coll? cons cons? cycle dec distinct
disassemble drop drop-while empty? even? every? first filter
flatten float? gensym identity inc instance? integer
integer? integer-char? iterable? iterate iterator? keyword?
list* macroexpand macroexpand-1 neg? nil? none? nth
numeric? odd? pos? remove repeat repeatedly rest second
some string string? take take-nth take-while zero? zipwith])
list* macroexpand macroexpand-1 map neg? nil? none? nth
numeric? odd? pos? range remove repeat repeatedly rest second
some string string? take take-nth take-while zero? zip zipwith])

View File

@ -82,8 +82,8 @@
(assert-equal res [None 4 5])
(setv res (list (drop 0 [1 2 3 4 5])))
(assert-equal res [1 2 3 4 5])
(setv res (list (drop -1 [1 2 3 4 5])))
(assert-equal res [1 2 3 4 5])
(try (do (list (drop -1 [1 2 3 4 5])) (assert False))
(catch [e [ValueError]] nil))
(setv res (list (drop 6 (iter [1 2 3 4 5]))))
(assert-equal res [])
(setv res (list (take 5 (drop 2 (iterate inc 0)))))
@ -335,12 +335,15 @@
"NATIVE: testing the nth function"
(assert-equal 2 (nth [1 2 4 7] 1))
(assert-equal 7 (nth [1 2 4 7] 3))
(assert-true (none? (nth [1 2 4 7] 5)))
(assert-true (none? (nth [1 2 4 7] -1)))
(try (do (nth [1 2 4 7] 5) (assert False))
(catch [e [IndexError]] nil))
(try (do (nth [1 2 4 7] -1) (assert False))
(catch [e [ValueError]] nil))
;; now for iterators
(assert-equal 2 (nth (iter [1 2 4 7]) 1))
(assert-equal 7 (nth (iter [1 2 4 7]) 3))
(assert-true (none? (nth (iter [1 2 4 7]) -1)))
(try (do (nth (iter [1 2 4 7]) -1) (assert False))
(catch [e [ValueError]] nil))
(assert-equal 5 (nth (take 3 (drop 2 [1 2 3 4 5 6])) 2)))
(defn test-numeric? []
@ -429,8 +432,8 @@
(assert-equal res ["s" "s" "s" "s"])
(setv res (list (take 0 (repeat "s"))))
(assert-equal res [])
(setv res (list (take -1 (repeat "s"))))
(assert-equal res [])
(try (do (list (take -1 (repeat "s"))) (assert False))
(catch [e [ValueError]] nil))
(setv res (list (take 6 [1 2 None 4])))
(assert-equal res [1 2 None 4]))

View File

@ -474,7 +474,7 @@
(defn test-rest []
"NATIVE: test rest"
(assert (= (rest [1 2 3 4 5]) [2 3 4 5])))
(assert (= (list (rest [1 2 3 4 5])) [2 3 4 5])))
(defn test-importas []